KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.6% 162 36 202
Functions: 88.9% 8 0 9
Branches: 96.2% 25 72 98

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // nrx4 => this function can take in generic nr values but the input is expected to have a block depth of 4.
8 // Block depth is calculated as kr / sr. The values of these parameters are defined in the matmul ukernel.
9
10 #if !defined(__aarch64__) && !defined(_M_ARM64)
11 #error This file must be compiled for AArch64.
12 #else // Architectural features check.
13
14 #include "kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h"
15
16 #include <arm_neon.h>
17 #include <stddef.h>
18 #include <stdint.h>
19 #include <string.h>
20
21 #include "kai/kai_common.h"
22
23 static const size_t kai_num_bytes_sum_rhs = sizeof(float);
24 static const size_t kai_num_bytes_bias = sizeof(float);
25 static const size_t kai_nr_multiple_of = 4;
26 static const size_t kai_bl_multiple_of = 32;
27
28 864 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
29 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
30 864 return kai_roundup(k, bl) / bl;
31 }
32
33 648 inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) {
34 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
35 648 return (bl / 2) + num_bytes_multiplier_rhs;
36 }
37
38 216 inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks(
39 size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) {
40 KAI_ASSERT((bl % kr) == 0);
41 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
42 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
43
44 216 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
45 216 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
46
47 432 return (nr * num_bytes_per_block * num_blocks_per_row);
48 216 }
49
50 size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(size_t nr) {
51 return nr;
52 }
53
54 216 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(
55 size_t n_idx, //
56 size_t rhs_stride) {
57 216 return n_idx * rhs_stride;
58 }
59
60 432 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(
61 size_t k, //
62 size_t nr, //
63 size_t kr, //
64 size_t sr, //
65 size_t bl, //
66 enum kai_datatype scale_dt) {
67 KAI_ASSERT((k % bl) == 0);
68 KAI_ASSERT((bl % kr) == 0);
69 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
70 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
71 KAI_ASSERT(scale_dt == kai_dt_bf16);
72
73 432 KAI_UNUSED(kr);
74 432 KAI_UNUSED(sr);
75
76 432 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
77 432 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
78 432 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
79
80 864 return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
81 432 }
82
83 216 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(
84 size_t n_idx, //
85 size_t k, //
86 size_t nr, //
87 size_t kr, //
88 size_t sr, //
89 size_t bl, //
90 enum kai_datatype scale_dt) {
91 KAI_ASSERT((n_idx % nr) == 0);
92 KAI_ASSERT((k % bl) == 0);
93 KAI_ASSERT((bl % kr) == 0);
94 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
95 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
96 KAI_ASSERT(scale_dt == kai_dt_bf16);
97
98 432 return (n_idx / nr) *
99 216 kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
100 }
101
102 216 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(
103 size_t n, //
104 size_t k, //
105 size_t nr, //
106 size_t kr, //
107 size_t sr, //
108 size_t bl, //
109 enum kai_datatype scale_dt) {
110 KAI_ASSERT((k % bl) == 0);
111 KAI_ASSERT((bl % kr) == 0);
112 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
113 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
114 KAI_ASSERT(scale_dt == kai_dt_bf16);
115
116 216 const size_t num_rows = kai_roundup(n, nr) / nr;
117
118 432 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
119 216 }
120
121 216 void kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(
122 size_t num_groups, //
123 size_t n, //
124 size_t k, //
125 size_t nr, //
126 size_t kr, //
127 size_t sr, //
128 size_t bl, //
129 const uint8_t* rhs, //
130 size_t rhs_stride, //
131 const float* bias, //
132 const void* scale, //
133 size_t scale_stride, //
134 void* rhs_packed, //
135 size_t extra_bytes, //
136 const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) {
137 KAI_ASSERT(num_groups == 1);
138 KAI_ASSERT(extra_bytes == 0);
139 KAI_ASSERT(rhs != NULL);
140 KAI_ASSERT(scale != NULL);
141 KAI_ASSERT(rhs_packed != NULL);
142 KAI_ASSERT(params != NULL);
143 KAI_ASSERT(params->rhs_zero_point == 8);
144 KAI_ASSERT(params->lhs_zero_point == 1);
145
146 KAI_ASSERT((k % bl) == 0);
147 KAI_ASSERT((bl % kr) == 0);
148 KAI_ASSERT((kr % sr) == 0);
149 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
150 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
151 KAI_ASSERT(params->scale_dt == kai_dt_bf16);
152
153 // Note: The input matrix (rhs) is expected with:
154 // "k" columns and "n" rows (NxK)
155 216 const enum kai_datatype scale_dt = params->scale_dt;
156 216 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
157 432 const size_t rhs_packed_offset_end_of_all_blocks =
158 216 kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs);
159 216 const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl);
160 216 const size_t num_bytes_per_block_k = bl / 2;
161 216 const size_t dst_num_rows = kai_roundup(n, nr);
162 216 const size_t block_length_in_bytes = kr / sr;
163 KAI_ASSERT(block_length_in_bytes == 4);
164
165 216 uint8_t* dst_row = (uint8_t*)rhs_packed;
166
167
2/2
✓ Branch 0 taken 216 times.
✓ Branch 1 taken 1312 times.
1528 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) {
168 1312 float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks);
169
170 // Initialize the RHS reduction sums to zero
171 1312 memset(sums, 0, nr * kai_num_bytes_sum_rhs);
172
173 // Iterate over the quantized blocks
174
2/2
✓ Branch 0 taken 5796 times.
✓ Branch 1 taken 1312 times.
7108 for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) {
175 // Store the scales after packing all K values in the block
176 5796 uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr;
177 5796 const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * num_bytes_multiplier_rhs;
178
179
2/2
✓ Branch 0 taken 31080 times.
✓ Branch 1 taken 5796 times.
36876 for (size_t i = 0; i < nr; ++i) {
180
2/2
✓ Branch 0 taken 29472 times.
✓ Branch 1 taken 1608 times.
31080 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
181 31080 const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride;
182 31080 void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs;
183
184 31080 memcpy(
185 dst_scales_ptr, //
186 src_scales_ptr, //
187 num_bytes_multiplier_rhs); //
188 31080 }
189
190 5796 size_t k0_idx_i = dst_qblock_idx * bl;
191 5796 const uint8x8_t top_mask = vdup_n_u8(0xF0);
192 5796 const uint8x8_t bottom_mask = vdup_n_u8(0x0F);
193 5796 const uint32x2_t zero_point_conversion_mask = vdup_n_u32(0x88888888);
194
195
2/2
✓ Branch 0 taken 7728 times.
✓ Branch 1 taken 5796 times.
13524 for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) {
196
2/2
✓ Branch 0 taken 10360 times.
✓ Branch 1 taken 7728 times.
18088 for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) {
197 // Clamp the indices to avoid out-of-bound reads
198
2/2
✓ Branch 0 taken 9952 times.
✓ Branch 1 taken 408 times.
10360 const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1);
199
2/2
✓ Branch 0 taken 9952 times.
✓ Branch 1 taken 408 times.
10360 const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1);
200
2/2
✓ Branch 0 taken 9856 times.
✓ Branch 1 taken 504 times.
10360 const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1);
201
2/2
✓ Branch 0 taken 9536 times.
✓ Branch 1 taken 824 times.
10360 const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1);
202
203 10360 const float d0 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 0]);
204 10360 const float d1 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 1]);
205 10360 const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]);
206 10360 const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]);
207
208 // Initialize partial sum taking new zero-point (8) into account
209 10360 int32_t partial_sum0 = -(32 * 8);
210 10360 int32_t partial_sum1 = -(32 * 8);
211 10360 int32_t partial_sum2 = -(32 * 8);
212 10360 int32_t partial_sum3 = -(32 * 8);
213
214 10360 const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx);
215
216 10360 const uint8x8_t vld0_0 = vld1_u8(src_block_base + n0_idx * rhs_stride);
217 10360 const uint8x8_t vld0_1 = vld1_u8(src_block_base + n0_idx * rhs_stride + 8);
218 10360 const uint8x8_t vld1_0 = vld1_u8(src_block_base + n1_idx * rhs_stride);
219 10360 const uint8x8_t vld1_1 = vld1_u8(src_block_base + n1_idx * rhs_stride + 8);
220 10360 const uint8x8_t vld2_0 = vld1_u8(src_block_base + n2_idx * rhs_stride);
221 10360 const uint8x8_t vld2_1 = vld1_u8(src_block_base + n2_idx * rhs_stride + 8);
222 10360 const uint8x8_t vld3_0 = vld1_u8(src_block_base + n3_idx * rhs_stride);
223 10360 const uint8x8_t vld3_1 = vld1_u8(src_block_base + n3_idx * rhs_stride + 8);
224
225 // Reorder blocks to give correct packing
226 10360 const uint8x8_t vld0_0_lower = vand_u8(vld0_0, bottom_mask);
227 10360 const uint8x8_t vld0_1_lower = vshl_n_u8(vld0_1, 4);
228 10360 const uint8x8_t vld0_0_upper = vshr_n_u8(vld0_0, 4);
229 10360 const uint8x8_t vld0_1_upper = vand_u8(vld0_1, top_mask);
230 20720 const uint8x8_t vstr0_04 =
231 10360 vorr_u8(vzip1_u8(vld0_0_lower, vld0_0_upper), vzip1_u8(vld0_1_lower, vld0_1_upper));
232 20720 const uint8x8_t vstr0_46 =
233 10360 vorr_u8(vzip2_u8(vld0_0_lower, vld0_0_upper), vzip2_u8(vld0_1_lower, vld0_1_upper));
234
235 10360 const uint8x8_t vld1_0_lower = vand_u8(vld1_0, bottom_mask);
236 10360 const uint8x8_t vld1_1_lower = vshl_n_u8(vld1_1, 4);
237 10360 const uint8x8_t vld1_0_upper = vshr_n_u8(vld1_0, 4);
238 10360 const uint8x8_t vld1_1_upper = vand_u8(vld1_1, top_mask);
239 20720 const uint8x8_t vstr0_04_1 =
240 10360 vorr_u8(vzip1_u8(vld1_0_lower, vld1_0_upper), vzip1_u8(vld1_1_lower, vld1_1_upper));
241 20720 const uint8x8_t vstr0_46_1 =
242 10360 vorr_u8(vzip2_u8(vld1_0_lower, vld1_0_upper), vzip2_u8(vld1_1_lower, vld1_1_upper));
243
244 10360 const uint8x8_t vld2_0_lower = vand_u8(vld2_0, bottom_mask);
245 10360 const uint8x8_t vld2_1_lower = vshl_n_u8(vld2_1, 4);
246 10360 const uint8x8_t vld2_0_upper = vshr_n_u8(vld2_0, 4);
247 10360 const uint8x8_t vld2_1_upper = vand_u8(vld2_1, top_mask);
248 20720 const uint8x8_t vstr0_15 =
249 10360 vorr_u8(vzip1_u8(vld2_0_lower, vld2_0_upper), vzip1_u8(vld2_1_lower, vld2_1_upper));
250 20720 const uint8x8_t vstr0_57 =
251 10360 vorr_u8(vzip2_u8(vld2_0_lower, vld2_0_upper), vzip2_u8(vld2_1_lower, vld2_1_upper));
252
253 10360 const uint8x8_t vld3_0_lower = vand_u8(vld3_0, bottom_mask);
254 10360 const uint8x8_t vld3_1_lower = vshl_n_u8(vld3_1, 4);
255 10360 const uint8x8_t vld3_0_upper = vshr_n_u8(vld3_0, 4);
256 10360 const uint8x8_t vld3_1_upper = vand_u8(vld3_1, top_mask);
257 20720 const uint8x8_t vstr0_15_1 =
258 10360 vorr_u8(vzip1_u8(vld3_0_lower, vld3_0_upper), vzip1_u8(vld3_1_lower, vld3_1_upper));
259 20720 const uint8x8_t vstr0_57_1 =
260 10360 vorr_u8(vzip2_u8(vld3_0_lower, vld3_0_upper), vzip2_u8(vld3_1_lower, vld3_1_upper));
261
262 20720 const uint32x2_t vstr0_0 =
263 10360 vzip1_u32(vreinterpret_u32_u8(vstr0_04), vreinterpret_u32_u8(vstr0_04_1));
264 20720 const uint32x2_t vstr0_4 =
265 10360 vzip1_u32(vreinterpret_u32_u8(vstr0_46), vreinterpret_u32_u8(vstr0_46_1));
266 20720 const uint32x2_t vstr0_2 =
267 10360 vzip2_u32(vreinterpret_u32_u8(vstr0_04), vreinterpret_u32_u8(vstr0_04_1));
268 20720 const uint32x2_t vstr0_6 =
269 10360 vzip2_u32(vreinterpret_u32_u8(vstr0_46), vreinterpret_u32_u8(vstr0_46_1));
270 20720 const uint32x2_t vstr0_1 =
271 10360 vzip1_u32(vreinterpret_u32_u8(vstr0_15), vreinterpret_u32_u8(vstr0_15_1));
272 20720 const uint32x2_t vstr0_5 =
273 10360 vzip1_u32(vreinterpret_u32_u8(vstr0_57), vreinterpret_u32_u8(vstr0_57_1));
274 20720 const uint32x2_t vstr0_3 =
275 10360 vzip2_u32(vreinterpret_u32_u8(vstr0_15), vreinterpret_u32_u8(vstr0_15_1));
276 20720 const uint32x2_t vstr0_7 =
277 10360 vzip2_u32(vreinterpret_u32_u8(vstr0_57), vreinterpret_u32_u8(vstr0_57_1));
278
279 // Convert to signed int4 and store repacked values
280 10360 vst1_u32((uint32_t*)dst_row + 0, veor_u32(vstr0_0, zero_point_conversion_mask));
281 10360 vst1_u32((uint32_t*)dst_row + 2, veor_u32(vstr0_1, zero_point_conversion_mask));
282
283 10360 vst1_u32(
284 (uint32_t*)(dst_row + nr * block_length_in_bytes) + 0,
285 veor_u32(vstr0_2, zero_point_conversion_mask));
286 10360 vst1_u32(
287 (uint32_t*)(dst_row + nr * block_length_in_bytes) + 2,
288 veor_u32(vstr0_3, zero_point_conversion_mask));
289
290 10360 vst1_u32(
291 (uint32_t*)(dst_row + (2 * nr * block_length_in_bytes)) + 0,
292 veor_u32(vstr0_4, zero_point_conversion_mask));
293 10360 vst1_u32(
294 (uint32_t*)(dst_row + (2 * nr * block_length_in_bytes)) + 2,
295 veor_u32(vstr0_5, zero_point_conversion_mask));
296
297 10360 vst1_u32(
298 (uint32_t*)(dst_row + (3 * nr * block_length_in_bytes)) + 0,
299 veor_u32(vstr0_6, zero_point_conversion_mask));
300 10360 vst1_u32(
301 (uint32_t*)(dst_row + (3 * nr * block_length_in_bytes)) + 2,
302 veor_u32(vstr0_7, zero_point_conversion_mask));
303
304 // Calculate and store row sums
305 10360 partial_sum0 += (int32_t)vaddlvq_u16(vaddl_u8(
306 10360 vadd_u8(vld0_0_lower, vand_u8(vld0_1, bottom_mask)),
307 10360 vadd_u8(vld0_0_upper, vshr_n_u8(vld0_1, 4))));
308 10360 partial_sum1 += (int32_t)vaddlvq_u16(vaddl_u8(
309 10360 vadd_u8(vld1_0_lower, vand_u8(vld1_1, bottom_mask)),
310 10360 vadd_u8(vld1_0_upper, vshr_n_u8(vld1_1, 4))));
311 10360 partial_sum2 += (int32_t)vaddlvq_u16(vaddl_u8(
312 10360 vadd_u8(vld2_0_lower, vand_u8(vld2_1, bottom_mask)),
313 10360 vadd_u8(vld2_0_upper, vshr_n_u8(vld2_1, 4))));
314 10360 partial_sum3 += (int32_t)vaddlvq_u16(vaddl_u8(
315 10360 vadd_u8(vld3_0_lower, vand_u8(vld3_1, bottom_mask)),
316 10360 vadd_u8(vld3_0_upper, vshr_n_u8(vld3_1, 4))));
317
318 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
319 10360 sums[nr_idx + 0] += (float)partial_sum0 * d0;
320 10360 sums[nr_idx + 1] += (float)partial_sum1 * d1;
321 10360 sums[nr_idx + 2] += (float)partial_sum2 * d2;
322 10360 sums[nr_idx + 3] += (float)partial_sum3 * d3;
323 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
324
325 10360 dst_row += (4 * block_length_in_bytes);
326 10360 }
327 // Skip to end of qblock
328 7728 dst_row += 3 * nr * block_length_in_bytes;
329 7728 }
330
331 // Move the pointer after scales
332 5796 dst_row += num_bytes_multiplier_rhs * nr;
333 5796 }
334
335 // Move the pointer after the row sum
336 1312 dst_row += kai_num_bytes_sum_rhs * nr;
337
338 // Set the bias
339
1/2
✓ Branch 0 taken 1312 times.
✗ Branch 1 not taken.
1312 if (bias == NULL) {
340 memset(dst_row, 0, nr * kai_num_bytes_bias);
341 } else {
342
2/2
✓ Branch 0 taken 7072 times.
✓ Branch 1 taken 1312 times.
8384 for (size_t i = 0; i < nr; ++i) {
343 // Clamp the row index to avoid out-of-bound reads
344
2/2
✓ Branch 0 taken 6400 times.
✓ Branch 1 taken 672 times.
7072 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
345 7072 ((float*)dst_row)[i] = bias[src_row_idx];
346 7072 }
347 }
348
349 // Move the pointer after the row sum
350 1312 dst_row += kai_num_bytes_bias * nr;
351 1312 }
352 216 }
353 #endif // Architectural features check.
354