KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 98.1% 359 / 5 / 371
Functions: 85.7% 6 / 0 / 7
Branches: 79.8% 67 / 6 / 90

kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if (!defined(__aarch64__) && !defined(_M_ARM64))
7 #error This file must be compiled for AArch64.
8 #else // Architectural features check.
9
10 #include "kai_lhs_quant_pack_qai8dxp_bf16_neon.h"
11
12 #include <arm_neon.h>
13 #endif
14 #include <float.h>
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include "kai/kai_common.h"
20
21 static const size_t kai_num_bytes_per_multiplier = sizeof(float);
22 static const size_t kai_num_bytes_per_offset = sizeof(int32_t);
23
24 19776 inline static size_t kai_k_roundedup(size_t k) {
25 // Round up k to be a multiple of 32.
26 static const size_t kai_k_multiple_of = 32;
27 19776 return kai_roundup(k, kai_k_multiple_of);
28 }
29
30 14880 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr) {
31 14880 const size_t k_internal = kai_k_roundedup(k);
32
33 KAI_ASSERT((k_internal % 2) == 0);
34
35 29760 return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset);
36 14880 }
37
38 size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16_neon(size_t mr) {
39 return mr;
40 }
41
42 4896 size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(size_t m_idx, size_t lhs_stride) {
43 4896 return m_idx * lhs_stride;
44 }
45
46 5088 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(
47 size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
48 5088 KAI_UNUSED(kr);
49 5088 KAI_UNUSED(sr);
50 // It always points to the beginning of the row
51 5088 return (m_idx / mr) * kai_lhs_packed_stride(k, mr);
52 }
53
54 4896 size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
55 4896 KAI_UNUSED(kr);
56 4896 KAI_UNUSED(sr);
57 4896 const size_t num_rows = kai_roundup(m, mr) / mr;
58
59 9792 return num_rows * kai_lhs_packed_stride(k, mr);
60 4896 }
61
62 // Note: The lhs parameter type has been changed from float* to void*.
63 // The bfloat16 values (packed in 16 bits) will be converted to float32.
64 4896 void kai_run_lhs_quant_pack_qai8dxp_bf16_neon(
65 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs,
66 size_t lhs_stride, void* restrict lhs_packed) {
67 KAI_ASSERT((kr % sr) == 0);
68
69
1/2
✓ Branch 0 taken 4896 times.
✗ Branch 1 not taken.
4896 if (m == 0) {
70 return;
71 }
72
73 // Now lhs is assumed to contain bfloat16 values encoded in uint16_t.
74 4896 const uint16_t* src_ptr = (uint16_t const*)lhs;
75
76 4896 const size_t dst_stride = kai_lhs_packed_stride(k, mr);
77 4896 const size_t k_internal = kai_k_roundedup(k);
78 4896 const int32_t k_block_len = (int32_t)(kr / sr);
79 KAI_ASSERT(k_block_len == 8);
80
81 4896 const int32_t num_blocks_k = (int32_t)(k / k_block_len);
82 4896 const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len);
83
84 4896 size_t row_idx = 0;
85
86
2/2
✓ Branch 0 taken 2448 times.
✓ Branch 1 taken 2448 times.
4896 if (mr == 4) {
87
2/2
✓ Branch 0 taken 11352 times.
✓ Branch 1 taken 2448 times.
13800 for (; row_idx + 3 < m; row_idx += 4) {
88 11352 float max0 = -FLT_MAX;
89 11352 float min0 = FLT_MAX;
90 11352 float max1 = -FLT_MAX;
91 11352 float min1 = FLT_MAX;
92 11352 float max2 = -FLT_MAX;
93 11352 float min2 = FLT_MAX;
94 11352 float max3 = -FLT_MAX;
95 11352 float min3 = FLT_MAX;
96
97 // Find min/max for each channel
98 11352 int32_t k_idx = 0;
99 11352 float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX);
100 11352 float32x4_t vmin0 = vdupq_n_f32(FLT_MAX);
101 11352 float32x4_t vmax1 = vmax0;
102 11352 float32x4_t vmin1 = vmin0;
103 11352 float32x4_t vmax2 = vmax0;
104 11352 float32x4_t vmin2 = vmin0;
105 11352 float32x4_t vmax3 = vmax0;
106 11352 float32x4_t vmin3 = vmin0;
107 11352 const uint16x8_t zero = vdupq_n_u16(0);
108 // Process 8 bfloat16 values per iteration.
109
2/2
✓ Branch 0 taken 105264 times.
✓ Branch 1 taken 11352 times.
116616 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
110 // Load eight bfloat16 values.
111 105264 const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx);
112 105264 const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t)));
113 105264 const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t))));
114 105264 const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t))));
115
116 105264 const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0);
117 105264 const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0);
118 105264 const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1);
119 105264 const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1);
120 105264 const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2);
121 105264 const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2);
122 105264 const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3);
123 105264 const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3);
124
125 105264 const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0);
126 105264 const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0);
127 105264 const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1);
128 105264 const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1);
129 105264 const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2);
130 105264 const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2);
131 105264 const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3);
132 105264 const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3);
133
134 // Calculate the maximum
135 105264 vmax0 = vmaxq_f32(src0_0, vmax0);
136 105264 vmax0 = vmaxq_f32(vmax0, src0_1);
137 105264 vmax1 = vmaxq_f32(src1_0, vmax1);
138 105264 vmax1 = vmaxq_f32(vmax1, src1_1);
139 105264 vmax2 = vmaxq_f32(src2_0, vmax2);
140 105264 vmax2 = vmaxq_f32(vmax2, src2_1);
141 105264 vmax3 = vmaxq_f32(src3_0, vmax3);
142 105264 vmax3 = vmaxq_f32(vmax3, src3_1);
143
144 // Calculate the minimum
145 105264 vmin0 = vminq_f32(src0_0, vmin0);
146 105264 vmin0 = vminq_f32(vmin0, src0_1);
147 105264 vmin1 = vminq_f32(src1_0, vmin1);
148 105264 vmin1 = vminq_f32(vmin1, src1_1);
149 105264 vmin2 = vminq_f32(src2_0, vmin2);
150 105264 vmin2 = vminq_f32(vmin2, src2_1);
151 105264 vmin3 = vminq_f32(src3_0, vmin3);
152 105264 vmin3 = vminq_f32(vmin3, src3_1);
153 105264 }
154 // Get the max/min scalar values.
155 11352 max0 = vmaxvq_f32(vmax0);
156 11352 min0 = vminvq_f32(vmin0);
157 11352 max1 = vmaxvq_f32(vmax1);
158 11352 min1 = vminvq_f32(vmin1);
159 11352 max2 = vmaxvq_f32(vmax2);
160 11352 min2 = vminvq_f32(vmin2);
161 11352 max3 = vmaxvq_f32(vmax3);
162 11352 min3 = vminvq_f32(vmin3);
163 // Process leftover elements with a scalar loop.
164
2/2
✓ Branch 0 taken 21168 times.
✓ Branch 1 taken 11352 times.
32520 for (; k_idx < (int32_t)k; ++k_idx) {
165 21168 const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx));
166 21168 max0 = fmaxf(src0, max0);
167 21168 min0 = fminf(src0, min0);
168 21168 const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t))));
169 21168 max1 = fmaxf(src1, max1);
170 21168 min1 = fminf(src1, min1);
171 21168 const float src2 = kai_cast_f32_bf16(*(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t)))));
172 21168 max2 = fmaxf(src2, max2);
173 21168 min2 = fminf(src2, min2);
174 21168 const float src3 = kai_cast_f32_bf16(*(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t)))));
175 21168 max3 = fmaxf(src3, max3);
176 21168 min3 = fminf(src3, min3);
177 21168 }
178
179 // Maximum/minimum int8 values
180 11352 const float qmin = (float)INT8_MIN;
181 11352 const float qmax = (float)INT8_MAX;
182
183 11352 const float rmin0 = fminf(0.0F, min0);
184 11352 const float rmax0 = fmaxf(0.0F, max0);
185
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 11352 times.
11352 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
186 11352 const float rmin1 = fminf(0.0F, min1);
187 11352 const float rmax1 = fmaxf(0.0F, max1);
188
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 11352 times.
11352 const float scale1 = rmin1 == rmax1 ? 1.F : (qmax - qmin) / (rmax1 - rmin1);
189 11352 const float rmin2 = fminf(0.0F, min2);
190 11352 const float rmax2 = fmaxf(0.0F, max2);
191
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 11352 times.
11352 const float scale2 = rmin2 == rmax2 ? 1.F : (qmax - qmin) / (rmax2 - rmin2);
192 11352 const float rmin3 = fminf(0.0F, min3);
193 11352 const float rmax3 = fmaxf(0.0F, max3);
194
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 11352 times.
11352 const float scale3 = rmin3 == rmax3 ? 1.F : (qmax - qmin) / (rmax3 - rmin3);
195
196 // Reciprocal to quantize
197
1/2
✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
11352 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
198
1/2
✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
11352 const float recip_scale1 = scale1 ? 1.0F / scale1 : 0.0F;
199
1/2
✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
11352 const float recip_scale2 = scale2 ? 1.0F / scale2 : 0.0F;
200
1/2
✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
11352 const float recip_scale3 = scale3 ? 1.0F / scale3 : 0.0F;
201
202 11352 const float descaled_min0 = rmin0 * scale0;
203 11352 const float descaled_max0 = rmax0 * scale0;
204 11352 const float descaled_min1 = rmin1 * scale1;
205 11352 const float descaled_max1 = rmax1 * scale1;
206 11352 const float descaled_min2 = rmin2 * scale2;
207 11352 const float descaled_max2 = rmax2 * scale2;
208 11352 const float descaled_min3 = rmin3 * scale3;
209 11352 const float descaled_max3 = rmax3 * scale3;
210
211 11352 const float zero_point_from_min_error0 = qmin + descaled_min0;
212 11352 const float zero_point_from_max_error0 = qmax + descaled_max0;
213 11352 const float zero_point_from_min_error1 = qmin + descaled_min1;
214 11352 const float zero_point_from_max_error1 = qmax + descaled_max1;
215 11352 const float zero_point_from_min_error2 = qmin + descaled_min2;
216 11352 const float zero_point_from_max_error2 = qmax + descaled_max2;
217 11352 const float zero_point_from_min_error3 = qmin + descaled_min3;
218 11352 const float zero_point_from_max_error3 = qmax + descaled_max3;
219
220
1/2
✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
11352 float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0
221 : qmax - descaled_max0;
222
1/2
✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
11352 float zero_point1 = (zero_point_from_min_error1 + zero_point_from_max_error1 > 0) ? qmin - descaled_min1
223 : qmax - descaled_max1;
224
1/2
✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
11352 float zero_point2 = (zero_point_from_min_error2 + zero_point_from_max_error2 > 0) ? qmin - descaled_min2
225 : qmax - descaled_max2;
226
1/2
✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
11352 float zero_point3 = (zero_point_from_min_error3 + zero_point_from_max_error3 > 0) ? qmin - descaled_min3
227 : qmax - descaled_max3;
228
229 11352 zero_point0 = fmaxf(zero_point0, qmin);
230 11352 zero_point0 = fminf(zero_point0, qmax);
231 11352 zero_point1 = fmaxf(zero_point1, qmin);
232 11352 zero_point1 = fminf(zero_point1, qmax);
233 11352 zero_point2 = fmaxf(zero_point2, qmin);
234 11352 zero_point2 = fminf(zero_point2, qmax);
235 11352 zero_point3 = fmaxf(zero_point3, qmin);
236 11352 zero_point3 = fminf(zero_point3, qmax);
237
238 // Round to nearest integer
239 11352 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
240 11352 const int32_t nudged_zero_point1 = (int32_t)rintf(zero_point1);
241 11352 const int32_t nudged_zero_point2 = (int32_t)rintf(zero_point2);
242 11352 const int32_t nudged_zero_point3 = (int32_t)rintf(zero_point3);
243
244 11352 const size_t dst_x = ((row_idx + m_idx_start) % mr);
245
246 11352 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len);
247
248 // Quantize the channels
249 11352 int32_t block_idx = 0;
250
251
2/2
✓ Branch 0 taken 105264 times.
✓ Branch 1 taken 11352 times.
116616 for (; block_idx < num_blocks_k; ++block_idx) {
252 // Clamp at the last valid k-index
253 105264 const int32_t k_idx_start = block_idx * k_block_len;
254
255 // Load eight bfloat16 values and convert them to float32.
256 105264 const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx_start);
257 105264 const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t)));
258 105264 const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t))));
259 105264 const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t))));
260 105264 const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0);
261 105264 const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0);
262 105264 const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1);
263 105264 const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1);
264 105264 const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2);
265 105264 const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2);
266 105264 const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3);
267 105264 const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3);
268 105264 const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0);
269 105264 const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0);
270 105264 const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1);
271 105264 const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1);
272 105264 const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2);
273 105264 const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2);
274 105264 const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3);
275 105264 const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3);
276
277 // Scale the values.
278 105264 const int16x4_t v0_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_0, scale0)));
279 105264 const int16x4_t v1_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_1, scale0)));
280 105264 const int16x4_t v0_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_0, scale1)));
281 105264 const int16x4_t v1_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_1, scale1)));
282 105264 const int16x4_t v0_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_0, scale2)));
283 105264 const int16x4_t v1_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_1, scale2)));
284 105264 const int16x4_t v0_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_0, scale3)));
285 105264 const int16x4_t v1_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_1, scale3)));
286
287 105264 int16x8_t v0_s16 = vcombine_s16(v0_0, v1_0);
288 105264 int16x8_t v1_s16 = vcombine_s16(v0_1, v1_1);
289 105264 int16x8_t v2_s16 = vcombine_s16(v0_2, v1_2);
290 105264 int16x8_t v3_s16 = vcombine_s16(v0_3, v1_3);
291
292 // Add zero points.
293 105264 const int16x8_t vnzp0 = vdupq_n_s16((int16_t)nudged_zero_point0);
294 105264 const int16x8_t vnzp1 = vdupq_n_s16((int16_t)nudged_zero_point1);
295 105264 const int16x8_t vnzp2 = vdupq_n_s16((int16_t)nudged_zero_point2);
296 105264 const int16x8_t vnzp3 = vdupq_n_s16((int16_t)nudged_zero_point3);
297
298 105264 v0_s16 = vaddq_s16(v0_s16, vnzp0);
299 105264 v0_s16 = vmaxq_s16(v0_s16, vdupq_n_s16(INT8_MIN));
300 105264 v0_s16 = vminq_s16(v0_s16, vdupq_n_s16(INT8_MAX));
301 105264 v1_s16 = vaddq_s16(v1_s16, vnzp1);
302 105264 v1_s16 = vmaxq_s16(v1_s16, vdupq_n_s16(INT8_MIN));
303 105264 v1_s16 = vminq_s16(v1_s16, vdupq_n_s16(INT8_MAX));
304 105264 v2_s16 = vaddq_s16(v2_s16, vnzp2);
305 105264 v2_s16 = vmaxq_s16(v2_s16, vdupq_n_s16(INT8_MIN));
306 105264 v2_s16 = vminq_s16(v2_s16, vdupq_n_s16(INT8_MAX));
307 105264 v3_s16 = vaddq_s16(v3_s16, vnzp3);
308 105264 v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN));
309 105264 v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX));
310
311 105264 const int8x8_t v0_s8 = vqmovn_s16(v0_s16);
312 105264 const int8x8_t v1_s8 = vqmovn_s16(v1_s16);
313 105264 const int8x8_t v2_s8 = vqmovn_s16(v2_s16);
314 105264 const int8x8_t v3_s8 = vqmovn_s16(v3_s16);
315
316 105264 vst1_s8((int8_t*)(dst_ptr), v0_s8);
317 105264 vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8);
318 105264 vst1_s8((int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)), v2_s8);
319 105264 vst1_s8((int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)), v3_s8);
320 105264 dst_ptr += 4 * sizeof(int8x8_t);
321 105264 }
322
323
2/2
✓ Branch 0 taken 11760 times.
✓ Branch 1 taken 11352 times.
23112 for (; block_idx < num_blocks_k_internal; ++block_idx) {
324 // Left over k
325
2/2
✓ Branch 0 taken 94080 times.
✓ Branch 1 taken 11760 times.
105840 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
326 // Clamp at the last valid k-index.
327
2/2
✓ Branch 0 taken 16464 times.
✓ Branch 1 taken 77616 times.
94080 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
328
329 94080 const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start));
330 94080 const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t))));
331 188160 const float src2 =
332 94080 kai_cast_f32_bf16(*(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t)))));
333 188160 const float src3 =
334 94080 kai_cast_f32_bf16(*(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t)))));
335
336 // Scale the value.
337 94080 int32_t v0_s32 = (int32_t)(roundf(src0 * scale0));
338 94080 int32_t v1_s32 = (int32_t)(roundf(src1 * scale1));
339 94080 int32_t v2_s32 = (int32_t)(roundf(src2 * scale2));
340 94080 int32_t v3_s32 = (int32_t)(roundf(src3 * scale3));
341
342 94080 v0_s32 = v0_s32 + nudged_zero_point0;
343
2/2
✓ Branch 0 taken 94052 times.
✓ Branch 1 taken 28 times.
94080 v0_s32 = KAI_MAX(v0_s32, INT8_MIN);
344
2/2
✓ Branch 0 taken 92864 times.
✓ Branch 1 taken 1216 times.
94080 v0_s32 = KAI_MIN(v0_s32, INT8_MAX);
345
346 94080 v1_s32 = v1_s32 + nudged_zero_point1;
347
2/2
✓ Branch 0 taken 93996 times.
✓ Branch 1 taken 84 times.
94080 v1_s32 = KAI_MAX(v1_s32, INT8_MIN);
348
2/2
✓ Branch 0 taken 92244 times.
✓ Branch 1 taken 1836 times.
94080 v1_s32 = KAI_MIN(v1_s32, INT8_MAX);
349
350 94080 v2_s32 = v2_s32 + nudged_zero_point2;
351
2/2
✓ Branch 0 taken 93988 times.
✓ Branch 1 taken 92 times.
94080 v2_s32 = KAI_MAX(v2_s32, INT8_MIN);
352
2/2
✓ Branch 0 taken 92816 times.
✓ Branch 1 taken 1264 times.
94080 v2_s32 = KAI_MIN(v2_s32, INT8_MAX);
353
354 94080 v3_s32 = v3_s32 + nudged_zero_point3;
355
1/2
✓ Branch 0 taken 94080 times.
✗ Branch 1 not taken.
94080 v3_s32 = KAI_MAX(v3_s32, INT8_MIN);
356
2/2
✓ Branch 0 taken 92932 times.
✓ Branch 1 taken 1148 times.
94080 v3_s32 = KAI_MIN(v3_s32, INT8_MAX);
357
358 94080 *(int8_t*)dst_ptr = (int8_t)v0_s32;
359 94080 *(int8_t*)(dst_ptr + sizeof(int8x8_t)) = (int8_t)v1_s32;
360 94080 *(int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)) = (int8_t)v2_s32;
361 94080 *(int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)) = (int8_t)v3_s32;
362
363 94080 dst_ptr += sizeof(int8_t);
364 94080 }
365 11760 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
366 11760 }
367
368 11352 uint8_t* dst_base = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
369
370 11352 dst_ptr = dst_base + dst_x * kai_num_bytes_per_offset;
371
372 // LHS offset at the beginning of the row.
373 11352 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
374 11352 *((int32_t*)(dst_ptr + kai_num_bytes_per_offset)) = -nudged_zero_point1;
375 11352 *((int32_t*)(dst_ptr + 2 * kai_num_bytes_per_offset)) = -nudged_zero_point2;
376 11352 *((int32_t*)(dst_ptr + 3 * kai_num_bytes_per_offset)) = -nudged_zero_point3;
377
378 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier.
379 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
380
381 11352 dst_ptr += mr * kai_num_bytes_per_offset;
382
383 // Store the scale quantization params.
384 11352 *((float*)(dst_ptr)) = recip_scale0;
385 11352 *((float*)(dst_ptr + kai_num_bytes_per_multiplier)) = recip_scale1;
386 11352 *((float*)(dst_ptr + 2 * kai_num_bytes_per_multiplier)) = recip_scale2;
387 11352 *((float*)(dst_ptr + 3 * kai_num_bytes_per_multiplier)) = recip_scale3;
388
389 // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each).
390 11352 src_ptr += (4 * lhs_stride / sizeof(uint16_t));
391
392 // Move to the next row as we have interleaved all Mr rows.
393 11352 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
394 11352 }
395 2448 }
396
397
2/2
✓ Branch 0 taken 48240 times.
✓ Branch 1 taken 4896 times.
53136 for (; row_idx < m; ++row_idx) {
398 48240 float max0 = -FLT_MAX;
399 48240 float min0 = FLT_MAX;
400
401 // Find min/max for each channel
402 48240 int32_t k_idx = 0;
403 48240 float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX);
404 48240 float32x4_t vmin0 = vdupq_n_f32(FLT_MAX);
405 48240 const uint16x8_t zero = vdupq_n_u16(0);
406 // Process 8 bfloat16 values per iteration.
407
2/2
✓ Branch 0 taken 431304 times.
✓ Branch 1 taken 48240 times.
479544 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
408 // Load eight bfloat16 values.
409 431304 const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx);
410 431304 const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec);
411 431304 const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec);
412 431304 const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1);
413 431304 const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2);
414
415 // Calculate the maximum
416 431304 vmax0 = vmaxq_f32(src0_0, vmax0);
417 431304 vmax0 = vmaxq_f32(vmax0, src0_1);
418
419 // Calculate the minimum
420 431304 vmin0 = vminq_f32(src0_0, vmin0);
421 431304 vmin0 = vminq_f32(vmin0, src0_1);
422 431304 }
423 // Get the max/min scalar values.
424 48240 max0 = vmaxvq_f32(vmax0);
425 48240 min0 = vminvq_f32(vmin0);
426 // Process leftover elements with a scalar loop.
427
2/2
✓ Branch 0 taken 86616 times.
✓ Branch 1 taken 48240 times.
134856 for (; k_idx < (int32_t)k; ++k_idx) {
428 86616 const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx));
429 86616 max0 = fmaxf(src0_0, max0);
430 86616 min0 = fminf(src0_0, min0);
431 86616 }
432
433 // Maximum/minimum int8 values
434 48240 const float qmin = (float)INT8_MIN;
435 48240 const float qmax = (float)INT8_MAX;
436
437 48240 const float rmin0 = fminf(0.0F, min0);
438 48240 const float rmax0 = fmaxf(0.0F, max0);
439
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 48240 times.
48240 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
440
441 // Reciprocal to quantize
442
1/2
✓ Branch 0 taken 48240 times.
✗ Branch 1 not taken.
48240 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
443
444 48240 const float descaled_min0 = rmin0 * scale0;
445 48240 const float descaled_max0 = rmax0 * scale0;
446
447 48240 const float zero_point_from_min_error0 = qmin + descaled_min0;
448 48240 const float zero_point_from_max_error0 = qmax + descaled_max0;
449
450 96480 float zero_point0 =
451
1/2
✓ Branch 0 taken 48240 times.
✗ Branch 1 not taken.
48240 (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 : qmax - descaled_max0;
452
453 48240 zero_point0 = fmaxf(zero_point0, qmin);
454 48240 zero_point0 = fminf(zero_point0, qmax);
455
456 // Round to nearest integer
457 48240 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
458
459 48240 const size_t dst_x = ((row_idx + m_idx_start) % mr);
460
461 48240 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t));
462
463 // Quantize the channels
464 48240 int32_t block_idx = 0;
465
466
2/2
✓ Branch 0 taken 431304 times.
✓ Branch 1 taken 48240 times.
479544 for (; block_idx < num_blocks_k; ++block_idx) {
467 // Clamp at the last valid k-index
468 431304 const int32_t k_idx_start = block_idx * k_block_len;
469
470 // Load eight bfloat16 values and convert them to float32.
471 431304 const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx_start);
472 431304 const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec);
473 431304 const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec);
474 431304 const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1);
475 431304 const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2);
476
477 // Scale the values.
478 431304 const float32x4_t v0_f32 = vmulq_n_f32(src0_0, scale0);
479 431304 const float32x4_t v1_f32 = vmulq_n_f32(src0_1, scale0);
480 431304 const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32);
481 431304 const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32);
482
483 431304 const int16x4_t v0_s16 = vqmovn_s32(v0_s32);
484 431304 const int16x4_t v1_s16 = vqmovn_s32(v1_s32);
485 431304 int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16);
486
487 // Add zero points.
488 431304 int16_t nzp_s16 = (int16_t)nudged_zero_point0;
489 431304 int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16);
490 431304 v_s16 = vaddq_s16(v_s16, vnzp_s16);
491 431304 v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN));
492 431304 v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX));
493
494 431304 const int8x8_t v0_s8 = vqmovn_s16(v_s16);
495 431304 vst1_s8((int8_t*)(dst_ptr), v0_s8);
496 431304 dst_ptr += 8 * sizeof(int8_t);
497 431304 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
498 431304 }
499
500
2/2
✓ Branch 0 taken 48120 times.
✓ Branch 1 taken 48240 times.
96360 for (; block_idx < num_blocks_k_internal; ++block_idx) {
501 // Left over k
502
2/2
✓ Branch 0 taken 384960 times.
✓ Branch 1 taken 48120 times.
433080 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
503 // Clamp at the last valid k-index.
504
2/2
✓ Branch 0 taken 67368 times.
✓ Branch 1 taken 317592 times.
384960 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
505
506 384960 const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start));
507
508 // Scale the value.
509 384960 int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0));
510
511 384960 v0_s32 = v0_s32 + nudged_zero_point0;
512
2/2
✓ Branch 0 taken 384756 times.
✓ Branch 1 taken 204 times.
384960 v0_s32 = KAI_MAX(v0_s32, INT8_MIN);
513
2/2
✓ Branch 0 taken 379420 times.
✓ Branch 1 taken 5540 times.
384960 v0_s32 = KAI_MIN(v0_s32, INT8_MAX);
514
515 384960 *((int8_t*)(dst_ptr)) = (int8_t)v0_s32;
516 384960 dst_ptr += sizeof(int8_t);
517 384960 }
518 48120 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
519 48120 }
520
521 48240 dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
522
523 48240 dst_ptr += dst_x * kai_num_bytes_per_offset;
524
525 // LHS offset at the beginning of the row.
526 48240 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
527
528 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier.
529 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
530
531 48240 dst_ptr += mr * kai_num_bytes_per_offset;
532
533 // Store the scale quantization params.
534 48240 *((float*)(dst_ptr)) = recip_scale0;
535
536 // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each).
537 48240 src_ptr += (lhs_stride / sizeof(uint16_t));
538
539 // Move to the next row if we have interleaved all Mr rows.
540
2/2
✓ Branch 0 taken 2304 times.
✓ Branch 1 taken 45936 times.
48240 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
541 45936 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
542 45936 }
543 48240 }
544 4896 }
545