KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 95.5% 150 / 34 / 191
Functions: 88.9% 8 / 0 / 9
Branches: 96.2% 25 / 68 / 94

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // nrx4 => this function can take in generic nr values but the input is expected to have a block depth of 4.
8 // Block depth is calculated as kr / sr. The values of these parameters are defined in the matmul ukernel.
9
10 #if !defined(__aarch64__) && !defined(_M_ARM64)
11 #error This file must be compiled for AArch64.
12 #else // Architectural features check.
13
14 #include "kai_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon.h"
15
16 #include <arm_neon.h>
17 #include <stddef.h>
18 #include <stdint.h>
19 #include <string.h>
20
21 #include "kai/kai_common.h"
22
23 static const size_t kai_num_bytes_sum_rhs = sizeof(float);
24 static const size_t kai_num_bytes_bias = sizeof(float);
25 static const size_t kai_nr_multiple_of = 4;
26 static const size_t kai_bl_multiple_of = 32;
27
28 2400 static size_t kai_get_num_blocks_per_row(const size_t k, const size_t bl) {
29 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
30 2400 return kai_roundup(k, bl) / bl;
31 }
32
33 2334 static size_t kai_get_num_bytes_per_block(const size_t bl, const size_t num_bytes_multiplier_rhs) {
34 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
35 2334 return (bl / 2) + num_bytes_multiplier_rhs;
36 }
37
38 66 static size_t kai_get_rhs_packed_offset_end_of_all_blocks(
39 // clang-format off
40 const size_t k,
41 const size_t nr,
42 const size_t kr,
43 const size_t bl,
44 const size_t num_bytes_multiplier_rhs) {
45 // clang-format on
46 KAI_ASSERT((bl % kr) == 0);
47 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
48 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
49
50 66 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
51 66 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
52
53 132 return (nr * num_bytes_per_block * num_blocks_per_row);
54 66 }
55
56 size_t kai_get_n_step_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(const size_t nr) {
57 return nr;
58 }
59
60 66 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(const size_t n_idx, const size_t rhs_stride) {
61 66 return n_idx * rhs_stride;
62 }
63
64 2268 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
65 // clang-format off
66 const size_t k,
67 const size_t nr,
68 const size_t kr,
69 const size_t sr,
70 const size_t bl,
71 const enum kai_datatype scale_dt) {
72 // clang-format on
73 KAI_ASSERT((k % bl) == 0);
74 KAI_ASSERT((bl % kr) == 0);
75 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
76 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
77 KAI_ASSERT(scale_dt == kai_dt_bf16);
78
79 2268 KAI_UNUSED(kr);
80 2268 KAI_UNUSED(sr);
81
82 2268 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
83 2268 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
84 2268 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
85
86 4536 return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
87 2268 }
88
89 // clang-format off
90 2202 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
91 const size_t n_idx,
92 const size_t k,
93 const size_t nr,
94 const size_t kr,
95 const size_t sr,
96 const size_t bl,
97 const enum kai_datatype scale_dt) {
98 // clang-format on
99 KAI_ASSERT((n_idx % nr) == 0);
100 KAI_ASSERT((k % bl) == 0);
101 KAI_ASSERT((bl % kr) == 0);
102 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
103 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
104 KAI_ASSERT(scale_dt == kai_dt_bf16);
105
106 4404 return (n_idx / nr) *
107 2202 kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
108 }
109
110 // clang-format off
111 66 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
112 const size_t n, //
113 const size_t k, //
114 const size_t nr, //
115 const size_t kr, //
116 const size_t sr, //
117 const size_t bl, //
118 const enum kai_datatype scale_dt) {
119 // clang-format on
120 KAI_ASSERT((k % bl) == 0);
121 KAI_ASSERT((bl % kr) == 0);
122 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
123 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
124 KAI_ASSERT(scale_dt == kai_dt_bf16);
125
126 66 const size_t num_rows = kai_roundup(n, nr) / nr;
127
128 198 return num_rows *
129 66 kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
130 66 }
131
132 66 void kai_run_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
133 // clang-format off
134 const size_t num_groups,
135 const size_t n,
136 const size_t k,
137 const size_t nr,
138 const size_t kr,
139 const size_t sr,
140 const size_t bl,
141 const uint8_t* rhs,
142 const size_t rhs_stride,
143 const float* bias,
144 const void* scale,
145 const size_t scale_stride,
146 void* rhs_packed,
147 const size_t extra_bytes,
148 const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) {
149 // clang-format on
150 66 KAI_UNUSED(num_groups);
151 66 KAI_UNUSED(extra_bytes);
152 KAI_ASSERT(rhs != NULL);
153 KAI_ASSERT(scale != NULL);
154 KAI_ASSERT(rhs_packed != NULL);
155 KAI_ASSERT(params != NULL);
156 KAI_ASSERT(params->rhs_zero_point == 8);
157 KAI_ASSERT(params->lhs_zero_point == 1);
158
159 KAI_ASSERT((k % bl) == 0);
160 KAI_ASSERT((bl % kr) == 0);
161 KAI_ASSERT((kr % sr) == 0);
162 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
163 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
164 KAI_ASSERT(params->scale_dt == kai_dt_bf16);
165
166 // Note: The input matrix (rhs) is expected with:
167 // "k" columns and "n" rows (NxK)
168 66 const size_t block_length = kr / sr;
169 KAI_ASSERT(block_length == 4);
170 66 const enum kai_datatype scale_dt = params->scale_dt;
171 66 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
172 132 const size_t rhs_packed_offset_end_of_all_blocks =
173 66 kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs);
174 66 const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl);
175 66 const size_t num_bytes_per_block_k = bl / 2;
176 66 const size_t dst_num_rows = kai_roundup(n, nr);
177 66 const size_t block_length_in_bytes = block_length / 2;
178
179 66 const int8x16_t rhs_zero_point = vdupq_n_s8(8);
180 66 const uint8x16_t low_mask = vdupq_n_u8(0x0F);
181 66 const size_t num_bytes_processed = 16;
182
183 66 uint8_t* dst_row = (uint8_t*)rhs_packed;
184
185
2/2
✓ Branch 0 taken 66 times.
✓ Branch 1 taken 188 times.
254 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) {
186 188 float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks);
187
188 // Initialize the RHS reduction sums to zero
189 188 memset(sums, 0, nr * kai_num_bytes_sum_rhs);
190
191 // Iterate over the quantized blocks
192
2/2
✓ Branch 0 taken 1680 times.
✓ Branch 1 taken 188 times.
1868 for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) {
193 // Store the scales after packing all K values in the block
194 1680 uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr;
195 1680 const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * num_bytes_multiplier_rhs;
196
197
2/2
✓ Branch 0 taken 107520 times.
✓ Branch 1 taken 1680 times.
109200 for (size_t i = 0; i < nr; ++i) {
198
2/2
✓ Branch 0 taken 102366 times.
✓ Branch 1 taken 5154 times.
107520 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
199 107520 const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride;
200 107520 void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs;
201
202 107520 memcpy(
203 dst_scales_ptr, //
204 src_scales_ptr, //
205 num_bytes_multiplier_rhs); //
206 107520 }
207
208 1680 size_t k0_idx_i = dst_qblock_idx * bl;
209
210
2/2
✓ Branch 0 taken 1750 times.
✓ Branch 1 taken 1680 times.
3430 for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += num_bytes_processed) {
211
2/2
✓ Branch 0 taken 28000 times.
✓ Branch 1 taken 1750 times.
29750 for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) {
212 // Clamp the indices to avoid out-of-bound reads
213
2/2
✓ Branch 0 taken 26756 times.
✓ Branch 1 taken 1244 times.
28000 const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1);
214
2/2
✓ Branch 0 taken 26756 times.
✓ Branch 1 taken 1244 times.
28000 const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1);
215
2/2
✓ Branch 0 taken 26728 times.
✓ Branch 1 taken 1272 times.
28000 const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1);
216
2/2
✓ Branch 0 taken 26248 times.
✓ Branch 1 taken 1752 times.
28000 const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1);
217
218 // Load scales
219 28000 const float d0 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 0]);
220 28000 const float d1 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 1]);
221 28000 const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]);
222 28000 const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]);
223
224 // Initialize partial sum
225 28000 int32_t partial_sum0 = 0;
226 28000 int32_t partial_sum1 = 0;
227 28000 int32_t partial_sum2 = 0;
228 28000 int32_t partial_sum3 = 0;
229
230 28000 const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx);
231 28000 const uint8x16_t vsrc0_0 = vld1q_u8(src_block_base + n0_idx * rhs_stride);
232 28000 const uint8x16_t vsrc1_0 = vld1q_u8(src_block_base + n1_idx * rhs_stride);
233 28000 const uint8x16_t vsrc2_0 = vld1q_u8(src_block_base + n2_idx * rhs_stride);
234 28000 const uint8x16_t vsrc3_0 = vld1q_u8(src_block_base + n3_idx * rhs_stride);
235
236 // Get the lower and higher nibble and apply zero-points
237 56000 const int8x16_t vsrc0_0_lo =
238 28000 vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vsrc0_0, low_mask)), rhs_zero_point);
239 28000 const int8x16_t vsrc0_0_hi = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vsrc0_0, 4)), rhs_zero_point);
240 56000 const int8x16_t vsrc1_0_lo =
241 28000 vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vsrc1_0, low_mask)), rhs_zero_point);
242 28000 const int8x16_t vsrc1_0_hi = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vsrc1_0, 4)), rhs_zero_point);
243 56000 const int8x16_t vsrc2_0_lo =
244 28000 vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vsrc2_0, low_mask)), rhs_zero_point);
245 28000 const int8x16_t vsrc2_0_hi = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vsrc2_0, 4)), rhs_zero_point);
246 56000 const int8x16_t vsrc3_0_lo =
247 28000 vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vsrc3_0, low_mask)), rhs_zero_point);
248 28000 const int8x16_t vsrc3_0_hi = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vsrc3_0, 4)), rhs_zero_point);
249
250 // Calculate and store row sums
251 28000 partial_sum0 += vaddlvq_s16(vaddl_s8(
252 28000 vadd_s8(vget_low_s8(vsrc0_0_lo), vget_high_s8(vsrc0_0_lo)),
253 28000 vadd_s8(vget_low_s8(vsrc0_0_hi), vget_high_s8(vsrc0_0_hi))));
254 28000 partial_sum1 += vaddlvq_s16(vaddl_s8(
255 28000 vadd_s8(vget_low_s8(vsrc1_0_lo), vget_high_s8(vsrc1_0_lo)),
256 28000 vadd_s8(vget_low_s8(vsrc1_0_hi), vget_high_s8(vsrc1_0_hi))));
257 28000 partial_sum2 += vaddlvq_s16(vaddl_s8(
258 28000 vadd_s8(vget_low_s8(vsrc2_0_lo), vget_high_s8(vsrc2_0_lo)),
259 28000 vadd_s8(vget_low_s8(vsrc2_0_hi), vget_high_s8(vsrc2_0_hi))));
260 28000 partial_sum3 += vaddlvq_s16(vaddl_s8(
261 28000 vadd_s8(vget_low_s8(vsrc3_0_lo), vget_high_s8(vsrc3_0_lo)),
262 28000 vadd_s8(vget_low_s8(vsrc3_0_hi), vget_high_s8(vsrc3_0_hi))));
263
264 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
265 28000 sums[nr_idx + 0] += (float)partial_sum0 * d0;
266 28000 sums[nr_idx + 1] += (float)partial_sum1 * d1;
267 28000 sums[nr_idx + 2] += (float)partial_sum2 * d2;
268 28000 sums[nr_idx + 3] += (float)partial_sum3 * d3;
269 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
270
271 56000 const uint8x16_t vdst_u8_0 = vorrq_u8(
272 28000 vandq_u8(vreinterpretq_u8_s8(vsrc0_0_lo), low_mask),
273 28000 vshlq_n_u8(vreinterpretq_u8_s8(vsrc0_0_hi), 4));
274 56000 const uint8x16_t vdst_u8_1 = vorrq_u8(
275 28000 vandq_u8(vreinterpretq_u8_s8(vsrc1_0_lo), low_mask),
276 28000 vshlq_n_u8(vreinterpretq_u8_s8(vsrc1_0_hi), 4));
277 56000 const uint8x16_t vdst_u8_2 = vorrq_u8(
278 28000 vandq_u8(vreinterpretq_u8_s8(vsrc2_0_lo), low_mask),
279 28000 vshlq_n_u8(vreinterpretq_u8_s8(vsrc2_0_hi), 4));
280 56000 const uint8x16_t vdst_u8_3 = vorrq_u8(
281 28000 vandq_u8(vreinterpretq_u8_s8(vsrc3_0_lo), low_mask),
282 28000 vshlq_n_u8(vreinterpretq_u8_s8(vsrc3_0_hi), 4));
283
284 // Reorder to interleave nr rows
285 28000 const uint16x8_t vdst_u16_0 = vreinterpretq_u16_u8(vdst_u8_0);
286 28000 const uint16x8_t vdst_u16_1 = vreinterpretq_u16_u8(vdst_u8_1);
287 28000 const uint16x8_t vdst_u16_2 = vreinterpretq_u16_u8(vdst_u8_2);
288 28000 const uint16x8_t vdst_u16_3 = vreinterpretq_u16_u8(vdst_u8_3);
289
290 28000 const uint32x4_t vdst_u32_0 = vreinterpretq_u32_u16(vzip1q_u16(vdst_u16_0, vdst_u16_1));
291 28000 const uint32x4_t vdst_u32_1 = vreinterpretq_u32_u16(vzip1q_u16(vdst_u16_2, vdst_u16_3));
292 28000 const uint32x4_t vdst_u32_2 = vreinterpretq_u32_u16(vzip2q_u16(vdst_u16_0, vdst_u16_1));
293 28000 const uint32x4_t vdst_u32_3 = vreinterpretq_u32_u16(vzip2q_u16(vdst_u16_2, vdst_u16_3));
294
295 28000 const uint32x4_t vdst0_0 = vzip1q_u32(vdst_u32_0, vdst_u32_1);
296 28000 const uint32x4_t vdst1_0 = vzip2q_u32(vdst_u32_0, vdst_u32_1);
297 28000 const uint32x4_t vdst2_0 = vzip1q_u32(vdst_u32_2, vdst_u32_3);
298 28000 const uint32x4_t vdst3_0 = vzip2q_u32(vdst_u32_2, vdst_u32_3);
299
300 // Store packed values
301 28000 vst1_u32((uint32_t*)dst_row, vget_low_u32(vdst0_0));
302 28000 vst1_u32((uint32_t*)(dst_row + nr * block_length_in_bytes), vget_high_u32(vdst0_0));
303 28000 vst1_u32((uint32_t*)(dst_row + (2 * nr * block_length_in_bytes)), vget_low_u32(vdst1_0));
304
305 28000 vst1_u32((uint32_t*)(dst_row + (3 * nr * block_length_in_bytes)), vget_high_u32(vdst1_0));
306 28000 vst1_u32((uint32_t*)(dst_row + (4 * nr * block_length_in_bytes)), vget_low_u32(vdst2_0));
307 28000 vst1_u32((uint32_t*)(dst_row + (5 * nr * block_length_in_bytes)), vget_high_u32(vdst2_0));
308 28000 vst1_u32((uint32_t*)(dst_row + (6 * nr * block_length_in_bytes)), vget_low_u32(vdst3_0));
309 28000 vst1_u32((uint32_t*)(dst_row + (7 * nr * block_length_in_bytes)), vget_high_u32(vdst3_0));
310
311 28000 dst_row += (4 * block_length_in_bytes);
312 28000 }
313 // Skip to end of qblock
314 1750 dst_row += 7 * nr * block_length_in_bytes;
315 1750 }
316
317 // Move the pointer after scales
318 1680 dst_row += num_bytes_multiplier_rhs * nr;
319 1680 }
320
321 // Move the pointer after the row sum
322 188 dst_row += kai_num_bytes_sum_rhs * nr;
323
324 // Set the bias
325
1/2
✓ Branch 0 taken 188 times.
✗ Branch 1 not taken.
188 if (bias == NULL) {
326 memset(dst_row, 0, nr * kai_num_bytes_bias);
327 } else {
328
2/2
✓ Branch 0 taken 12032 times.
✓ Branch 1 taken 188 times.
12220 for (size_t i = 0; i < nr; ++i) {
329 // Clamp the row index to avoid out-of-bound reads
330
2/2
✓ Branch 0 taken 11180 times.
✓ Branch 1 taken 852 times.
12032 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
331 12032 ((float*)dst_row)[i] = bias[src_row_idx];
332 12032 }
333 }
334
335 // Move the pointer after the row sum
336 188 dst_row += kai_num_bytes_bias * nr;
337 188 }
338 66 }
339 #endif // Architectural features check.
340