KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.7% 217 / 11 / 233
Functions: 87.5% 7 / 0 / 8
Branches: 88.5% 23 / 22 / 48

kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
7
8 #include <arm_neon.h>
9 #include <stddef.h>
10 #include <stdint.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_multiplier = sizeof(uint16_t);
15
16 920 inline static size_t kai_num_bytes_per_block(size_t bl) {
17 920 return bl * sizeof(int8_t) + kai_num_bytes_multiplier;
18 }
19
20 920 inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
21 KAI_ASSERT((k % bl) == 0);
22 920 return k / bl;
23 }
24
25 736 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) {
26 736 KAI_UNUSED(kr);
27 736 return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl);
28 }
29
30 size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t mr) {
31 return mr;
32 }
33
34 184 size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t m_idx, size_t lhs_stride) {
35 184 return m_idx * lhs_stride;
36 }
37
38 368 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(
39 size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
40 KAI_ASSUME((k % 2) == 0);
41 KAI_ASSUME((k % kr) == 0);
42 KAI_ASSUME((k % bl) == 0);
43
44 368 KAI_UNUSED(sr);
45 368 KAI_UNUSED(kr);
46
47 368 return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl);
48 }
49
50 184 size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(
51 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
52 KAI_ASSUME((k % 2) == 0);
53 KAI_ASSUME((k % kr) == 0);
54 KAI_ASSUME((k % bl) == 0);
55
56 184 KAI_UNUSED(sr);
57 184 KAI_UNUSED(kr);
58
59 184 const size_t num_rows = kai_roundup(m, mr) / mr;
60
61 368 return num_rows * kai_lhs_packed_stride(k, mr, kr, bl);
62 184 }
63
64 184 void kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(
65 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
66 size_t lhs_stride, void* lhs_packed) {
67
1/2
✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
184 if (m == 0) {
68 return;
69 }
70
71 KAI_ASSUME(bl == 32);
72 KAI_ASSUME(mr == 4);
73 KAI_ASSUME(kr == 16);
74 KAI_ASSUME(sr == 2);
75
76 184 const size_t local_bl = 32;
77 184 const size_t local_mr = 4;
78 184 const size_t local_kr = 16;
79 184 const size_t local_sr = 2;
80 184 const size_t num_rows = m;
81 184 const size_t k_block_len = local_kr / local_sr;
82 184 const size_t lhs_packed_stride = kai_lhs_packed_stride(k, local_mr, local_kr, local_bl);
83 184 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, local_bl);
84 184 const size_t num_bytes_per_block = kai_num_bytes_per_block(local_bl);
85
86 184 size_t row_idx = 0;
87
88 184 const size_t write_mem_increment = 2 * k_block_len * sizeof(int8_t);
89 184 const size_t read_mem_increment = num_blocks_per_row * local_bl * sizeof(int8_t);
90
91
2/2
✓ Branch 0 taken 72 times.
✓ Branch 1 taken 112 times.
184 if (num_rows >= 4) {
92
2/2
✓ Branch 0 taken 1032 times.
✓ Branch 1 taken 112 times.
1144 for (; row_idx + 4 <= num_rows; row_idx += 4) {
93 1032 const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride);
94
95
2/2
✓ Branch 0 taken 2040 times.
✓ Branch 1 taken 1032 times.
3072 for (size_t b = 0; b < num_blocks_per_row; ++b) {
96 2040 const size_t dst_x = ((row_idx + m_idx_start) % local_mr);
97 2040 int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block;
98
99 2040 float abs_max_0 = 0.0F;
100 2040 float abs_max_1 = 0.0F;
101 2040 float abs_max_2 = 0.0F;
102 2040 float abs_max_3 = 0.0F;
103
104 2040 float32x4_t v_currentmax_0 = vdupq_n_f32(0);
105 2040 float32x4_t v_currentmax_1 = vdupq_n_f32(0);
106 2040 float32x4_t v_currentmax_2 = vdupq_n_f32(0);
107 2040 float32x4_t v_currentmax_3 = vdupq_n_f32(0);
108
109
2/2
✓ Branch 0 taken 16320 times.
✓ Branch 1 taken 2040 times.
18360 for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) {
110 16320 const float32x4_t v_f32_maxvals_0 = vld1q_f32(src_ptr + idx_v);
111 16320 const float32x4_t v_f32_abs_values_0 = vabsq_f32(v_f32_maxvals_0);
112 16320 v_currentmax_0 = vmaxq_f32(v_f32_abs_values_0, v_currentmax_0);
113 16320 const float32x4_t v_f32_maxvals_1 = vld1q_f32(src_ptr + idx_v + read_mem_increment);
114 16320 const float32x4_t v_f32_abs_values_1 = vabsq_f32(v_f32_maxvals_1);
115 16320 v_currentmax_1 = vmaxq_f32(v_f32_abs_values_1, v_currentmax_1);
116 16320 const float32x4_t v_f32_maxvals_2 = vld1q_f32(src_ptr + idx_v + 2 * read_mem_increment);
117 16320 const float32x4_t v_f32_abs_values_2 = vabsq_f32(v_f32_maxvals_2);
118 16320 v_currentmax_2 = vmaxq_f32(v_f32_abs_values_2, v_currentmax_2);
119 16320 const float32x4_t v_f32_maxvals_3 = vld1q_f32(src_ptr + idx_v + 3 * read_mem_increment);
120 16320 const float32x4_t v_f32_abs_values_3 = vabsq_f32(v_f32_maxvals_3);
121 16320 v_currentmax_3 = vmaxq_f32(v_f32_abs_values_3, v_currentmax_3);
122 16320 }
123
124 2040 abs_max_0 = vmaxvq_f32(v_currentmax_0);
125 2040 abs_max_1 = vmaxvq_f32(v_currentmax_1);
126 2040 abs_max_2 = vmaxvq_f32(v_currentmax_2);
127 2040 abs_max_3 = vmaxvq_f32(v_currentmax_3);
128
129 2040 float32x4_t abs_max_vec = vdupq_n_f32(abs_max_0);
130 2040 abs_max_vec = vsetq_lane_f32(abs_max_1, abs_max_vec, 1);
131 2040 abs_max_vec = vsetq_lane_f32(abs_max_2, abs_max_vec, 2);
132 2040 abs_max_vec = vsetq_lane_f32(abs_max_3, abs_max_vec, 3);
133
134 // Calculate scale and reciprocals
135 2040 const float32x4_t scales = vdivq_f32(abs_max_vec, vdupq_n_f32((1 << 7) - 1));
136 2040 const uint32x4_t valid_scales = vmvnq_u32(vceqq_f32(scales, vdupq_n_f32(0.0F)));
137 2040 const float32x4_t reciprocals = vdivq_f32(vdupq_n_f32(1.0F), scales);
138 2040 const float32x4_t rep_scales = vbslq_f32(valid_scales, reciprocals, vdupq_n_f32(0.0F));
139 2040 const float16x4_t f16_scales = vcvt_f16_f32(scales);
140
141 2040 vst1_u16((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier), vreinterpret_u16_f16(f16_scales));
142
143 2040 dst_ptr += local_mr * kai_num_bytes_multiplier;
144
145 2040 dst_ptr += dst_x * k_block_len * sizeof(int8_t);
146
147 // Quantize and pack the blocks
148
2/2
✓ Branch 0 taken 4080 times.
✓ Branch 1 taken 2040 times.
6120 for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) {
149 // Row 1 blocks
150 4080 const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx);
151 4080 const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, vgetq_lane_f32(rep_scales, 0));
152 4080 const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1);
153
154 4080 const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4);
155 4080 const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, vgetq_lane_f32(rep_scales, 0));
156 4080 const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2);
157
158 8160 const int16x8_t v_full_i16_block1 =
159 4080 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2));
160
161 4080 const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8);
162 4080 const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, vgetq_lane_f32(rep_scales, 0));
163 4080 const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3);
164
165 4080 const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12);
166 4080 const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, vgetq_lane_f32(rep_scales, 0));
167 4080 const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4);
168
169 8160 const int16x8_t v_full_i16_block2 =
170 4080 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4));
171
172 // Row 2 blocks
173 4080 const float32x4_t v_f32_block5 = vld1q_f32(src_ptr + k_idx + read_mem_increment);
174 4080 const float32x4_t v_f32_sblock5 = vmulq_n_f32(v_f32_block5, vgetq_lane_f32(rep_scales, 1));
175 4080 const int32x4_t v_i32_block5 = vcvtnq_s32_f32(v_f32_sblock5);
176
177 4080 const float32x4_t v_f32_block6 = vld1q_f32(src_ptr + k_idx + 4 + read_mem_increment);
178 4080 const float32x4_t v_f32_sblock6 = vmulq_n_f32(v_f32_block6, vgetq_lane_f32(rep_scales, 1));
179 4080 const int32x4_t v_i32_block6 = vcvtnq_s32_f32(v_f32_sblock6);
180
181 8160 const int16x8_t v_full_i16_block3 =
182 4080 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block5), vreinterpretq_s16_s32(v_i32_block6));
183
184 4080 const float32x4_t v_f32_block7 = vld1q_f32(src_ptr + k_idx + 8 + read_mem_increment);
185 4080 const float32x4_t v_f32_sblock7 = vmulq_n_f32(v_f32_block7, vgetq_lane_f32(rep_scales, 1));
186 4080 const int32x4_t v_i32_block7 = vcvtnq_s32_f32(v_f32_sblock7);
187
188 4080 const float32x4_t v_f32_block8 = vld1q_f32(src_ptr + k_idx + 12 + read_mem_increment);
189 4080 const float32x4_t v_f32_sblock8 = vmulq_n_f32(v_f32_block8, vgetq_lane_f32(rep_scales, 1));
190 4080 const int32x4_t v_i32_block8 = vcvtnq_s32_f32(v_f32_sblock8);
191
192 8160 const int16x8_t v_full_i16_block4 =
193 4080 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block7), vreinterpretq_s16_s32(v_i32_block8));
194
195 // Row 3 blocks
196 4080 const float32x4_t v_f32_block9 = vld1q_f32(src_ptr + k_idx + 2 * read_mem_increment);
197 4080 const float32x4_t v_f32_sblock9 = vmulq_n_f32(v_f32_block9, vgetq_lane_f32(rep_scales, 2));
198 4080 const int32x4_t v_i32_block9 = vcvtnq_s32_f32(v_f32_sblock9);
199
200 4080 const float32x4_t v_f32_blockA = vld1q_f32(src_ptr + k_idx + 4 + 2 * read_mem_increment);
201 4080 const float32x4_t v_f32_sblockA = vmulq_n_f32(v_f32_blockA, vgetq_lane_f32(rep_scales, 2));
202 4080 const int32x4_t v_i32_blockA = vcvtnq_s32_f32(v_f32_sblockA);
203
204 8160 const int16x8_t v_full_i16_block5 =
205 4080 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block9), vreinterpretq_s16_s32(v_i32_blockA));
206
207 4080 const float32x4_t v_f32_blockB = vld1q_f32(src_ptr + k_idx + 8 + 2 * read_mem_increment);
208 4080 const float32x4_t v_f32_sblockB = vmulq_n_f32(v_f32_blockB, vgetq_lane_f32(rep_scales, 2));
209 4080 const int32x4_t v_i32_blockB = vcvtnq_s32_f32(v_f32_sblockB);
210
211 4080 const float32x4_t v_f32_blockC = vld1q_f32(src_ptr + k_idx + 12 + 2 * read_mem_increment);
212 4080 const float32x4_t v_f32_sblockC = vmulq_n_f32(v_f32_blockC, vgetq_lane_f32(rep_scales, 2));
213 4080 const int32x4_t v_i32_blockC = vcvtnq_s32_f32(v_f32_sblockC);
214
215 8160 const int16x8_t v_full_i16_block6 =
216 4080 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockB), vreinterpretq_s16_s32(v_i32_blockC));
217
218 // Row 4 blocks
219 4080 const float32x4_t v_f32_blockD = vld1q_f32(src_ptr + k_idx + 3 * read_mem_increment);
220 4080 const float32x4_t v_f32_sblockD = vmulq_n_f32(v_f32_blockD, vgetq_lane_f32(rep_scales, 3));
221 4080 const int32x4_t v_i32_blockD = vcvtnq_s32_f32(v_f32_sblockD);
222
223 4080 const float32x4_t v_f32_blockE = vld1q_f32(src_ptr + k_idx + 4 + 3 * read_mem_increment);
224 4080 const float32x4_t v_f32_sblockE = vmulq_n_f32(v_f32_blockE, vgetq_lane_f32(rep_scales, 3));
225 4080 const int32x4_t v_i32_blockE = vcvtnq_s32_f32(v_f32_sblockE);
226
227 8160 const int16x8_t v_full_i16_block7 =
228 4080 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockD), vreinterpretq_s16_s32(v_i32_blockE));
229
230 4080 const float32x4_t v_f32_blockF = vld1q_f32(src_ptr + k_idx + 8 + 3 * read_mem_increment);
231 4080 const float32x4_t v_f32_sblockF = vmulq_n_f32(v_f32_blockF, vgetq_lane_f32(rep_scales, 3));
232 4080 const int32x4_t v_i32_blockF = vcvtnq_s32_f32(v_f32_sblockF);
233
234 4080 const float32x4_t v_f32_block0 = vld1q_f32(src_ptr + k_idx + 12 + 3 * read_mem_increment);
235 4080 const float32x4_t v_f32_sblock0 = vmulq_n_f32(v_f32_block0, vgetq_lane_f32(rep_scales, 3));
236 4080 const int32x4_t v_i32_block0 = vcvtnq_s32_f32(v_f32_sblock0);
237
238 8160 const int16x8_t v_full_i16_block8 =
239 4080 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockF), vreinterpretq_s16_s32(v_i32_block0));
240
241 8160 const int8x16_t v_i8_block1_3 =
242 4080 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block3));
243 4080 vst1q_s8(dst_ptr, v_i8_block1_3);
244 4080 dst_ptr += write_mem_increment;
245
246 8160 const int8x16_t v_i8_block5_7 =
247 4080 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block5), vreinterpretq_s8_s16(v_full_i16_block7));
248 4080 vst1q_s8(dst_ptr, v_i8_block5_7);
249 4080 dst_ptr += write_mem_increment;
250
251 8160 const int8x16_t v_i8_block2_4 =
252 4080 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block2), vreinterpretq_s8_s16(v_full_i16_block4));
253 4080 vst1q_s8(dst_ptr, v_i8_block2_4);
254 4080 dst_ptr += write_mem_increment;
255
256 8160 const int8x16_t v_i8_block6_8 =
257 4080 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block6), vreinterpretq_s8_s16(v_full_i16_block8));
258 4080 vst1q_s8(dst_ptr, v_i8_block6_8);
259 4080 dst_ptr += write_mem_increment;
260 4080 }
261 2040 src_ptr += local_bl;
262 2040 }
263 1032 lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride);
264 1032 }
265 112 }
266
2/2
✓ Branch 0 taken 48 times.
✓ Branch 1 taken 136 times.
184 if (num_rows % 4 != 0) {
267
2/2
✓ Branch 0 taken 152 times.
✓ Branch 1 taken 136 times.
288 for (; row_idx < num_rows; ++row_idx) {
268 152 const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride);
269
270
2/2
✓ Branch 0 taken 208 times.
✓ Branch 1 taken 152 times.
360 for (size_t b = 0; b < num_blocks_per_row; ++b) {
271 208 float abs_max = 0.0F;
272
273 208 const size_t dst_x = ((row_idx + m_idx_start) % local_mr);
274 208 int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block;
275
276 208 float32x4_t v_f32_abs_values;
277 208 float32x4_t v_f32_maxvals;
278 208 float32x4_t v_currentmax = vdupq_n_f32(0);
279
280
2/2
✓ Branch 0 taken 1664 times.
✓ Branch 1 taken 208 times.
1872 for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) {
281 1664 v_f32_maxvals = vld1q_f32(src_ptr + idx_v);
282 1664 v_f32_abs_values = vabsq_f32(v_f32_maxvals);
283 1664 v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax);
284 1664 }
285 208 abs_max = vmaxvq_f32(v_currentmax);
286
287 // Calculate scale and reciprocal
288 208 const float scale = abs_max / ((1 << 7) - 1);
289
1/2
✓ Branch 0 taken 208 times.
✗ Branch 1 not taken.
208 const float rep_scale = scale ? 1.0F / scale : 0.0F;
290
291 208 *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale);
292 208 dst_ptr += local_mr * kai_num_bytes_multiplier;
293
294 208 dst_ptr += dst_x * k_block_len * sizeof(int8_t);
295
296 // Quantize and pack the block
297
2/2
✓ Branch 0 taken 416 times.
✓ Branch 1 taken 208 times.
624 for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) {
298 416 const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx);
299 416 const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale);
300 416 const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1);
301
302 416 const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4);
303 416 const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale);
304 416 const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2);
305
306 832 const int16x8_t v_full_i16_block1 =
307 416 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2));
308
309 416 const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8);
310 416 const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale);
311 416 const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3);
312
313 416 const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12);
314 416 const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale);
315 416 const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4);
316
317 832 const int16x8_t v_full_i16_block2 =
318 416 vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4));
319
320 832 const int8x16_t v_full_i8_block =
321 416 vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block2));
322
323 416 vst1_s8(dst_ptr, vget_low_s8(v_full_i8_block));
324 416 dst_ptr += 8 * sizeof(int8_t);
325 416 dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t);
326
327 416 vst1_s8(dst_ptr, vget_high_s8(v_full_i8_block));
328 416 dst_ptr += 8 * sizeof(int8_t);
329 416 dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t);
330 416 }
331 208 src_ptr += local_bl;
332 208 }
333 // Move to the next row if we have interleaved all Mr rows
334
1/2
✓ Branch 0 taken 152 times.
✗ Branch 1 not taken.
152 if ((((row_idx + 1) + m_idx_start) % local_mr) == 0) {
335 lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride);
336 }
337 152 }
338 136 }
339 184 }
340