KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.5% 318 5 331
Functions: 85.7% 6 0 7
Branches: 75.6% 65 8 94

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \
8 !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
9 #error This file must be compiled for AArch64, FEAT_FP16.
10 #else // Architectural features check.
11
12 #include "kai_lhs_quant_pack_qai8dxp_f16_neon.h"
13
14 #include <arm_fp16.h>
15 #include <arm_neon.h>
16 #include <float.h>
17 #include <math.h>
18 #include <stddef.h>
19 #include <stdint.h>
20
21 #include "kai/kai_common.h"
22 #define FLT16_MAX 65504.0
23 #define FLT16_MIN (-65504.0F)
24
25 static const size_t kai_num_bytes_per_multiplier = sizeof(float);
26 static const size_t kai_num_bytes_per_offset = sizeof(int32_t);
27
28 4480 inline static size_t kai_k_roundedup(size_t k) {
29 // Round up k to be a multiple of 32.
30 4480 size_t kai_k_multiple_of = 32;
31 8960 return kai_roundup(k, kai_k_multiple_of);
32 4480 }
33
34 3360 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) {
35 3360 KAI_UNUSED(kr);
36 3360 KAI_UNUSED(sr);
37
38 3360 const size_t k_internal = kai_k_roundedup(k);
39
40 KAI_ASSERT((k_internal % 2) == 0);
41
42 6720 return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset);
43 3360 }
44
45 size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f16_neon(size_t mr) {
46 KAI_UNUSED(mr);
47 return 1;
48 }
49
50 1120 size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(size_t m_idx, size_t lhs_stride) {
51 1120 return m_idx * lhs_stride;
52 }
53
54 1120 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon(
55 size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
56 // It always points to the beginning of the row
57 1120 return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr);
58 }
59
60 1120 size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
61 1120 const size_t num_rows = kai_roundup(m, mr) / mr;
62
63 2240 return num_rows * kai_lhs_packed_stride(k, mr, kr, sr);
64 1120 }
65
66 1120 void kai_run_lhs_quant_pack_qai8dxp_f16_neon(
67 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs,
68 size_t lhs_stride, void* restrict lhs_packed) {
69 KAI_ASSERT((kr % sr) == 0);
70 KAI_ASSUME((kr / sr == 8) || (kr / sr == 4));
71
72
1/2
✓ Branch 0 taken 1120 times.
✗ Branch 1 not taken.
1120 if (m == 0) {
73 return;
74 }
75
76 1120 const size_t num_rows = m;
77
78 1120 float16_t const* src_ptr = (float16_t const*)lhs;
79
80 1120 const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr);
81 1120 const size_t k_internal = kai_k_roundedup(k);
82 1120 const int32_t k_block_len = (int32_t)(kr / sr);
83
84 1120 const int32_t num_blocks_k = (int32_t)(k / k_block_len);
85 1120 const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len);
86 1120 const size_t lhs_row_length = lhs_stride / sizeof(float16_t);
87
88 1120 const float16x8_t vmax = vdupq_n_f16((float16_t)FLT16_MIN);
89 1120 const float16x8_t vmin = vdupq_n_f16((float16_t)FLT16_MAX);
90
91 // As we load 8-element vectors, limit vectorized loop to avoid reading out-of-bounds
92 1120 const int32_t blocks_lim_k = num_blocks_k - (8 / k_block_len);
93
94 1120 size_t row_idx = 0;
95
96 // Improved performance with 4x loop unrolling where packing parameters allow
97
2/2
✓ Branch 0 taken 232 times.
✓ Branch 1 taken 888 times.
1120 if (mr == 4) {
98
2/2
✓ Branch 0 taken 2496 times.
✓ Branch 1 taken 888 times.
3384 for (; row_idx + 3 < m; row_idx += 4) {
99 // Find min/max for each channel
100 2496 int32_t k_idx = 0;
101 2496 float16x8_t vmax0 = vmax;
102 2496 float16x8_t vmin0 = vmin;
103 2496 float16x8_t vmax1 = vmax;
104 2496 float16x8_t vmin1 = vmin;
105 2496 float16x8_t vmax2 = vmax;
106 2496 float16x8_t vmin2 = vmin;
107 2496 float16x8_t vmax3 = vmax;
108 2496 float16x8_t vmin3 = vmin;
109
110
2/2
✓ Branch 0 taken 15376 times.
✓ Branch 1 taken 2496 times.
17872 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
111 15376 const float16x8_t src0 = vld1q_f16(src_ptr + k_idx);
112 15376 const float16x8_t src1 = vld1q_f16(src_ptr + k_idx + lhs_row_length);
113 15376 const float16x8_t src2 = vld1q_f16(src_ptr + k_idx + (2 * lhs_row_length));
114 15376 const float16x8_t src3 = vld1q_f16(src_ptr + k_idx + (3 * lhs_row_length));
115
116 15376 vmax0 = vmaxq_f16(src0, vmax0);
117 15376 vmax1 = vmaxq_f16(src1, vmax1);
118 15376 vmax2 = vmaxq_f16(src2, vmax2);
119 15376 vmax3 = vmaxq_f16(src3, vmax3);
120 15376 vmin0 = vminq_f16(src0, vmin0);
121 15376 vmin1 = vminq_f16(src1, vmin1);
122 15376 vmin2 = vminq_f16(src2, vmin2);
123 15376 vmin3 = vminq_f16(src3, vmin3);
124 15376 }
125
126 2496 float16_t max0 = vmaxvq_f16(vmax0);
127 2496 float16_t min0 = vminvq_f16(vmin0);
128 2496 float16_t max1 = vmaxvq_f16(vmax1);
129 2496 float16_t min1 = vminvq_f16(vmin1);
130 2496 float16_t max2 = vmaxvq_f16(vmax2);
131 2496 float16_t min2 = vminvq_f16(vmin2);
132 2496 float16_t max3 = vmaxvq_f16(vmax3);
133 2496 float16_t min3 = vminvq_f16(vmin3);
134 // Process leftover elements with a scalar loop.
135
2/2
✓ Branch 0 taken 2296 times.
✓ Branch 1 taken 2496 times.
4792 for (; k_idx < (int32_t)k; ++k_idx) {
136 2296 const float16_t src0 = *(src_ptr + (size_t)k_idx);
137 2296 max0 = vmaxh_f16(src0, max0);
138 2296 min0 = vminh_f16(src0, min0);
139 2296 const float16_t src1 = *(src_ptr + (size_t)k_idx + lhs_row_length);
140 2296 max1 = vmaxh_f16(src1, max1);
141 2296 min1 = vminh_f16(src1, min1);
142 2296 const float16_t src2 = *(src_ptr + (size_t)k_idx + (2 * lhs_row_length));
143 2296 max2 = vmaxh_f16(src2, max2);
144 2296 min2 = vminh_f16(src2, min2);
145 2296 const float16_t src3 = *(src_ptr + (size_t)k_idx + (3 * lhs_row_length));
146 2296 max3 = vmaxh_f16(src3, max3);
147 2296 min3 = vminh_f16(src3, min3);
148 2296 }
149
150 // Maximum/minimum int8 values
151 2496 const float qmin = (float)INT8_MIN;
152 2496 const float qmax = (float)INT8_MAX;
153
154 2496 const float rmin0 = fminf(0.0F, min0);
155 2496 const float rmax0 = fmaxf(0.0F, max0);
156
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2496 times.
2496 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
157 2496 const float rmin1 = fminf(0.0F, min1);
158 2496 const float rmax1 = fmaxf(0.0F, max1);
159
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2496 times.
2496 const float scale1 = rmin1 == rmax1 ? 1.F : (qmax - qmin) / (rmax1 - rmin1);
160 2496 const float rmin2 = fminf(0.0F, min2);
161 2496 const float rmax2 = fmaxf(0.0F, max2);
162
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2496 times.
2496 const float scale2 = rmin2 == rmax2 ? 1.F : (qmax - qmin) / (rmax2 - rmin2);
163 2496 const float rmin3 = fminf(0.0F, min3);
164 2496 const float rmax3 = fmaxf(0.0F, max3);
165
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2496 times.
2496 const float scale3 = rmin3 == rmax3 ? 1.F : (qmax - qmin) / (rmax3 - rmin3);
166
167 // Reciprocal to quantize
168
1/2
✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
2496 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
169
1/2
✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
2496 const float recip_scale1 = scale1 ? 1.0F / scale1 : 0.0F;
170
1/2
✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
2496 const float recip_scale2 = scale2 ? 1.0F / scale2 : 0.0F;
171
1/2
✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
2496 const float recip_scale3 = scale3 ? 1.0F / scale3 : 0.0F;
172
173 2496 const float descaled_min0 = rmin0 * scale0;
174 2496 const float descaled_max0 = rmax0 * scale0;
175 2496 const float descaled_min1 = rmin1 * scale1;
176 2496 const float descaled_max1 = rmax1 * scale1;
177 2496 const float descaled_min2 = rmin2 * scale2;
178 2496 const float descaled_max2 = rmax2 * scale2;
179 2496 const float descaled_min3 = rmin3 * scale3;
180 2496 const float descaled_max3 = rmax3 * scale3;
181
182 2496 const float zero_point_from_min_error0 = qmin + descaled_min0;
183 2496 const float zero_point_from_max_error0 = qmax + descaled_max0;
184 2496 const float zero_point_from_min_error1 = qmin + descaled_min1;
185 2496 const float zero_point_from_max_error1 = qmax + descaled_max1;
186 2496 const float zero_point_from_min_error2 = qmin + descaled_min2;
187 2496 const float zero_point_from_max_error2 = qmax + descaled_max2;
188 2496 const float zero_point_from_min_error3 = qmin + descaled_min3;
189 2496 const float zero_point_from_max_error3 = qmax + descaled_max3;
190
191
1/2
✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
2496 float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0
192 : qmax - descaled_max0;
193
1/2
✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
2496 float zero_point1 = (zero_point_from_min_error1 + zero_point_from_max_error1 > 0) ? qmin - descaled_min1
194 : qmax - descaled_max1;
195
1/2
✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
2496 float zero_point2 = (zero_point_from_min_error2 + zero_point_from_max_error2 > 0) ? qmin - descaled_min2
196 : qmax - descaled_max2;
197
1/2
✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
2496 float zero_point3 = (zero_point_from_min_error3 + zero_point_from_max_error3 > 0) ? qmin - descaled_min3
198 : qmax - descaled_max3;
199
200 2496 zero_point0 = fmaxf(zero_point0, qmin);
201 2496 zero_point0 = fminf(zero_point0, qmax);
202 2496 zero_point1 = fmaxf(zero_point1, qmin);
203 2496 zero_point1 = fminf(zero_point1, qmax);
204 2496 zero_point2 = fmaxf(zero_point2, qmin);
205 2496 zero_point2 = fminf(zero_point2, qmax);
206 2496 zero_point3 = fmaxf(zero_point3, qmin);
207 2496 zero_point3 = fminf(zero_point3, qmax);
208
209 // Round to nearest integer
210 2496 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
211 2496 const int32_t nudged_zero_point1 = (int32_t)rintf(zero_point1);
212 2496 const int32_t nudged_zero_point2 = (int32_t)rintf(zero_point2);
213 2496 const int32_t nudged_zero_point3 = (int32_t)rintf(zero_point3);
214
215 2496 const size_t dst_x = ((row_idx + m_idx_start) % mr);
216
217 2496 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len);
218
219 // Quantize the channels
220 2496 int32_t block_idx = 0;
221 2496 const int32_t block_incr = 8 / k_block_len;
222
223
2/2
✓ Branch 0 taken 15376 times.
✓ Branch 1 taken 2496 times.
17872 for (; block_idx <= blocks_lim_k; block_idx += block_incr) {
224 // Clamp at the last valid k-index
225 15376 const int32_t k_idx_start = block_idx * k_block_len;
226
227 15376 const float16x8_t src0 = vld1q_f16(src_ptr + k_idx_start);
228 15376 const float16x8_t src1 = vld1q_f16(src_ptr + k_idx_start + lhs_row_length);
229 15376 const float16x8_t src2 = vld1q_f16(src_ptr + k_idx_start + (2 * lhs_row_length));
230 15376 const float16x8_t src3 = vld1q_f16(src_ptr + k_idx_start + (3 * lhs_row_length));
231
232 // Scale the values.
233 15376 const int32x4_t v0_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src0)), scale0));
234 15376 const int32x4_t v0_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src0), scale0));
235 15376 const int32x4_t v1_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src1)), scale1));
236 15376 const int32x4_t v1_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src1), scale1));
237 15376 const int32x4_t v2_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src2)), scale2));
238 15376 const int32x4_t v2_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src2), scale2));
239 15376 const int32x4_t v3_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src3)), scale3));
240 15376 const int32x4_t v3_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src3), scale3));
241
242 15376 const int16x4_t v0_0_s16 = vqmovn_s32(v0_0_s32);
243 15376 const int16x4_t v0_1_s16 = vqmovn_s32(v0_1_s32);
244 15376 const int16x4_t v1_0_s16 = vqmovn_s32(v1_0_s32);
245 15376 const int16x4_t v1_1_s16 = vqmovn_s32(v1_1_s32);
246 15376 const int16x4_t v2_0_s16 = vqmovn_s32(v2_0_s32);
247 15376 const int16x4_t v2_1_s16 = vqmovn_s32(v2_1_s32);
248 15376 const int16x4_t v3_0_s16 = vqmovn_s32(v3_0_s32);
249 15376 const int16x4_t v3_1_s16 = vqmovn_s32(v3_1_s32);
250
251 15376 int16x8_t v0_s16;
252 15376 int16x8_t v1_s16;
253 15376 int16x8_t v2_s16;
254 15376 int16x8_t v3_s16;
255
2/2
✓ Branch 0 taken 7688 times.
✓ Branch 1 taken 7688 times.
15376 if (k_block_len == 8) {
256 7688 v0_s16 = vcombine_s16(v0_0_s16, v0_1_s16);
257 7688 v1_s16 = vcombine_s16(v1_0_s16, v1_1_s16);
258 7688 v2_s16 = vcombine_s16(v2_0_s16, v2_1_s16);
259 7688 v3_s16 = vcombine_s16(v3_0_s16, v3_1_s16);
260 7688 } else { // k_block_len == 4
261 7688 v0_s16 = vcombine_s16(v0_0_s16, v1_0_s16);
262 7688 v1_s16 = vcombine_s16(v2_0_s16, v3_0_s16);
263 7688 v2_s16 = vcombine_s16(v0_1_s16, v1_1_s16);
264 7688 v3_s16 = vcombine_s16(v2_1_s16, v3_1_s16);
265 }
266
267 // Add zero points.
268 15376 const int16x8_t vnzp0 = vdupq_n_s16((int16_t)nudged_zero_point0);
269 15376 const int16x8_t vnzp1 = vdupq_n_s16((int16_t)nudged_zero_point1);
270 15376 const int16x8_t vnzp2 = vdupq_n_s16((int16_t)nudged_zero_point2);
271 15376 const int16x8_t vnzp3 = vdupq_n_s16((int16_t)nudged_zero_point3);
272
273 15376 v0_s16 = vaddq_s16(v0_s16, vnzp0);
274 15376 v0_s16 = vmaxq_s16(v0_s16, vdupq_n_s16(INT8_MIN));
275 15376 v0_s16 = vminq_s16(v0_s16, vdupq_n_s16(INT8_MAX));
276 15376 v1_s16 = vaddq_s16(v1_s16, vnzp1);
277 15376 v1_s16 = vmaxq_s16(v1_s16, vdupq_n_s16(INT8_MIN));
278 15376 v1_s16 = vminq_s16(v1_s16, vdupq_n_s16(INT8_MAX));
279 15376 v2_s16 = vaddq_s16(v2_s16, vnzp2);
280 15376 v2_s16 = vmaxq_s16(v2_s16, vdupq_n_s16(INT8_MIN));
281 15376 v2_s16 = vminq_s16(v2_s16, vdupq_n_s16(INT8_MAX));
282 15376 v3_s16 = vaddq_s16(v3_s16, vnzp3);
283 15376 v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN));
284 15376 v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX));
285
286 15376 int8x8_t v0_s8 = vqmovn_s16(v0_s16);
287 15376 int8x8_t v1_s8 = vqmovn_s16(v1_s16);
288 15376 int8x8_t v2_s8 = vqmovn_s16(v2_s16);
289 15376 int8x8_t v3_s8 = vqmovn_s16(v3_s16);
290
291 15376 vst1_s8((int8_t*)(dst_ptr), v0_s8);
292 15376 vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8);
293 15376 vst1_s8((int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)), v2_s8);
294 15376 vst1_s8((int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)), v3_s8);
295 15376 dst_ptr += block_incr * mr * k_block_len * sizeof(int8_t);
296 15376 }
297
298
2/2
✓ Branch 0 taken 3192 times.
✓ Branch 1 taken 2496 times.
5688 for (; block_idx < num_blocks_k_internal; ++block_idx) {
299 // left over k
300
2/2
✓ Branch 0 taken 17024 times.
✓ Branch 1 taken 3192 times.
20216 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
301 // Clamp at the last valid k-index.
302
2/2
✓ Branch 0 taken 1680 times.
✓ Branch 1 taken 15344 times.
17024 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
303
304 17024 const float src0 = (float)(*(src_ptr + k_idx_start));
305 17024 const float src1 = (float)(*(src_ptr + k_idx_start + lhs_row_length));
306 17024 const float src2 = (float)(*(src_ptr + k_idx_start + (2 * lhs_row_length)));
307 17024 const float src3 = (float)(*(src_ptr + k_idx_start + (3 * lhs_row_length)));
308
309 // Scale the value.
310 17024 int32_t d0_s32 = (int32_t)(roundf(src0 * scale0));
311 17024 int32_t d1_s32 = (int32_t)(roundf(src1 * scale1));
312 17024 int32_t d2_s32 = (int32_t)(roundf(src2 * scale2));
313 17024 int32_t d3_s32 = (int32_t)(roundf(src3 * scale3));
314
315 17024 d0_s32 = d0_s32 + nudged_zero_point0;
316
1/2
✓ Branch 0 taken 17024 times.
✗ Branch 1 not taken.
17024 d0_s32 = KAI_MAX(d0_s32, INT8_MIN);
317
2/2
✓ Branch 0 taken 16912 times.
✓ Branch 1 taken 112 times.
17024 d0_s32 = KAI_MIN(d0_s32, INT8_MAX);
318
319 17024 d1_s32 = d1_s32 + nudged_zero_point1;
320
1/2
✓ Branch 0 taken 17024 times.
✗ Branch 1 not taken.
17024 d1_s32 = KAI_MAX(d1_s32, INT8_MIN);
321
2/2
✓ Branch 0 taken 16912 times.
✓ Branch 1 taken 112 times.
17024 d1_s32 = KAI_MIN(d1_s32, INT8_MAX);
322
323 17024 d2_s32 = d2_s32 + nudged_zero_point2;
324
1/2
✓ Branch 0 taken 17024 times.
✗ Branch 1 not taken.
17024 d2_s32 = KAI_MAX(d2_s32, INT8_MIN);
325
2/2
✓ Branch 0 taken 16912 times.
✓ Branch 1 taken 112 times.
17024 d2_s32 = KAI_MIN(d2_s32, INT8_MAX);
326
327 17024 d3_s32 = d3_s32 + nudged_zero_point3;
328
1/2
✓ Branch 0 taken 17024 times.
✗ Branch 1 not taken.
17024 d3_s32 = KAI_MAX(d3_s32, INT8_MIN);
329
2/2
✓ Branch 0 taken 16912 times.
✓ Branch 1 taken 112 times.
17024 d3_s32 = KAI_MIN(d3_s32, INT8_MAX);
330
331 17024 *(int8_t*)dst_ptr = (int8_t)d0_s32;
332 17024 *(int8_t*)(dst_ptr + k_block_len * sizeof(int8_t)) = (int8_t)d1_s32;
333 17024 *(int8_t*)(dst_ptr + 2 * (k_block_len * sizeof(int8_t))) = (int8_t)d2_s32;
334 17024 *(int8_t*)(dst_ptr + 3 * (k_block_len * sizeof(int8_t))) = (int8_t)d3_s32;
335 17024 dst_ptr += sizeof(int8_t);
336 17024 }
337 3192 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
338 3192 }
339
340 2496 uint8_t* dst_base = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
341
342 2496 dst_ptr = dst_base + dst_x * kai_num_bytes_per_offset;
343
344 // LHS offset at the beginning of the row.
345 2496 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
346 2496 *((int32_t*)(dst_ptr + kai_num_bytes_per_offset)) = -nudged_zero_point1;
347 2496 *((int32_t*)(dst_ptr + 2 * kai_num_bytes_per_offset)) = -nudged_zero_point2;
348 2496 *((int32_t*)(dst_ptr + 3 * kai_num_bytes_per_offset)) = -nudged_zero_point3;
349
350 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier.
351 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
352
353 2496 dst_ptr += mr * kai_num_bytes_per_offset;
354
355 // Store the scale quantization params.
356 2496 *((float*)(dst_ptr)) = recip_scale0;
357 2496 *((float*)(dst_ptr + kai_num_bytes_per_multiplier)) = recip_scale1;
358 2496 *((float*)(dst_ptr + 2 * kai_num_bytes_per_multiplier)) = recip_scale2;
359 2496 *((float*)(dst_ptr + 3 * kai_num_bytes_per_multiplier)) = recip_scale3;
360
361 // Update src_ptr. Note: now lhs contains fp16 values (2 bytes each).
362 2496 src_ptr += (4 * lhs_row_length);
363
364 // Move to the next row as we have interleaved all Mr rows.
365 2496 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
366 2496 }
367 888 }
368
369
2/2
✓ Branch 0 taken 1168 times.
✓ Branch 1 taken 1120 times.
2288 for (; row_idx < num_rows; ++row_idx) {
370 // Find min/max for each channel
371 1168 int32_t k_idx = 0;
372 1168 float16x8_t vmax0 = vmax;
373 1168 float16x8_t vmin0 = vmin;
374
375
2/2
✓ Branch 0 taken 5936 times.
✓ Branch 1 taken 1168 times.
7104 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
376 5936 const float16x8_t src0_0 = vld1q_f16(src_ptr + (size_t)k_idx);
377 5936 vmax0 = vmaxq_f16(vmax0, src0_0);
378 5936 vmin0 = vminq_f16(vmin0, src0_0);
379 5936 }
380 // Get the max/min
381 1168 float16_t max0 = vmaxvq_f16(vmax0);
382 1168 float16_t min0 = vminvq_f16(vmin0);
383
384
2/2
✓ Branch 0 taken 2168 times.
✓ Branch 1 taken 1168 times.
3336 for (; k_idx < (int32_t)k; ++k_idx) {
385 2168 const float16_t src0 = *(src_ptr + (size_t)k_idx);
386 2168 max0 = vmaxh_f16(src0, max0);
387 2168 min0 = vminh_f16(src0, min0);
388 2168 }
389
390 // Maximum/minimum int8 values
391 1168 const float qmin = (float)INT8_MIN;
392 1168 const float qmax = (float)INT8_MAX;
393
394 1168 const float rmin0 = fminf(0.0F, min0);
395 1168 const float rmax0 = fmaxf(0.0F, max0);
396
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1168 times.
1168 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
397
398 // Reciprocal to quantize
399
1/2
✓ Branch 0 taken 1168 times.
✗ Branch 1 not taken.
1168 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
400
401 1168 const float descaled_min0 = rmin0 * scale0;
402 1168 const float descaled_max0 = rmax0 * scale0;
403
404 1168 const float zero_point_from_min_error0 = qmin + descaled_min0;
405 1168 const float zero_point_from_max_error0 = qmax + descaled_max0;
406
407 2336 float zero_point0 =
408
1/2
✓ Branch 0 taken 1168 times.
✗ Branch 1 not taken.
1168 zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0;
409
410 1168 zero_point0 = fmaxf(zero_point0, qmin);
411 1168 zero_point0 = fminf(zero_point0, qmax);
412
413 // Round to nearest integer
414 1168 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
415
416 1168 const size_t dst_x = ((row_idx + m_idx_start) % mr);
417
418 1168 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t));
419
420 // Quantize the channels
421 1168 int32_t block_idx = 0;
422
423
2/2
✓ Branch 0 taken 8480 times.
✓ Branch 1 taken 1168 times.
9648 for (; block_idx <= blocks_lim_k; ++block_idx) {
424 8480 const int32_t k_idx_start = block_idx * k_block_len;
425
426 8480 const float16x8_t src_0 = vld1q_f16(src_ptr + k_idx_start);
427
428 // Scale the values
429 8480 const float32x4_t v0_f32 = vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src_0)), scale0);
430 8480 const float32x4_t v1_f32 = vmulq_n_f32(vcvt_high_f32_f16(src_0), scale0);
431 8480 const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32);
432 8480 const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32);
433
434 8480 const int16x4_t v0_s16 = vqmovn_s32(v0_s32);
435 8480 const int16x4_t v1_s16 = vqmovn_s32(v1_s32);
436 8480 int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16);
437
438 // Add zero points
439 8480 int16_t nzp_s16 = (int16_t)nudged_zero_point0;
440 8480 int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16);
441 8480 v_s16 = vaddq_s16(v_s16, vnzp_s16);
442 8480 v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN));
443 8480 v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX));
444
445 8480 int8x8_t v_s8 = vqmovn_s16(v_s16);
446 8480 vst1_s8((int8_t*)(dst_ptr), v_s8);
447 8480 dst_ptr += mr * k_block_len * sizeof(int8_t);
448 8480 }
449
450
2/2
✓ Branch 0 taken 2992 times.
✓ Branch 1 taken 1168 times.
4160 for (; block_idx < num_blocks_k_internal; ++block_idx) {
451 // left over k
452
2/2
✓ Branch 0 taken 15392 times.
✓ Branch 1 taken 2992 times.
18384 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
453 // Clamp at the last valid k-index
454
2/2
✓ Branch 0 taken 2988 times.
✓ Branch 1 taken 12404 times.
15392 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
455
456 15392 const float src0 = (float)(*(src_ptr + k_idx_start));
457
458 // Scale the values
459 15392 int32_t d0_s32 = (int32_t)(roundf(src0 * scale0));
460
461 15392 d0_s32 = d0_s32 + nudged_zero_point0;
462
1/2
✓ Branch 0 taken 15392 times.
✗ Branch 1 not taken.
15392 d0_s32 = KAI_MAX(d0_s32, INT8_MIN);
463
2/2
✓ Branch 0 taken 15308 times.
✓ Branch 1 taken 84 times.
15392 d0_s32 = KAI_MIN(d0_s32, INT8_MAX);
464
465 15392 *((int8_t*)(dst_ptr)) = (int8_t)d0_s32;
466 15392 dst_ptr += sizeof(int8_t);
467 15392 }
468 2992 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
469 2992 }
470
471 1168 dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
472
473 1168 dst_ptr += dst_x * kai_num_bytes_per_offset;
474
475 // LHS offset at the beginning of the row
476 1168 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
477
478 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier
479 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
480
481 1168 dst_ptr += mr * kai_num_bytes_per_offset;
482
483 // Store the scale quantization params
484 1168 *((float*)(dst_ptr)) = recip_scale0;
485
486 1168 src_ptr += lhs_row_length;
487
488 // Move to the next row if we have interleaved all Mr rows
489
2/2
✓ Branch 0 taken 936 times.
✓ Branch 1 taken 232 times.
1168 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
490 232 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
491 232 }
492 1168 }
493 1120 }
494
495 #endif // Architectural features check.
496