KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.3% 217 11 234
Functions: 87.5% 7 0 8
Branches: 88.5% 23 22 48

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
7
8 #include <arm_neon.h>
9 #include <stddef.h>
10 #include <stdint.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_multiplier = sizeof(uint16_t);
15
16 115 inline static size_t kai_num_bytes_per_block(size_t bl) {
17 115 return bl * sizeof(int8_t) + kai_num_bytes_multiplier;
18 }
19
20 115 inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
21 KAI_ASSERT((k % bl) == 0);
22 115 return k / bl;
23 }
24
25 92 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) {
26 92 KAI_UNUSED(kr);
27 92 return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl);
28 }
29
30 size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t mr) {
31 KAI_UNUSED(mr);
32 return 1;
33 }
34
35 23 size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t m_idx, size_t lhs_stride) {
36 23 return m_idx * lhs_stride;
37 }
38
39 46 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(
40 size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
41 KAI_ASSUME((k % 2) == 0);
42 KAI_ASSUME((k % kr) == 0);
43 KAI_ASSUME((k % bl) == 0);
44
45 46 KAI_UNUSED(sr);
46 46 KAI_UNUSED(kr);
47
48 46 return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl);
49 }
50
51 23 size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(
52 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
53 KAI_ASSUME((k % 2) == 0);
54 KAI_ASSUME((k % kr) == 0);
55 KAI_ASSUME((k % bl) == 0);
56
57 23 KAI_UNUSED(sr);
58 23 KAI_UNUSED(kr);
59
60 23 const size_t num_rows = kai_roundup(m, mr) / mr;
61
62 46 return num_rows * kai_lhs_packed_stride(k, mr, kr, bl);
63 23 }
64
65 23 void kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(
66 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
67 size_t lhs_stride, void* lhs_packed) {
68
1/2
✓ Branch 0 taken 23 times.
✗ Branch 1 not taken.
23 if (m == 0) {
69 return;
70 }
71
72 KAI_ASSUME(bl == 32);
73 KAI_ASSUME(mr == 4);
74 KAI_ASSUME(kr == 16);
75 KAI_ASSUME(sr == 2);
76
77 23 const size_t local_bl = 32;
78 23 const size_t local_mr = 4;
79 23 const size_t local_kr = 16;
80 23 const size_t local_sr = 2;
81 23 const size_t num_rows = m;
82 23 const size_t k_block_len = local_kr / local_sr;
83 23 const size_t lhs_packed_stride = kai_lhs_packed_stride(k, local_mr, local_kr, local_bl);
84 23 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, local_bl);
85 23 const size_t num_bytes_per_block = kai_num_bytes_per_block(local_bl);
86
87 23 size_t row_idx = 0;
88
89 23 const size_t write_mem_increment = 2 * k_block_len * sizeof(int8_t);
90 23 const size_t read_mem_increment = num_blocks_per_row * local_bl * sizeof(int8_t);
91
92
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 20 times.
23 if (num_rows >= 4) {
93
2/2
✓ Branch 0 taken 144 times.
✓ Branch 1 taken 20 times.
164 for (; row_idx + 4 <= num_rows; row_idx += 4) {
94 144 const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride);
95
96
2/2
✓ Branch 0 taken 276 times.
✓ Branch 1 taken 144 times.
420 for (size_t b = 0; b < num_blocks_per_row; ++b) {
97 276 const size_t dst_x = ((row_idx + m_idx_start) % local_mr);
98 276 int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block;
99
100 276 float abs_max_0 = 0.0F;
101 276 float abs_max_1 = 0.0F;
102 276 float abs_max_2 = 0.0F;
103 276 float abs_max_3 = 0.0F;
104
105 276 float32x4_t v_currentmax_0 = vdupq_n_f32(0);
106 276 float32x4_t v_currentmax_1 = vdupq_n_f32(0);
107 276 float32x4_t v_currentmax_2 = vdupq_n_f32(0);
108 276 float32x4_t v_currentmax_3 = vdupq_n_f32(0);
109
110
2/2
✓ Branch 0 taken 2208 times.
✓ Branch 1 taken 276 times.
2484 for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) {
111 2208 const float32x4_t v_f32_maxvals_0 = vld1q_f32(src_ptr + idx_v);
112 2208 const float32x4_t v_f32_abs_values_0 = vabsq_f32(v_f32_maxvals_0);
113 2208 v_currentmax_0 = vmaxq_f32(v_f32_abs_values_0, v_currentmax_0);
114 2208 const float32x4_t v_f32_maxvals_1 = vld1q_f32(src_ptr + idx_v + read_mem_increment);
115 2208 const float32x4_t v_f32_abs_values_1 = vabsq_f32(v_f32_maxvals_1);
116 2208 v_currentmax_1 = vmaxq_f32(v_f32_abs_values_1, v_currentmax_1);
117 2208 const float32x4_t v_f32_maxvals_2 = vld1q_f32(src_ptr + idx_v + 2 * read_mem_increment);
118 2208 const float32x4_t v_f32_abs_values_2 = vabsq_f32(v_f32_maxvals_2);
119 2208 v_currentmax_2 = vmaxq_f32(v_f32_abs_values_2, v_currentmax_2);
120 2208 const float32x4_t v_f32_maxvals_3 = vld1q_f32(src_ptr + idx_v + 3 * read_mem_increment);
121 2208 const float32x4_t v_f32_abs_values_3 = vabsq_f32(v_f32_maxvals_3);
122 2208 v_currentmax_3 = vmaxq_f32(v_f32_abs_values_3, v_currentmax_3);
123 2208 }
124
125 276 abs_max_0 = vmaxvq_f32(v_currentmax_0);
126 276 abs_max_1 = vmaxvq_f32(v_currentmax_1);
127 276 abs_max_2 = vmaxvq_f32(v_currentmax_2);
128 276 abs_max_3 = vmaxvq_f32(v_currentmax_3);
129
130 276 float32x4_t abs_max_vec = vdupq_n_f32(abs_max_0);
131 276 abs_max_vec = vsetq_lane_f32(abs_max_1, abs_max_vec, 1);
132 276 abs_max_vec = vsetq_lane_f32(abs_max_2, abs_max_vec, 2);
133 276 abs_max_vec = vsetq_lane_f32(abs_max_3, abs_max_vec, 3);
134
135 // Calculate scale and reciprocals
136 276 const float32x4_t scales = vdivq_f32(abs_max_vec, vdupq_n_f32((1 << 7) - 1));
137 276 const uint32x4_t valid_scales = vmvnq_u32(vceqq_f32(scales, vdupq_n_f32(0.0F)));
138 276 const float32x4_t reciprocals = vdivq_f32(vdupq_n_f32(1.0F), scales);
139 276 const float32x4_t rep_scales = vbslq_f32(valid_scales, reciprocals, vdupq_n_f32(0.0F));
140 276 const float16x4_t f16_scales = vcvt_f16_f32(scales);
141
142 276 vst1_u16((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier), vreinterpret_u16_f16(f16_scales));
143
144 276 dst_ptr += local_mr * kai_num_bytes_multiplier;
145
146 276 dst_ptr += dst_x * k_block_len * sizeof(int8_t);
147
148 // Quantize and pack the blocks
149
2/2
✓ Branch 0 taken 552 times.
✓ Branch 1 taken 276 times.
828 for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) {
150 // Row 1 blocks
151 552 const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx);
152 552 const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, vgetq_lane_f32(rep_scales, 0));
153 552 const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1);
154
155 552 const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4);
156 552 const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, vgetq_lane_f32(rep_scales, 0));
157 552 const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2);
158
159 1104 const int16x8_t v_full_i16_block1 =
160 552 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2));
161
162 552 const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8);
163 552 const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, vgetq_lane_f32(rep_scales, 0));
164 552 const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3);
165
166 552 const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12);
167 552 const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, vgetq_lane_f32(rep_scales, 0));
168 552 const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4);
169
170 1104 const int16x8_t v_full_i16_block2 =
171 552 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4));
172
173 // Row 2 blocks
174 552 const float32x4_t v_f32_block5 = vld1q_f32(src_ptr + k_idx + read_mem_increment);
175 552 const float32x4_t v_f32_sblock5 = vmulq_n_f32(v_f32_block5, vgetq_lane_f32(rep_scales, 1));
176 552 const int32x4_t v_i32_block5 = vcvtnq_s32_f32(v_f32_sblock5);
177
178 552 const float32x4_t v_f32_block6 = vld1q_f32(src_ptr + k_idx + 4 + read_mem_increment);
179 552 const float32x4_t v_f32_sblock6 = vmulq_n_f32(v_f32_block6, vgetq_lane_f32(rep_scales, 1));
180 552 const int32x4_t v_i32_block6 = vcvtnq_s32_f32(v_f32_sblock6);
181
182 1104 const int16x8_t v_full_i16_block3 =
183 552 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block5), vreinterpretq_s16_s32(v_i32_block6));
184
185 552 const float32x4_t v_f32_block7 = vld1q_f32(src_ptr + k_idx + 8 + read_mem_increment);
186 552 const float32x4_t v_f32_sblock7 = vmulq_n_f32(v_f32_block7, vgetq_lane_f32(rep_scales, 1));
187 552 const int32x4_t v_i32_block7 = vcvtnq_s32_f32(v_f32_sblock7);
188
189 552 const float32x4_t v_f32_block8 = vld1q_f32(src_ptr + k_idx + 12 + read_mem_increment);
190 552 const float32x4_t v_f32_sblock8 = vmulq_n_f32(v_f32_block8, vgetq_lane_f32(rep_scales, 1));
191 552 const int32x4_t v_i32_block8 = vcvtnq_s32_f32(v_f32_sblock8);
192
193 1104 const int16x8_t v_full_i16_block4 =
194 552 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block7), vreinterpretq_s16_s32(v_i32_block8));
195
196 // Row 3 blocks
197 552 const float32x4_t v_f32_block9 = vld1q_f32(src_ptr + k_idx + 2 * read_mem_increment);
198 552 const float32x4_t v_f32_sblock9 = vmulq_n_f32(v_f32_block9, vgetq_lane_f32(rep_scales, 2));
199 552 const int32x4_t v_i32_block9 = vcvtnq_s32_f32(v_f32_sblock9);
200
201 552 const float32x4_t v_f32_blockA = vld1q_f32(src_ptr + k_idx + 4 + 2 * read_mem_increment);
202 552 const float32x4_t v_f32_sblockA = vmulq_n_f32(v_f32_blockA, vgetq_lane_f32(rep_scales, 2));
203 552 const int32x4_t v_i32_blockA = vcvtnq_s32_f32(v_f32_sblockA);
204
205 1104 const int16x8_t v_full_i16_block5 =
206 552 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block9), vreinterpretq_s16_s32(v_i32_blockA));
207
208 552 const float32x4_t v_f32_blockB = vld1q_f32(src_ptr + k_idx + 8 + 2 * read_mem_increment);
209 552 const float32x4_t v_f32_sblockB = vmulq_n_f32(v_f32_blockB, vgetq_lane_f32(rep_scales, 2));
210 552 const int32x4_t v_i32_blockB = vcvtnq_s32_f32(v_f32_sblockB);
211
212 552 const float32x4_t v_f32_blockC = vld1q_f32(src_ptr + k_idx + 12 + 2 * read_mem_increment);
213 552 const float32x4_t v_f32_sblockC = vmulq_n_f32(v_f32_blockC, vgetq_lane_f32(rep_scales, 2));
214 552 const int32x4_t v_i32_blockC = vcvtnq_s32_f32(v_f32_sblockC);
215
216 1104 const int16x8_t v_full_i16_block6 =
217 552 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockB), vreinterpretq_s16_s32(v_i32_blockC));
218
219 // Row 4 blocks
220 552 const float32x4_t v_f32_blockD = vld1q_f32(src_ptr + k_idx + 3 * read_mem_increment);
221 552 const float32x4_t v_f32_sblockD = vmulq_n_f32(v_f32_blockD, vgetq_lane_f32(rep_scales, 3));
222 552 const int32x4_t v_i32_blockD = vcvtnq_s32_f32(v_f32_sblockD);
223
224 552 const float32x4_t v_f32_blockE = vld1q_f32(src_ptr + k_idx + 4 + 3 * read_mem_increment);
225 552 const float32x4_t v_f32_sblockE = vmulq_n_f32(v_f32_blockE, vgetq_lane_f32(rep_scales, 3));
226 552 const int32x4_t v_i32_blockE = vcvtnq_s32_f32(v_f32_sblockE);
227
228 1104 const int16x8_t v_full_i16_block7 =
229 552 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockD), vreinterpretq_s16_s32(v_i32_blockE));
230
231 552 const float32x4_t v_f32_blockF = vld1q_f32(src_ptr + k_idx + 8 + 3 * read_mem_increment);
232 552 const float32x4_t v_f32_sblockF = vmulq_n_f32(v_f32_blockF, vgetq_lane_f32(rep_scales, 3));
233 552 const int32x4_t v_i32_blockF = vcvtnq_s32_f32(v_f32_sblockF);
234
235 552 const float32x4_t v_f32_block0 = vld1q_f32(src_ptr + k_idx + 12 + 3 * read_mem_increment);
236 552 const float32x4_t v_f32_sblock0 = vmulq_n_f32(v_f32_block0, vgetq_lane_f32(rep_scales, 3));
237 552 const int32x4_t v_i32_block0 = vcvtnq_s32_f32(v_f32_sblock0);
238
239 1104 const int16x8_t v_full_i16_block8 =
240 552 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockF), vreinterpretq_s16_s32(v_i32_block0));
241
242 1104 const int8x16_t v_i8_block1_3 =
243 552 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block3));
244 552 vst1q_s8(dst_ptr, v_i8_block1_3);
245 552 dst_ptr += write_mem_increment;
246
247 1104 const int8x16_t v_i8_block5_7 =
248 552 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block5), vreinterpretq_s8_s16(v_full_i16_block7));
249 552 vst1q_s8(dst_ptr, v_i8_block5_7);
250 552 dst_ptr += write_mem_increment;
251
252 1104 const int8x16_t v_i8_block2_4 =
253 552 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block2), vreinterpretq_s8_s16(v_full_i16_block4));
254 552 vst1q_s8(dst_ptr, v_i8_block2_4);
255 552 dst_ptr += write_mem_increment;
256
257 1104 const int8x16_t v_i8_block6_8 =
258 552 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block6), vreinterpretq_s8_s16(v_full_i16_block8));
259 552 vst1q_s8(dst_ptr, v_i8_block6_8);
260 552 dst_ptr += write_mem_increment;
261 552 }
262 276 src_ptr += local_bl;
263 276 }
264 144 lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride);
265 144 }
266 20 }
267
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 11 times.
23 if (num_rows % 4 != 0) {
268
2/2
✓ Branch 0 taken 19 times.
✓ Branch 1 taken 11 times.
30 for (; row_idx < num_rows; ++row_idx) {
269 19 const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride);
270
271
2/2
✓ Branch 0 taken 23 times.
✓ Branch 1 taken 19 times.
42 for (size_t b = 0; b < num_blocks_per_row; ++b) {
272 23 float abs_max = 0.0F;
273
274 23 const size_t dst_x = ((row_idx + m_idx_start) % local_mr);
275 23 int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block;
276
277 23 float32x4_t v_f32_abs_values;
278 23 float32x4_t v_f32_maxvals;
279 23 float32x4_t v_currentmax = vdupq_n_f32(0);
280
281
2/2
✓ Branch 0 taken 184 times.
✓ Branch 1 taken 23 times.
207 for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) {
282 184 v_f32_maxvals = vld1q_f32(src_ptr + idx_v);
283 184 v_f32_abs_values = vabsq_f32(v_f32_maxvals);
284 184 v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax);
285 184 }
286 23 abs_max = vmaxvq_f32(v_currentmax);
287
288 // Calculate scale and reciprocal
289 23 const float scale = abs_max / ((1 << 7) - 1);
290
1/2
✓ Branch 0 taken 23 times.
✗ Branch 1 not taken.
23 const float rep_scale = scale ? 1.0F / scale : 0.0F;
291
292 23 *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale);
293 23 dst_ptr += local_mr * kai_num_bytes_multiplier;
294
295 23 dst_ptr += dst_x * k_block_len * sizeof(int8_t);
296
297 // Quantize and pack the block
298
2/2
✓ Branch 0 taken 46 times.
✓ Branch 1 taken 23 times.
69 for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) {
299 46 const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx);
300 46 const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale);
301 46 const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1);
302
303 46 const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4);
304 46 const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale);
305 46 const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2);
306
307 92 const int16x8_t v_full_i16_block1 =
308 46 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2));
309
310 46 const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8);
311 46 const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale);
312 46 const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3);
313
314 46 const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12);
315 46 const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale);
316 46 const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4);
317
318 92 const int16x8_t v_full_i16_block2 =
319 46 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4));
320
321 92 const int8x16_t v_full_i8_block =
322 46 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block2));
323
324 46 vst1_s8(dst_ptr, vget_low_s8(v_full_i8_block));
325 46 dst_ptr += 8 * sizeof(int8_t);
326 46 dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t);
327
328 46 vst1_s8(dst_ptr, vget_high_s8(v_full_i8_block));
329 46 dst_ptr += 8 * sizeof(int8_t);
330 46 dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t);
331 46 }
332 23 src_ptr += local_bl;
333 23 }
334 // Move to the next row if we have interleaved all Mr rows
335
1/2
✓ Branch 0 taken 19 times.
✗ Branch 1 not taken.
19 if ((((row_idx + 1) + m_idx_start) % local_mr) == 0) {
336 lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride);
337 }
338 19 }
339 11 }
340 23 }
341