KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 88.4% 167 / 9 / 198
Functions: 88.1% 37 / 0 / 42
Branches: 56.8% 146 / 20 / 277

test/reference/quantize.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/quantize.hpp"
8
9 #include <algorithm>
10 #include <cmath>
11 #include <cstddef>
12 #include <cstdint>
13 #include <tuple>
14
15 #include "test/common/bfloat16.hpp"
16 #include "test/common/buffer.hpp"
17 #include "test/common/int4.hpp"
18 #include "test/common/memory.hpp"
19 #include "test/common/numeric_limits.hpp"
20 #include "test/common/round.hpp"
21 #include "test/common/type_traits.hpp"
22 #include "test/reference/cast.hpp"
23
24 namespace kai::test {
25
26 namespace {
27
28 template <typename FloatData, typename IntData, typename ZeroPoint>
29 3166992 std::tuple<FloatData, ZeroPoint> get_scale_zero_point_from_range(FloatData min_value, FloatData max_value) {
30 3166992 const FloatData q_min = numeric_lowest<IntData>;
31 3166992 const FloatData q_max = numeric_highest<IntData>;
32
33
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 561816 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2605176 times.
3166992 if (min_value > 0) {
34 3166992 min_value = 0;
35 3166992 }
36
37
2/4
✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2605176 times.
✗ Branch 3 not taken.
3166992 if (max_value < 0) {
38 max_value = 0;
39 }
40
41 // The reason for computing the inverted scale first is to make it bit-perfect with quantized packing
42 // micro-kernels. If those micro-kernels don't do it this way anymore, it makes more sense to calculate
43 // the scale directly.
44
2/4
✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2605176 times.
✗ Branch 3 not taken.
3166992 const FloatData inv_scale = max_value != min_value ? (q_max - q_min) / (max_value - min_value) : 1.0F;
45 3166992 const FloatData scale = 1.0F / inv_scale;
46
47 3166992 const FloatData scaled_min = min_value / scale;
48 3166992 const FloatData scaled_max = max_value / scale;
49
50
2/4
✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2605176 times.
✗ Branch 3 not taken.
3166992 const FloatData zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max;
51 3166992 const ZeroPoint zero_point = -round_to_nearest_even<ZeroPoint>(zero_point_f);
52
53 3166992 return {scale, zero_point};
54 3166992 }
55
56 /// Quantized a float value to an integer datatype using a provided scale.
57 ///
58 /// @tparam IntType Quantized integer datatype.
59 ///
60 /// @param[in] float The value to quantize
61 /// @param[in] scale The scale used to quantize the provided float value.
62 ///
63 /// @return The quantized data matrix, the quantization scale matrix and the quantization zero point matrix.
64 template <typename IntType>
65 99249221 IntType quantize_symmetric(float value, float scale) {
66
3/6
✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 76177560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 23010898 times.
✗ Branch 5 not taken.
99249221 const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F;
67 99249221 auto qsi32 = round_to_nearest_even_i32(value * inv_scale);
68
69 if (is_unsigned<IntType>) {
70 qsi32 += 1 << (size_in_bits<IntType> - 1);
71 }
72
73 122320882 return static_cast<IntType>(std::clamp<int32_t>(qsi32, numeric_lowest<IntType>, numeric_highest<IntType>));
74 99249221 }
75
76 /// Computes the quantization information using symmetric per-block quantization method.
77 ///
78 /// The input matrix is divided into quantization blocks of the same size.
79 ///
80 /// The height of the block does not affect the behavior of this function hence it is omitted
81 /// from the function arguments and the figures below.
82 ///
83 /// ```
84 /// Quantization blocks -------+
85 /// | |
86 /// | |
87 /// v v
88 /// +-----------------+-----------------+----- ...
89 /// | f00 f01 f02 f03 | f04 f05 f06 f07 | ........
90 /// | f10 f11 f12 f13 | f14 f15 f16 f17 | ........
91 /// | f20 f21 f22 f23 | f24 f25 f26 f27 | ........
92 /// | f30 f31 f32 f33 | f34 f35 f36 f37 | ........
93 /// | ............... | ............... | ........
94 /// : ............... : ............... : ........
95 /// ```
96 ///
97 /// Each row of the quantization block is quantized individually.
98 ///
99 /// ```
100 /// Floating-point data Scale
101 /// +-----------------+ +-----+
102 /// | f00 f01 f02 f03 | -------> | s00 |
103 /// | f10 f11 f12 f13 | -------> | s10 |
104 /// | f20 f21 f22 f23 | -------> | s20 |
105 /// | f30 f31 f32 f33 | -------> | s30 |
106 /// | ............... | | ... |
107 /// : ............... : : ... :
108 /// ```
109 ///
110 /// The computed quantization scale matrix:
111 ///
112 /// ```
113 /// +-----+-----+-- ...
114 /// | s00 | s01 | .....
115 /// | s10 | s11 | .....
116 /// | s20 | s21 | .....
117 /// | s30 | s31 | .....
118 /// | ... | ... | .....
119 /// : ... : ... : .....
120 /// ```
121 ///
122 /// @tparam SrcType The data type of the input data (must be floating-point).
123 /// @tparam DstType The data type of the output data (must be integer).
124 /// @tparam ScaleType The data type of the quantization scales (must be floating-point).
125 ///
126 /// @param[in] src The input matrix.
127 /// @param[in] height The number of rows.
128 /// @param[in] width The number of columns.
129 /// @param[in] quant_width The number of columns of the quantization block.
130 ///
131 /// @return The quantization scale matrix.
132 template <typename SrcType, typename DstType, typename ScaleType>
133 38937 Buffer compute_symmetric_per_block_quantization_info(const void* src, size_t height, size_t width, size_t quant_width) {
134 static_assert(is_floating_point<SrcType>);
135 static_assert(is_integral<DstType>);
136 static_assert(is_floating_point<ScaleType>);
137
138 KAI_ASSUME_ALWAYS(quant_width != 0);
139
140 38937 const auto num_quant_packets_x = round_up_division(width, quant_width);
141
142 38937 const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType);
143 38937 Buffer scales(scales_bytes);
144
145 38937 const auto* src_ptr = reinterpret_cast<const SrcType*>(src);
146
147
4/6
✓ Branch 0 taken 968873 times.
✓ Branch 1 taken 11534 times.
✓ Branch 2 taken 347657 times.
✓ Branch 3 taken 27403 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
1355467 for (size_t y = 0; y < height; ++y) {
148
4/6
✓ Branch 0 taken 968873 times.
✓ Branch 1 taken 1850246 times.
✓ Branch 2 taken 395173 times.
✓ Branch 3 taken 432469 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
3646761 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
149 // Computes the quantization scale.
150 2330231 SrcType max_abs = 0;
151
152
4/6
✓ Branch 0 taken 76177560 times.
✓ Branch 1 taken 1850246 times.
✓ Branch 2 taken 23010898 times.
✓ Branch 3 taken 479985 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
101518689 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
153 99188458 const auto x = x_quant + x_element;
154
155
2/6
✗ Branch 0 not taken.
✓ Branch 1 taken 76177560 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 23010898 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
99188458 if (x < width) {
156
2/6
✗ Branch 0 not taken.
✓ Branch 1 taken 76177560 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 17005565 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
99188458 max_abs = std::max<SrcType>(max_abs, std::abs(src_ptr[y * width + x]));
157 99188458 }
158 99188458 }
159
160 4660462 const auto scale =
161 2330231 max_abs / static_cast<SrcType>((static_cast<uint64_t>(1) << (size_in_bits<DstType> - 1)) - 1);
162
163 // Stores the scales.
164 2330231 write_array<ScaleType>(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale);
165 2330231 }
166 1316530 }
167
168 38937 return scales;
169 38937 }
170
171 /// Dynamically quantizes each block of the matrix using symmetric quantization method.
172 ///
173 /// The quantization information is calculated using
174 /// @ref compute_symmetric_per_block_quantization_info function.
175 /// The floating-point data is then quantized using
176 /// @ref quantize_symmetric_per_block function.
177 ///
178 /// To retain highest quantization accuracy, the data is quantized using the quantization scale
179 /// with the same data type as the input data.
180 /// After that the quantization scale can be stored in the buffer using `ScaleType` data type
181 /// which might have lowest precision than the input data type.
182 ///
183 /// @tparam SrcType The data type of the input data (must be floating-point).
184 /// @tparam DstType The data type of the output data (must be integer).
185 /// @tparam ScaleType The data type of the quantization scales (must be floating-point).
186 ///
187 /// @param[in] src The input matrix.
188 /// @param[in] height The number of rows.
189 /// @param[in] width The number of columns.
190 /// @param[in] quant_width The number of columns of the quantization block.
191 ///
192 /// @return The quantized data matrix and the quantization scale matrix.
193 template <typename SrcType, typename DstType, typename ScaleType>
194 38937 std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic(
195 const void* src, size_t height, size_t width, size_t quant_width) {
196 38937 auto scales_src_type =
197 38937 compute_symmetric_per_block_quantization_info<SrcType, DstType, SrcType>(src, height, width, quant_width);
198
10/24
✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9480 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9480 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1134 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1134 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 920 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 920 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 26483 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 26483 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
77874 auto data = quantize_symmetric_per_block<SrcType, DstType, SrcType>(
199 38937 src, scales_src_type.data(), height, width, quant_width);
200
201 if constexpr (std::is_same_v<ScaleType, SrcType>) {
202 35963 return {std::move(data), std::move(scales_src_type)};
203 } else {
204 2974 auto scales =
205
3/6
✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1134 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 920 times.
✗ Branch 5 not taken.
2974 cast<ScaleType, SrcType>(scales_src_type.data(), scales_src_type.size() * 8 / size_in_bits<SrcType>);
206
207 2974 return {std::move(data), std::move(scales)};
208 2974 }
209 38937 }
210
211 /// Dynamically quantizes each block of the matrix using symmetric quantization method.
212 ///
213 /// @param[in] src The input matrix.
214 /// @param[in] src_type The data type of the input data (must be FP32).
215 /// @param[in] height The number of rows.
216 /// @param[in] width The number of columns.
217 /// @param[in] qinfo The quantization information.
218 ///
219 /// @return The quantized data matrix and the quantization scale matrix.
220 38937 std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic(
221 const void* src, DataType src_type, size_t height, size_t width, const QuantizationInfo& qinfo) {
222 // Fail fast for datatypes that must be fixed.
223 KAI_ASSUME_ALWAYS(src_type == DataType::FP32);
224
225
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 11534 times.
✓ Branch 2 taken 27403 times.
✗ Branch 3 not taken.
38937 switch (qinfo.dst_type) {
226 case DataType::QSI4:
227
3/4
✓ Branch 0 taken 9480 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✓ Branch 3 taken 1134 times.
11534 switch (qinfo.scale_type) {
228 case DataType::FP16:
229 920 return quantize_symmetric_per_block_dynamic<float, Int4, Float16>(
230 920 src, height, width, qinfo.quant_width);
231 case DataType::FP32:
232 9480 return quantize_symmetric_per_block_dynamic<float, Int4, float>(
233 9480 src, height, width, qinfo.quant_width);
234 case DataType::BF16:
235 1134 return quantize_symmetric_per_block_dynamic<float, Int4, BFloat16<>>(
236 1134 src, height, width, qinfo.quant_width);
237 default:
238 break;
239 }
240 break;
241 case DataType::QSI8:
242
2/3
✓ Branch 0 taken 26483 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
27403 switch (qinfo.scale_type) {
243 case DataType::FP16:
244 920 return quantize_symmetric_per_block_dynamic<float, int8_t, Float16>(
245 920 src, height, width, qinfo.quant_width);
246 case DataType::FP32:
247 26483 return quantize_symmetric_per_block_dynamic<float, int8_t, float>(
248 26483 src, height, width, qinfo.quant_width);
249 default:
250 break;
251 }
252 break;
253 case DataType::I32:
254 if (qinfo.scale_type == DataType::FP32) {
255 return quantize_symmetric_per_block_dynamic<float, int32_t, float>(
256 src, height, width, qinfo.quant_width);
257 }
258 break;
259 default:
260 break;
261 }
262 KAI_ERROR("Unsupported combination of data types for symmetric quantization.");
263 38937 }
264
265 /// Dynamically quantizes each block of the matrix using asymmetric quantization method.
266 ///
267 /// The quantization information is calculated using
268 /// @ref compute_asymmetric_per_block_quantization_info function.
269 /// The floating-point data is then quantized using
270 /// @ref quantize_asymmetric_per_block function.
271 ///
272 /// To retain highest quantization accuracy, the data is quantized using the quantization scale
273 /// with the same data type as the input data.
274 /// After that the quantization scale can be stored in the buffer using `ScaleType` data type
275 /// which might have lowest precision than the input data type.
276 ///
277 /// @tparam SrcType The data type of the input data (must be floating-point).
278 /// @tparam DstType The data type of the output data (must be integer).
279 /// @tparam ScaleType The data type of the quantization scales (must be floating-point).
280 /// @tparam ZeroPointType The data type of the quantization zero points (must be integer).
281 ///
282 /// @param[in] src The input matrix.
283 /// @param[in] height The number of rows.
284 /// @param[in] width The number of columns.
285 /// @param[in] quant_width The number of columns of the quantization block.
286 ///
287 /// @return The quantized data matrix, the quantization scale matrix and the quantization zero point matrix.
288 template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType>
289 39233 std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic(
290 const void* src, size_t height, size_t width, size_t quant_width) {
291 /* Calculate the asymmetric quantization information, one scaling per row */
292 156932 auto [scales_src_type, zero_points] =
293 39233 compute_asymmetric_per_block_quantization_info<SrcType, DstType, SrcType, ZeroPointType>(
294 39233 src, height, width, quant_width);
295
296 /* Do the actual quantization */
297
4/12
✓ Branch 0 taken 13529 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13529 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 25704 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 25704 times.
✗ Branch 11 not taken.
78466 auto data = quantize_asymmetric_per_block<SrcType, DstType, SrcType, ZeroPointType>(
298 117699 src, scales_src_type.data(), zero_points.data(), height, width, quant_width);
299
300 if constexpr (std::is_same_v<ScaleType, SrcType>) {
301 39233 return {std::move(data), std::move(scales_src_type), std::move(zero_points)};
302 } else {
303 auto scales =
304 cast<ScaleType, SrcType>(scales_src_type.data(), scales_src_type.size() * 8 / size_in_bits<SrcType>);
305
306 return {std::move(data), std::move(scales), std::move(zero_points)};
307 }
308 39233 }
309
310 /// Dynamically quantizes each block of the matrix using asymmetric quantization method.
311 ///
312 /// @param[in] src The input matrix.
313 /// @param[in] src_type The data type of the input data (must be FP32).
314 /// @param[in] height The number of rows.
315 /// @param[in] width The number of columns.
316 /// @param[in] qinfo The quantization information.
317 ///
318 /// @return The quantized data matrix, the quantization scale matrix and the quantization zero point matrix.
319 39233 std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic(
320 const void* src, DataType src_type, size_t height, size_t width, const QuantizationInfo& qinfo) {
321 // Fail fast for datatypes that must be fixed.
322 KAI_ASSUME_ALWAYS(src_type == DataType::FP32);
323 KAI_ASSUME_ALWAYS(qinfo.zero_point_type == DataType::I32);
324
325
2/3
✗ Branch 0 not taken.
✓ Branch 1 taken 13529 times.
✓ Branch 2 taken 25704 times.
39233 switch (qinfo.dst_type) {
326 case DataType::QAI8:
327
1/3
✗ Branch 0 not taken.
✓ Branch 1 taken 13529 times.
✗ Branch 2 not taken.
13529 switch (qinfo.scale_type) {
328 case DataType::FP32:
329 13529 return quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(
330 13529 src, height, width, qinfo.quant_width);
331 case DataType::BF16:
332 return quantize_asymmetric_per_block_dynamic<float, int8_t, BFloat16<>, int32_t>(
333 src, height, width, qinfo.quant_width);
334 default:
335 break;
336 }
337 break;
338 case DataType::QAI4:
339
1/2
✓ Branch 0 taken 25704 times.
✗ Branch 1 not taken.
25704 switch (qinfo.scale_type) {
340 case DataType::FP32:
341 25704 return quantize_asymmetric_per_block_dynamic<float, Int4, float, int32_t>(
342 25704 src, height, width, qinfo.quant_width);
343 default:
344 break;
345 }
346 break;
347 default:
348 break;
349 }
350 KAI_ERROR("Unsupported combination of destination/scale types for asymmetric quantization.");
351 39233 }
352
353 } // namespace
354
355 template <typename FloatType, typename IntType, typename ZeroPointType>
356 210858462 IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero_point) {
357
2/4
✓ Branch 0 taken 104220126 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 106638336 times.
✗ Branch 3 not taken.
210858462 const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F;
358 210858462 auto quantized_value = round_to_nearest_even<ZeroPointType>(value * inv_scale) + zero_point;
359 315078588 return static_cast<IntType>(
360 210858462 std::clamp<ZeroPointType>(quantized_value, numeric_lowest<IntType>, numeric_highest<IntType>));
361 210858462 }
362
363 template int8_t quantize_asymmetric(float value, float scale, int32_t zero_point);
364
365 template <typename SrcType, typename DstType, typename ScaleType>
366 39458 Buffer quantize_symmetric_per_block(
367 const void* src, const void* scales, size_t height, size_t width, size_t quant_width) {
368 static_assert(is_floating_point<SrcType>);
369 static_assert(is_integral<DstType>);
370 static_assert(is_floating_point<ScaleType>);
371
372 39458 const auto num_quant_packets_x = round_up_division(width, quant_width);
373
374 39458 const auto data_bytes = round_up_division(height * width * size_in_bits<DstType>, 8);
375 39458 Buffer data(data_bytes);
376
377 39458 const auto* src_ptr = reinterpret_cast<const SrcType*>(src);
378
379
6/6
✓ Branch 0 taken 60763 times.
✓ Branch 1 taken 521 times.
✓ Branch 2 taken 968873 times.
✓ Branch 3 taken 11534 times.
✓ Branch 4 taken 347657 times.
✓ Branch 5 taken 27403 times.
1416751 for (size_t y = 0; y < height; ++y) {
380
6/6
✓ Branch 0 taken 60763 times.
✓ Branch 1 taken 60763 times.
✓ Branch 2 taken 968873 times.
✓ Branch 3 taken 1850246 times.
✓ Branch 4 taken 479985 times.
✓ Branch 5 taken 347657 times.
3768287 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
381
3/6
✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1850246 times.
✓ Branch 4 taken 479985 times.
✗ Branch 5 not taken.
2390994 const auto scale = read_array<ScaleType>(scales, y * num_quant_packets_x + x_quant / quant_width);
382
383 // Quantizes and stores the data.
384
6/6
✓ Branch 0 taken 60763 times.
✓ Branch 1 taken 60763 times.
✓ Branch 2 taken 1850246 times.
✓ Branch 3 taken 76177560 times.
✓ Branch 4 taken 23010898 times.
✓ Branch 5 taken 479985 times.
101640215 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
385 99249221 const auto x = x_quant + x_element;
386
387
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 60763 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 76177560 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 23010898 times.
99249221 if (x < width) {
388
3/6
✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 76177560 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 23010898 times.
99249221 const auto quantized = quantize_symmetric<DstType>(src_ptr[y * width + x], scale);
389
3/6
✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60763 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 76177560 times.
✗ Branch 5 not taken.
99249221 write_array(data.data(), y * width + x, quantized);
390 99249221 }
391 99249221 }
392 2390994 }
393 1377293 }
394 39458 return data;
395 39458 }
396
397 template Buffer quantize_symmetric_per_block<float, int32_t, float>(
398 const void* src, const void* scales, size_t height, size_t width, size_t quant_width);
399
400 template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType>
401 39754 std::tuple<Buffer, Buffer> compute_asymmetric_per_block_quantization_info(
402 const void* src, size_t height, size_t width, size_t quant_width) {
403 static_assert(is_floating_point<SrcType>);
404 static_assert(is_integral<DstType>);
405 static_assert(is_floating_point<ScaleType>);
406 static_assert(is_integral<ZeroPointType>);
407
408 KAI_ASSUME_ALWAYS(quant_width != 0);
409
410 39754 const auto num_quant_packets_x = round_up_division(width, quant_width);
411
412 39754 const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType);
413 39754 Buffer scales(scales_bytes);
414
415 39754 const auto zero_points_bytes = height * num_quant_packets_x * sizeof(ZeroPointType);
416
2/4
✓ Branch 0 taken 14050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25704 times.
✗ Branch 3 not taken.
39754 Buffer zero_points(zero_points_bytes);
417
418
4/4
✓ Branch 0 taken 14050 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 514080 times.
✓ Branch 3 taken 1002456 times.
2092402 for (size_t y = 0; y < height; ++y) {
419
4/4
✓ Branch 0 taken 561816 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 2605176 times.
✓ Branch 3 taken 1490832 times.
5219640 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
420 // Computes the quantization scale and zero point.
421 3166992 auto min_value = numeric_highest<SrcType>;
422 3166992 auto max_value = numeric_lowest<SrcType>;
423
424
4/4
✓ Branch 0 taken 104219084 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 106638336 times.
✓ Branch 3 taken 2605176 times.
214024412 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
425 210857420 const auto x = x_quant + x_element;
426
427
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 104219084 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 106638336 times.
210857420 if (x < width) {
428
3/4
✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 71092224 times.
✓ Branch 3 taken 35546112 times.
210857420 const auto value = read_array<SrcType>(src, y * width + x);
429
430
2/4
✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 71092224 times.
✗ Branch 3 not taken.
210857420 min_value = std::min(min_value, value);
431
2/4
✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 71092224 times.
✗ Branch 3 not taken.
210857420 max_value = std::max(max_value, value);
432 210857420 }
433 210857420 }
434
435 3166992 const auto [scale, zero_point] =
436
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 561816 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2605176 times.
3166992 get_scale_zero_point_from_range<SrcType, DstType, ZeroPointType>(min_value, max_value);
437
438 // Stores the scale and zero point.
439
2/4
✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 561816 times.
✗ Branch 3 not taken.
6333984 write_array<ScaleType>(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale);
440
4/8
✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 561816 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2605176 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2605176 times.
6333984 write_array<ZeroPointType>(zero_points.data(), y * num_quant_packets_x + x_quant / quant_width, zero_point);
441 3166992 }
442 2052648 }
443
444 39754 return {std::move(scales), std::move(zero_points)};
445 39754 }
446
447 template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType>
448 39754 Buffer quantize_asymmetric_per_block(
449 const void* src, const void* scales, const void* zero_points, size_t height, size_t width, size_t quant_width) {
450 static_assert(is_floating_point<SrcType>);
451 static_assert(is_integral<DstType>);
452 static_assert(is_floating_point<ScaleType>);
453 static_assert(is_integral<ZeroPointType>);
454
455 39754 const auto num_quant_packets_x = round_up_division(width, quant_width);
456
457 39754 const auto data_bytes = round_up_division(height * width * size_in_bits<DstType>, 8);
458 39754 Buffer data(data_bytes);
459
460
4/4
✓ Branch 0 taken 14050 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 25704 times.
✓ Branch 3 taken 1490832 times.
2092402 for (size_t y = 0; y < height; ++y) {
461
4/4
✓ Branch 0 taken 561816 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 2605176 times.
✓ Branch 3 taken 1490832 times.
5219640 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
462
2/4
✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2605176 times.
✗ Branch 3 not taken.
3166992 const auto scale = read_array<ScaleType>(scales, y * num_quant_packets_x + x_quant / quant_width);
463 5772168 const auto zero_point =
464
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 561816 times.
3166992 read_array<ZeroPointType>(zero_points, y * num_quant_packets_x + x_quant / quant_width);
465
466 // Quantizes and stores the data.
467
4/4
✓ Branch 0 taken 104219084 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 106638336 times.
✓ Branch 3 taken 2605176 times.
214024412 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
468 210857420 const auto x = x_quant + x_element;
469
470
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 104219084 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 106638336 times.
210857420 if (x < width) {
471
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 104219084 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 106638336 times.
210857420 const auto value_f = read_array<SrcType>(src, y * width + x);
472 210857420 const auto value_q =
473
2/4
✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 106638336 times.
✗ Branch 3 not taken.
210857420 quantize_asymmetric<SrcType, DstType, ZeroPointType>(value_f, scale, zero_point);
474
475
2/4
✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 106638336 times.
✗ Branch 3 not taken.
210857420 write_array<DstType>(data.data(), y * width + x, value_q);
476 210857420 }
477 210857420 }
478 3166992 }
479 2052648 }
480
481 39754 return data;
482 39754 }
483
484 78170 std::tuple<Buffer, QuantizationOutputs> quantize_dynamic(
485 const void* src, DataType src_type, size_t height, size_t width, const QuantizationInfo& qinfo) {
486 KAI_ASSUME_ALWAYS(data_type_is_quantized(qinfo.dst_type));
487 78170 Buffer data;
488 78170 QuantizationOutputs qoutputs;
489
3/4
✓ Branch 0 taken 78170 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 39233 times.
✓ Branch 3 taken 38937 times.
78170 if (data_type_is_quantized_asymm(qinfo.dst_type)) {
490 KAI_ASSUME_ALWAYS(qinfo.zero_point_type != DataType::UNKNOWN);
491 39233 std::tie(data, qoutputs.scales, qoutputs.zero_points) =
492
1/2
✓ Branch 0 taken 39233 times.
✗ Branch 1 not taken.
39233 quantize_asymmetric_per_block_dynamic(src, src_type, height, width, qinfo);
493 39233 } else {
494
1/2
✓ Branch 0 taken 38937 times.
✗ Branch 1 not taken.
38937 std::tie(data, qoutputs.scales) = quantize_symmetric_per_block_dynamic(src, src_type, height, width, qinfo);
495 }
496 78170 return {std::move(data), std::move(qoutputs)};
497 78170 }
498 } // namespace kai::test
499