KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.8% 318 / 5 / 330
Functions: 85.7% 6 / 0 / 7
Branches: 80.2% 69 / 8 / 94

kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \
8 !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
9 #error This file must be compiled for AArch64, FEAT_FP16.
10 #else // Architectural features check.
11
12 #include "kai_lhs_quant_pack_qai8dxp_f16_neon.h"
13
14 #include <arm_fp16.h>
15 #include <arm_neon.h>
16 #include <float.h>
17 #include <math.h>
18 #include <stddef.h>
19 #include <stdint.h>
20
21 #include "kai/kai_common.h"
22 #define FLT16_MAX 65504.0
23 #define FLT16_MIN (-65504.0F)
24
25 static const size_t kai_num_bytes_per_multiplier = sizeof(float);
26 static const size_t kai_num_bytes_per_offset = sizeof(int32_t);
27
28 29568 inline static size_t kai_k_roundedup(size_t k) {
29 // Round up k to be a multiple of 32.
30 29568 size_t kai_k_multiple_of = 32;
31 59136 return kai_roundup(k, kai_k_multiple_of);
32 29568 }
33
34 22176 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) {
35 22176 KAI_UNUSED(kr);
36 22176 KAI_UNUSED(sr);
37
38 22176 const size_t k_internal = kai_k_roundedup(k);
39
40 KAI_ASSERT((k_internal % 2) == 0);
41
42 44352 return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset);
43 22176 }
44
45 size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f16_neon(size_t mr) {
46 return mr;
47 }
48
49 7392 size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(size_t m_idx, size_t lhs_stride) {
50 7392 return m_idx * lhs_stride;
51 }
52
53 7392 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon(
54 size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
55 // It always points to the beginning of the row
56 7392 return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr);
57 }
58
59 7392 size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
60 7392 const size_t num_rows = kai_roundup(m, mr) / mr;
61
62 14784 return num_rows * kai_lhs_packed_stride(k, mr, kr, sr);
63 7392 }
64
65 7392 void kai_run_lhs_quant_pack_qai8dxp_f16_neon(
66 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs,
67 size_t lhs_stride, void* restrict lhs_packed) {
68 KAI_ASSERT((kr % sr) == 0);
69 KAI_ASSUME((kr / sr == 8) || (kr / sr == 4));
70
71
1/2
✓ Branch 0 taken 7392 times.
✗ Branch 1 not taken.
7392 if (m == 0) {
72 return;
73 }
74
75 7392 const size_t num_rows = m;
76
77 7392 float16_t const* src_ptr = (float16_t const*)lhs;
78
79 7392 const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr);
80 7392 const size_t k_internal = kai_k_roundedup(k);
81 7392 const int32_t k_block_len = (int32_t)(kr / sr);
82
83 7392 const int32_t num_blocks_k = (int32_t)(k / k_block_len);
84 7392 const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len);
85 7392 const size_t lhs_row_length = lhs_stride / sizeof(float16_t);
86
87 7392 const float16x8_t vmax = vdupq_n_f16((float16_t)FLT16_MIN);
88 7392 const float16x8_t vmin = vdupq_n_f16((float16_t)FLT16_MAX);
89
90 // As we load 8-element vectors, limit vectorized loop to avoid reading out-of-bounds
91 7392 const int32_t blocks_lim_k = num_blocks_k - (8 / k_block_len);
92
93 7392 size_t row_idx = 0;
94
95 // Improved performance with 4x loop unrolling where packing parameters allow
96
2/2
✓ Branch 0 taken 2016 times.
✓ Branch 1 taken 5376 times.
7392 if (mr == 4) {
97
2/2
✓ Branch 0 taken 14976 times.
✓ Branch 1 taken 5376 times.
20352 for (; row_idx + 3 < m; row_idx += 4) {
98 // Find min/max for each channel
99 14976 int32_t k_idx = 0;
100 14976 float16x8_t vmax0 = vmax;
101 14976 float16x8_t vmin0 = vmin;
102 14976 float16x8_t vmax1 = vmax;
103 14976 float16x8_t vmin1 = vmin;
104 14976 float16x8_t vmax2 = vmax;
105 14976 float16x8_t vmin2 = vmin;
106 14976 float16x8_t vmax3 = vmax;
107 14976 float16x8_t vmin3 = vmin;
108
109
2/2
✓ Branch 0 taken 92256 times.
✓ Branch 1 taken 14976 times.
107232 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
110 92256 const float16x8_t src0 = vld1q_f16(src_ptr + k_idx);
111 92256 const float16x8_t src1 = vld1q_f16(src_ptr + k_idx + lhs_row_length);
112 92256 const float16x8_t src2 = vld1q_f16(src_ptr + k_idx + (2 * lhs_row_length));
113 92256 const float16x8_t src3 = vld1q_f16(src_ptr + k_idx + (3 * lhs_row_length));
114
115 92256 vmax0 = vmaxq_f16(src0, vmax0);
116 92256 vmax1 = vmaxq_f16(src1, vmax1);
117 92256 vmax2 = vmaxq_f16(src2, vmax2);
118 92256 vmax3 = vmaxq_f16(src3, vmax3);
119 92256 vmin0 = vminq_f16(src0, vmin0);
120 92256 vmin1 = vminq_f16(src1, vmin1);
121 92256 vmin2 = vminq_f16(src2, vmin2);
122 92256 vmin3 = vminq_f16(src3, vmin3);
123 92256 }
124
125 14976 float16_t max0 = vmaxvq_f16(vmax0);
126 14976 float16_t min0 = vminvq_f16(vmin0);
127 14976 float16_t max1 = vmaxvq_f16(vmax1);
128 14976 float16_t min1 = vminvq_f16(vmin1);
129 14976 float16_t max2 = vmaxvq_f16(vmax2);
130 14976 float16_t min2 = vminvq_f16(vmin2);
131 14976 float16_t max3 = vmaxvq_f16(vmax3);
132 14976 float16_t min3 = vminvq_f16(vmin3);
133 // Process leftover elements with a scalar loop.
134
2/2
✓ Branch 0 taken 13776 times.
✓ Branch 1 taken 14976 times.
28752 for (; k_idx < (int32_t)k; ++k_idx) {
135 13776 const float16_t src0 = *(src_ptr + (size_t)k_idx);
136 13776 max0 = vmaxh_f16(src0, max0);
137 13776 min0 = vminh_f16(src0, min0);
138 13776 const float16_t src1 = *(src_ptr + (size_t)k_idx + lhs_row_length);
139 13776 max1 = vmaxh_f16(src1, max1);
140 13776 min1 = vminh_f16(src1, min1);
141 13776 const float16_t src2 = *(src_ptr + (size_t)k_idx + (2 * lhs_row_length));
142 13776 max2 = vmaxh_f16(src2, max2);
143 13776 min2 = vminh_f16(src2, min2);
144 13776 const float16_t src3 = *(src_ptr + (size_t)k_idx + (3 * lhs_row_length));
145 13776 max3 = vmaxh_f16(src3, max3);
146 13776 min3 = vminh_f16(src3, min3);
147 13776 }
148
149 // Maximum/minimum int8 values
150 14976 const float qmin = (float)INT8_MIN;
151 14976 const float qmax = (float)INT8_MAX;
152
153 14976 const float rmin0 = fminf(0.0F, min0);
154 14976 const float rmax0 = fmaxf(0.0F, max0);
155
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 14976 times.
14976 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
156 14976 const float rmin1 = fminf(0.0F, min1);
157 14976 const float rmax1 = fmaxf(0.0F, max1);
158
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 14976 times.
14976 const float scale1 = rmin1 == rmax1 ? 1.F : (qmax - qmin) / (rmax1 - rmin1);
159 14976 const float rmin2 = fminf(0.0F, min2);
160 14976 const float rmax2 = fmaxf(0.0F, max2);
161
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 14976 times.
14976 const float scale2 = rmin2 == rmax2 ? 1.F : (qmax - qmin) / (rmax2 - rmin2);
162 14976 const float rmin3 = fminf(0.0F, min3);
163 14976 const float rmax3 = fmaxf(0.0F, max3);
164
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 14976 times.
14976 const float scale3 = rmin3 == rmax3 ? 1.F : (qmax - qmin) / (rmax3 - rmin3);
165
166 // Reciprocal to quantize
167
1/2
✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
14976 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
168
1/2
✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
14976 const float recip_scale1 = scale1 ? 1.0F / scale1 : 0.0F;
169
1/2
✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
14976 const float recip_scale2 = scale2 ? 1.0F / scale2 : 0.0F;
170
1/2
✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
14976 const float recip_scale3 = scale3 ? 1.0F / scale3 : 0.0F;
171
172 14976 const float descaled_min0 = rmin0 * scale0;
173 14976 const float descaled_max0 = rmax0 * scale0;
174 14976 const float descaled_min1 = rmin1 * scale1;
175 14976 const float descaled_max1 = rmax1 * scale1;
176 14976 const float descaled_min2 = rmin2 * scale2;
177 14976 const float descaled_max2 = rmax2 * scale2;
178 14976 const float descaled_min3 = rmin3 * scale3;
179 14976 const float descaled_max3 = rmax3 * scale3;
180
181 14976 const float zero_point_from_min_error0 = qmin + descaled_min0;
182 14976 const float zero_point_from_max_error0 = qmax + descaled_max0;
183 14976 const float zero_point_from_min_error1 = qmin + descaled_min1;
184 14976 const float zero_point_from_max_error1 = qmax + descaled_max1;
185 14976 const float zero_point_from_min_error2 = qmin + descaled_min2;
186 14976 const float zero_point_from_max_error2 = qmax + descaled_max2;
187 14976 const float zero_point_from_min_error3 = qmin + descaled_min3;
188 14976 const float zero_point_from_max_error3 = qmax + descaled_max3;
189
190
1/2
✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
14976 float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0
191 : qmax - descaled_max0;
192
1/2
✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
14976 float zero_point1 = (zero_point_from_min_error1 + zero_point_from_max_error1 > 0) ? qmin - descaled_min1
193 : qmax - descaled_max1;
194
1/2
✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
14976 float zero_point2 = (zero_point_from_min_error2 + zero_point_from_max_error2 > 0) ? qmin - descaled_min2
195 : qmax - descaled_max2;
196
1/2
✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
14976 float zero_point3 = (zero_point_from_min_error3 + zero_point_from_max_error3 > 0) ? qmin - descaled_min3
197 : qmax - descaled_max3;
198
199 14976 zero_point0 = fmaxf(zero_point0, qmin);
200 14976 zero_point0 = fminf(zero_point0, qmax);
201 14976 zero_point1 = fmaxf(zero_point1, qmin);
202 14976 zero_point1 = fminf(zero_point1, qmax);
203 14976 zero_point2 = fmaxf(zero_point2, qmin);
204 14976 zero_point2 = fminf(zero_point2, qmax);
205 14976 zero_point3 = fmaxf(zero_point3, qmin);
206 14976 zero_point3 = fminf(zero_point3, qmax);
207
208 // Round to nearest integer
209 14976 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
210 14976 const int32_t nudged_zero_point1 = (int32_t)rintf(zero_point1);
211 14976 const int32_t nudged_zero_point2 = (int32_t)rintf(zero_point2);
212 14976 const int32_t nudged_zero_point3 = (int32_t)rintf(zero_point3);
213
214 14976 const size_t dst_x = ((row_idx + m_idx_start) % mr);
215
216 14976 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len);
217
218 // Quantize the channels
219 14976 int32_t block_idx = 0;
220 14976 const int32_t block_incr = 8 / k_block_len;
221
222
2/2
✓ Branch 0 taken 92256 times.
✓ Branch 1 taken 14976 times.
107232 for (; block_idx <= blocks_lim_k; block_idx += block_incr) {
223 // Clamp at the last valid k-index
224 92256 const int32_t k_idx_start = block_idx * k_block_len;
225
226 92256 const float16x8_t src0 = vld1q_f16(src_ptr + k_idx_start);
227 92256 const float16x8_t src1 = vld1q_f16(src_ptr + k_idx_start + lhs_row_length);
228 92256 const float16x8_t src2 = vld1q_f16(src_ptr + k_idx_start + (2 * lhs_row_length));
229 92256 const float16x8_t src3 = vld1q_f16(src_ptr + k_idx_start + (3 * lhs_row_length));
230
231 // Scale the values.
232 92256 const int32x4_t v0_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src0)), scale0));
233 92256 const int32x4_t v0_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src0), scale0));
234 92256 const int32x4_t v1_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src1)), scale1));
235 92256 const int32x4_t v1_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src1), scale1));
236 92256 const int32x4_t v2_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src2)), scale2));
237 92256 const int32x4_t v2_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src2), scale2));
238 92256 const int32x4_t v3_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src3)), scale3));
239 92256 const int32x4_t v3_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src3), scale3));
240
241 92256 const int16x4_t v0_0_s16 = vqmovn_s32(v0_0_s32);
242 92256 const int16x4_t v0_1_s16 = vqmovn_s32(v0_1_s32);
243 92256 const int16x4_t v1_0_s16 = vqmovn_s32(v1_0_s32);
244 92256 const int16x4_t v1_1_s16 = vqmovn_s32(v1_1_s32);
245 92256 const int16x4_t v2_0_s16 = vqmovn_s32(v2_0_s32);
246 92256 const int16x4_t v2_1_s16 = vqmovn_s32(v2_1_s32);
247 92256 const int16x4_t v3_0_s16 = vqmovn_s32(v3_0_s32);
248 92256 const int16x4_t v3_1_s16 = vqmovn_s32(v3_1_s32);
249
250 92256 int16x8_t v0_s16;
251 92256 int16x8_t v1_s16;
252 92256 int16x8_t v2_s16;
253 92256 int16x8_t v3_s16;
254
2/2
✓ Branch 0 taken 46128 times.
✓ Branch 1 taken 46128 times.
92256 if (k_block_len == 8) {
255 46128 v0_s16 = vcombine_s16(v0_0_s16, v0_1_s16);
256 46128 v1_s16 = vcombine_s16(v1_0_s16, v1_1_s16);
257 46128 v2_s16 = vcombine_s16(v2_0_s16, v2_1_s16);
258 46128 v3_s16 = vcombine_s16(v3_0_s16, v3_1_s16);
259 46128 } else { // k_block_len == 4
260 46128 v0_s16 = vcombine_s16(v0_0_s16, v1_0_s16);
261 46128 v1_s16 = vcombine_s16(v2_0_s16, v3_0_s16);
262 46128 v2_s16 = vcombine_s16(v0_1_s16, v1_1_s16);
263 46128 v3_s16 = vcombine_s16(v2_1_s16, v3_1_s16);
264 }
265
266 // Add zero points.
267 92256 const int16x8_t vnzp0 = vdupq_n_s16((int16_t)nudged_zero_point0);
268 92256 const int16x8_t vnzp1 = vdupq_n_s16((int16_t)nudged_zero_point1);
269 92256 const int16x8_t vnzp2 = vdupq_n_s16((int16_t)nudged_zero_point2);
270 92256 const int16x8_t vnzp3 = vdupq_n_s16((int16_t)nudged_zero_point3);
271
272 92256 v0_s16 = vaddq_s16(v0_s16, vnzp0);
273 92256 v0_s16 = vmaxq_s16(v0_s16, vdupq_n_s16(INT8_MIN));
274 92256 v0_s16 = vminq_s16(v0_s16, vdupq_n_s16(INT8_MAX));
275 92256 v1_s16 = vaddq_s16(v1_s16, vnzp1);
276 92256 v1_s16 = vmaxq_s16(v1_s16, vdupq_n_s16(INT8_MIN));
277 92256 v1_s16 = vminq_s16(v1_s16, vdupq_n_s16(INT8_MAX));
278 92256 v2_s16 = vaddq_s16(v2_s16, vnzp2);
279 92256 v2_s16 = vmaxq_s16(v2_s16, vdupq_n_s16(INT8_MIN));
280 92256 v2_s16 = vminq_s16(v2_s16, vdupq_n_s16(INT8_MAX));
281 92256 v3_s16 = vaddq_s16(v3_s16, vnzp3);
282 92256 v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN));
283 92256 v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX));
284
285 92256 int8x8_t v0_s8 = vqmovn_s16(v0_s16);
286 92256 int8x8_t v1_s8 = vqmovn_s16(v1_s16);
287 92256 int8x8_t v2_s8 = vqmovn_s16(v2_s16);
288 92256 int8x8_t v3_s8 = vqmovn_s16(v3_s16);
289
290 92256 vst1_s8((int8_t*)(dst_ptr), v0_s8);
291 92256 vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8);
292 92256 vst1_s8((int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)), v2_s8);
293 92256 vst1_s8((int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)), v3_s8);
294 92256 dst_ptr += block_incr * mr * k_block_len * sizeof(int8_t);
295 92256 }
296
297
2/2
✓ Branch 0 taken 19152 times.
✓ Branch 1 taken 14976 times.
34128 for (; block_idx < num_blocks_k_internal; ++block_idx) {
298 // left over k
299
2/2
✓ Branch 0 taken 102144 times.
✓ Branch 1 taken 19152 times.
121296 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
300 // Clamp at the last valid k-index.
301
2/2
✓ Branch 0 taken 10080 times.
✓ Branch 1 taken 92064 times.
102144 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
302
303 102144 const float src0 = (float)(*(src_ptr + k_idx_start));
304 102144 const float src1 = (float)(*(src_ptr + k_idx_start + lhs_row_length));
305 102144 const float src2 = (float)(*(src_ptr + k_idx_start + (2 * lhs_row_length)));
306 102144 const float src3 = (float)(*(src_ptr + k_idx_start + (3 * lhs_row_length)));
307
308 // Scale the value.
309 102144 int32_t d0_s32 = (int32_t)(roundf(src0 * scale0));
310 102144 int32_t d1_s32 = (int32_t)(roundf(src1 * scale1));
311 102144 int32_t d2_s32 = (int32_t)(roundf(src2 * scale2));
312 102144 int32_t d3_s32 = (int32_t)(roundf(src3 * scale3));
313
314 102144 d0_s32 = d0_s32 + nudged_zero_point0;
315
1/2
✓ Branch 0 taken 102144 times.
✗ Branch 1 not taken.
102144 d0_s32 = KAI_MAX(d0_s32, INT8_MIN);
316
2/2
✓ Branch 0 taken 101724 times.
✓ Branch 1 taken 420 times.
102144 d0_s32 = KAI_MIN(d0_s32, INT8_MAX);
317
318 102144 d1_s32 = d1_s32 + nudged_zero_point1;
319
2/2
✓ Branch 0 taken 102088 times.
✓ Branch 1 taken 56 times.
102144 d1_s32 = KAI_MAX(d1_s32, INT8_MIN);
320
2/2
✓ Branch 0 taken 101164 times.
✓ Branch 1 taken 980 times.
102144 d1_s32 = KAI_MIN(d1_s32, INT8_MAX);
321
322 102144 d2_s32 = d2_s32 + nudged_zero_point2;
323
2/2
✓ Branch 0 taken 102116 times.
✓ Branch 1 taken 28 times.
102144 d2_s32 = KAI_MAX(d2_s32, INT8_MIN);
324
2/2
✓ Branch 0 taken 101360 times.
✓ Branch 1 taken 784 times.
102144 d2_s32 = KAI_MIN(d2_s32, INT8_MAX);
325
326 102144 d3_s32 = d3_s32 + nudged_zero_point3;
327
2/2
✓ Branch 0 taken 101248 times.
✓ Branch 1 taken 896 times.
102144 d3_s32 = KAI_MAX(d3_s32, INT8_MIN);
328
2/2
✓ Branch 0 taken 100492 times.
✓ Branch 1 taken 1652 times.
102144 d3_s32 = KAI_MIN(d3_s32, INT8_MAX);
329
330 102144 *(int8_t*)dst_ptr = (int8_t)d0_s32;
331 102144 *(int8_t*)(dst_ptr + k_block_len * sizeof(int8_t)) = (int8_t)d1_s32;
332 102144 *(int8_t*)(dst_ptr + 2 * (k_block_len * sizeof(int8_t))) = (int8_t)d2_s32;
333 102144 *(int8_t*)(dst_ptr + 3 * (k_block_len * sizeof(int8_t))) = (int8_t)d3_s32;
334 102144 dst_ptr += sizeof(int8_t);
335 102144 }
336 19152 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
337 19152 }
338
339 14976 uint8_t* dst_base = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
340
341 14976 dst_ptr = dst_base + dst_x * kai_num_bytes_per_offset;
342
343 // LHS offset at the beginning of the row.
344 14976 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
345 14976 *((int32_t*)(dst_ptr + kai_num_bytes_per_offset)) = -nudged_zero_point1;
346 14976 *((int32_t*)(dst_ptr + 2 * kai_num_bytes_per_offset)) = -nudged_zero_point2;
347 14976 *((int32_t*)(dst_ptr + 3 * kai_num_bytes_per_offset)) = -nudged_zero_point3;
348
349 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier.
350 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
351
352 14976 dst_ptr += mr * kai_num_bytes_per_offset;
353
354 // Store the scale quantization params.
355 14976 *((float*)(dst_ptr)) = recip_scale0;
356 14976 *((float*)(dst_ptr + kai_num_bytes_per_multiplier)) = recip_scale1;
357 14976 *((float*)(dst_ptr + 2 * kai_num_bytes_per_multiplier)) = recip_scale2;
358 14976 *((float*)(dst_ptr + 3 * kai_num_bytes_per_multiplier)) = recip_scale3;
359
360 // Update src_ptr. Note: now lhs contains fp16 values (2 bytes each).
361 14976 src_ptr += (4 * lhs_row_length);
362
363 // Move to the next row as we have interleaved all Mr rows.
364 14976 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
365 14976 }
366 5376 }
367
368
2/2
✓ Branch 0 taken 7680 times.
✓ Branch 1 taken 7392 times.
15072 for (; row_idx < num_rows; ++row_idx) {
369 // Find min/max for each channel
370 7680 int32_t k_idx = 0;
371 7680 float16x8_t vmax0 = vmax;
372 7680 float16x8_t vmin0 = vmin;
373
374
2/2
✓ Branch 0 taken 38592 times.
✓ Branch 1 taken 7680 times.
46272 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
375 38592 const float16x8_t src0_0 = vld1q_f16(src_ptr + (size_t)k_idx);
376 38592 vmax0 = vmaxq_f16(vmax0, src0_0);
377 38592 vmin0 = vminq_f16(vmin0, src0_0);
378 38592 }
379 // Get the max/min
380 7680 float16_t max0 = vmaxvq_f16(vmax0);
381 7680 float16_t min0 = vminvq_f16(vmin0);
382
383
2/2
✓ Branch 0 taken 13776 times.
✓ Branch 1 taken 7680 times.
21456 for (; k_idx < (int32_t)k; ++k_idx) {
384 13776 const float16_t src0 = *(src_ptr + (size_t)k_idx);
385 13776 max0 = vmaxh_f16(src0, max0);
386 13776 min0 = vminh_f16(src0, min0);
387 13776 }
388
389 // Maximum/minimum int8 values
390 7680 const float qmin = (float)INT8_MIN;
391 7680 const float qmax = (float)INT8_MAX;
392
393 7680 const float rmin0 = fminf(0.0F, min0);
394 7680 const float rmax0 = fmaxf(0.0F, max0);
395
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7680 times.
7680 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
396
397 // Reciprocal to quantize
398
1/2
✓ Branch 0 taken 7680 times.
✗ Branch 1 not taken.
7680 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
399
400 7680 const float descaled_min0 = rmin0 * scale0;
401 7680 const float descaled_max0 = rmax0 * scale0;
402
403 7680 const float zero_point_from_min_error0 = qmin + descaled_min0;
404 7680 const float zero_point_from_max_error0 = qmax + descaled_max0;
405
406 15360 float zero_point0 =
407
1/2
✓ Branch 0 taken 7680 times.
✗ Branch 1 not taken.
7680 zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0;
408
409 7680 zero_point0 = fmaxf(zero_point0, qmin);
410 7680 zero_point0 = fminf(zero_point0, qmax);
411
412 // Round to nearest integer
413 7680 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
414
415 7680 const size_t dst_x = ((row_idx + m_idx_start) % mr);
416
417 7680 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t));
418
419 // Quantize the channels
420 7680 int32_t block_idx = 0;
421
422
2/2
✓ Branch 0 taken 55056 times.
✓ Branch 1 taken 7680 times.
62736 for (; block_idx <= blocks_lim_k; ++block_idx) {
423 55056 const int32_t k_idx_start = block_idx * k_block_len;
424
425 55056 const float16x8_t src_0 = vld1q_f16(src_ptr + k_idx_start);
426
427 // Scale the values
428 55056 const float32x4_t v0_f32 = vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src_0)), scale0);
429 55056 const float32x4_t v1_f32 = vmulq_n_f32(vcvt_high_f32_f16(src_0), scale0);
430 55056 const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32);
431 55056 const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32);
432
433 55056 const int16x4_t v0_s16 = vqmovn_s32(v0_s32);
434 55056 const int16x4_t v1_s16 = vqmovn_s32(v1_s32);
435 55056 int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16);
436
437 // Add zero points
438 55056 int16_t nzp_s16 = (int16_t)nudged_zero_point0;
439 55056 int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16);
440 55056 v_s16 = vaddq_s16(v_s16, vnzp_s16);
441 55056 v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN));
442 55056 v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX));
443
444 55056 int8x8_t v_s8 = vqmovn_s16(v_s16);
445 55056 vst1_s8((int8_t*)(dst_ptr), v_s8);
446 55056 dst_ptr += mr * k_block_len * sizeof(int8_t);
447 55056 }
448
449
2/2
✓ Branch 0 taken 18960 times.
✓ Branch 1 taken 7680 times.
26640 for (; block_idx < num_blocks_k_internal; ++block_idx) {
450 // left over k
451
2/2
✓ Branch 0 taken 97344 times.
✓ Branch 1 taken 18960 times.
116304 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
452 // Clamp at the last valid k-index
453
2/2
✓ Branch 0 taken 19416 times.
✓ Branch 1 taken 77928 times.
97344 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
454
455 97344 const float src0 = (float)(*(src_ptr + k_idx_start));
456
457 // Scale the values
458 97344 int32_t d0_s32 = (int32_t)(roundf(src0 * scale0));
459
460 97344 d0_s32 = d0_s32 + nudged_zero_point0;
461
2/2
✓ Branch 0 taken 97316 times.
✓ Branch 1 taken 28 times.
97344 d0_s32 = KAI_MAX(d0_s32, INT8_MIN);
462
2/2
✓ Branch 0 taken 96490 times.
✓ Branch 1 taken 854 times.
97344 d0_s32 = KAI_MIN(d0_s32, INT8_MAX);
463
464 97344 *((int8_t*)(dst_ptr)) = (int8_t)d0_s32;
465 97344 dst_ptr += sizeof(int8_t);
466 97344 }
467 18960 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
468 18960 }
469
470 7680 dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
471
472 7680 dst_ptr += dst_x * kai_num_bytes_per_offset;
473
474 // LHS offset at the beginning of the row
475 7680 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
476
477 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier
478 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
479
480 7680 dst_ptr += mr * kai_num_bytes_per_offset;
481
482 // Store the scale quantization params
483 7680 *((float*)(dst_ptr)) = recip_scale0;
484
485 7680 src_ptr += lhs_row_length;
486
487 // Move to the next row if we have interleaved all Mr rows
488
2/2
✓ Branch 0 taken 5664 times.
✓ Branch 1 taken 2016 times.
7680 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
489 2016 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
490 2016 }
491 7680 }
492 7392 }
493
494 #endif // Architectural features check.
495