KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 94.9% 130 / 35 / 172
Functions: 88.9% 8 / 0 / 9
Branches: 95.0% 19 / 70 / 90

kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10
11 #include "kai_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon.h"
12
13 #include <arm_neon.h>
14 #include <stddef.h>
15 #include <stdint.h>
16 #include <string.h>
17
18 #include "kai/kai_common.h"
19
20 // nrx4 => this function can take in generic nr values but the input is expected to have a block depth of 4.
21 // Block depth is calculated as kr / sr. The values of these parameters are defined in the matmul ukernel.
22
23 static const size_t kai_num_bytes_sum_rhs = sizeof(float);
24 static const size_t kai_num_bytes_bias = sizeof(float);
25 static const size_t kai_nr_multiple_of = 4;
26 static const size_t kai_bl_multiple_of = 32;
27
28 2320 static size_t kai_get_num_blocks_per_row(const size_t k, const size_t bl) {
29 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
30 2320 return kai_roundup(k, bl) / bl;
31 }
32
33 2274 static size_t kai_get_num_bytes_per_block(const size_t bl, const size_t num_bytes_multiplier_rhs) {
34 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
35 2274 return (bl / 2) + num_bytes_multiplier_rhs;
36 }
37
38 46 static size_t kai_get_rhs_packed_offset_end_of_all_blocks(
39 // clang-format off
40 const size_t k,
41 const size_t nr,
42 const size_t kr,
43 const size_t bl,
44 const size_t num_bytes_multiplier_rhs) {
45 // clang-format on
46 KAI_ASSERT((bl % kr) == 0);
47 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
48 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
49
50 46 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
51 46 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
52
53 92 return (nr * num_bytes_per_block * num_blocks_per_row);
54 46 }
55
56 size_t kai_get_n_step_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(const size_t nr) {
57 return nr;
58 }
59
60 46 size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
61 const size_t n_idx, //
62 const size_t rhs_stride) {
63 46 KAI_UNUSED(rhs_stride);
64 KAI_ASSERT((n_idx % 2) == 0);
65
66 46 return (n_idx / 2) * sizeof(int8_t);
67 }
68
69 2228 size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
70 // clang-format off
71 const size_t k,
72 const size_t nr,
73 const size_t kr,
74 const size_t sr,
75 const size_t bl,
76 const enum kai_datatype scale_dt) {
77 // clang-format on
78 KAI_ASSERT((k % bl) == 0);
79 KAI_ASSERT((bl % kr) == 0);
80 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
81 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
82 KAI_ASSERT(scale_dt == kai_dt_bf16);
83
84 2228 KAI_UNUSED(kr);
85 2228 KAI_UNUSED(sr);
86
87 2228 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
88 2228 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
89 2228 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs);
90
91 4456 return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
92 2228 }
93
94 // clang-format off
95 2182 size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
96 const size_t n_idx,
97 const size_t k,
98 const size_t nr,
99 const size_t kr,
100 const size_t sr,
101 const size_t bl,
102 const enum kai_datatype scale_dt) {
103 // clang-format on
104 KAI_ASSERT((n_idx % nr) == 0);
105 KAI_ASSERT((k % bl) == 0);
106 KAI_ASSERT((bl % kr) == 0);
107 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
108 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
109 KAI_ASSERT(scale_dt == kai_dt_bf16);
110
111 4364 return (n_idx / nr) *
112 2182 kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
113 }
114
115 // clang-format off
116 46 size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
117 const size_t n, //
118 const size_t k, //
119 const size_t nr, //
120 const size_t kr, //
121 const size_t sr, //
122 const size_t bl, //
123 const enum kai_datatype scale_dt) {
124 // clang-format on
125 KAI_ASSERT((k % bl) == 0);
126 KAI_ASSERT((bl % kr) == 0);
127 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
128 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
129 KAI_ASSERT(scale_dt == kai_dt_bf16);
130
131 46 const size_t num_rows = kai_roundup(n, nr) / nr;
132
133 138 return num_rows *
134 46 kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt);
135 46 }
136
137 46 void kai_run_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
138 // clang-format off
139 const size_t num_groups,
140 const size_t n,
141 const size_t k,
142 const size_t nr,
143 const size_t kr,
144 const size_t sr,
145 const size_t bl,
146 const uint8_t* rhs,
147 const size_t rhs_stride,
148 const float* bias,
149 const void* scale,
150 const size_t scale_stride,
151 void* rhs_packed,
152 const size_t extra_bytes,
153 const struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params* params) {
154 // clang-format on
155 46 KAI_UNUSED(num_groups);
156 46 KAI_UNUSED(extra_bytes);
157 KAI_ASSERT(rhs != NULL);
158 KAI_ASSERT(scale != NULL);
159 KAI_ASSERT(rhs_packed != NULL);
160 KAI_ASSERT(params != NULL);
161 KAI_ASSERT(params->rhs_zero_point == 8);
162 KAI_ASSERT(params->lhs_zero_point == 1);
163
164 KAI_ASSERT((k % bl) == 0);
165 KAI_ASSERT((bl % kr) == 0);
166 KAI_ASSERT((kr % sr) == 0);
167 KAI_ASSERT((nr % kai_nr_multiple_of) == 0);
168 KAI_ASSERT((bl % kai_bl_multiple_of) == 0);
169 KAI_ASSERT(params->scale_dt == kai_dt_bf16);
170
171 // Note: The input matrix (rhs) is expected with:
172 // "k" rows and "n" columns (kxn)
173 46 const size_t block_length = kr / sr;
174 KAI_ASSERT(block_length == 4);
175 46 const enum kai_datatype scale_dt = params->scale_dt;
176 46 const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt);
177 92 const size_t rhs_packed_offset_end_of_all_blocks =
178 46 kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs);
179 46 const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl);
180 46 const size_t num_bytes_per_block_k = bl / 2;
181 46 const size_t dst_num_rows = kai_roundup(n, nr);
182 46 const size_t block_length_in_bytes = block_length / 2;
183
184 46 const int8x8_t rhs_zero_point = vdup_n_s8(8);
185 46 const uint8x8_t low_mask = vdup_n_u8(0x0F);
186 46 const size_t num_bytes_processed = 2;
187
188 46 uint8_t* dst_row = (uint8_t*)rhs_packed;
189
190
2/2
✓ Branch 0 taken 46 times.
✓ Branch 1 taken 77 times.
123 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) {
191 77 float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks);
192
193 // Initialize the RHS reduction sums to zero
194 77 memset(sums, 0, nr * kai_num_bytes_sum_rhs);
195
196 // Iterate over the quantized blocks
197
2/2
✓ Branch 0 taken 614 times.
✓ Branch 1 taken 77 times.
691 for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) {
198 // Store the scales after packing all K values in the block
199 614 uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr;
200 614 const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * num_bytes_multiplier_rhs;
201
202
2/2
✓ Branch 0 taken 39296 times.
✓ Branch 1 taken 614 times.
39910 for (size_t i = 0; i < nr; ++i) {
203
2/2
✓ Branch 0 taken 34626 times.
✓ Branch 1 taken 4670 times.
39296 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
204 39296 const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride;
205 39296 void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs;
206
207 39296 memcpy(
208 dst_scales_ptr, //
209 src_scales_ptr, //
210 num_bytes_multiplier_rhs); //
211 39296 }
212
213 614 size_t k0_idx_i = dst_qblock_idx * bl;
214
215
2/2
✓ Branch 0 taken 5472 times.
✓ Branch 1 taken 614 times.
6086 for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += num_bytes_processed) {
216
2/2
✓ Branch 0 taken 21888 times.
✓ Branch 1 taken 5472 times.
27360 for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 16) {
217 // Clamp the indices to avoid out-of-bound reads
218
2/2
✓ Branch 0 taken 19712 times.
✓ Branch 1 taken 2176 times.
21888 const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1);
219
220 // Load scales and convert to float
221 #if defined(__ARM_FEATURE_BF16)
222
223 21888 const bfloat16_t* rhs_bf16_scale = (const bfloat16_t*)rhs_packed_scale + nr_idx + 0;
224 21888 const bfloat16x4x4_t vd_bf16 = vld1_bf16_x4(rhs_bf16_scale);
225 21888 const float32x4_t vd_0 = vcvt_f32_bf16(vd_bf16.val[0]);
226 21888 const float32x4_t vd_1 = vcvt_f32_bf16(vd_bf16.val[1]);
227 21888 const float32x4_t vd_2 = vcvt_f32_bf16(vd_bf16.val[2]);
228 21888 const float32x4_t vd_3 = vcvt_f32_bf16(vd_bf16.val[3]);
229 #else
230 // Portable BF16 -> F32 conversion using integer NEON: (u16 << 16) reinterpret as f32
231 const uint16_t* bf16_ptr = ((const uint16_t*)rhs_packed_scale) + nr_idx;
232 const uint16x4_t vbf0 = vld1_u16(bf16_ptr + 0);
233 const uint16x4_t vbf1 = vld1_u16(bf16_ptr + 4);
234 const uint16x4_t vbf2 = vld1_u16(bf16_ptr + 8);
235 const uint16x4_t vbf3 = vld1_u16(bf16_ptr + 12);
236 const uint32x4_t vbf0_u32 = vshlq_n_u32(vmovl_u16(vbf0), 16);
237 const uint32x4_t vbf1_u32 = vshlq_n_u32(vmovl_u16(vbf1), 16);
238 const uint32x4_t vbf2_u32 = vshlq_n_u32(vmovl_u16(vbf2), 16);
239 const uint32x4_t vbf3_u32 = vshlq_n_u32(vmovl_u16(vbf3), 16);
240 const float32x4_t vd_0 = vreinterpretq_f32_u32(vbf0_u32);
241 const float32x4_t vd_1 = vreinterpretq_f32_u32(vbf1_u32);
242 const float32x4_t vd_2 = vreinterpretq_f32_u32(vbf2_u32);
243 const float32x4_t vd_3 = vreinterpretq_f32_u32(vbf3_u32);
244 #endif
245
246 21888 const uint8_t* src_block_base = rhs + n0_idx / 2;
247 21888 const size_t k_idx = k0_idx_i + dst_byte_idx * 2;
248 21888 const uint8x8_t vsrc0_0 = vld1_u8(src_block_base + ((k_idx)*rhs_stride));
249 21888 const uint8x8_t vsrc1_0 = vld1_u8(src_block_base + ((k_idx + 1) * rhs_stride));
250 21888 const uint8x8_t vsrc2_0 = vld1_u8(src_block_base + ((k_idx + 2) * rhs_stride));
251 21888 const uint8x8_t vsrc3_0 = vld1_u8(src_block_base + ((k_idx + 3) * rhs_stride));
252
253 // Get the lower and higher nibble and apply zero-points
254 21888 const int8x8_t vsrc0_lo = vsub_s8(vreinterpret_s8_u8(vand_u8(vsrc0_0, low_mask)), rhs_zero_point);
255 21888 const int8x8_t vsrc0_hi = vsub_s8(vreinterpret_s8_u8(vshr_n_u8(vsrc0_0, 4)), rhs_zero_point);
256 21888 const int8x8_t vsrc1_lo = vsub_s8(vreinterpret_s8_u8(vand_u8(vsrc1_0, low_mask)), rhs_zero_point);
257 21888 const int8x8_t vsrc1_hi = vsub_s8(vreinterpret_s8_u8(vshr_n_u8(vsrc1_0, 4)), rhs_zero_point);
258 21888 const int8x8_t vsrc2_lo = vsub_s8(vreinterpret_s8_u8(vand_u8(vsrc2_0, low_mask)), rhs_zero_point);
259 21888 const int8x8_t vsrc2_hi = vsub_s8(vreinterpret_s8_u8(vshr_n_u8(vsrc2_0, 4)), rhs_zero_point);
260 21888 const int8x8_t vsrc3_lo = vsub_s8(vreinterpret_s8_u8(vand_u8(vsrc3_0, low_mask)), rhs_zero_point);
261 21888 const int8x8_t vsrc3_hi = vsub_s8(vreinterpret_s8_u8(vshr_n_u8(vsrc3_0, 4)), rhs_zero_point);
262
263 // Calculate and store row sums
264 21888 const int16x8_t vsum_lo = vaddl_s8(vadd_s8(vsrc0_lo, vsrc1_lo), vadd_s8(vsrc2_lo, vsrc3_lo));
265 21888 const int16x8_t vsum_hi = vaddl_s8(vadd_s8(vsrc0_hi, vsrc1_hi), vadd_s8(vsrc2_hi, vsrc3_hi));
266
267 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
268 21888 const float16x8_t vsum_0 = vcvtq_f16_s16(vzip1q_s16(vsum_lo, vsum_hi));
269 21888 const float16x8_t vsum_1 = vcvtq_f16_s16(vzip2q_s16(vsum_lo, vsum_hi));
270
271 21888 const float32x4_t vpartialsum_0 = vcvt_f32_f16(vget_low_f16(vsum_0));
272 21888 const float32x4_t vpartialsum_1 = vcvt_high_f32_f16(vsum_0);
273 21888 const float32x4_t vpartialsum_2 = vcvt_f32_f16(vget_low_f16(vsum_1));
274 21888 const float32x4_t vpartialsum_3 = vcvt_high_f32_f16(vsum_1);
275 #else
276 // Portable int16 -> f32 path without FP16 vector arithmetic
277 const int16x8_t _zip0 = vzip1q_s16(vsum_lo, vsum_hi);
278 const int16x8_t _zip1 = vzip2q_s16(vsum_lo, vsum_hi);
279 const int32x4_t i0_lo = vmovl_s16(vget_low_s16(_zip0));
280 const int32x4_t i0_hi = vmovl_s16(vget_high_s16(_zip0));
281 const int32x4_t i1_lo = vmovl_s16(vget_low_s16(_zip1));
282 const int32x4_t i1_hi = vmovl_s16(vget_high_s16(_zip1));
283 const float32x4_t vpartialsum_0 = vcvtq_f32_s32(i0_lo);
284 const float32x4_t vpartialsum_1 = vcvtq_f32_s32(i0_hi);
285 const float32x4_t vpartialsum_2 = vcvtq_f32_s32(i1_lo);
286 const float32x4_t vpartialsum_3 = vcvtq_f32_s32(i1_hi);
287 #endif
288
289 21888 float32x4_t vsum_f32_0 = vld1q_f32(sums + nr_idx);
290 21888 float32x4_t vsum_f32_1 = vld1q_f32(sums + nr_idx + 4);
291 21888 float32x4_t vsum_f32_2 = vld1q_f32(sums + nr_idx + 8);
292 21888 float32x4_t vsum_f32_3 = vld1q_f32(sums + nr_idx + 12);
293
294 21888 vsum_f32_0 = vfmaq_f32(vsum_f32_0, vpartialsum_0, vd_0);
295 21888 vsum_f32_1 = vfmaq_f32(vsum_f32_1, vpartialsum_1, vd_1);
296 21888 vsum_f32_2 = vfmaq_f32(vsum_f32_2, vpartialsum_2, vd_2);
297 21888 vsum_f32_3 = vfmaq_f32(vsum_f32_3, vpartialsum_3, vd_3);
298
299 21888 vst1q_f32(sums + nr_idx, vsum_f32_0);
300 21888 vst1q_f32(sums + nr_idx + 4, vsum_f32_1);
301 21888 vst1q_f32(sums + nr_idx + 8, vsum_f32_2);
302 21888 vst1q_f32(sums + nr_idx + 12, vsum_f32_3);
303
304 43776 const uint8x8_t vdst_u8_0 = vorr_u8(
305 21888 vand_u8(vreinterpret_u8_s8(vsrc0_lo), low_mask), vshl_n_u8(vreinterpret_u8_s8(vsrc1_lo), 4));
306 43776 const uint8x8_t vdst_u8_1 = vorr_u8(
307 21888 vand_u8(vreinterpret_u8_s8(vsrc2_lo), low_mask), vshl_n_u8(vreinterpret_u8_s8(vsrc3_lo), 4));
308 43776 const uint8x8_t vdst_u8_2 = vorr_u8(
309 21888 vand_u8(vreinterpret_u8_s8(vsrc0_hi), low_mask), vshl_n_u8(vreinterpret_u8_s8(vsrc1_hi), 4));
310 43776 const uint8x8_t vdst_u8_3 = vorr_u8(
311 21888 vand_u8(vreinterpret_u8_s8(vsrc2_hi), low_mask), vshl_n_u8(vreinterpret_u8_s8(vsrc3_hi), 4));
312
313 43776 const uint16x8_t vdst_u16_even = vreinterpretq_u16_u8(
314 21888 vcombine_u8(vzip1_u8(vdst_u8_0, vdst_u8_1), vzip2_u8(vdst_u8_0, vdst_u8_1)));
315 43776 const uint16x8_t vdst_u16_odd = vreinterpretq_u16_u8(
316 21888 vcombine_u8(vzip1_u8(vdst_u8_2, vdst_u8_3), vzip2_u8(vdst_u8_2, vdst_u8_3)));
317
318 21888 const uint16x8_t vdst_0 = vzip1q_u16(vdst_u16_even, vdst_u16_odd);
319 21888 const uint16x8_t vdst_1 = vzip2q_u16(vdst_u16_even, vdst_u16_odd);
320
321 21888 vst1q_u16((uint16_t*)dst_row, vdst_0);
322 21888 vst1q_u16((uint16_t*)(dst_row + 8 * block_length_in_bytes), vdst_1);
323
324 21888 dst_row += (16 * block_length_in_bytes);
325 21888 }
326 5472 }
327 // Move the pointer after scales
328 614 dst_row += num_bytes_multiplier_rhs * nr;
329 614 }
330
331 // Move the pointer after the row sum
332 77 dst_row += kai_num_bytes_sum_rhs * nr;
333
334 // Set the bias
335
1/2
✓ Branch 0 taken 77 times.
✗ Branch 1 not taken.
77 if (bias == NULL) {
336 memset(dst_row, 0, nr * kai_num_bytes_bias);
337 } else {
338
2/2
✓ Branch 0 taken 4928 times.
✓ Branch 1 taken 77 times.
5005 for (size_t i = 0; i < nr; ++i) {
339 // Clamp the row index to avoid out-of-bound reads
340
2/2
✓ Branch 0 taken 4258 times.
✓ Branch 1 taken 670 times.
4928 const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1);
341 4928 ((float*)dst_row)[i] = bias[src_row_idx];
342 4928 }
343 }
344 // Move the pointer after the row sum
345 77 dst_row += kai_num_bytes_bias * nr;
346 77 }
347 46 }
348 #endif // Architectural features check.
349