KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.3% 142 36 182
Functions: 88.9% 8 0 9
Branches: 96.2% 25 72 98

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // nrx8 => this function can take in generic nr values but the input is expected to have a block depth of 8.
8 // Block depth is calculated as kr / sr. The values of these parameters are defined in the matmul ukernel.
9
10 #if !defined(__aarch64__) && !defined(_M_ARM64)
11 #error This file must be compiled for AArch64.
12 #else // Architectural features check.
13
14 #include "kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h"
15
16 #include <arm_neon.h>
17 #include <stddef.h>
18 #include <stdint.h>
19 #include <string.h>
20
21 #include "kai/kai_common.h"
22
23 static const size_t kai_num_bytes_sum_rhs = sizeof(float);
24 static const size_t kai_num_bytes_bias = sizeof(float);
25 static const size_t kai_nr_multiple_of = 4;
26 static const size_t kai_bl_multiple_of = 32;
27
28 1944 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
29 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
30 1944 return kai_roundup(k, bl) / bl;
31 }
32
33 1458 inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) {
34 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
35 1458 return (bl / 2) + num_bytes_multiplier_rhs;
36 }
37
38 486 inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks(
39 size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) {
40 KAI_ASSERT((bl % kr) == 0);
41 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
42 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
43
44 486 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
45 486 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
46
47 972 return (nr * num_bytes_per_block * num_blocks_per_row);
48 486 }
49
50 size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(size_t nr) {
51 return nr;
52 }
53
54 486 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
55 size_t n_idx, //
56 size_t rhs_stride) {
57 486 return n_idx * rhs_stride;
58 }
59
60 972 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
61 size_t k, //
62 size_t nr, //
63 size_t kr, //
64 size_t sr, //
65 size_t bl, //
66 enum kai_datatype scale_dt) {
67 KAI_ASSERT((k % bl) == 0);
68 KAI_ASSERT((bl % kr) == 0);
69 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
70 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
71 KAI_ASSERT(scale_dt == kai_dt_bf16);
72
73 972 KAI_UNUSED(kr);
74 972 KAI_UNUSED(sr);
75
76 972 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
77 972 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
78 972 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
79
80 1944 return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
81 972 }
82
83 486 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
84 size_t n_idx, //
85 size_t k, //
86 size_t nr, //
87 size_t kr, //
88 size_t sr, //
89 size_t bl, //
90 enum kai_datatype scale_dt) {
91 KAI_ASSERT((n_idx % nr) == 0);
92 KAI_ASSERT((k % bl) == 0);
93 KAI_ASSERT((bl % kr) == 0);
94 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
95 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
96 KAI_ASSERT(scale_dt == kai_dt_bf16);
97
98 972 return (n_idx / nr) *
99 486 kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
100 }
101
102 486 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
103 size_t n, //
104 size_t k, //
105 size_t nr, //
106 size_t kr, //
107 size_t sr, //
108 size_t bl, //
109 enum kai_datatype scale_dt) {
110 KAI_ASSERT((k % bl) == 0);
111 KAI_ASSERT((bl % kr) == 0);
112 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
113 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
114 KAI_ASSERT(scale_dt == kai_dt_bf16);
115
116 486 const size_t num_rows = kai_roundup(n, nr) / nr;
117
118 972 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
119 486 }
120
121 486 void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
122 size_t num_groups, //
123 size_t n, //
124 size_t k, //
125 size_t nr, //
126 size_t kr, //
127 size_t sr, //
128 size_t bl, //
129 const uint8_t* rhs, //
130 size_t rhs_stride, //
131 const float* bias, //
132 const void* scale, //
133 size_t scale_stride, //
134 void* rhs_packed, //
135 size_t extra_bytes, //
136 const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) {
137 KAI_ASSERT(num_groups == 1);
138 KAI_ASSERT(extra_bytes == 0);
139 KAI_ASSERT(rhs != NULL);
140 KAI_ASSERT(scale != NULL);
141 KAI_ASSERT(rhs_packed != NULL);
142 KAI_ASSERT(params != NULL);
143 KAI_ASSERT(params->rhs_zero_point == 8);
144 KAI_ASSERT(params->lhs_zero_point == 1);
145
146 KAI_ASSERT((k % bl) == 0);
147 KAI_ASSERT((bl % kr) == 0);
148 KAI_ASSERT((kr % sr) == 0);
149 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
150 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
151 KAI_ASSERT(params->scale_dt == kai_dt_bf16);
152
153 // Note: The input matrix (rhs) is expected with:
154 // "k" columns and "n" rows (NxK)
155 486 const enum kai_datatype scale_dt = params->scale_dt;
156 486 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
157 972 const size_t rhs_packed_offset_end_of_all_blocks =
158 486 kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs);
159 486 const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl);
160 486 const size_t num_bytes_per_block_k = bl / 2;
161 486 const size_t dst_num_rows = kai_roundup(n, nr);
162 486 const size_t block_length_in_bytes = kr / sr;
163 KAI_ASSERT(block_length_in_bytes == 8);
164
165 486 uint8_t* dst_row = (uint8_t*)rhs_packed;
166
167
2/2
✓ Branch 0 taken 486 times.
✓ Branch 1 taken 3052 times.
3538 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) {
168 3052 float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks);
169
170 // Initialize the RHS reduction sums to zero
171 3052 memset(sums, 0, nr * kai_num_bytes_sum_rhs);
172
173 // Iterate over the quantized blocks
174
2/2
✓ Branch 0 taken 13503 times.
✓ Branch 1 taken 3052 times.
16555 for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) {
175 // Store the scales after packing all K values in the block
176 13503 uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr;
177 13503 const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * num_bytes_multiplier_rhs;
178
179
2/2
✓ Branch 0 taken 69804 times.
✓ Branch 1 taken 13503 times.
83307 for (size_t i = 0; i < nr; ++i) {
180
2/2
✓ Branch 0 taken 66240 times.
✓ Branch 1 taken 3564 times.
69804 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
181 69804 const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride;
182 69804 void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs;
183
184 69804 memcpy(
185 dst_scales_ptr, //
186 src_scales_ptr, //
187 num_bytes_multiplier_rhs); //
188 69804 }
189
190 13503 size_t k0_idx_i = dst_qblock_idx * bl;
191 13503 const uint8x8_t top_mask = vdup_n_u8(0xF0);
192 13503 const uint8x8_t bottom_mask = vdup_n_u8(0x0F);
193 13503 const uint8x8_t zero_point_conversion_mask = vdup_n_u8(0x88);
194
195
2/2
✓ Branch 0 taken 18004 times.
✓ Branch 1 taken 13503 times.
31507 for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) {
196
2/2
✓ Branch 0 taken 23268 times.
✓ Branch 1 taken 18004 times.
41272 for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) {
197 // Clamp the indices to avoid out-of-bound reads
198
2/2
✓ Branch 0 taken 22368 times.
✓ Branch 1 taken 900 times.
23268 const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1);
199
2/2
✓ Branch 0 taken 22368 times.
✓ Branch 1 taken 900 times.
23268 const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1);
200
2/2
✓ Branch 0 taken 22152 times.
✓ Branch 1 taken 1116 times.
23268 const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1);
201
2/2
✓ Branch 0 taken 21432 times.
✓ Branch 1 taken 1836 times.
23268 const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1);
202
203 23268 const float d0 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 0]);
204 23268 const float d1 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 1]);
205 23268 const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]);
206 23268 const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]);
207
208 // Initialize partial sum taking new zero-point (8) into account
209 23268 int32_t partial_sum0 = -(32 * 8);
210 23268 int32_t partial_sum1 = -(32 * 8);
211 23268 int32_t partial_sum2 = -(32 * 8);
212 23268 int32_t partial_sum3 = -(32 * 8);
213
214 23268 const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx);
215
216 23268 const uint8x8_t vld0_0 = vld1_u8(src_block_base + n0_idx * rhs_stride);
217 23268 const uint8x8_t vld0_1 = vld1_u8(src_block_base + n0_idx * rhs_stride + 8);
218 23268 const uint8x8_t vld1_0 = vld1_u8(src_block_base + n1_idx * rhs_stride);
219 23268 const uint8x8_t vld1_1 = vld1_u8(src_block_base + n1_idx * rhs_stride + 8);
220 23268 const uint8x8_t vld2_0 = vld1_u8(src_block_base + n2_idx * rhs_stride);
221 23268 const uint8x8_t vld2_1 = vld1_u8(src_block_base + n2_idx * rhs_stride + 8);
222 23268 const uint8x8_t vld3_0 = vld1_u8(src_block_base + n3_idx * rhs_stride);
223 23268 const uint8x8_t vld3_1 = vld1_u8(src_block_base + n3_idx * rhs_stride + 8);
224
225 23268 const uint8x8_t vld0_s1s = vand_u8(vld0_0, bottom_mask);
226 23268 const uint8x8_t vld0_s0s = vshr_n_u8(vld0_0, 4);
227 23268 const uint8x8_t vld0_s17s = vshl_n_u8(vld0_1, 4);
228 23268 const uint8x8_t vld0_s16s = vand_u8(vld0_1, top_mask);
229
230 46536 const uint8x8_t vld0_s16s0s_lower =
231 23268 vorr_u8(vzip1_u8(vld0_s1s, vld0_s0s), vzip1_u8(vld0_s17s, vld0_s16s));
232 46536 const uint8x8_t vld0_s16s0s_upper =
233 23268 vorr_u8(vzip2_u8(vld0_s1s, vld0_s0s), vzip2_u8(vld0_s17s, vld0_s16s));
234
235 23268 const uint8x8_t vld1_s1s = vand_u8(vld1_0, bottom_mask);
236 23268 const uint8x8_t vld1_s0s = vshr_n_u8(vld1_0, 4);
237 23268 const uint8x8_t vld1_s17s = vshl_n_u8(vld1_1, 4);
238 23268 const uint8x8_t vld1_s16s = vand_u8(vld1_1, top_mask);
239
240 46536 const uint8x8_t vld1_s16s0s_lower =
241 23268 vorr_u8(vzip1_u8(vld1_s1s, vld1_s0s), vzip1_u8(vld1_s17s, vld1_s16s));
242 46536 const uint8x8_t vld1_s16s0s_upper =
243 23268 vorr_u8(vzip2_u8(vld1_s1s, vld1_s0s), vzip2_u8(vld1_s17s, vld1_s16s));
244
245 23268 const uint8x8_t vld2_s1s = vand_u8(vld2_0, bottom_mask);
246 23268 const uint8x8_t vld2_s0s = vshr_n_u8(vld2_0, 4);
247 23268 const uint8x8_t vld2_s17s = vshl_n_u8(vld2_1, 4);
248 23268 const uint8x8_t vld2_s16s = vand_u8(vld2_1, top_mask);
249
250 46536 const uint8x8_t vld2_s16s0s_lower =
251 23268 vorr_u8(vzip1_u8(vld2_s1s, vld2_s0s), vzip1_u8(vld2_s17s, vld2_s16s));
252 46536 const uint8x8_t vld2_s16s0s_upper =
253 23268 vorr_u8(vzip2_u8(vld2_s1s, vld2_s0s), vzip2_u8(vld2_s17s, vld2_s16s));
254
255 23268 const uint8x8_t vld3_s1s = vand_u8(vld3_0, bottom_mask);
256 23268 const uint8x8_t vld3_s0s = vshr_n_u8(vld3_0, 4);
257 23268 const uint8x8_t vld3_s17s = vshl_n_u8(vld3_1, 4);
258 23268 const uint8x8_t vld3_s16s = vand_u8(vld3_1, top_mask);
259
260 46536 const uint8x8_t vld3_s16s0s_lower =
261 23268 vorr_u8(vzip1_u8(vld3_s1s, vld3_s0s), vzip1_u8(vld3_s17s, vld3_s16s));
262 46536 const uint8x8_t vld3_s16s0s_upper =
263 23268 vorr_u8(vzip2_u8(vld3_s1s, vld3_s0s), vzip2_u8(vld3_s17s, vld3_s16s));
264
265 // Convert to unsigned int4 and store repacked values
266 23268 vst1_u8((uint8_t*)dst_row, veor_u8(vld0_s16s0s_lower, zero_point_conversion_mask));
267 23268 vst1_u8((uint8_t*)dst_row + 8, veor_u8(vld1_s16s0s_lower, zero_point_conversion_mask));
268 23268 vst1_u8((uint8_t*)dst_row + 16, veor_u8(vld2_s16s0s_lower, zero_point_conversion_mask));
269 23268 vst1_u8((uint8_t*)dst_row + 24, veor_u8(vld3_s16s0s_lower, zero_point_conversion_mask));
270
271 23268 vst1_u8(
272 (uint8_t*)dst_row + (nr * block_length_in_bytes),
273 veor_u8(vld0_s16s0s_upper, zero_point_conversion_mask));
274 23268 vst1_u8(
275 (uint8_t*)dst_row + (nr * block_length_in_bytes) + 8,
276 veor_u8(vld1_s16s0s_upper, zero_point_conversion_mask));
277 23268 vst1_u8(
278 (uint8_t*)dst_row + (nr * block_length_in_bytes) + 16,
279 veor_u8(vld2_s16s0s_upper, zero_point_conversion_mask));
280 23268 vst1_u8(
281 (uint8_t*)dst_row + (nr * block_length_in_bytes) + 24,
282 veor_u8(vld3_s16s0s_upper, zero_point_conversion_mask));
283
284 // Calculate and store row sums
285 23268 partial_sum0 += (int32_t)vaddlvq_u16(vaddl_u8(
286 23268 vadd_u8(vld0_s1s, vand_u8(vld0_1, bottom_mask)), vadd_u8(vld0_s0s, vshr_n_u8(vld0_1, 4))));
287 23268 partial_sum1 += (int32_t)vaddlvq_u16(vaddl_u8(
288 23268 vadd_u8(vld1_s1s, vand_u8(vld1_1, bottom_mask)), vadd_u8(vld1_s0s, vshr_n_u8(vld1_1, 4))));
289 23268 partial_sum2 += (int32_t)vaddlvq_u16(vaddl_u8(
290 23268 vadd_u8(vld2_s1s, vand_u8(vld2_1, bottom_mask)), vadd_u8(vld2_s0s, vshr_n_u8(vld2_1, 4))));
291 23268 partial_sum3 += (int32_t)vaddlvq_u16(vaddl_u8(
292 23268 vadd_u8(vld3_s1s, vand_u8(vld3_1, bottom_mask)), vadd_u8(vld3_s0s, vshr_n_u8(vld3_1, 4))));
293
294 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
295 23268 sums[nr_idx + 0] += (float)partial_sum0 * d0;
296 23268 sums[nr_idx + 1] += (float)partial_sum1 * d1;
297 23268 sums[nr_idx + 2] += (float)partial_sum2 * d2;
298 23268 sums[nr_idx + 3] += (float)partial_sum3 * d3;
299 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
300
301 23268 dst_row += block_length_in_bytes * 4;
302 23268 }
303 // Skip to end of qblock
304 18004 dst_row += nr * block_length_in_bytes;
305 18004 }
306
307 // Move the pointer after scales
308 13503 dst_row += num_bytes_multiplier_rhs * nr;
309 13503 }
310
311 // Move the pointer after the row sum
312 3052 dst_row += kai_num_bytes_sum_rhs * nr;
313
314 // Set the bias
315
1/2
✓ Branch 0 taken 3052 times.
✗ Branch 1 not taken.
3052 if (bias == NULL) {
316 memset(dst_row, 0, nr * kai_num_bytes_bias);
317 } else {
318
2/2
✓ Branch 0 taken 15856 times.
✓ Branch 1 taken 3052 times.
18908 for (size_t i = 0; i < nr; ++i) {
319 // Clamp the row index to avoid out-of-bound reads
320
2/2
✓ Branch 0 taken 14380 times.
✓ Branch 1 taken 1476 times.
15856 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
321 15856 ((float*)dst_row)[i] = bias[src_row_idx];
322 15856 }
323 }
324
325 // Move the pointer after the row sum
326 3052 dst_row += kai_num_bytes_bias * nr;
327 3052 }
328 486 }
329 #endif // Architectural features check.
330