Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #include "kai_lhs_quant_pack_qai8dxp_f32.h" | ||
7 | |||
8 | #if defined(__aarch64__) | ||
9 | #include <arm_neon.h> | ||
10 | #endif | ||
11 | #include <float.h> | ||
12 | #include <math.h> | ||
13 | #include <stddef.h> | ||
14 | #include <stdint.h> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | |||
18 | static const size_t kai_num_bytes_per_multiplier = sizeof(float); | ||
19 | static const size_t kai_num_bytes_per_offset = sizeof(int32_t); | ||
20 | |||
21 | 17176 | inline static size_t kai_k_roundedup(size_t k) { | |
22 | // Round up k to be a multiple of 32. | ||
23 | 17176 | size_t kai_k_multiple_of = 32; | |
24 | 34352 | return kai_roundup(k, kai_k_multiple_of); | |
25 | 17176 | } | |
26 | |||
27 | 13236 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { | |
28 | 13236 | KAI_UNUSED(kr); | |
29 | 13236 | KAI_UNUSED(sr); | |
30 | |||
31 | 13236 | const size_t k_internal = kai_k_roundedup(k); | |
32 | |||
33 | − | KAI_ASSERT((k_internal % 2) == 0); | |
34 | |||
35 | 26472 | return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); | |
36 | 13236 | } | |
37 | |||
38 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr) { | |
39 | ✗ | KAI_UNUSED(mr); | |
40 | ✗ | return 1; | |
41 | } | ||
42 | |||
43 | 3940 | size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride) { | |
44 | 3940 | return m_idx * lhs_stride; | |
45 | } | ||
46 | |||
47 | 5356 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | |
48 | // It always points to the beginning of the row | ||
49 | 5356 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); | |
50 | } | ||
51 | |||
52 | 3940 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
53 | 3940 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
54 | |||
55 | 7880 | return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); | |
56 | 3940 | } | |
57 | |||
58 | 3940 | void kai_run_lhs_quant_pack_qai8dxp_f32( | |
59 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs, | ||
60 | size_t lhs_stride, void* restrict lhs_packed) { | ||
61 | − | KAI_ASSERT((kr % sr) == 0); | |
62 | |||
63 |
1/2✓ Branch 0 taken 3940 times.
✗ Branch 1 not taken.
|
3940 | if (m == 0) { |
64 | ✗ | return; | |
65 | } | ||
66 | |||
67 | 3940 | const size_t num_rows = m; | |
68 | |||
69 | 3940 | const float* src_ptr = lhs; | |
70 | |||
71 | 3940 | const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); | |
72 | 3940 | const size_t k_internal = kai_k_roundedup(k); | |
73 | 3940 | const int32_t k_block_len = (int32_t)(kr / sr); | |
74 | |||
75 | 3940 | const int32_t num_blocks_k = (int32_t)(k / k_block_len); | |
76 | 3940 | const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); | |
77 | |||
78 |
2/2✓ Branch 0 taken 80332 times.
✓ Branch 1 taken 3940 times.
|
84272 | for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { |
79 | 80332 | float max0 = -FLT_MAX; | |
80 | 80332 | float min0 = FLT_MAX; | |
81 | |||
82 | // Find min/max for each channel | ||
83 | 80332 | int32_t k_idx = 0; | |
84 | |||
85 | #if defined(__aarch64__) | ||
86 | 80332 | float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); | |
87 | 80332 | float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); | |
88 | |||
89 |
2/2✓ Branch 0 taken 987704 times.
✓ Branch 1 taken 80332 times.
|
1068036 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
90 | 987704 | const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); | |
91 | 987704 | const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); | |
92 | |||
93 | // Calculate the max | ||
94 | 987704 | vmax0 = vmaxq_f32(src0_0, vmax0); | |
95 | 987704 | vmax0 = vmaxq_f32(vmax0, src0_1); | |
96 | |||
97 | // Calculate the min | ||
98 | 987704 | vmin0 = vminq_f32(src0_0, vmin0); | |
99 | 987704 | vmin0 = vminq_f32(vmin0, src0_1); | |
100 | 987704 | } | |
101 | // Get the max/min | ||
102 | 80332 | max0 = vmaxvq_f32(vmax0); | |
103 | 80332 | min0 = vminvq_f32(vmin0); | |
104 | #endif | ||
105 |
2/2✓ Branch 0 taken 76600 times.
✓ Branch 1 taken 80332 times.
|
156932 | for (; k_idx < (int32_t)k; ++k_idx) { |
106 | 76600 | const float src0_0 = *(src_ptr + (size_t)k_idx); | |
107 | 76600 | max0 = fmaxf(src0_0, max0); | |
108 | 76600 | min0 = fminf(src0_0, min0); | |
109 | 76600 | } | |
110 | |||
111 | // Maximum/minimum int8 values | ||
112 | 80332 | const float qmin = (float)INT8_MIN; | |
113 | 80332 | const float qmax = (float)INT8_MAX; | |
114 | |||
115 | 80332 | const float rmin0 = fminf(0.0F, min0); | |
116 | 80332 | const float rmax0 = fmaxf(0.0F, max0); | |
117 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 80332 times.
|
80332 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
118 | |||
119 | // Reciprocal to quantize | ||
120 |
1/2✓ Branch 0 taken 80332 times.
✗ Branch 1 not taken.
|
80332 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
121 | |||
122 | 80332 | const float descaled_min0 = rmin0 * scale0; | |
123 | 80332 | const float descaled_max0 = rmax0 * scale0; | |
124 | |||
125 | 80332 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
126 | 80332 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
127 | |||
128 | 160664 | float zero_point0 = | |
129 |
1/2✓ Branch 0 taken 80332 times.
✗ Branch 1 not taken.
|
80332 | zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; |
130 | |||
131 | 80332 | zero_point0 = fmaxf(zero_point0, qmin); | |
132 | 80332 | zero_point0 = fminf(zero_point0, qmax); | |
133 | |||
134 | // Round to nearest integer | ||
135 | 80332 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
136 | |||
137 | 80332 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
138 | |||
139 | 80332 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); | |
140 | |||
141 | // Quantize the channels | ||
142 | 80332 | int32_t block_idx = 0; | |
143 | |||
144 | #if defined(__aarch64__) | ||
145 |
2/2✓ Branch 0 taken 45550 times.
✓ Branch 1 taken 34782 times.
|
80332 | if (k_block_len == 8) { |
146 |
2/2✓ Branch 0 taken 600230 times.
✓ Branch 1 taken 45550 times.
|
645780 | for (; block_idx < num_blocks_k; ++block_idx) { |
147 | // Clamp at the last valid k-index | ||
148 | 600230 | const int32_t k_idx_start = block_idx * k_block_len; | |
149 | |||
150 | 600230 | const float32x4_t src_0 = vld1q_f32(src_ptr + k_idx_start); | |
151 | 600230 | const float32x4_t src_1 = vld1q_f32(src_ptr + k_idx_start + 4); | |
152 | |||
153 | // Scale the values | ||
154 | 600230 | float32x4_t v0_f32 = vmulq_n_f32(src_0, scale0); | |
155 | 600230 | float32x4_t v1_f32 = vmulq_n_f32(src_1, scale0); | |
156 | 600230 | int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); | |
157 | 600230 | int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); | |
158 | |||
159 | 600230 | int16x4_t v0_s16 = vqmovn_s32(v0_s32); | |
160 | 600230 | int16x4_t v1_s16 = vqmovn_s32(v1_s32); | |
161 | 600230 | int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); | |
162 | |||
163 | // Add zero points | ||
164 | 600230 | int16_t nzp_s16 = (int16_t)nudged_zero_point0; | |
165 | 600230 | int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); | |
166 | 600230 | v_s16 = vaddq_s16(v_s16, vnzp_s16); | |
167 | 600230 | v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); | |
168 | 600230 | v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); | |
169 | |||
170 | 600230 | int8x8_t v0_s8 = vqmovn_s16(v_s16); | |
171 | 600230 | vst1_s8((int8_t*)(dst_ptr), v0_s8); | |
172 | 600230 | dst_ptr += 8 * sizeof(int8_t); | |
173 | 600230 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
174 | 600230 | } | |
175 | 45550 | } else | |
176 | #endif | ||
177 | { | ||
178 |
2/2✓ Branch 0 taken 781226 times.
✓ Branch 1 taken 34782 times.
|
816008 | for (; block_idx < num_blocks_k; ++block_idx) { |
179 |
2/2✓ Branch 0 taken 3124904 times.
✓ Branch 1 taken 781226 times.
|
3906130 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
180 | 3124904 | const int32_t k_idx_start = (block_idx * k_block_len) + k_block_idx; | |
181 | |||
182 | 3124904 | const float src0_0 = *(src_ptr + k_idx_start); | |
183 | |||
184 | // Scale the values | ||
185 | 3124904 | int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); | |
186 | |||
187 | 3124904 | v0_s32 = v0_s32 + nudged_zero_point0; | |
188 |
2/2✓ Branch 0 taken 3118356 times.
✓ Branch 1 taken 6548 times.
|
3124904 | v0_s32 = KAI_MAX(v0_s32, INT8_MIN); |
189 |
2/2✓ Branch 0 taken 3084798 times.
✓ Branch 1 taken 40106 times.
|
3124904 | v0_s32 = KAI_MIN(v0_s32, INT8_MAX); |
190 | |||
191 | 3124904 | *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; | |
192 | 3124904 | dst_ptr += sizeof(int8_t); | |
193 | 3124904 | } | |
194 | 781226 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
195 | 781226 | } | |
196 | } | ||
197 | |||
198 |
2/2✓ Branch 0 taken 117664 times.
✓ Branch 1 taken 80332 times.
|
197996 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
199 | // left over k | ||
200 |
2/2✓ Branch 0 taken 618280 times.
✓ Branch 1 taken 117664 times.
|
735944 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
201 | // Clamp at the last valid k-index | ||
202 |
2/2✓ Branch 0 taken 34360 times.
✓ Branch 1 taken 583920 times.
|
618280 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
203 | |||
204 | 618280 | const float src0_0 = *(src_ptr + k_idx_start); | |
205 | |||
206 | // Scale the values | ||
207 | 618280 | int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); | |
208 | |||
209 | 618280 | v0_s32 = v0_s32 + nudged_zero_point0; | |
210 |
2/2✓ Branch 0 taken 618254 times.
✓ Branch 1 taken 26 times.
|
618280 | v0_s32 = KAI_MAX(v0_s32, INT8_MIN); |
211 |
2/2✓ Branch 0 taken 605826 times.
✓ Branch 1 taken 12454 times.
|
618280 | v0_s32 = KAI_MIN(v0_s32, INT8_MAX); |
212 | |||
213 | 618280 | *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; | |
214 | 618280 | dst_ptr += sizeof(int8_t); | |
215 | 618280 | } | |
216 | 117664 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
217 | 117664 | } | |
218 | |||
219 | 80332 | dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
220 | |||
221 | 80332 | dst_ptr += dst_x * kai_num_bytes_per_offset; | |
222 | |||
223 | // LHS offset at the beginning of the row | ||
224 | 80332 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
225 | |||
226 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier | ||
227 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
228 | |||
229 | 80332 | dst_ptr += mr * kai_num_bytes_per_offset; | |
230 | |||
231 | // Store the scale quantization params | ||
232 | 80332 | *((float*)(dst_ptr)) = recip_scale0; | |
233 | |||
234 | 80332 | src_ptr += (lhs_stride / sizeof(float)); | |
235 | |||
236 | // Move to the next row if we have interleaved all Mr rows | ||
237 |
2/2✓ Branch 0 taken 37912 times.
✓ Branch 1 taken 42420 times.
|
80332 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
238 | 42420 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
239 | 42420 | } | |
240 | 80332 | } | |
241 | 3940 | } | |
242 |