KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 90.0% 63 8 78
Functions: 71.4% 5 0 7
Branches: 95.5% 21 16 38

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
15 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
16 static const size_t kai_num_bytes_bias = sizeof(float);
17
18 2920 inline static size_t kai_k_roundedup(size_t k) {
19 // Round up k to be a multiple of 32.
20 2920 size_t kai_k_multiple_of = 32;
21 5840 return kai_roundup(k, kai_k_multiple_of);
22 2920 }
23
24 size_t kai_get_n_step_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(size_t nr) {
25 return nr;
26 }
27
28 size_t kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(size_t n_idx, size_t rhs_stride) {
29 KAI_UNUSED(rhs_stride);
30 return n_idx;
31 }
32
33 2344 size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(size_t k, size_t nr, size_t kr, size_t sr) {
34 2344 KAI_UNUSED(kr);
35 2344 KAI_UNUSED(sr);
36 2344 const size_t k_internal = kai_k_roundedup(k);
37
38 4688 return nr * (k_internal + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
39 2344 }
40
41 1192 size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(
42 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
43 KAI_ASSERT((n_idx % nr) == 0);
44
45 1192 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
46 }
47
48 576 size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
49 576 const size_t num_rows = kai_roundup(n, nr) / nr;
50
51 1152 return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
52 576 }
53
54 576 void kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(
55 size_t num_groups, //
56 size_t n, //
57 size_t k, //
58 size_t nr, //
59 size_t kr, //
60 size_t sr, //
61 const int8_t* rhs, //
62 const float* bias, //
63 const float* scale, //
64 void* rhs_packed, //
65 size_t extra_bytes, const struct kai_rhs_pack_qsi8cx_params* params) {
66 KAI_ASSERT(num_groups == 1);
67 KAI_ASSERT(extra_bytes == 0);
68 KAI_ASSERT(sr == 1);
69 KAI_ASSERT(rhs != NULL);
70 KAI_ASSERT(scale != NULL);
71 KAI_ASSERT(rhs_packed != NULL);
72 KAI_ASSERT(params != NULL);
73
74 576 const int32_t lhs_zero_point = params->lhs_zero_point;
75 576 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
76 576 const size_t k_internal = kai_k_roundedup(k);
77 576 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
78 576 const size_t dst_num_bytes_per_row = nr * k_internal;
79 576 const size_t rhs_stride = n;
80
81
2/2
✓ Branch 0 taken 576 times.
✓ Branch 1 taken 5380 times.
5956 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
82 5380 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
83
84 5380 int32_t* sums = (int32_t*)(dst_row + nr * k_internal);
85
86 // Initialize to zero the RHS reduction sums
87 5380 memset(sums, 0, nr * sizeof(int32_t));
88
89
2/2
✓ Branch 0 taken 574592 times.
✓ Branch 1 taken 5380 times.
579972 for (size_t dst_offset = 0; dst_offset < dst_num_bytes_per_row; dst_offset += kr) {
90 574592 const size_t block_idx = dst_offset / kr;
91 574592 const size_t nr_idx = block_idx % nr;
92 574592 const size_t super_block_idx = block_idx / nr;
93
94 574592 const size_t k0_idx = super_block_idx * kr;
95 574592 const size_t n0_idx = dst_row_idx * nr + nr_idx;
96
97 // Clamp the index to avoid out-of-bound reads
98
2/2
✓ Branch 0 taken 442848 times.
✓ Branch 1 taken 131744 times.
574592 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
99
100 574592 const size_t src_offset = n0_valid_idx;
101
102 574592 int32_t partial_sum = 0;
103
104 // Get the partial reduction sum
105
2/2
✓ Branch 0 taken 2562048 times.
✓ Branch 1 taken 574592 times.
3136640 for (size_t i = 0; i < kr; i++) {
106 2562048 const size_t k0_valid_idx = k0_idx + i;
107 2562048 int8_t v = 0;
108
2/2
✓ Branch 0 taken 987680 times.
✓ Branch 1 taken 1574368 times.
2562048 if (k0_valid_idx < k) {
109 1574368 v = rhs[src_offset + (k0_valid_idx * rhs_stride)];
110 1574368 }
111
112 2562048 ((int8_t*)dst_row)[i] = v;
113 2562048 partial_sum += v;
114 2562048 }
115
116 574592 sums[nr_idx] += partial_sum * lhs_zero_point;
117
118 574592 dst_row += kr;
119 574592 }
120
121 // Adjust the reduction sums
122
2/2
✓ Branch 0 taken 48160 times.
✓ Branch 1 taken 5380 times.
53540 for (size_t i = 0; i < nr; ++i) {
123 48160 dst_row += sizeof(int32_t);
124 48160 }
125
126 // Adjust the scales
127
2/2
✓ Branch 0 taken 48160 times.
✓ Branch 1 taken 5380 times.
53540 for (size_t i = 0; i < nr; ++i) {
128 // Clamp the row index to avoid out-of-bound reads
129
2/2
✓ Branch 0 taken 38016 times.
✓ Branch 1 taken 10144 times.
48160 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
130 48160 *((float*)(dst_row)) = scale[src_row_idx];
131 48160 dst_row += sizeof(float);
132 48160 }
133
134 // Set the bias
135
1/2
✓ Branch 0 taken 5380 times.
✗ Branch 1 not taken.
5380 if (bias == NULL) {
136 memset(dst_row, 0, nr * kai_num_bytes_bias);
137 } else {
138
2/2
✓ Branch 0 taken 48160 times.
✓ Branch 1 taken 5380 times.
53540 for (size_t i = 0; i < nr; ++i) {
139 // Clamp the row index to avoid out-of-bound reads
140
2/2
✓ Branch 0 taken 38016 times.
✓ Branch 1 taken 10144 times.
48160 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
141 48160 ((float*)dst_row)[i] = bias[src_row_idx];
142 48160 }
143 }
144 5380 }
145 576 }
146