KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.8% 359 5 372
Functions: 85.7% 6 0 7
Branches: 77.4% 65 6 90

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if (!defined(__aarch64__) && !defined(_M_ARM64))
7 #error This file must be compiled for AArch64.
8 #else // Architectural features check.
9
10 #include "kai_lhs_quant_pack_qai8dxp_bf16_neon.h"
11
12 #include <arm_neon.h>
13 #endif
14 #include <float.h>
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include "kai/kai_common.h"
20
21 static const size_t kai_num_bytes_per_multiplier = sizeof(float);
22 static const size_t kai_num_bytes_per_offset = sizeof(int32_t);
23
24 3840 inline static size_t kai_k_roundedup(size_t k) {
25 // Round up k to be a multiple of 32.
26 static const size_t kai_k_multiple_of = 32;
27 3840 return kai_roundup(k, kai_k_multiple_of);
28 }
29
30 2880 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr) {
31 2880 const size_t k_internal = kai_k_roundedup(k);
32
33 KAI_ASSERT((k_internal % 2) == 0);
34
35 5760 return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset);
36 2880 }
37
38 size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16_neon(size_t mr) {
39 KAI_UNUSED(mr);
40 return 1;
41 }
42
43 960 size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(size_t m_idx, size_t lhs_stride) {
44 960 return m_idx * lhs_stride;
45 }
46
47 960 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(
48 size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
49 960 KAI_UNUSED(kr);
50 960 KAI_UNUSED(sr);
51 // It always points to the beginning of the row
52 960 return (m_idx / mr) * kai_lhs_packed_stride(k, mr);
53 }
54
55 960 size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
56 960 KAI_UNUSED(kr);
57 960 KAI_UNUSED(sr);
58 960 const size_t num_rows = kai_roundup(m, mr) / mr;
59
60 1920 return num_rows * kai_lhs_packed_stride(k, mr);
61 960 }
62
63 // Note: The lhs parameter type has been changed from float* to void*.
64 // The bfloat16 values (packed in 16 bits) will be converted to float32.
65 960 void kai_run_lhs_quant_pack_qai8dxp_bf16_neon(
66 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs,
67 size_t lhs_stride, void* restrict lhs_packed) {
68 KAI_ASSERT((kr % sr) == 0);
69
70
1/2
✓ Branch 0 taken 960 times.
✗ Branch 1 not taken.
960 if (m == 0) {
71 return;
72 }
73
74 // Now lhs is assumed to contain bfloat16 values encoded in uint16_t.
75 960 const uint16_t* src_ptr = (uint16_t const*)lhs;
76
77 960 const size_t dst_stride = kai_lhs_packed_stride(k, mr);
78 960 const size_t k_internal = kai_k_roundedup(k);
79 960 const int32_t k_block_len = (int32_t)(kr / sr);
80 KAI_ASSERT(k_block_len == 8);
81
82 960 const int32_t num_blocks_k = (int32_t)(k / k_block_len);
83 960 const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len);
84
85 960 size_t row_idx = 0;
86
87
2/2
✓ Branch 0 taken 464 times.
✓ Branch 1 taken 496 times.
960 if (mr == 4) {
88
2/2
✓ Branch 0 taken 2228 times.
✓ Branch 1 taken 496 times.
2724 for (; row_idx + 3 < m; row_idx += 4) {
89 2228 float max0 = -FLT_MAX;
90 2228 float min0 = FLT_MAX;
91 2228 float max1 = -FLT_MAX;
92 2228 float min1 = FLT_MAX;
93 2228 float max2 = -FLT_MAX;
94 2228 float min2 = FLT_MAX;
95 2228 float max3 = -FLT_MAX;
96 2228 float min3 = FLT_MAX;
97
98 // Find min/max for each channel
99 2228 int32_t k_idx = 0;
100 2228 float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX);
101 2228 float32x4_t vmin0 = vdupq_n_f32(FLT_MAX);
102 2228 float32x4_t vmax1 = vmax0;
103 2228 float32x4_t vmin1 = vmin0;
104 2228 float32x4_t vmax2 = vmax0;
105 2228 float32x4_t vmin2 = vmin0;
106 2228 float32x4_t vmax3 = vmax0;
107 2228 float32x4_t vmin3 = vmin0;
108 2228 const uint16x8_t zero = vdupq_n_u16(0);
109 // Process 8 bfloat16 values per iteration.
110
2/2
✓ Branch 0 taken 25608 times.
✓ Branch 1 taken 2228 times.
27836 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
111 // Load eight bfloat16 values.
112 25608 const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx);
113 25608 const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t)));
114 25608 const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t))));
115 25608 const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t))));
116
117 25608 const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0);
118 25608 const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0);
119 25608 const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1);
120 25608 const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1);
121 25608 const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2);
122 25608 const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2);
123 25608 const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3);
124 25608 const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3);
125
126 25608 const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0);
127 25608 const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0);
128 25608 const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1);
129 25608 const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1);
130 25608 const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2);
131 25608 const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2);
132 25608 const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3);
133 25608 const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3);
134
135 // Calculate the maximum
136 25608 vmax0 = vmaxq_f32(src0_0, vmax0);
137 25608 vmax0 = vmaxq_f32(vmax0, src0_1);
138 25608 vmax1 = vmaxq_f32(src1_0, vmax1);
139 25608 vmax1 = vmaxq_f32(vmax1, src1_1);
140 25608 vmax2 = vmaxq_f32(src2_0, vmax2);
141 25608 vmax2 = vmaxq_f32(vmax2, src2_1);
142 25608 vmax3 = vmaxq_f32(src3_0, vmax3);
143 25608 vmax3 = vmaxq_f32(vmax3, src3_1);
144
145 // Calculate the minimum
146 25608 vmin0 = vminq_f32(src0_0, vmin0);
147 25608 vmin0 = vminq_f32(vmin0, src0_1);
148 25608 vmin1 = vminq_f32(src1_0, vmin1);
149 25608 vmin1 = vminq_f32(vmin1, src1_1);
150 25608 vmin2 = vminq_f32(src2_0, vmin2);
151 25608 vmin2 = vminq_f32(vmin2, src2_1);
152 25608 vmin3 = vminq_f32(src3_0, vmin3);
153 25608 vmin3 = vminq_f32(vmin3, src3_1);
154 25608 }
155 // Get the max/min scalar values.
156 2228 max0 = vmaxvq_f32(vmax0);
157 2228 min0 = vminvq_f32(vmin0);
158 2228 max1 = vmaxvq_f32(vmax1);
159 2228 min1 = vminvq_f32(vmin1);
160 2228 max2 = vmaxvq_f32(vmax2);
161 2228 min2 = vminvq_f32(vmin2);
162 2228 max3 = vmaxvq_f32(vmax3);
163 2228 min3 = vminvq_f32(vmin3);
164 // Process leftover elements with a scalar loop.
165
2/2
✓ Branch 0 taken 3528 times.
✓ Branch 1 taken 2228 times.
5756 for (; k_idx < (int32_t)k; ++k_idx) {
166 3528 const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx));
167 3528 max0 = fmaxf(src0, max0);
168 3528 min0 = fminf(src0, min0);
169 3528 const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t))));
170 3528 max1 = fmaxf(src1, max1);
171 3528 min1 = fminf(src1, min1);
172 3528 const float src2 = kai_cast_f32_bf16(*(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t)))));
173 3528 max2 = fmaxf(src2, max2);
174 3528 min2 = fminf(src2, min2);
175 3528 const float src3 = kai_cast_f32_bf16(*(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t)))));
176 3528 max3 = fmaxf(src3, max3);
177 3528 min3 = fminf(src3, min3);
178 3528 }
179
180 // Maximum/minimum int8 values
181 2228 const float qmin = (float)INT8_MIN;
182 2228 const float qmax = (float)INT8_MAX;
183
184 2228 const float rmin0 = fminf(0.0F, min0);
185 2228 const float rmax0 = fmaxf(0.0F, max0);
186
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2228 times.
2228 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
187 2228 const float rmin1 = fminf(0.0F, min1);
188 2228 const float rmax1 = fmaxf(0.0F, max1);
189
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2228 times.
2228 const float scale1 = rmin1 == rmax1 ? 1.F : (qmax - qmin) / (rmax1 - rmin1);
190 2228 const float rmin2 = fminf(0.0F, min2);
191 2228 const float rmax2 = fmaxf(0.0F, max2);
192
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2228 times.
2228 const float scale2 = rmin2 == rmax2 ? 1.F : (qmax - qmin) / (rmax2 - rmin2);
193 2228 const float rmin3 = fminf(0.0F, min3);
194 2228 const float rmax3 = fmaxf(0.0F, max3);
195
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2228 times.
2228 const float scale3 = rmin3 == rmax3 ? 1.F : (qmax - qmin) / (rmax3 - rmin3);
196
197 // Reciprocal to quantize
198
1/2
✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
2228 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
199
1/2
✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
2228 const float recip_scale1 = scale1 ? 1.0F / scale1 : 0.0F;
200
1/2
✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
2228 const float recip_scale2 = scale2 ? 1.0F / scale2 : 0.0F;
201
1/2
✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
2228 const float recip_scale3 = scale3 ? 1.0F / scale3 : 0.0F;
202
203 2228 const float descaled_min0 = rmin0 * scale0;
204 2228 const float descaled_max0 = rmax0 * scale0;
205 2228 const float descaled_min1 = rmin1 * scale1;
206 2228 const float descaled_max1 = rmax1 * scale1;
207 2228 const float descaled_min2 = rmin2 * scale2;
208 2228 const float descaled_max2 = rmax2 * scale2;
209 2228 const float descaled_min3 = rmin3 * scale3;
210 2228 const float descaled_max3 = rmax3 * scale3;
211
212 2228 const float zero_point_from_min_error0 = qmin + descaled_min0;
213 2228 const float zero_point_from_max_error0 = qmax + descaled_max0;
214 2228 const float zero_point_from_min_error1 = qmin + descaled_min1;
215 2228 const float zero_point_from_max_error1 = qmax + descaled_max1;
216 2228 const float zero_point_from_min_error2 = qmin + descaled_min2;
217 2228 const float zero_point_from_max_error2 = qmax + descaled_max2;
218 2228 const float zero_point_from_min_error3 = qmin + descaled_min3;
219 2228 const float zero_point_from_max_error3 = qmax + descaled_max3;
220
221
1/2
✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
2228 float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0
222 : qmax - descaled_max0;
223
1/2
✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
2228 float zero_point1 = (zero_point_from_min_error1 + zero_point_from_max_error1 > 0) ? qmin - descaled_min1
224 : qmax - descaled_max1;
225
1/2
✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
2228 float zero_point2 = (zero_point_from_min_error2 + zero_point_from_max_error2 > 0) ? qmin - descaled_min2
226 : qmax - descaled_max2;
227
1/2
✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
2228 float zero_point3 = (zero_point_from_min_error3 + zero_point_from_max_error3 > 0) ? qmin - descaled_min3
228 : qmax - descaled_max3;
229
230 2228 zero_point0 = fmaxf(zero_point0, qmin);
231 2228 zero_point0 = fminf(zero_point0, qmax);
232 2228 zero_point1 = fmaxf(zero_point1, qmin);
233 2228 zero_point1 = fminf(zero_point1, qmax);
234 2228 zero_point2 = fmaxf(zero_point2, qmin);
235 2228 zero_point2 = fminf(zero_point2, qmax);
236 2228 zero_point3 = fmaxf(zero_point3, qmin);
237 2228 zero_point3 = fminf(zero_point3, qmax);
238
239 // Round to nearest integer
240 2228 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
241 2228 const int32_t nudged_zero_point1 = (int32_t)rintf(zero_point1);
242 2228 const int32_t nudged_zero_point2 = (int32_t)rintf(zero_point2);
243 2228 const int32_t nudged_zero_point3 = (int32_t)rintf(zero_point3);
244
245 2228 const size_t dst_x = ((row_idx + m_idx_start) % mr);
246
247 2228 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len);
248
249 // Quantize the channels
250 2228 int32_t block_idx = 0;
251
252
2/2
✓ Branch 0 taken 25608 times.
✓ Branch 1 taken 2228 times.
27836 for (; block_idx < num_blocks_k; ++block_idx) {
253 // Clamp at the last valid k-index
254 25608 const int32_t k_idx_start = block_idx * k_block_len;
255
256 // Load eight bfloat16 values and convert them to float32.
257 25608 const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx_start);
258 25608 const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t)));
259 25608 const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t))));
260 25608 const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t))));
261 25608 const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0);
262 25608 const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0);
263 25608 const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1);
264 25608 const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1);
265 25608 const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2);
266 25608 const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2);
267 25608 const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3);
268 25608 const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3);
269 25608 const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0);
270 25608 const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0);
271 25608 const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1);
272 25608 const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1);
273 25608 const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2);
274 25608 const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2);
275 25608 const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3);
276 25608 const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3);
277
278 // Scale the values.
279 25608 const int16x4_t v0_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_0, scale0)));
280 25608 const int16x4_t v1_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_1, scale0)));
281 25608 const int16x4_t v0_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_0, scale1)));
282 25608 const int16x4_t v1_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_1, scale1)));
283 25608 const int16x4_t v0_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_0, scale2)));
284 25608 const int16x4_t v1_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_1, scale2)));
285 25608 const int16x4_t v0_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_0, scale3)));
286 25608 const int16x4_t v1_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_1, scale3)));
287
288 25608 int16x8_t v0_s16 = vcombine_s16(v0_0, v1_0);
289 25608 int16x8_t v1_s16 = vcombine_s16(v0_1, v1_1);
290 25608 int16x8_t v2_s16 = vcombine_s16(v0_2, v1_2);
291 25608 int16x8_t v3_s16 = vcombine_s16(v0_3, v1_3);
292
293 // Add zero points.
294 25608 const int16x8_t vnzp0 = vdupq_n_s16((int16_t)nudged_zero_point0);
295 25608 const int16x8_t vnzp1 = vdupq_n_s16((int16_t)nudged_zero_point1);
296 25608 const int16x8_t vnzp2 = vdupq_n_s16((int16_t)nudged_zero_point2);
297 25608 const int16x8_t vnzp3 = vdupq_n_s16((int16_t)nudged_zero_point3);
298
299 25608 v0_s16 = vaddq_s16(v0_s16, vnzp0);
300 25608 v0_s16 = vmaxq_s16(v0_s16, vdupq_n_s16(INT8_MIN));
301 25608 v0_s16 = vminq_s16(v0_s16, vdupq_n_s16(INT8_MAX));
302 25608 v1_s16 = vaddq_s16(v1_s16, vnzp1);
303 25608 v1_s16 = vmaxq_s16(v1_s16, vdupq_n_s16(INT8_MIN));
304 25608 v1_s16 = vminq_s16(v1_s16, vdupq_n_s16(INT8_MAX));
305 25608 v2_s16 = vaddq_s16(v2_s16, vnzp2);
306 25608 v2_s16 = vmaxq_s16(v2_s16, vdupq_n_s16(INT8_MIN));
307 25608 v2_s16 = vminq_s16(v2_s16, vdupq_n_s16(INT8_MAX));
308 25608 v3_s16 = vaddq_s16(v3_s16, vnzp3);
309 25608 v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN));
310 25608 v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX));
311
312 25608 const int8x8_t v0_s8 = vqmovn_s16(v0_s16);
313 25608 const int8x8_t v1_s8 = vqmovn_s16(v1_s16);
314 25608 const int8x8_t v2_s8 = vqmovn_s16(v2_s16);
315 25608 const int8x8_t v3_s8 = vqmovn_s16(v3_s16);
316
317 25608 vst1_s8((int8_t*)(dst_ptr), v0_s8);
318 25608 vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8);
319 25608 vst1_s8((int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)), v2_s8);
320 25608 vst1_s8((int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)), v3_s8);
321 25608 dst_ptr += 4 * sizeof(int8x8_t);
322 25608 }
323
324
2/2
✓ Branch 0 taken 1960 times.
✓ Branch 1 taken 2228 times.
4188 for (; block_idx < num_blocks_k_internal; ++block_idx) {
325 // Left over k
326
2/2
✓ Branch 0 taken 15680 times.
✓ Branch 1 taken 1960 times.
17640 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
327 // Clamp at the last valid k-index.
328
2/2
✓ Branch 0 taken 2744 times.
✓ Branch 1 taken 12936 times.
15680 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
329
330 15680 const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start));
331 15680 const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t))));
332 31360 const float src2 =
333 15680 kai_cast_f32_bf16(*(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t)))));
334 31360 const float src3 =
335 15680 kai_cast_f32_bf16(*(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t)))));
336
337 // Scale the value.
338 15680 int32_t v0_s32 = (int32_t)(roundf(src0 * scale0));
339 15680 int32_t v1_s32 = (int32_t)(roundf(src1 * scale1));
340 15680 int32_t v2_s32 = (int32_t)(roundf(src2 * scale2));
341 15680 int32_t v3_s32 = (int32_t)(roundf(src3 * scale3));
342
343 15680 v0_s32 = v0_s32 + nudged_zero_point0;
344
2/2
✓ Branch 0 taken 15620 times.
✓ Branch 1 taken 60 times.
15680 v0_s32 = KAI_MAX(v0_s32, INT8_MIN);
345
2/2
✓ Branch 0 taken 15520 times.
✓ Branch 1 taken 160 times.
15680 v0_s32 = KAI_MIN(v0_s32, INT8_MAX);
346
347 15680 v1_s32 = v1_s32 + nudged_zero_point1;
348
1/2
✓ Branch 0 taken 15680 times.
✗ Branch 1 not taken.
15680 v1_s32 = KAI_MAX(v1_s32, INT8_MIN);
349
2/2
✓ Branch 0 taken 15616 times.
✓ Branch 1 taken 64 times.
15680 v1_s32 = KAI_MIN(v1_s32, INT8_MAX);
350
351 15680 v2_s32 = v2_s32 + nudged_zero_point2;
352
1/2
✓ Branch 0 taken 15680 times.
✗ Branch 1 not taken.
15680 v2_s32 = KAI_MAX(v2_s32, INT8_MIN);
353
2/2
✓ Branch 0 taken 15532 times.
✓ Branch 1 taken 148 times.
15680 v2_s32 = KAI_MIN(v2_s32, INT8_MAX);
354
355 15680 v3_s32 = v3_s32 + nudged_zero_point3;
356
1/2
✓ Branch 0 taken 15680 times.
✗ Branch 1 not taken.
15680 v3_s32 = KAI_MAX(v3_s32, INT8_MIN);
357
2/2
✓ Branch 0 taken 15584 times.
✓ Branch 1 taken 96 times.
15680 v3_s32 = KAI_MIN(v3_s32, INT8_MAX);
358
359 15680 *(int8_t*)dst_ptr = (int8_t)v0_s32;
360 15680 *(int8_t*)(dst_ptr + sizeof(int8x8_t)) = (int8_t)v1_s32;
361 15680 *(int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)) = (int8_t)v2_s32;
362 15680 *(int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)) = (int8_t)v3_s32;
363
364 15680 dst_ptr += sizeof(int8_t);
365 15680 }
366 1960 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
367 1960 }
368
369 2228 uint8_t* dst_base = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
370
371 2228 dst_ptr = dst_base + dst_x * kai_num_bytes_per_offset;
372
373 // LHS offset at the beginning of the row.
374 2228 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
375 2228 *((int32_t*)(dst_ptr + kai_num_bytes_per_offset)) = -nudged_zero_point1;
376 2228 *((int32_t*)(dst_ptr + 2 * kai_num_bytes_per_offset)) = -nudged_zero_point2;
377 2228 *((int32_t*)(dst_ptr + 3 * kai_num_bytes_per_offset)) = -nudged_zero_point3;
378
379 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier.
380 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
381
382 2228 dst_ptr += mr * kai_num_bytes_per_offset;
383
384 // Store the scale quantization params.
385 2228 *((float*)(dst_ptr)) = recip_scale0;
386 2228 *((float*)(dst_ptr + kai_num_bytes_per_multiplier)) = recip_scale1;
387 2228 *((float*)(dst_ptr + 2 * kai_num_bytes_per_multiplier)) = recip_scale2;
388 2228 *((float*)(dst_ptr + 3 * kai_num_bytes_per_multiplier)) = recip_scale3;
389
390 // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each).
391 2228 src_ptr += (4 * lhs_stride / sizeof(uint16_t));
392
393 // Move to the next row as we have interleaved all Mr rows.
394 2228 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
395 2228 }
396 496 }
397
398
2/2
✓ Branch 0 taken 9500 times.
✓ Branch 1 taken 960 times.
10460 for (; row_idx < m; ++row_idx) {
399 9500 float max0 = -FLT_MAX;
400 9500 float min0 = FLT_MAX;
401
402 // Find min/max for each channel
403 9500 int32_t k_idx = 0;
404 9500 float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX);
405 9500 float32x4_t vmin0 = vdupq_n_f32(FLT_MAX);
406 9500 const uint16x8_t zero = vdupq_n_u16(0);
407 // Process 8 bfloat16 values per iteration.
408
2/2
✓ Branch 0 taken 106808 times.
✓ Branch 1 taken 9500 times.
116308 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
409 // Load eight bfloat16 values.
410 106808 const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx);
411 106808 const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec);
412 106808 const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec);
413 106808 const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1);
414 106808 const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2);
415
416 // Calculate the maximum
417 106808 vmax0 = vmaxq_f32(src0_0, vmax0);
418 106808 vmax0 = vmaxq_f32(vmax0, src0_1);
419
420 // Calculate the minimum
421 106808 vmin0 = vminq_f32(src0_0, vmin0);
422 106808 vmin0 = vminq_f32(vmin0, src0_1);
423 106808 }
424 // Get the max/min scalar values.
425 9500 max0 = vmaxvq_f32(vmax0);
426 9500 min0 = vminvq_f32(vmin0);
427 // Process leftover elements with a scalar loop.
428
2/2
✓ Branch 0 taken 14328 times.
✓ Branch 1 taken 9500 times.
23828 for (; k_idx < (int32_t)k; ++k_idx) {
429 14328 const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx));
430 14328 max0 = fmaxf(src0_0, max0);
431 14328 min0 = fminf(src0_0, min0);
432 14328 }
433
434 // Maximum/minimum int8 values
435 9500 const float qmin = (float)INT8_MIN;
436 9500 const float qmax = (float)INT8_MAX;
437
438 9500 const float rmin0 = fminf(0.0F, min0);
439 9500 const float rmax0 = fmaxf(0.0F, max0);
440
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 9500 times.
9500 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
441
442 // Reciprocal to quantize
443
1/2
✓ Branch 0 taken 9500 times.
✗ Branch 1 not taken.
9500 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
444
445 9500 const float descaled_min0 = rmin0 * scale0;
446 9500 const float descaled_max0 = rmax0 * scale0;
447
448 9500 const float zero_point_from_min_error0 = qmin + descaled_min0;
449 9500 const float zero_point_from_max_error0 = qmax + descaled_max0;
450
451 19000 float zero_point0 =
452
1/2
✓ Branch 0 taken 9500 times.
✗ Branch 1 not taken.
9500 (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 : qmax - descaled_max0;
453
454 9500 zero_point0 = fmaxf(zero_point0, qmin);
455 9500 zero_point0 = fminf(zero_point0, qmax);
456
457 // Round to nearest integer
458 9500 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
459
460 9500 const size_t dst_x = ((row_idx + m_idx_start) % mr);
461
462 9500 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t));
463
464 // Quantize the channels
465 9500 int32_t block_idx = 0;
466
467
2/2
✓ Branch 0 taken 106808 times.
✓ Branch 1 taken 9500 times.
116308 for (; block_idx < num_blocks_k; ++block_idx) {
468 // Clamp at the last valid k-index
469 106808 const int32_t k_idx_start = block_idx * k_block_len;
470
471 // Load eight bfloat16 values and convert them to float32.
472 106808 const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx_start);
473 106808 const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec);
474 106808 const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec);
475 106808 const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1);
476 106808 const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2);
477
478 // Scale the values.
479 106808 const float32x4_t v0_f32 = vmulq_n_f32(src0_0, scale0);
480 106808 const float32x4_t v1_f32 = vmulq_n_f32(src0_1, scale0);
481 106808 const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32);
482 106808 const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32);
483
484 106808 const int16x4_t v0_s16 = vqmovn_s32(v0_s32);
485 106808 const int16x4_t v1_s16 = vqmovn_s32(v1_s32);
486 106808 int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16);
487
488 // Add zero points.
489 106808 int16_t nzp_s16 = (int16_t)nudged_zero_point0;
490 106808 int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16);
491 106808 v_s16 = vaddq_s16(v_s16, vnzp_s16);
492 106808 v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN));
493 106808 v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX));
494
495 106808 const int8x8_t v0_s8 = vqmovn_s16(v_s16);
496 106808 vst1_s8((int8_t*)(dst_ptr), v0_s8);
497 106808 dst_ptr += 8 * sizeof(int8_t);
498 106808 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
499 106808 }
500
501
2/2
✓ Branch 0 taken 7960 times.
✓ Branch 1 taken 9500 times.
17460 for (; block_idx < num_blocks_k_internal; ++block_idx) {
502 // Left over k
503
2/2
✓ Branch 0 taken 63680 times.
✓ Branch 1 taken 7960 times.
71640 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
504 // Clamp at the last valid k-index.
505
2/2
✓ Branch 0 taken 11144 times.
✓ Branch 1 taken 52536 times.
63680 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
506
507 63680 const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start));
508
509 // Scale the value.
510 63680 int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0));
511
512 63680 v0_s32 = v0_s32 + nudged_zero_point0;
513
2/2
✓ Branch 0 taken 63620 times.
✓ Branch 1 taken 60 times.
63680 v0_s32 = KAI_MAX(v0_s32, INT8_MIN);
514
2/2
✓ Branch 0 taken 63224 times.
✓ Branch 1 taken 456 times.
63680 v0_s32 = KAI_MIN(v0_s32, INT8_MAX);
515
516 63680 *((int8_t*)(dst_ptr)) = (int8_t)v0_s32;
517 63680 dst_ptr += sizeof(int8_t);
518 63680 }
519 7960 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
520 7960 }
521
522 9500 dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
523
524 9500 dst_ptr += dst_x * kai_num_bytes_per_offset;
525
526 // LHS offset at the beginning of the row.
527 9500 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
528
529 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier.
530 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
531
532 9500 dst_ptr += mr * kai_num_bytes_per_offset;
533
534 // Store the scale quantization params.
535 9500 *((float*)(dst_ptr)) = recip_scale0;
536
537 // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each).
538 9500 src_ptr += (lhs_stride / sizeof(uint16_t));
539
540 // Move to the next row if we have interleaved all Mr rows.
541
2/2
✓ Branch 0 taken 484 times.
✓ Branch 1 taken 9016 times.
9500 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
542 9016 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
543 9016 }
544 9500 }
545 960 }
546