KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.3% 145 / 36 / 185
Functions: 88.9% 8 / 0 / 9
Branches: 96.2% 25 / 72 / 98

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // nrx8 => this function can take in generic nr values but the input is expected to have a block depth of 8.
8 // Block depth is calculated as kr / sr. The values of these parameters are defined in the matmul ukernel.
9
10 #if !defined(__aarch64__) && !defined(_M_ARM64)
11 #error This file must be compiled for AArch64.
12 #else // Architectural features check.
13
14 #include "kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h"
15
16 #include <arm_neon.h>
17 #include <stddef.h>
18 #include <stdint.h>
19 #include <string.h>
20
21 #include "kai/kai_common.h"
22
23 static const size_t kai_num_bytes_sum_rhs = sizeof(float);
24 static const size_t kai_num_bytes_bias = sizeof(float);
25 static const size_t kai_nr_multiple_of = 4;
26 static const size_t kai_bl_multiple_of = 32;
27
28 1280 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
29 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
30 1280 return kai_roundup(k, bl) / bl;
31 }
32
33 960 inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) {
34 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
35 960 return (bl / 2) + num_bytes_multiplier_rhs;
36 }
37
38 320 inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks(
39 size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) {
40 KAI_ASSERT((bl % kr) == 0);
41 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
42 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
43
44 320 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
45 320 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
46
47 640 return (nr * num_bytes_per_block * num_blocks_per_row);
48 320 }
49
50 size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(size_t nr) {
51 return nr;
52 }
53
54 320 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
55 size_t n_idx, //
56 size_t rhs_stride) {
57 320 return n_idx * rhs_stride;
58 }
59
60 640 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
61 size_t k, //
62 size_t nr, //
63 size_t kr, //
64 size_t sr, //
65 size_t bl, //
66 enum kai_datatype scale_dt) {
67 KAI_ASSERT((k % bl) == 0);
68 KAI_ASSERT((bl % kr) == 0);
69 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
70 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
71 KAI_ASSERT(scale_dt == kai_dt_bf16);
72
73 640 KAI_UNUSED(kr);
74 640 KAI_UNUSED(sr);
75
76 640 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
77 640 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
78 640 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
79
80 1280 return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
81 640 }
82
83 320 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
84 size_t n_idx, //
85 size_t k, //
86 size_t nr, //
87 size_t kr, //
88 size_t sr, //
89 size_t bl, //
90 enum kai_datatype scale_dt) {
91 KAI_ASSERT((n_idx % nr) == 0);
92 KAI_ASSERT((k % bl) == 0);
93 KAI_ASSERT((bl % kr) == 0);
94 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
95 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
96 KAI_ASSERT(scale_dt == kai_dt_bf16);
97
98 640 return (n_idx / nr) *
99 320 kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
100 }
101
102 320 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
103 size_t n, //
104 size_t k, //
105 size_t nr, //
106 size_t kr, //
107 size_t sr, //
108 size_t bl, //
109 enum kai_datatype scale_dt) {
110 KAI_ASSERT((k % bl) == 0);
111 KAI_ASSERT((bl % kr) == 0);
112 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
113 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
114 KAI_ASSERT(scale_dt == kai_dt_bf16);
115
116 320 const size_t num_rows = kai_roundup(n, nr) / nr;
117
118 640 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
119 320 }
120
121 320 void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
122 size_t num_groups, //
123 size_t n, //
124 size_t k, //
125 size_t nr, //
126 size_t kr, //
127 size_t sr, //
128 size_t bl, //
129 const uint8_t* rhs, //
130 size_t rhs_stride, //
131 const float* bias, //
132 const void* scale, //
133 size_t scale_stride, //
134 void* rhs_packed, //
135 size_t extra_bytes, //
136 const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) {
137 KAI_ASSERT(num_groups == 1);
138 KAI_ASSERT(extra_bytes == 0);
139 KAI_ASSERT(rhs != NULL);
140 KAI_ASSERT(scale != NULL);
141 KAI_ASSERT(rhs_packed != NULL);
142 KAI_ASSERT(params != NULL);
143 KAI_ASSERT(params->rhs_zero_point == 8);
144 KAI_ASSERT(params->lhs_zero_point == 1);
145
146 KAI_ASSERT((k % bl) == 0);
147 KAI_ASSERT((bl % kr) == 0);
148 KAI_ASSERT((kr % sr) == 0);
149 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
150 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
151 KAI_ASSERT(params->scale_dt == kai_dt_bf16);
152
153 // Note: The input matrix (rhs) is expected with:
154 // "k" columns and "n" rows (NxK)
155 320 const enum kai_datatype scale_dt = params->scale_dt;
156 320 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
157 640 const size_t rhs_packed_offset_end_of_all_blocks =
158 320 kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs);
159 320 const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl);
160 320 const size_t num_bytes_per_block_k = bl / 2;
161 320 const size_t dst_num_rows = kai_roundup(n, nr);
162 320 const size_t block_length_in_bytes = kr / sr;
163 KAI_ASSERT(block_length_in_bytes == 8);
164
165 320 uint8_t* dst_row = (uint8_t*)rhs_packed;
166
167
2/2
✓ Branch 0 taken 320 times.
✓ Branch 1 taken 3778 times.
4098 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) {
168 3778 float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks);
169
170 // Initialize the RHS reduction sums to zero
171 3778 memset(sums, 0, nr * kai_num_bytes_sum_rhs);
172
173 // Iterate over the quantized blocks
174
2/2
✓ Branch 0 taken 27564 times.
✓ Branch 1 taken 3778 times.
31342 for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) {
175 // Store the scales after packing all K values in the block
176 27564 uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr;
177 27564 const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * num_bytes_multiplier_rhs;
178
179
2/2
✓ Branch 0 taken 138480 times.
✓ Branch 1 taken 27564 times.
166044 for (size_t i = 0; i < nr; ++i) {
180
2/2
✓ Branch 0 taken 134368 times.
✓ Branch 1 taken 4112 times.
138480 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
181 138480 const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride;
182 138480 void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs;
183
184 138480 memcpy(
185 69240 dst_scales_ptr, //
186 69240 src_scales_ptr, //
187 69240 num_bytes_multiplier_rhs); //
188 138480 }
189
190 27564 size_t k0_idx_i = dst_qblock_idx * bl;
191 27564 const uint8x8_t top_mask = vdup_n_u8(0xF0);
192 27564 const uint8x8_t bottom_mask = vdup_n_u8(0x0F);
193 27564 const uint8x8_t zero_point_conversion_mask = vdup_n_u8(0x88);
194
195
2/2
✓ Branch 0 taken 27564 times.
✓ Branch 1 taken 27564 times.
55128 for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) {
196
2/2
✓ Branch 0 taken 34620 times.
✓ Branch 1 taken 27564 times.
62184 for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) {
197 // Clamp the indices to avoid out-of-bound reads
198
2/2
✓ Branch 0 taken 34052 times.
✓ Branch 1 taken 568 times.
34620 const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1);
199
2/2
✓ Branch 0 taken 34052 times.
✓ Branch 1 taken 568 times.
34620 const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1);
200
2/2
✓ Branch 0 taken 33716 times.
✓ Branch 1 taken 904 times.
34620 const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1);
201
2/2
✓ Branch 0 taken 32548 times.
✓ Branch 1 taken 2072 times.
34620 const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1);
202
203 34620 const float d0 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 0]);
204 34620 const float d1 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 1]);
205 34620 const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]);
206 34620 const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]);
207
208 // Initialize partial sum taking new zero-point (8) into account
209 34620 int32_t partial_sum0 = -(32 * 8);
210 34620 int32_t partial_sum1 = -(32 * 8);
211 34620 int32_t partial_sum2 = -(32 * 8);
212 34620 int32_t partial_sum3 = -(32 * 8);
213
214 34620 const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx);
215
216 34620 const uint8x8_t vld0_0 = vld1_u8(src_block_base + n0_idx * rhs_stride);
217 34620 const uint8x8_t vld0_1 = vld1_u8(src_block_base + n0_idx * rhs_stride + 8);
218 34620 const uint8x8_t vld1_0 = vld1_u8(src_block_base + n1_idx * rhs_stride);
219 34620 const uint8x8_t vld1_1 = vld1_u8(src_block_base + n1_idx * rhs_stride + 8);
220 34620 const uint8x8_t vld2_0 = vld1_u8(src_block_base + n2_idx * rhs_stride);
221 34620 const uint8x8_t vld2_1 = vld1_u8(src_block_base + n2_idx * rhs_stride + 8);
222 34620 const uint8x8_t vld3_0 = vld1_u8(src_block_base + n3_idx * rhs_stride);
223 34620 const uint8x8_t vld3_1 = vld1_u8(src_block_base + n3_idx * rhs_stride + 8);
224
225 34620 const uint8x8_t vld0_s1s = vand_u8(vld0_0, bottom_mask);
226 34620 const uint8x8_t vld0_s0s = vshr_n_u8(vld0_0, 4);
227 34620 const uint8x8_t vld0_s17s = vshl_n_u8(vld0_1, 4);
228 34620 const uint8x8_t vld0_s16s = vand_u8(vld0_1, top_mask);
229
230 69240 const uint8x8_t vld0_s16s0s_lower =
231 34620 vorr_u8(vzip1_u8(vld0_s1s, vld0_s0s), vzip1_u8(vld0_s17s, vld0_s16s));
232 69240 const uint8x8_t vld0_s16s0s_upper =
233 34620 vorr_u8(vzip2_u8(vld0_s1s, vld0_s0s), vzip2_u8(vld0_s17s, vld0_s16s));
234
235 34620 const uint8x8_t vld1_s1s = vand_u8(vld1_0, bottom_mask);
236 34620 const uint8x8_t vld1_s0s = vshr_n_u8(vld1_0, 4);
237 34620 const uint8x8_t vld1_s17s = vshl_n_u8(vld1_1, 4);
238 34620 const uint8x8_t vld1_s16s = vand_u8(vld1_1, top_mask);
239
240 69240 const uint8x8_t vld1_s16s0s_lower =
241 34620 vorr_u8(vzip1_u8(vld1_s1s, vld1_s0s), vzip1_u8(vld1_s17s, vld1_s16s));
242 69240 const uint8x8_t vld1_s16s0s_upper =
243 34620 vorr_u8(vzip2_u8(vld1_s1s, vld1_s0s), vzip2_u8(vld1_s17s, vld1_s16s));
244
245 34620 const uint8x8_t vld2_s1s = vand_u8(vld2_0, bottom_mask);
246 34620 const uint8x8_t vld2_s0s = vshr_n_u8(vld2_0, 4);
247 34620 const uint8x8_t vld2_s17s = vshl_n_u8(vld2_1, 4);
248 34620 const uint8x8_t vld2_s16s = vand_u8(vld2_1, top_mask);
249
250 69240 const uint8x8_t vld2_s16s0s_lower =
251 34620 vorr_u8(vzip1_u8(vld2_s1s, vld2_s0s), vzip1_u8(vld2_s17s, vld2_s16s));
252 69240 const uint8x8_t vld2_s16s0s_upper =
253 34620 vorr_u8(vzip2_u8(vld2_s1s, vld2_s0s), vzip2_u8(vld2_s17s, vld2_s16s));
254
255 34620 const uint8x8_t vld3_s1s = vand_u8(vld3_0, bottom_mask);
256 34620 const uint8x8_t vld3_s0s = vshr_n_u8(vld3_0, 4);
257 34620 const uint8x8_t vld3_s17s = vshl_n_u8(vld3_1, 4);
258 34620 const uint8x8_t vld3_s16s = vand_u8(vld3_1, top_mask);
259
260 69240 const uint8x8_t vld3_s16s0s_lower =
261 34620 vorr_u8(vzip1_u8(vld3_s1s, vld3_s0s), vzip1_u8(vld3_s17s, vld3_s16s));
262 69240 const uint8x8_t vld3_s16s0s_upper =
263 34620 vorr_u8(vzip2_u8(vld3_s1s, vld3_s0s), vzip2_u8(vld3_s17s, vld3_s16s));
264
265 // Convert to unsigned int4 and store repacked values
266 34620 vst1_u8((uint8_t*)dst_row, veor_u8(vld0_s16s0s_lower, zero_point_conversion_mask));
267 34620 vst1_u8((uint8_t*)dst_row + 8, veor_u8(vld1_s16s0s_lower, zero_point_conversion_mask));
268 34620 vst1_u8((uint8_t*)dst_row + 16, veor_u8(vld2_s16s0s_lower, zero_point_conversion_mask));
269 34620 vst1_u8((uint8_t*)dst_row + 24, veor_u8(vld3_s16s0s_lower, zero_point_conversion_mask));
270
271 34620 vst1_u8(
272 (uint8_t*)dst_row + (nr * block_length_in_bytes),
273 veor_u8(vld0_s16s0s_upper, zero_point_conversion_mask));
274 34620 vst1_u8(
275 (uint8_t*)dst_row + (nr * block_length_in_bytes) + 8,
276 veor_u8(vld1_s16s0s_upper, zero_point_conversion_mask));
277 34620 vst1_u8(
278 (uint8_t*)dst_row + (nr * block_length_in_bytes) + 16,
279 veor_u8(vld2_s16s0s_upper, zero_point_conversion_mask));
280 34620 vst1_u8(
281 (uint8_t*)dst_row + (nr * block_length_in_bytes) + 24,
282 veor_u8(vld3_s16s0s_upper, zero_point_conversion_mask));
283
284 // Calculate and store row sums
285 34620 partial_sum0 += (int32_t)vaddlvq_u16(vaddl_u8(
286 34620 vadd_u8(vld0_s1s, vand_u8(vld0_1, bottom_mask)), vadd_u8(vld0_s0s, vshr_n_u8(vld0_1, 4))));
287 34620 partial_sum1 += (int32_t)vaddlvq_u16(vaddl_u8(
288 34620 vadd_u8(vld1_s1s, vand_u8(vld1_1, bottom_mask)), vadd_u8(vld1_s0s, vshr_n_u8(vld1_1, 4))));
289 34620 partial_sum2 += (int32_t)vaddlvq_u16(vaddl_u8(
290 34620 vadd_u8(vld2_s1s, vand_u8(vld2_1, bottom_mask)), vadd_u8(vld2_s0s, vshr_n_u8(vld2_1, 4))));
291 34620 partial_sum3 += (int32_t)vaddlvq_u16(vaddl_u8(
292 34620 vadd_u8(vld3_s1s, vand_u8(vld3_1, bottom_mask)), vadd_u8(vld3_s0s, vshr_n_u8(vld3_1, 4))));
293
294 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
295 34620 sums[nr_idx + 0] += (float)partial_sum0 * d0;
296 34620 sums[nr_idx + 1] += (float)partial_sum1 * d1;
297 34620 sums[nr_idx + 2] += (float)partial_sum2 * d2;
298 34620 sums[nr_idx + 3] += (float)partial_sum3 * d3;
299 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
300
301 34620 dst_row += block_length_in_bytes * 4;
302 34620 }
303 // Skip to end of qblock
304 27564 dst_row += nr * block_length_in_bytes;
305 27564 }
306
307 // Move the pointer after scales
308 27564 dst_row += num_bytes_multiplier_rhs * nr;
309 27564 }
310
311 // Move the pointer after the row sum
312 3778 dst_row += kai_num_bytes_sum_rhs * nr;
313
314 // Set the bias
315
1/2
✓ Branch 0 taken 3778 times.
✗ Branch 1 not taken.
3778 if (bias == NULL) {
316 memset(dst_row, 0, nr * kai_num_bytes_bias);
317 } else {
318
2/2
✓ Branch 0 taken 18632 times.
✓ Branch 1 taken 3778 times.
22410 for (size_t i = 0; i < nr; ++i) {
319 // Clamp the row index to avoid out-of-bound reads
320
2/2
✓ Branch 0 taken 17888 times.
✓ Branch 1 taken 744 times.
18632 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
321 18632 ((float*)dst_row)[i] = bias[src_row_idx];
322 18632 }
323 }
324
325 // Move the pointer after the row sum
326 3778 dst_row += kai_num_bytes_bias * nr;
327 3778 }
328 320 }
329 #endif // Architectural features check.
330