Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ | ||
8 | !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) | ||
9 | #error This file must be compiled for AArch64, FEAT_FP16. | ||
10 | #else // Architectural features check. | ||
11 | |||
12 | #include "kai_lhs_quant_pack_qai8dxp_f16_neon.h" | ||
13 | |||
14 | #include <arm_fp16.h> | ||
15 | #include <arm_neon.h> | ||
16 | #include <float.h> | ||
17 | #include <math.h> | ||
18 | #include <stddef.h> | ||
19 | #include <stdint.h> | ||
20 | |||
21 | #include "kai/kai_common.h" | ||
22 | #define FLT16_MAX 65504.0 | ||
23 | #define FLT16_MIN (-65504.0F) | ||
24 | |||
25 | static const size_t kai_num_bytes_per_multiplier = sizeof(float); | ||
26 | static const size_t kai_num_bytes_per_offset = sizeof(int32_t); | ||
27 | |||
28 | 4480 | inline static size_t kai_k_roundedup(size_t k) { | |
29 | // Round up k to be a multiple of 32. | ||
30 | 4480 | size_t kai_k_multiple_of = 32; | |
31 | 8960 | return kai_roundup(k, kai_k_multiple_of); | |
32 | 4480 | } | |
33 | |||
34 | 3360 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { | |
35 | 3360 | KAI_UNUSED(kr); | |
36 | 3360 | KAI_UNUSED(sr); | |
37 | |||
38 | 3360 | const size_t k_internal = kai_k_roundedup(k); | |
39 | |||
40 | − | KAI_ASSERT((k_internal % 2) == 0); | |
41 | |||
42 | 6720 | return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); | |
43 | 3360 | } | |
44 | |||
45 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f16_neon(size_t mr) { | |
46 | ✗ | KAI_UNUSED(mr); | |
47 | ✗ | return 1; | |
48 | } | ||
49 | |||
50 | 1120 | size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(size_t m_idx, size_t lhs_stride) { | |
51 | 1120 | return m_idx * lhs_stride; | |
52 | } | ||
53 | |||
54 | 1120 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon( | |
55 | size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | ||
56 | // It always points to the beginning of the row | ||
57 | 1120 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); | |
58 | } | ||
59 | |||
60 | 1120 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
61 | 1120 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
62 | |||
63 | 2240 | return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); | |
64 | 1120 | } | |
65 | |||
66 | 1120 | void kai_run_lhs_quant_pack_qai8dxp_f16_neon( | |
67 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs, | ||
68 | size_t lhs_stride, void* restrict lhs_packed) { | ||
69 | − | KAI_ASSERT((kr % sr) == 0); | |
70 | − | KAI_ASSUME((kr / sr == 8) || (kr / sr == 4)); | |
71 | |||
72 |
1/2✓ Branch 0 taken 1120 times.
✗ Branch 1 not taken.
|
1120 | if (m == 0) { |
73 | ✗ | return; | |
74 | } | ||
75 | |||
76 | 1120 | const size_t num_rows = m; | |
77 | |||
78 | 1120 | float16_t const* src_ptr = (float16_t const*)lhs; | |
79 | |||
80 | 1120 | const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); | |
81 | 1120 | const size_t k_internal = kai_k_roundedup(k); | |
82 | 1120 | const int32_t k_block_len = (int32_t)(kr / sr); | |
83 | |||
84 | 1120 | const int32_t num_blocks_k = (int32_t)(k / k_block_len); | |
85 | 1120 | const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); | |
86 | 1120 | const size_t lhs_row_length = lhs_stride / sizeof(float16_t); | |
87 | |||
88 | 1120 | const float16x8_t vmax = vdupq_n_f16((float16_t)FLT16_MIN); | |
89 | 1120 | const float16x8_t vmin = vdupq_n_f16((float16_t)FLT16_MAX); | |
90 | |||
91 | // As we load 8-element vectors, limit vectorized loop to avoid reading out-of-bounds | ||
92 | 1120 | const int32_t blocks_lim_k = num_blocks_k - (8 / k_block_len); | |
93 | |||
94 | 1120 | size_t row_idx = 0; | |
95 | |||
96 | // Improved performance with 4x loop unrolling where packing parameters allow | ||
97 |
2/2✓ Branch 0 taken 232 times.
✓ Branch 1 taken 888 times.
|
1120 | if (mr == 4) { |
98 |
2/2✓ Branch 0 taken 2496 times.
✓ Branch 1 taken 888 times.
|
3384 | for (; row_idx + 3 < m; row_idx += 4) { |
99 | // Find min/max for each channel | ||
100 | 2496 | int32_t k_idx = 0; | |
101 | 2496 | float16x8_t vmax0 = vmax; | |
102 | 2496 | float16x8_t vmin0 = vmin; | |
103 | 2496 | float16x8_t vmax1 = vmax; | |
104 | 2496 | float16x8_t vmin1 = vmin; | |
105 | 2496 | float16x8_t vmax2 = vmax; | |
106 | 2496 | float16x8_t vmin2 = vmin; | |
107 | 2496 | float16x8_t vmax3 = vmax; | |
108 | 2496 | float16x8_t vmin3 = vmin; | |
109 | |||
110 |
2/2✓ Branch 0 taken 15376 times.
✓ Branch 1 taken 2496 times.
|
17872 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
111 | 15376 | const float16x8_t src0 = vld1q_f16(src_ptr + k_idx); | |
112 | 15376 | const float16x8_t src1 = vld1q_f16(src_ptr + k_idx + lhs_row_length); | |
113 | 15376 | const float16x8_t src2 = vld1q_f16(src_ptr + k_idx + (2 * lhs_row_length)); | |
114 | 15376 | const float16x8_t src3 = vld1q_f16(src_ptr + k_idx + (3 * lhs_row_length)); | |
115 | |||
116 | 15376 | vmax0 = vmaxq_f16(src0, vmax0); | |
117 | 15376 | vmax1 = vmaxq_f16(src1, vmax1); | |
118 | 15376 | vmax2 = vmaxq_f16(src2, vmax2); | |
119 | 15376 | vmax3 = vmaxq_f16(src3, vmax3); | |
120 | 15376 | vmin0 = vminq_f16(src0, vmin0); | |
121 | 15376 | vmin1 = vminq_f16(src1, vmin1); | |
122 | 15376 | vmin2 = vminq_f16(src2, vmin2); | |
123 | 15376 | vmin3 = vminq_f16(src3, vmin3); | |
124 | 15376 | } | |
125 | |||
126 | 2496 | float16_t max0 = vmaxvq_f16(vmax0); | |
127 | 2496 | float16_t min0 = vminvq_f16(vmin0); | |
128 | 2496 | float16_t max1 = vmaxvq_f16(vmax1); | |
129 | 2496 | float16_t min1 = vminvq_f16(vmin1); | |
130 | 2496 | float16_t max2 = vmaxvq_f16(vmax2); | |
131 | 2496 | float16_t min2 = vminvq_f16(vmin2); | |
132 | 2496 | float16_t max3 = vmaxvq_f16(vmax3); | |
133 | 2496 | float16_t min3 = vminvq_f16(vmin3); | |
134 | // Process leftover elements with a scalar loop. | ||
135 |
2/2✓ Branch 0 taken 2296 times.
✓ Branch 1 taken 2496 times.
|
4792 | for (; k_idx < (int32_t)k; ++k_idx) { |
136 | 2296 | const float16_t src0 = *(src_ptr + (size_t)k_idx); | |
137 | 2296 | max0 = vmaxh_f16(src0, max0); | |
138 | 2296 | min0 = vminh_f16(src0, min0); | |
139 | 2296 | const float16_t src1 = *(src_ptr + (size_t)k_idx + lhs_row_length); | |
140 | 2296 | max1 = vmaxh_f16(src1, max1); | |
141 | 2296 | min1 = vminh_f16(src1, min1); | |
142 | 2296 | const float16_t src2 = *(src_ptr + (size_t)k_idx + (2 * lhs_row_length)); | |
143 | 2296 | max2 = vmaxh_f16(src2, max2); | |
144 | 2296 | min2 = vminh_f16(src2, min2); | |
145 | 2296 | const float16_t src3 = *(src_ptr + (size_t)k_idx + (3 * lhs_row_length)); | |
146 | 2296 | max3 = vmaxh_f16(src3, max3); | |
147 | 2296 | min3 = vminh_f16(src3, min3); | |
148 | 2296 | } | |
149 | |||
150 | // Maximum/minimum int8 values | ||
151 | 2496 | const float qmin = (float)INT8_MIN; | |
152 | 2496 | const float qmax = (float)INT8_MAX; | |
153 | |||
154 | 2496 | const float rmin0 = fminf(0.0F, min0); | |
155 | 2496 | const float rmax0 = fmaxf(0.0F, max0); | |
156 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2496 times.
|
2496 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
157 | 2496 | const float rmin1 = fminf(0.0F, min1); | |
158 | 2496 | const float rmax1 = fmaxf(0.0F, max1); | |
159 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2496 times.
|
2496 | const float scale1 = rmin1 == rmax1 ? 1.F : (qmax - qmin) / (rmax1 - rmin1); |
160 | 2496 | const float rmin2 = fminf(0.0F, min2); | |
161 | 2496 | const float rmax2 = fmaxf(0.0F, max2); | |
162 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2496 times.
|
2496 | const float scale2 = rmin2 == rmax2 ? 1.F : (qmax - qmin) / (rmax2 - rmin2); |
163 | 2496 | const float rmin3 = fminf(0.0F, min3); | |
164 | 2496 | const float rmax3 = fmaxf(0.0F, max3); | |
165 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2496 times.
|
2496 | const float scale3 = rmin3 == rmax3 ? 1.F : (qmax - qmin) / (rmax3 - rmin3); |
166 | |||
167 | // Reciprocal to quantize | ||
168 |
1/2✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
|
2496 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
169 |
1/2✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
|
2496 | const float recip_scale1 = scale1 ? 1.0F / scale1 : 0.0F; |
170 |
1/2✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
|
2496 | const float recip_scale2 = scale2 ? 1.0F / scale2 : 0.0F; |
171 |
1/2✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
|
2496 | const float recip_scale3 = scale3 ? 1.0F / scale3 : 0.0F; |
172 | |||
173 | 2496 | const float descaled_min0 = rmin0 * scale0; | |
174 | 2496 | const float descaled_max0 = rmax0 * scale0; | |
175 | 2496 | const float descaled_min1 = rmin1 * scale1; | |
176 | 2496 | const float descaled_max1 = rmax1 * scale1; | |
177 | 2496 | const float descaled_min2 = rmin2 * scale2; | |
178 | 2496 | const float descaled_max2 = rmax2 * scale2; | |
179 | 2496 | const float descaled_min3 = rmin3 * scale3; | |
180 | 2496 | const float descaled_max3 = rmax3 * scale3; | |
181 | |||
182 | 2496 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
183 | 2496 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
184 | 2496 | const float zero_point_from_min_error1 = qmin + descaled_min1; | |
185 | 2496 | const float zero_point_from_max_error1 = qmax + descaled_max1; | |
186 | 2496 | const float zero_point_from_min_error2 = qmin + descaled_min2; | |
187 | 2496 | const float zero_point_from_max_error2 = qmax + descaled_max2; | |
188 | 2496 | const float zero_point_from_min_error3 = qmin + descaled_min3; | |
189 | 2496 | const float zero_point_from_max_error3 = qmax + descaled_max3; | |
190 | |||
191 |
1/2✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
|
2496 | float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 |
192 | ✗ | : qmax - descaled_max0; | |
193 |
1/2✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
|
2496 | float zero_point1 = (zero_point_from_min_error1 + zero_point_from_max_error1 > 0) ? qmin - descaled_min1 |
194 | ✗ | : qmax - descaled_max1; | |
195 |
1/2✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
|
2496 | float zero_point2 = (zero_point_from_min_error2 + zero_point_from_max_error2 > 0) ? qmin - descaled_min2 |
196 | ✗ | : qmax - descaled_max2; | |
197 |
1/2✓ Branch 0 taken 2496 times.
✗ Branch 1 not taken.
|
2496 | float zero_point3 = (zero_point_from_min_error3 + zero_point_from_max_error3 > 0) ? qmin - descaled_min3 |
198 | ✗ | : qmax - descaled_max3; | |
199 | |||
200 | 2496 | zero_point0 = fmaxf(zero_point0, qmin); | |
201 | 2496 | zero_point0 = fminf(zero_point0, qmax); | |
202 | 2496 | zero_point1 = fmaxf(zero_point1, qmin); | |
203 | 2496 | zero_point1 = fminf(zero_point1, qmax); | |
204 | 2496 | zero_point2 = fmaxf(zero_point2, qmin); | |
205 | 2496 | zero_point2 = fminf(zero_point2, qmax); | |
206 | 2496 | zero_point3 = fmaxf(zero_point3, qmin); | |
207 | 2496 | zero_point3 = fminf(zero_point3, qmax); | |
208 | |||
209 | // Round to nearest integer | ||
210 | 2496 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
211 | 2496 | const int32_t nudged_zero_point1 = (int32_t)rintf(zero_point1); | |
212 | 2496 | const int32_t nudged_zero_point2 = (int32_t)rintf(zero_point2); | |
213 | 2496 | const int32_t nudged_zero_point3 = (int32_t)rintf(zero_point3); | |
214 | |||
215 | 2496 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
216 | |||
217 | 2496 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len); | |
218 | |||
219 | // Quantize the channels | ||
220 | 2496 | int32_t block_idx = 0; | |
221 | 2496 | const int32_t block_incr = 8 / k_block_len; | |
222 | |||
223 |
2/2✓ Branch 0 taken 15376 times.
✓ Branch 1 taken 2496 times.
|
17872 | for (; block_idx <= blocks_lim_k; block_idx += block_incr) { |
224 | // Clamp at the last valid k-index | ||
225 | 15376 | const int32_t k_idx_start = block_idx * k_block_len; | |
226 | |||
227 | 15376 | const float16x8_t src0 = vld1q_f16(src_ptr + k_idx_start); | |
228 | 15376 | const float16x8_t src1 = vld1q_f16(src_ptr + k_idx_start + lhs_row_length); | |
229 | 15376 | const float16x8_t src2 = vld1q_f16(src_ptr + k_idx_start + (2 * lhs_row_length)); | |
230 | 15376 | const float16x8_t src3 = vld1q_f16(src_ptr + k_idx_start + (3 * lhs_row_length)); | |
231 | |||
232 | // Scale the values. | ||
233 | 15376 | const int32x4_t v0_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src0)), scale0)); | |
234 | 15376 | const int32x4_t v0_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src0), scale0)); | |
235 | 15376 | const int32x4_t v1_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src1)), scale1)); | |
236 | 15376 | const int32x4_t v1_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src1), scale1)); | |
237 | 15376 | const int32x4_t v2_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src2)), scale2)); | |
238 | 15376 | const int32x4_t v2_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src2), scale2)); | |
239 | 15376 | const int32x4_t v3_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src3)), scale3)); | |
240 | 15376 | const int32x4_t v3_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src3), scale3)); | |
241 | |||
242 | 15376 | const int16x4_t v0_0_s16 = vqmovn_s32(v0_0_s32); | |
243 | 15376 | const int16x4_t v0_1_s16 = vqmovn_s32(v0_1_s32); | |
244 | 15376 | const int16x4_t v1_0_s16 = vqmovn_s32(v1_0_s32); | |
245 | 15376 | const int16x4_t v1_1_s16 = vqmovn_s32(v1_1_s32); | |
246 | 15376 | const int16x4_t v2_0_s16 = vqmovn_s32(v2_0_s32); | |
247 | 15376 | const int16x4_t v2_1_s16 = vqmovn_s32(v2_1_s32); | |
248 | 15376 | const int16x4_t v3_0_s16 = vqmovn_s32(v3_0_s32); | |
249 | 15376 | const int16x4_t v3_1_s16 = vqmovn_s32(v3_1_s32); | |
250 | |||
251 | 15376 | int16x8_t v0_s16; | |
252 | 15376 | int16x8_t v1_s16; | |
253 | 15376 | int16x8_t v2_s16; | |
254 | 15376 | int16x8_t v3_s16; | |
255 |
2/2✓ Branch 0 taken 7688 times.
✓ Branch 1 taken 7688 times.
|
15376 | if (k_block_len == 8) { |
256 | 7688 | v0_s16 = vcombine_s16(v0_0_s16, v0_1_s16); | |
257 | 7688 | v1_s16 = vcombine_s16(v1_0_s16, v1_1_s16); | |
258 | 7688 | v2_s16 = vcombine_s16(v2_0_s16, v2_1_s16); | |
259 | 7688 | v3_s16 = vcombine_s16(v3_0_s16, v3_1_s16); | |
260 | 7688 | } else { // k_block_len == 4 | |
261 | 7688 | v0_s16 = vcombine_s16(v0_0_s16, v1_0_s16); | |
262 | 7688 | v1_s16 = vcombine_s16(v2_0_s16, v3_0_s16); | |
263 | 7688 | v2_s16 = vcombine_s16(v0_1_s16, v1_1_s16); | |
264 | 7688 | v3_s16 = vcombine_s16(v2_1_s16, v3_1_s16); | |
265 | } | ||
266 | |||
267 | // Add zero points. | ||
268 | 15376 | const int16x8_t vnzp0 = vdupq_n_s16((int16_t)nudged_zero_point0); | |
269 | 15376 | const int16x8_t vnzp1 = vdupq_n_s16((int16_t)nudged_zero_point1); | |
270 | 15376 | const int16x8_t vnzp2 = vdupq_n_s16((int16_t)nudged_zero_point2); | |
271 | 15376 | const int16x8_t vnzp3 = vdupq_n_s16((int16_t)nudged_zero_point3); | |
272 | |||
273 | 15376 | v0_s16 = vaddq_s16(v0_s16, vnzp0); | |
274 | 15376 | v0_s16 = vmaxq_s16(v0_s16, vdupq_n_s16(INT8_MIN)); | |
275 | 15376 | v0_s16 = vminq_s16(v0_s16, vdupq_n_s16(INT8_MAX)); | |
276 | 15376 | v1_s16 = vaddq_s16(v1_s16, vnzp1); | |
277 | 15376 | v1_s16 = vmaxq_s16(v1_s16, vdupq_n_s16(INT8_MIN)); | |
278 | 15376 | v1_s16 = vminq_s16(v1_s16, vdupq_n_s16(INT8_MAX)); | |
279 | 15376 | v2_s16 = vaddq_s16(v2_s16, vnzp2); | |
280 | 15376 | v2_s16 = vmaxq_s16(v2_s16, vdupq_n_s16(INT8_MIN)); | |
281 | 15376 | v2_s16 = vminq_s16(v2_s16, vdupq_n_s16(INT8_MAX)); | |
282 | 15376 | v3_s16 = vaddq_s16(v3_s16, vnzp3); | |
283 | 15376 | v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN)); | |
284 | 15376 | v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX)); | |
285 | |||
286 | 15376 | int8x8_t v0_s8 = vqmovn_s16(v0_s16); | |
287 | 15376 | int8x8_t v1_s8 = vqmovn_s16(v1_s16); | |
288 | 15376 | int8x8_t v2_s8 = vqmovn_s16(v2_s16); | |
289 | 15376 | int8x8_t v3_s8 = vqmovn_s16(v3_s16); | |
290 | |||
291 | 15376 | vst1_s8((int8_t*)(dst_ptr), v0_s8); | |
292 | 15376 | vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8); | |
293 | 15376 | vst1_s8((int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)), v2_s8); | |
294 | 15376 | vst1_s8((int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)), v3_s8); | |
295 | 15376 | dst_ptr += block_incr * mr * k_block_len * sizeof(int8_t); | |
296 | 15376 | } | |
297 | |||
298 |
2/2✓ Branch 0 taken 3192 times.
✓ Branch 1 taken 2496 times.
|
5688 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
299 | // left over k | ||
300 |
2/2✓ Branch 0 taken 17024 times.
✓ Branch 1 taken 3192 times.
|
20216 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
301 | // Clamp at the last valid k-index. | ||
302 |
2/2✓ Branch 0 taken 1680 times.
✓ Branch 1 taken 15344 times.
|
17024 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
303 | |||
304 | 17024 | const float src0 = (float)(*(src_ptr + k_idx_start)); | |
305 | 17024 | const float src1 = (float)(*(src_ptr + k_idx_start + lhs_row_length)); | |
306 | 17024 | const float src2 = (float)(*(src_ptr + k_idx_start + (2 * lhs_row_length))); | |
307 | 17024 | const float src3 = (float)(*(src_ptr + k_idx_start + (3 * lhs_row_length))); | |
308 | |||
309 | // Scale the value. | ||
310 | 17024 | int32_t d0_s32 = (int32_t)(roundf(src0 * scale0)); | |
311 | 17024 | int32_t d1_s32 = (int32_t)(roundf(src1 * scale1)); | |
312 | 17024 | int32_t d2_s32 = (int32_t)(roundf(src2 * scale2)); | |
313 | 17024 | int32_t d3_s32 = (int32_t)(roundf(src3 * scale3)); | |
314 | |||
315 | 17024 | d0_s32 = d0_s32 + nudged_zero_point0; | |
316 |
1/2✓ Branch 0 taken 17024 times.
✗ Branch 1 not taken.
|
17024 | d0_s32 = KAI_MAX(d0_s32, INT8_MIN); |
317 |
2/2✓ Branch 0 taken 16912 times.
✓ Branch 1 taken 112 times.
|
17024 | d0_s32 = KAI_MIN(d0_s32, INT8_MAX); |
318 | |||
319 | 17024 | d1_s32 = d1_s32 + nudged_zero_point1; | |
320 |
1/2✓ Branch 0 taken 17024 times.
✗ Branch 1 not taken.
|
17024 | d1_s32 = KAI_MAX(d1_s32, INT8_MIN); |
321 |
2/2✓ Branch 0 taken 16912 times.
✓ Branch 1 taken 112 times.
|
17024 | d1_s32 = KAI_MIN(d1_s32, INT8_MAX); |
322 | |||
323 | 17024 | d2_s32 = d2_s32 + nudged_zero_point2; | |
324 |
1/2✓ Branch 0 taken 17024 times.
✗ Branch 1 not taken.
|
17024 | d2_s32 = KAI_MAX(d2_s32, INT8_MIN); |
325 |
2/2✓ Branch 0 taken 16912 times.
✓ Branch 1 taken 112 times.
|
17024 | d2_s32 = KAI_MIN(d2_s32, INT8_MAX); |
326 | |||
327 | 17024 | d3_s32 = d3_s32 + nudged_zero_point3; | |
328 |
1/2✓ Branch 0 taken 17024 times.
✗ Branch 1 not taken.
|
17024 | d3_s32 = KAI_MAX(d3_s32, INT8_MIN); |
329 |
2/2✓ Branch 0 taken 16912 times.
✓ Branch 1 taken 112 times.
|
17024 | d3_s32 = KAI_MIN(d3_s32, INT8_MAX); |
330 | |||
331 | 17024 | *(int8_t*)dst_ptr = (int8_t)d0_s32; | |
332 | 17024 | *(int8_t*)(dst_ptr + k_block_len * sizeof(int8_t)) = (int8_t)d1_s32; | |
333 | 17024 | *(int8_t*)(dst_ptr + 2 * (k_block_len * sizeof(int8_t))) = (int8_t)d2_s32; | |
334 | 17024 | *(int8_t*)(dst_ptr + 3 * (k_block_len * sizeof(int8_t))) = (int8_t)d3_s32; | |
335 | 17024 | dst_ptr += sizeof(int8_t); | |
336 | 17024 | } | |
337 | 3192 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
338 | 3192 | } | |
339 | |||
340 | 2496 | uint8_t* dst_base = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
341 | |||
342 | 2496 | dst_ptr = dst_base + dst_x * kai_num_bytes_per_offset; | |
343 | |||
344 | // LHS offset at the beginning of the row. | ||
345 | 2496 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
346 | 2496 | *((int32_t*)(dst_ptr + kai_num_bytes_per_offset)) = -nudged_zero_point1; | |
347 | 2496 | *((int32_t*)(dst_ptr + 2 * kai_num_bytes_per_offset)) = -nudged_zero_point2; | |
348 | 2496 | *((int32_t*)(dst_ptr + 3 * kai_num_bytes_per_offset)) = -nudged_zero_point3; | |
349 | |||
350 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. | ||
351 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
352 | |||
353 | 2496 | dst_ptr += mr * kai_num_bytes_per_offset; | |
354 | |||
355 | // Store the scale quantization params. | ||
356 | 2496 | *((float*)(dst_ptr)) = recip_scale0; | |
357 | 2496 | *((float*)(dst_ptr + kai_num_bytes_per_multiplier)) = recip_scale1; | |
358 | 2496 | *((float*)(dst_ptr + 2 * kai_num_bytes_per_multiplier)) = recip_scale2; | |
359 | 2496 | *((float*)(dst_ptr + 3 * kai_num_bytes_per_multiplier)) = recip_scale3; | |
360 | |||
361 | // Update src_ptr. Note: now lhs contains fp16 values (2 bytes each). | ||
362 | 2496 | src_ptr += (4 * lhs_row_length); | |
363 | |||
364 | // Move to the next row as we have interleaved all Mr rows. | ||
365 | 2496 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
366 | 2496 | } | |
367 | 888 | } | |
368 | |||
369 |
2/2✓ Branch 0 taken 1168 times.
✓ Branch 1 taken 1120 times.
|
2288 | for (; row_idx < num_rows; ++row_idx) { |
370 | // Find min/max for each channel | ||
371 | 1168 | int32_t k_idx = 0; | |
372 | 1168 | float16x8_t vmax0 = vmax; | |
373 | 1168 | float16x8_t vmin0 = vmin; | |
374 | |||
375 |
2/2✓ Branch 0 taken 5936 times.
✓ Branch 1 taken 1168 times.
|
7104 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
376 | 5936 | const float16x8_t src0_0 = vld1q_f16(src_ptr + (size_t)k_idx); | |
377 | 5936 | vmax0 = vmaxq_f16(vmax0, src0_0); | |
378 | 5936 | vmin0 = vminq_f16(vmin0, src0_0); | |
379 | 5936 | } | |
380 | // Get the max/min | ||
381 | 1168 | float16_t max0 = vmaxvq_f16(vmax0); | |
382 | 1168 | float16_t min0 = vminvq_f16(vmin0); | |
383 | |||
384 |
2/2✓ Branch 0 taken 2168 times.
✓ Branch 1 taken 1168 times.
|
3336 | for (; k_idx < (int32_t)k; ++k_idx) { |
385 | 2168 | const float16_t src0 = *(src_ptr + (size_t)k_idx); | |
386 | 2168 | max0 = vmaxh_f16(src0, max0); | |
387 | 2168 | min0 = vminh_f16(src0, min0); | |
388 | 2168 | } | |
389 | |||
390 | // Maximum/minimum int8 values | ||
391 | 1168 | const float qmin = (float)INT8_MIN; | |
392 | 1168 | const float qmax = (float)INT8_MAX; | |
393 | |||
394 | 1168 | const float rmin0 = fminf(0.0F, min0); | |
395 | 1168 | const float rmax0 = fmaxf(0.0F, max0); | |
396 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1168 times.
|
1168 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
397 | |||
398 | // Reciprocal to quantize | ||
399 |
1/2✓ Branch 0 taken 1168 times.
✗ Branch 1 not taken.
|
1168 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
400 | |||
401 | 1168 | const float descaled_min0 = rmin0 * scale0; | |
402 | 1168 | const float descaled_max0 = rmax0 * scale0; | |
403 | |||
404 | 1168 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
405 | 1168 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
406 | |||
407 | 2336 | float zero_point0 = | |
408 |
1/2✓ Branch 0 taken 1168 times.
✗ Branch 1 not taken.
|
1168 | zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; |
409 | |||
410 | 1168 | zero_point0 = fmaxf(zero_point0, qmin); | |
411 | 1168 | zero_point0 = fminf(zero_point0, qmax); | |
412 | |||
413 | // Round to nearest integer | ||
414 | 1168 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
415 | |||
416 | 1168 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
417 | |||
418 | 1168 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); | |
419 | |||
420 | // Quantize the channels | ||
421 | 1168 | int32_t block_idx = 0; | |
422 | |||
423 |
2/2✓ Branch 0 taken 8480 times.
✓ Branch 1 taken 1168 times.
|
9648 | for (; block_idx <= blocks_lim_k; ++block_idx) { |
424 | 8480 | const int32_t k_idx_start = block_idx * k_block_len; | |
425 | |||
426 | 8480 | const float16x8_t src_0 = vld1q_f16(src_ptr + k_idx_start); | |
427 | |||
428 | // Scale the values | ||
429 | 8480 | const float32x4_t v0_f32 = vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src_0)), scale0); | |
430 | 8480 | const float32x4_t v1_f32 = vmulq_n_f32(vcvt_high_f32_f16(src_0), scale0); | |
431 | 8480 | const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); | |
432 | 8480 | const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); | |
433 | |||
434 | 8480 | const int16x4_t v0_s16 = vqmovn_s32(v0_s32); | |
435 | 8480 | const int16x4_t v1_s16 = vqmovn_s32(v1_s32); | |
436 | 8480 | int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); | |
437 | |||
438 | // Add zero points | ||
439 | 8480 | int16_t nzp_s16 = (int16_t)nudged_zero_point0; | |
440 | 8480 | int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); | |
441 | 8480 | v_s16 = vaddq_s16(v_s16, vnzp_s16); | |
442 | 8480 | v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); | |
443 | 8480 | v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); | |
444 | |||
445 | 8480 | int8x8_t v_s8 = vqmovn_s16(v_s16); | |
446 | 8480 | vst1_s8((int8_t*)(dst_ptr), v_s8); | |
447 | 8480 | dst_ptr += mr * k_block_len * sizeof(int8_t); | |
448 | 8480 | } | |
449 | |||
450 |
2/2✓ Branch 0 taken 2992 times.
✓ Branch 1 taken 1168 times.
|
4160 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
451 | // left over k | ||
452 |
2/2✓ Branch 0 taken 15392 times.
✓ Branch 1 taken 2992 times.
|
18384 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
453 | // Clamp at the last valid k-index | ||
454 |
2/2✓ Branch 0 taken 2988 times.
✓ Branch 1 taken 12404 times.
|
15392 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
455 | |||
456 | 15392 | const float src0 = (float)(*(src_ptr + k_idx_start)); | |
457 | |||
458 | // Scale the values | ||
459 | 15392 | int32_t d0_s32 = (int32_t)(roundf(src0 * scale0)); | |
460 | |||
461 | 15392 | d0_s32 = d0_s32 + nudged_zero_point0; | |
462 |
1/2✓ Branch 0 taken 15392 times.
✗ Branch 1 not taken.
|
15392 | d0_s32 = KAI_MAX(d0_s32, INT8_MIN); |
463 |
2/2✓ Branch 0 taken 15308 times.
✓ Branch 1 taken 84 times.
|
15392 | d0_s32 = KAI_MIN(d0_s32, INT8_MAX); |
464 | |||
465 | 15392 | *((int8_t*)(dst_ptr)) = (int8_t)d0_s32; | |
466 | 15392 | dst_ptr += sizeof(int8_t); | |
467 | 15392 | } | |
468 | 2992 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
469 | 2992 | } | |
470 | |||
471 | 1168 | dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
472 | |||
473 | 1168 | dst_ptr += dst_x * kai_num_bytes_per_offset; | |
474 | |||
475 | // LHS offset at the beginning of the row | ||
476 | 1168 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
477 | |||
478 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier | ||
479 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
480 | |||
481 | 1168 | dst_ptr += mr * kai_num_bytes_per_offset; | |
482 | |||
483 | // Store the scale quantization params | ||
484 | 1168 | *((float*)(dst_ptr)) = recip_scale0; | |
485 | |||
486 | 1168 | src_ptr += lhs_row_length; | |
487 | |||
488 | // Move to the next row if we have interleaved all Mr rows | ||
489 |
2/2✓ Branch 0 taken 936 times.
✓ Branch 1 taken 232 times.
|
1168 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
490 | 232 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
491 | 232 | } | |
492 | 1168 | } | |
493 | 1120 | } | |
494 | |||
495 | #endif // Architectural features check. | ||
496 |