Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #if (!defined(__aarch64__) && !defined(_M_ARM64)) | ||
7 | #error This file must be compiled for AArch64. | ||
8 | #else // Architectural features check. | ||
9 | |||
10 | #include "kai_lhs_quant_pack_qai8dxp_bf16_neon.h" | ||
11 | |||
12 | #include <arm_neon.h> | ||
13 | #endif | ||
14 | #include <float.h> | ||
15 | #include <math.h> | ||
16 | #include <stddef.h> | ||
17 | #include <stdint.h> | ||
18 | |||
19 | #include "kai/kai_common.h" | ||
20 | |||
21 | static const size_t kai_num_bytes_per_multiplier = sizeof(float); | ||
22 | static const size_t kai_num_bytes_per_offset = sizeof(int32_t); | ||
23 | |||
24 | 3840 | inline static size_t kai_k_roundedup(size_t k) { | |
25 | // Round up k to be a multiple of 32. | ||
26 | static const size_t kai_k_multiple_of = 32; | ||
27 | 3840 | return kai_roundup(k, kai_k_multiple_of); | |
28 | } | ||
29 | |||
30 | 2880 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr) { | |
31 | 2880 | const size_t k_internal = kai_k_roundedup(k); | |
32 | |||
33 | − | KAI_ASSERT((k_internal % 2) == 0); | |
34 | |||
35 | 5760 | return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); | |
36 | 2880 | } | |
37 | |||
38 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16_neon(size_t mr) { | |
39 | ✗ | KAI_UNUSED(mr); | |
40 | ✗ | return 1; | |
41 | } | ||
42 | |||
43 | 960 | size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(size_t m_idx, size_t lhs_stride) { | |
44 | 960 | return m_idx * lhs_stride; | |
45 | } | ||
46 | |||
47 | 960 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon( | |
48 | size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | ||
49 | 960 | KAI_UNUSED(kr); | |
50 | 960 | KAI_UNUSED(sr); | |
51 | // It always points to the beginning of the row | ||
52 | 960 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr); | |
53 | } | ||
54 | |||
55 | 960 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
56 | 960 | KAI_UNUSED(kr); | |
57 | 960 | KAI_UNUSED(sr); | |
58 | 960 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
59 | |||
60 | 1920 | return num_rows * kai_lhs_packed_stride(k, mr); | |
61 | 960 | } | |
62 | |||
63 | // Note: The lhs parameter type has been changed from float* to void*. | ||
64 | // The bfloat16 values (packed in 16 bits) will be converted to float32. | ||
65 | 960 | void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( | |
66 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs, | ||
67 | size_t lhs_stride, void* restrict lhs_packed) { | ||
68 | − | KAI_ASSERT((kr % sr) == 0); | |
69 | |||
70 |
1/2✓ Branch 0 taken 960 times.
✗ Branch 1 not taken.
|
960 | if (m == 0) { |
71 | ✗ | return; | |
72 | } | ||
73 | |||
74 | // Now lhs is assumed to contain bfloat16 values encoded in uint16_t. | ||
75 | 960 | const uint16_t* src_ptr = (uint16_t const*)lhs; | |
76 | |||
77 | 960 | const size_t dst_stride = kai_lhs_packed_stride(k, mr); | |
78 | 960 | const size_t k_internal = kai_k_roundedup(k); | |
79 | 960 | const int32_t k_block_len = (int32_t)(kr / sr); | |
80 | − | KAI_ASSERT(k_block_len == 8); | |
81 | |||
82 | 960 | const int32_t num_blocks_k = (int32_t)(k / k_block_len); | |
83 | 960 | const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); | |
84 | |||
85 | 960 | size_t row_idx = 0; | |
86 | |||
87 |
2/2✓ Branch 0 taken 464 times.
✓ Branch 1 taken 496 times.
|
960 | if (mr == 4) { |
88 |
2/2✓ Branch 0 taken 2228 times.
✓ Branch 1 taken 496 times.
|
2724 | for (; row_idx + 3 < m; row_idx += 4) { |
89 | 2228 | float max0 = -FLT_MAX; | |
90 | 2228 | float min0 = FLT_MAX; | |
91 | 2228 | float max1 = -FLT_MAX; | |
92 | 2228 | float min1 = FLT_MAX; | |
93 | 2228 | float max2 = -FLT_MAX; | |
94 | 2228 | float min2 = FLT_MAX; | |
95 | 2228 | float max3 = -FLT_MAX; | |
96 | 2228 | float min3 = FLT_MAX; | |
97 | |||
98 | // Find min/max for each channel | ||
99 | 2228 | int32_t k_idx = 0; | |
100 | 2228 | float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); | |
101 | 2228 | float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); | |
102 | 2228 | float32x4_t vmax1 = vmax0; | |
103 | 2228 | float32x4_t vmin1 = vmin0; | |
104 | 2228 | float32x4_t vmax2 = vmax0; | |
105 | 2228 | float32x4_t vmin2 = vmin0; | |
106 | 2228 | float32x4_t vmax3 = vmax0; | |
107 | 2228 | float32x4_t vmin3 = vmin0; | |
108 | 2228 | const uint16x8_t zero = vdupq_n_u16(0); | |
109 | // Process 8 bfloat16 values per iteration. | ||
110 |
2/2✓ Branch 0 taken 25608 times.
✓ Branch 1 taken 2228 times.
|
27836 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
111 | // Load eight bfloat16 values. | ||
112 | 25608 | const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx); | |
113 | 25608 | const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t))); | |
114 | 25608 | const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t)))); | |
115 | 25608 | const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t)))); | |
116 | |||
117 | 25608 | const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0); | |
118 | 25608 | const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0); | |
119 | 25608 | const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1); | |
120 | 25608 | const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1); | |
121 | 25608 | const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2); | |
122 | 25608 | const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2); | |
123 | 25608 | const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3); | |
124 | 25608 | const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3); | |
125 | |||
126 | 25608 | const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0); | |
127 | 25608 | const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0); | |
128 | 25608 | const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1); | |
129 | 25608 | const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1); | |
130 | 25608 | const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2); | |
131 | 25608 | const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2); | |
132 | 25608 | const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3); | |
133 | 25608 | const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3); | |
134 | |||
135 | // Calculate the maximum | ||
136 | 25608 | vmax0 = vmaxq_f32(src0_0, vmax0); | |
137 | 25608 | vmax0 = vmaxq_f32(vmax0, src0_1); | |
138 | 25608 | vmax1 = vmaxq_f32(src1_0, vmax1); | |
139 | 25608 | vmax1 = vmaxq_f32(vmax1, src1_1); | |
140 | 25608 | vmax2 = vmaxq_f32(src2_0, vmax2); | |
141 | 25608 | vmax2 = vmaxq_f32(vmax2, src2_1); | |
142 | 25608 | vmax3 = vmaxq_f32(src3_0, vmax3); | |
143 | 25608 | vmax3 = vmaxq_f32(vmax3, src3_1); | |
144 | |||
145 | // Calculate the minimum | ||
146 | 25608 | vmin0 = vminq_f32(src0_0, vmin0); | |
147 | 25608 | vmin0 = vminq_f32(vmin0, src0_1); | |
148 | 25608 | vmin1 = vminq_f32(src1_0, vmin1); | |
149 | 25608 | vmin1 = vminq_f32(vmin1, src1_1); | |
150 | 25608 | vmin2 = vminq_f32(src2_0, vmin2); | |
151 | 25608 | vmin2 = vminq_f32(vmin2, src2_1); | |
152 | 25608 | vmin3 = vminq_f32(src3_0, vmin3); | |
153 | 25608 | vmin3 = vminq_f32(vmin3, src3_1); | |
154 | 25608 | } | |
155 | // Get the max/min scalar values. | ||
156 | 2228 | max0 = vmaxvq_f32(vmax0); | |
157 | 2228 | min0 = vminvq_f32(vmin0); | |
158 | 2228 | max1 = vmaxvq_f32(vmax1); | |
159 | 2228 | min1 = vminvq_f32(vmin1); | |
160 | 2228 | max2 = vmaxvq_f32(vmax2); | |
161 | 2228 | min2 = vminvq_f32(vmin2); | |
162 | 2228 | max3 = vmaxvq_f32(vmax3); | |
163 | 2228 | min3 = vminvq_f32(vmin3); | |
164 | // Process leftover elements with a scalar loop. | ||
165 |
2/2✓ Branch 0 taken 3528 times.
✓ Branch 1 taken 2228 times.
|
5756 | for (; k_idx < (int32_t)k; ++k_idx) { |
166 | 3528 | const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); | |
167 | 3528 | max0 = fmaxf(src0, max0); | |
168 | 3528 | min0 = fminf(src0, min0); | |
169 | 3528 | const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t)))); | |
170 | 3528 | max1 = fmaxf(src1, max1); | |
171 | 3528 | min1 = fminf(src1, min1); | |
172 | 3528 | const float src2 = kai_cast_f32_bf16(*(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t))))); | |
173 | 3528 | max2 = fmaxf(src2, max2); | |
174 | 3528 | min2 = fminf(src2, min2); | |
175 | 3528 | const float src3 = kai_cast_f32_bf16(*(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t))))); | |
176 | 3528 | max3 = fmaxf(src3, max3); | |
177 | 3528 | min3 = fminf(src3, min3); | |
178 | 3528 | } | |
179 | |||
180 | // Maximum/minimum int8 values | ||
181 | 2228 | const float qmin = (float)INT8_MIN; | |
182 | 2228 | const float qmax = (float)INT8_MAX; | |
183 | |||
184 | 2228 | const float rmin0 = fminf(0.0F, min0); | |
185 | 2228 | const float rmax0 = fmaxf(0.0F, max0); | |
186 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2228 times.
|
2228 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
187 | 2228 | const float rmin1 = fminf(0.0F, min1); | |
188 | 2228 | const float rmax1 = fmaxf(0.0F, max1); | |
189 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2228 times.
|
2228 | const float scale1 = rmin1 == rmax1 ? 1.F : (qmax - qmin) / (rmax1 - rmin1); |
190 | 2228 | const float rmin2 = fminf(0.0F, min2); | |
191 | 2228 | const float rmax2 = fmaxf(0.0F, max2); | |
192 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2228 times.
|
2228 | const float scale2 = rmin2 == rmax2 ? 1.F : (qmax - qmin) / (rmax2 - rmin2); |
193 | 2228 | const float rmin3 = fminf(0.0F, min3); | |
194 | 2228 | const float rmax3 = fmaxf(0.0F, max3); | |
195 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2228 times.
|
2228 | const float scale3 = rmin3 == rmax3 ? 1.F : (qmax - qmin) / (rmax3 - rmin3); |
196 | |||
197 | // Reciprocal to quantize | ||
198 |
1/2✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
|
2228 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
199 |
1/2✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
|
2228 | const float recip_scale1 = scale1 ? 1.0F / scale1 : 0.0F; |
200 |
1/2✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
|
2228 | const float recip_scale2 = scale2 ? 1.0F / scale2 : 0.0F; |
201 |
1/2✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
|
2228 | const float recip_scale3 = scale3 ? 1.0F / scale3 : 0.0F; |
202 | |||
203 | 2228 | const float descaled_min0 = rmin0 * scale0; | |
204 | 2228 | const float descaled_max0 = rmax0 * scale0; | |
205 | 2228 | const float descaled_min1 = rmin1 * scale1; | |
206 | 2228 | const float descaled_max1 = rmax1 * scale1; | |
207 | 2228 | const float descaled_min2 = rmin2 * scale2; | |
208 | 2228 | const float descaled_max2 = rmax2 * scale2; | |
209 | 2228 | const float descaled_min3 = rmin3 * scale3; | |
210 | 2228 | const float descaled_max3 = rmax3 * scale3; | |
211 | |||
212 | 2228 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
213 | 2228 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
214 | 2228 | const float zero_point_from_min_error1 = qmin + descaled_min1; | |
215 | 2228 | const float zero_point_from_max_error1 = qmax + descaled_max1; | |
216 | 2228 | const float zero_point_from_min_error2 = qmin + descaled_min2; | |
217 | 2228 | const float zero_point_from_max_error2 = qmax + descaled_max2; | |
218 | 2228 | const float zero_point_from_min_error3 = qmin + descaled_min3; | |
219 | 2228 | const float zero_point_from_max_error3 = qmax + descaled_max3; | |
220 | |||
221 |
1/2✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
|
2228 | float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 |
222 | ✗ | : qmax - descaled_max0; | |
223 |
1/2✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
|
2228 | float zero_point1 = (zero_point_from_min_error1 + zero_point_from_max_error1 > 0) ? qmin - descaled_min1 |
224 | ✗ | : qmax - descaled_max1; | |
225 |
1/2✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
|
2228 | float zero_point2 = (zero_point_from_min_error2 + zero_point_from_max_error2 > 0) ? qmin - descaled_min2 |
226 | ✗ | : qmax - descaled_max2; | |
227 |
1/2✓ Branch 0 taken 2228 times.
✗ Branch 1 not taken.
|
2228 | float zero_point3 = (zero_point_from_min_error3 + zero_point_from_max_error3 > 0) ? qmin - descaled_min3 |
228 | ✗ | : qmax - descaled_max3; | |
229 | |||
230 | 2228 | zero_point0 = fmaxf(zero_point0, qmin); | |
231 | 2228 | zero_point0 = fminf(zero_point0, qmax); | |
232 | 2228 | zero_point1 = fmaxf(zero_point1, qmin); | |
233 | 2228 | zero_point1 = fminf(zero_point1, qmax); | |
234 | 2228 | zero_point2 = fmaxf(zero_point2, qmin); | |
235 | 2228 | zero_point2 = fminf(zero_point2, qmax); | |
236 | 2228 | zero_point3 = fmaxf(zero_point3, qmin); | |
237 | 2228 | zero_point3 = fminf(zero_point3, qmax); | |
238 | |||
239 | // Round to nearest integer | ||
240 | 2228 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
241 | 2228 | const int32_t nudged_zero_point1 = (int32_t)rintf(zero_point1); | |
242 | 2228 | const int32_t nudged_zero_point2 = (int32_t)rintf(zero_point2); | |
243 | 2228 | const int32_t nudged_zero_point3 = (int32_t)rintf(zero_point3); | |
244 | |||
245 | 2228 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
246 | |||
247 | 2228 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len); | |
248 | |||
249 | // Quantize the channels | ||
250 | 2228 | int32_t block_idx = 0; | |
251 | |||
252 |
2/2✓ Branch 0 taken 25608 times.
✓ Branch 1 taken 2228 times.
|
27836 | for (; block_idx < num_blocks_k; ++block_idx) { |
253 | // Clamp at the last valid k-index | ||
254 | 25608 | const int32_t k_idx_start = block_idx * k_block_len; | |
255 | |||
256 | // Load eight bfloat16 values and convert them to float32. | ||
257 | 25608 | const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx_start); | |
258 | 25608 | const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t))); | |
259 | 25608 | const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t)))); | |
260 | 25608 | const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t)))); | |
261 | 25608 | const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0); | |
262 | 25608 | const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0); | |
263 | 25608 | const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1); | |
264 | 25608 | const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1); | |
265 | 25608 | const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2); | |
266 | 25608 | const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2); | |
267 | 25608 | const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3); | |
268 | 25608 | const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3); | |
269 | 25608 | const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0); | |
270 | 25608 | const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0); | |
271 | 25608 | const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1); | |
272 | 25608 | const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1); | |
273 | 25608 | const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2); | |
274 | 25608 | const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2); | |
275 | 25608 | const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3); | |
276 | 25608 | const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3); | |
277 | |||
278 | // Scale the values. | ||
279 | 25608 | const int16x4_t v0_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_0, scale0))); | |
280 | 25608 | const int16x4_t v1_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_1, scale0))); | |
281 | 25608 | const int16x4_t v0_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_0, scale1))); | |
282 | 25608 | const int16x4_t v1_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_1, scale1))); | |
283 | 25608 | const int16x4_t v0_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_0, scale2))); | |
284 | 25608 | const int16x4_t v1_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_1, scale2))); | |
285 | 25608 | const int16x4_t v0_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_0, scale3))); | |
286 | 25608 | const int16x4_t v1_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_1, scale3))); | |
287 | |||
288 | 25608 | int16x8_t v0_s16 = vcombine_s16(v0_0, v1_0); | |
289 | 25608 | int16x8_t v1_s16 = vcombine_s16(v0_1, v1_1); | |
290 | 25608 | int16x8_t v2_s16 = vcombine_s16(v0_2, v1_2); | |
291 | 25608 | int16x8_t v3_s16 = vcombine_s16(v0_3, v1_3); | |
292 | |||
293 | // Add zero points. | ||
294 | 25608 | const int16x8_t vnzp0 = vdupq_n_s16((int16_t)nudged_zero_point0); | |
295 | 25608 | const int16x8_t vnzp1 = vdupq_n_s16((int16_t)nudged_zero_point1); | |
296 | 25608 | const int16x8_t vnzp2 = vdupq_n_s16((int16_t)nudged_zero_point2); | |
297 | 25608 | const int16x8_t vnzp3 = vdupq_n_s16((int16_t)nudged_zero_point3); | |
298 | |||
299 | 25608 | v0_s16 = vaddq_s16(v0_s16, vnzp0); | |
300 | 25608 | v0_s16 = vmaxq_s16(v0_s16, vdupq_n_s16(INT8_MIN)); | |
301 | 25608 | v0_s16 = vminq_s16(v0_s16, vdupq_n_s16(INT8_MAX)); | |
302 | 25608 | v1_s16 = vaddq_s16(v1_s16, vnzp1); | |
303 | 25608 | v1_s16 = vmaxq_s16(v1_s16, vdupq_n_s16(INT8_MIN)); | |
304 | 25608 | v1_s16 = vminq_s16(v1_s16, vdupq_n_s16(INT8_MAX)); | |
305 | 25608 | v2_s16 = vaddq_s16(v2_s16, vnzp2); | |
306 | 25608 | v2_s16 = vmaxq_s16(v2_s16, vdupq_n_s16(INT8_MIN)); | |
307 | 25608 | v2_s16 = vminq_s16(v2_s16, vdupq_n_s16(INT8_MAX)); | |
308 | 25608 | v3_s16 = vaddq_s16(v3_s16, vnzp3); | |
309 | 25608 | v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN)); | |
310 | 25608 | v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX)); | |
311 | |||
312 | 25608 | const int8x8_t v0_s8 = vqmovn_s16(v0_s16); | |
313 | 25608 | const int8x8_t v1_s8 = vqmovn_s16(v1_s16); | |
314 | 25608 | const int8x8_t v2_s8 = vqmovn_s16(v2_s16); | |
315 | 25608 | const int8x8_t v3_s8 = vqmovn_s16(v3_s16); | |
316 | |||
317 | 25608 | vst1_s8((int8_t*)(dst_ptr), v0_s8); | |
318 | 25608 | vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8); | |
319 | 25608 | vst1_s8((int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)), v2_s8); | |
320 | 25608 | vst1_s8((int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)), v3_s8); | |
321 | 25608 | dst_ptr += 4 * sizeof(int8x8_t); | |
322 | 25608 | } | |
323 | |||
324 |
2/2✓ Branch 0 taken 1960 times.
✓ Branch 1 taken 2228 times.
|
4188 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
325 | // Left over k | ||
326 |
2/2✓ Branch 0 taken 15680 times.
✓ Branch 1 taken 1960 times.
|
17640 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
327 | // Clamp at the last valid k-index. | ||
328 |
2/2✓ Branch 0 taken 2744 times.
✓ Branch 1 taken 12936 times.
|
15680 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
329 | |||
330 | 15680 | const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); | |
331 | 15680 | const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t)))); | |
332 | 31360 | const float src2 = | |
333 | 15680 | kai_cast_f32_bf16(*(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t))))); | |
334 | 31360 | const float src3 = | |
335 | 15680 | kai_cast_f32_bf16(*(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t))))); | |
336 | |||
337 | // Scale the value. | ||
338 | 15680 | int32_t v0_s32 = (int32_t)(roundf(src0 * scale0)); | |
339 | 15680 | int32_t v1_s32 = (int32_t)(roundf(src1 * scale1)); | |
340 | 15680 | int32_t v2_s32 = (int32_t)(roundf(src2 * scale2)); | |
341 | 15680 | int32_t v3_s32 = (int32_t)(roundf(src3 * scale3)); | |
342 | |||
343 | 15680 | v0_s32 = v0_s32 + nudged_zero_point0; | |
344 |
2/2✓ Branch 0 taken 15620 times.
✓ Branch 1 taken 60 times.
|
15680 | v0_s32 = KAI_MAX(v0_s32, INT8_MIN); |
345 |
2/2✓ Branch 0 taken 15520 times.
✓ Branch 1 taken 160 times.
|
15680 | v0_s32 = KAI_MIN(v0_s32, INT8_MAX); |
346 | |||
347 | 15680 | v1_s32 = v1_s32 + nudged_zero_point1; | |
348 |
1/2✓ Branch 0 taken 15680 times.
✗ Branch 1 not taken.
|
15680 | v1_s32 = KAI_MAX(v1_s32, INT8_MIN); |
349 |
2/2✓ Branch 0 taken 15616 times.
✓ Branch 1 taken 64 times.
|
15680 | v1_s32 = KAI_MIN(v1_s32, INT8_MAX); |
350 | |||
351 | 15680 | v2_s32 = v2_s32 + nudged_zero_point2; | |
352 |
1/2✓ Branch 0 taken 15680 times.
✗ Branch 1 not taken.
|
15680 | v2_s32 = KAI_MAX(v2_s32, INT8_MIN); |
353 |
2/2✓ Branch 0 taken 15532 times.
✓ Branch 1 taken 148 times.
|
15680 | v2_s32 = KAI_MIN(v2_s32, INT8_MAX); |
354 | |||
355 | 15680 | v3_s32 = v3_s32 + nudged_zero_point3; | |
356 |
1/2✓ Branch 0 taken 15680 times.
✗ Branch 1 not taken.
|
15680 | v3_s32 = KAI_MAX(v3_s32, INT8_MIN); |
357 |
2/2✓ Branch 0 taken 15584 times.
✓ Branch 1 taken 96 times.
|
15680 | v3_s32 = KAI_MIN(v3_s32, INT8_MAX); |
358 | |||
359 | 15680 | *(int8_t*)dst_ptr = (int8_t)v0_s32; | |
360 | 15680 | *(int8_t*)(dst_ptr + sizeof(int8x8_t)) = (int8_t)v1_s32; | |
361 | 15680 | *(int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)) = (int8_t)v2_s32; | |
362 | 15680 | *(int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)) = (int8_t)v3_s32; | |
363 | |||
364 | 15680 | dst_ptr += sizeof(int8_t); | |
365 | 15680 | } | |
366 | 1960 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
367 | 1960 | } | |
368 | |||
369 | 2228 | uint8_t* dst_base = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
370 | |||
371 | 2228 | dst_ptr = dst_base + dst_x * kai_num_bytes_per_offset; | |
372 | |||
373 | // LHS offset at the beginning of the row. | ||
374 | 2228 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
375 | 2228 | *((int32_t*)(dst_ptr + kai_num_bytes_per_offset)) = -nudged_zero_point1; | |
376 | 2228 | *((int32_t*)(dst_ptr + 2 * kai_num_bytes_per_offset)) = -nudged_zero_point2; | |
377 | 2228 | *((int32_t*)(dst_ptr + 3 * kai_num_bytes_per_offset)) = -nudged_zero_point3; | |
378 | |||
379 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. | ||
380 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
381 | |||
382 | 2228 | dst_ptr += mr * kai_num_bytes_per_offset; | |
383 | |||
384 | // Store the scale quantization params. | ||
385 | 2228 | *((float*)(dst_ptr)) = recip_scale0; | |
386 | 2228 | *((float*)(dst_ptr + kai_num_bytes_per_multiplier)) = recip_scale1; | |
387 | 2228 | *((float*)(dst_ptr + 2 * kai_num_bytes_per_multiplier)) = recip_scale2; | |
388 | 2228 | *((float*)(dst_ptr + 3 * kai_num_bytes_per_multiplier)) = recip_scale3; | |
389 | |||
390 | // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). | ||
391 | 2228 | src_ptr += (4 * lhs_stride / sizeof(uint16_t)); | |
392 | |||
393 | // Move to the next row as we have interleaved all Mr rows. | ||
394 | 2228 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
395 | 2228 | } | |
396 | 496 | } | |
397 | |||
398 |
2/2✓ Branch 0 taken 9500 times.
✓ Branch 1 taken 960 times.
|
10460 | for (; row_idx < m; ++row_idx) { |
399 | 9500 | float max0 = -FLT_MAX; | |
400 | 9500 | float min0 = FLT_MAX; | |
401 | |||
402 | // Find min/max for each channel | ||
403 | 9500 | int32_t k_idx = 0; | |
404 | 9500 | float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); | |
405 | 9500 | float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); | |
406 | 9500 | const uint16x8_t zero = vdupq_n_u16(0); | |
407 | // Process 8 bfloat16 values per iteration. | ||
408 |
2/2✓ Branch 0 taken 106808 times.
✓ Branch 1 taken 9500 times.
|
116308 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
409 | // Load eight bfloat16 values. | ||
410 | 106808 | const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx); | |
411 | 106808 | const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); | |
412 | 106808 | const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); | |
413 | 106808 | const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); | |
414 | 106808 | const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); | |
415 | |||
416 | // Calculate the maximum | ||
417 | 106808 | vmax0 = vmaxq_f32(src0_0, vmax0); | |
418 | 106808 | vmax0 = vmaxq_f32(vmax0, src0_1); | |
419 | |||
420 | // Calculate the minimum | ||
421 | 106808 | vmin0 = vminq_f32(src0_0, vmin0); | |
422 | 106808 | vmin0 = vminq_f32(vmin0, src0_1); | |
423 | 106808 | } | |
424 | // Get the max/min scalar values. | ||
425 | 9500 | max0 = vmaxvq_f32(vmax0); | |
426 | 9500 | min0 = vminvq_f32(vmin0); | |
427 | // Process leftover elements with a scalar loop. | ||
428 |
2/2✓ Branch 0 taken 14328 times.
✓ Branch 1 taken 9500 times.
|
23828 | for (; k_idx < (int32_t)k; ++k_idx) { |
429 | 14328 | const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); | |
430 | 14328 | max0 = fmaxf(src0_0, max0); | |
431 | 14328 | min0 = fminf(src0_0, min0); | |
432 | 14328 | } | |
433 | |||
434 | // Maximum/minimum int8 values | ||
435 | 9500 | const float qmin = (float)INT8_MIN; | |
436 | 9500 | const float qmax = (float)INT8_MAX; | |
437 | |||
438 | 9500 | const float rmin0 = fminf(0.0F, min0); | |
439 | 9500 | const float rmax0 = fmaxf(0.0F, max0); | |
440 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 9500 times.
|
9500 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
441 | |||
442 | // Reciprocal to quantize | ||
443 |
1/2✓ Branch 0 taken 9500 times.
✗ Branch 1 not taken.
|
9500 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
444 | |||
445 | 9500 | const float descaled_min0 = rmin0 * scale0; | |
446 | 9500 | const float descaled_max0 = rmax0 * scale0; | |
447 | |||
448 | 9500 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
449 | 9500 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
450 | |||
451 | 19000 | float zero_point0 = | |
452 |
1/2✓ Branch 0 taken 9500 times.
✗ Branch 1 not taken.
|
9500 | (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 : qmax - descaled_max0; |
453 | |||
454 | 9500 | zero_point0 = fmaxf(zero_point0, qmin); | |
455 | 9500 | zero_point0 = fminf(zero_point0, qmax); | |
456 | |||
457 | // Round to nearest integer | ||
458 | 9500 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
459 | |||
460 | 9500 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
461 | |||
462 | 9500 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); | |
463 | |||
464 | // Quantize the channels | ||
465 | 9500 | int32_t block_idx = 0; | |
466 | |||
467 |
2/2✓ Branch 0 taken 106808 times.
✓ Branch 1 taken 9500 times.
|
116308 | for (; block_idx < num_blocks_k; ++block_idx) { |
468 | // Clamp at the last valid k-index | ||
469 | 106808 | const int32_t k_idx_start = block_idx * k_block_len; | |
470 | |||
471 | // Load eight bfloat16 values and convert them to float32. | ||
472 | 106808 | const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx_start); | |
473 | 106808 | const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); | |
474 | 106808 | const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); | |
475 | 106808 | const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); | |
476 | 106808 | const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); | |
477 | |||
478 | // Scale the values. | ||
479 | 106808 | const float32x4_t v0_f32 = vmulq_n_f32(src0_0, scale0); | |
480 | 106808 | const float32x4_t v1_f32 = vmulq_n_f32(src0_1, scale0); | |
481 | 106808 | const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); | |
482 | 106808 | const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); | |
483 | |||
484 | 106808 | const int16x4_t v0_s16 = vqmovn_s32(v0_s32); | |
485 | 106808 | const int16x4_t v1_s16 = vqmovn_s32(v1_s32); | |
486 | 106808 | int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); | |
487 | |||
488 | // Add zero points. | ||
489 | 106808 | int16_t nzp_s16 = (int16_t)nudged_zero_point0; | |
490 | 106808 | int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); | |
491 | 106808 | v_s16 = vaddq_s16(v_s16, vnzp_s16); | |
492 | 106808 | v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); | |
493 | 106808 | v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); | |
494 | |||
495 | 106808 | const int8x8_t v0_s8 = vqmovn_s16(v_s16); | |
496 | 106808 | vst1_s8((int8_t*)(dst_ptr), v0_s8); | |
497 | 106808 | dst_ptr += 8 * sizeof(int8_t); | |
498 | 106808 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
499 | 106808 | } | |
500 | |||
501 |
2/2✓ Branch 0 taken 7960 times.
✓ Branch 1 taken 9500 times.
|
17460 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
502 | // Left over k | ||
503 |
2/2✓ Branch 0 taken 63680 times.
✓ Branch 1 taken 7960 times.
|
71640 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
504 | // Clamp at the last valid k-index. | ||
505 |
2/2✓ Branch 0 taken 11144 times.
✓ Branch 1 taken 52536 times.
|
63680 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
506 | |||
507 | 63680 | const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); | |
508 | |||
509 | // Scale the value. | ||
510 | 63680 | int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); | |
511 | |||
512 | 63680 | v0_s32 = v0_s32 + nudged_zero_point0; | |
513 |
2/2✓ Branch 0 taken 63620 times.
✓ Branch 1 taken 60 times.
|
63680 | v0_s32 = KAI_MAX(v0_s32, INT8_MIN); |
514 |
2/2✓ Branch 0 taken 63224 times.
✓ Branch 1 taken 456 times.
|
63680 | v0_s32 = KAI_MIN(v0_s32, INT8_MAX); |
515 | |||
516 | 63680 | *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; | |
517 | 63680 | dst_ptr += sizeof(int8_t); | |
518 | 63680 | } | |
519 | 7960 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
520 | 7960 | } | |
521 | |||
522 | 9500 | dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
523 | |||
524 | 9500 | dst_ptr += dst_x * kai_num_bytes_per_offset; | |
525 | |||
526 | // LHS offset at the beginning of the row. | ||
527 | 9500 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
528 | |||
529 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. | ||
530 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
531 | |||
532 | 9500 | dst_ptr += mr * kai_num_bytes_per_offset; | |
533 | |||
534 | // Store the scale quantization params. | ||
535 | 9500 | *((float*)(dst_ptr)) = recip_scale0; | |
536 | |||
537 | // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). | ||
538 | 9500 | src_ptr += (lhs_stride / sizeof(uint16_t)); | |
539 | |||
540 | // Move to the next row if we have interleaved all Mr rows. | ||
541 |
2/2✓ Branch 0 taken 484 times.
✓ Branch 1 taken 9016 times.
|
9500 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
542 | 9016 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
543 | 9016 | } | |
544 | 9500 | } | |
545 | 960 | } | |
546 |