KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 96.6% 112 / 0 / 116
Functions: 66.7% 8 / 0 / 12
Branches: 59.3% 64 / 0 / 108

test/nextgen/reference/quantize.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/nextgen/reference/quantize.hpp"
8
9 #include <algorithm>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <tuple>
14 #include <type_traits>
15 #include <utility>
16
17 #include "test/common/assert.hpp"
18 #include "test/common/buffer.hpp"
19 #include "test/common/data_type.hpp"
20 #include "test/common/int4.hpp"
21 #include "test/common/memory.hpp"
22 #include "test/common/numeric_limits.hpp"
23 #include "test/common/round.hpp"
24 #include "test/common/span.hpp"
25 #include "test/nextgen/functions/round.hpp"
26
27 namespace kai::test {
28
29 namespace {
30
31 template <typename FpData, typename QData, typename QZp, RoundMode ZP_ROUND_MODE>
32 15452 std::tuple<FpData, FpData, QZp> get_scale_zero_point_from_range(FpData min_value, FpData max_value) {
33 15452 const FpData q_min = numeric_lowest<QData>;
34 15452 const FpData q_max = numeric_highest<QData>;
35
36
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 15452 times.
15452 if (min_value > 0) {
37 15452 min_value = 0;
38 15452 }
39
40
1/2
✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
15452 if (max_value < 0) {
41 max_value = 0;
42 }
43
44 // The reason for computing the inverted scale first is to make it bit-perfect with quantized packing
45 // micro-kernels. If those micro-kernels don't do it this way anymore, it makes more sense to calculate
46 // the scale directly.
47
1/2
✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
15452 const FpData inv_scale = max_value != min_value ? (q_max - q_min) / (max_value - min_value) : 1.0F;
48 15452 const FpData scale = 1.0F / inv_scale;
49
50 15452 const FpData scaled_min = min_value / scale;
51 15452 const FpData scaled_max = max_value / scale;
52
53
1/2
✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
15452 const FpData zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max;
54 15452 const QZp zero_point = -static_cast<QZp>(round<FpData, ZP_ROUND_MODE>(zero_point_f));
55
56 15452 return {scale, inv_scale, zero_point};
57 15452 }
58
59 template <typename FpData, typename QData>
60 15923 std::tuple<FpData, FpData> get_scale_from_max_abs(FpData max_abs) {
61 15923 const FpData scale = max_abs / static_cast<FpData>((1 << (size_in_bits<QData> - 1)) - 1);
62 15923 const FpData inv_scale = static_cast<FpData>(1) / scale;
63
64 15923 return {scale, inv_scale};
65 15923 }
66
67 template <typename FpData, typename QData, RoundMode QDATA_ROUND_MODE>
68 1272221 QData quantize_symmetric(FpData value, FpData inv_scale) {
69 1272221 int32_t quantized_value = round<FpData, QDATA_ROUND_MODE>(value * inv_scale);
70
71 if (is_unsigned<QData>) {
72 1272221 quantized_value += 1 << (size_in_bits<QData> - 1);
73 }
74
75 1272221 return static_cast<QData>(std::clamp<int32_t>(quantized_value, numeric_lowest<QData>, numeric_highest<QData>));
76 1272221 }
77
78 template <typename FpData, typename QData, typename QZp, RoundMode QDATA_ROUND_MODE>
79 1274204 [[nodiscard]] QData quantize_asymmetric(FpData value, FpData inv_scale, QZp zero_point) {
80 1274204 const QZp quantized_value = static_cast<QZp>(round<FpData, QDATA_ROUND_MODE>(value * inv_scale)) + zero_point;
81 2548408 return static_cast<QData>(std::clamp<QZp>(quantized_value, numeric_lowest<QData>, numeric_highest<QData>));
82 1274204 }
83
84 template <
85 typename FpData, typename QData, typename QScale, typename QZp, RoundMode QDATA_ROUND_MODE,
86 RoundMode QZP_ROUND_MODE>
87 200 [[nodiscard]] std::tuple<Buffer, Buffer, Buffer> dynamic_asymmetric_quantize_linear(
88 size_t height, size_t width, size_t block_height, size_t block_width, Span<const std::byte> fp_data) {
89 200 const size_t num_block_rows = round_up_division(height, block_height);
90 200 const size_t num_block_cols = round_up_division(width, block_width);
91
92 200 Buffer qdata(height * round_up_division(width * size_in_bits<QData>, 8), 0);
93
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 Buffer qscale(num_block_rows * num_block_cols * size_in_bits<QScale> / 8, 0);
94 static_assert(size_in_bits<QScale> % 8 == 0);
95
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 Buffer qzp(num_block_rows * num_block_cols * size_in_bits<QZp> / 8);
96
97
2/2
✓ Branch 0 taken 15452 times.
✓ Branch 1 taken 200 times.
15652 for (size_t block_row = 0; block_row < num_block_rows; ++block_row) {
98
2/2
✓ Branch 0 taken 15452 times.
✓ Branch 1 taken 15452 times.
30904 for (size_t block_col = 0; block_col < num_block_cols; ++block_col) {
99 15452 const size_t block_idx = block_row * num_block_cols + block_col;
100 15452 const size_t start_row = block_row * block_height;
101 15452 const size_t start_col = block_col * block_width;
102
1/2
✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
15452 const size_t size_row = std::min(block_height, height - start_row);
103
1/2
✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
15452 const size_t size_col = std::min(block_width, width - start_col);
104
105 // Finds the value range.
106 15452 FpData min_value = numeric_highest<FpData>;
107 15452 FpData max_value = numeric_lowest<FpData>;
108
109
2/2
✓ Branch 0 taken 15452 times.
✓ Branch 1 taken 15452 times.
30904 for (size_t row = 0; row < size_row; ++row) {
110
2/2
✓ Branch 0 taken 1274204 times.
✓ Branch 1 taken 15452 times.
1289656 for (size_t col = 0; col < size_col; ++col) {
111
1/2
✓ Branch 0 taken 1274204 times.
✗ Branch 1 not taken.
1274204 const FpData value = read_2d<FpData>(fp_data, width, start_row + row, start_col + col);
112
1/2
✓ Branch 0 taken 1274204 times.
✗ Branch 1 not taken.
1274204 min_value = std::min(min_value, value);
113
1/2
✓ Branch 0 taken 1274204 times.
✗ Branch 1 not taken.
1274204 max_value = std::max(max_value, value);
114 1274204 }
115 15452 }
116
117 // Computes the quantization information.
118 2579312 const auto [qscale_value, inv_qscale_value, qzp_value] =
119
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 15452 times.
15452 get_scale_zero_point_from_range<FpData, QData, QZp, QZP_ROUND_MODE>(min_value, max_value);
120
121
3/6
✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15452 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15452 times.
✗ Branch 5 not taken.
15452 write_array<QScale>(qscale, block_idx, qscale_value);
122
3/6
✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15452 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15452 times.
✗ Branch 5 not taken.
15452 write_array<QZp>(qzp, block_idx, qzp_value);
123
124 // Quantizes the data.
125
2/2
✓ Branch 0 taken 15452 times.
✓ Branch 1 taken 15452 times.
30904 for (size_t row = 0; row < size_row; ++row) {
126
2/2
✓ Branch 0 taken 1274204 times.
✓ Branch 1 taken 15452 times.
1289656 for (size_t col = 0; col < size_col; ++col) {
127
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1274204 times.
1274204 const FpData value = read_2d<FpData>(fp_data, width, start_row + row, start_col + col);
128 1274204 const QData qvalue =
129
3/6
✓ Branch 0 taken 1274204 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1274204 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1274204 times.
✗ Branch 5 not taken.
3822612 quantize_asymmetric<FpData, QData, QZp, QDATA_ROUND_MODE>(value, inv_qscale_value, qzp_value);
130
2/4
✓ Branch 0 taken 1274204 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1274204 times.
✗ Branch 3 not taken.
1274204 write_2d<QData>(qdata, width, start_row + row, start_col + col, qvalue);
131 1274204 }
132 15452 }
133 15452 }
134 15452 }
135
136 200 return {std::move(qdata), std::move(qscale), std::move(qzp)};
137 200 }
138
139 template <typename FpData, typename QData, typename QScale, RoundMode QDATA_ROUND_MODE>
140 200 [[nodiscard]] std::tuple<Buffer, Buffer, Buffer> dynamic_symmetric_quantize_linear(
141 size_t height, size_t width, size_t block_height, size_t block_width, Span<const std::byte> fp_data) {
142 200 const size_t num_block_rows = round_up_division(height, block_height);
143 200 const size_t num_block_cols = round_up_division(width, block_width);
144
145 200 Buffer qdata(height * round_up_division(width * size_in_bits<QData>, 8), 0);
146
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 Buffer qscale(num_block_rows * num_block_cols * size_in_bits<QScale> / 8, 0);
147 static_assert(size_in_bits<QScale> % 8 == 0);
148
149
2/2
✓ Branch 0 taken 15923 times.
✓ Branch 1 taken 200 times.
16123 for (size_t block_row = 0; block_row < num_block_rows; ++block_row) {
150
2/2
✓ Branch 0 taken 15923 times.
✓ Branch 1 taken 15923 times.
31846 for (size_t block_col = 0; block_col < num_block_cols; ++block_col) {
151 15923 const size_t block_idx = block_row * num_block_cols + block_col;
152 15923 const size_t start_row = block_row * block_height;
153 15923 const size_t start_col = block_col * block_width;
154
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 15923 times.
15923 const size_t size_row = std::min(block_height, height - start_row);
155
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 15923 times.
15923 const size_t size_col = std::min(block_width, width - start_col);
156
157 // Finds the value range.
158 15923 FpData max_abs = numeric_lowest<FpData>;
159
160
2/2
✓ Branch 0 taken 15923 times.
✓ Branch 1 taken 15923 times.
31846 for (size_t row = 0; row < size_row; ++row) {
161
2/2
✓ Branch 0 taken 15923 times.
✓ Branch 1 taken 1272221 times.
1288144 for (size_t col = 0; col < size_col; ++col) {
162
1/2
✓ Branch 0 taken 1272221 times.
✗ Branch 1 not taken.
1272221 const FpData value = read_2d<FpData>(fp_data, width, start_row + row, start_col + col);
163
1/2
✓ Branch 0 taken 1272221 times.
✗ Branch 1 not taken.
1272221 max_abs = std::max(max_abs, std::abs(value));
164 1272221 }
165 15923 }
166
167 // Computes the quantization information.
168
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 15923 times.
1288144 const auto [qscale_value, inv_qscale_value] = get_scale_from_max_abs<FpData, QData>(max_abs);
169
3/6
✓ Branch 0 taken 15923 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15923 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15923 times.
✗ Branch 5 not taken.
15923 write_array<QScale>(qscale, block_idx, qscale_value);
170
171 // Quantizes the data.
172
2/2
✓ Branch 0 taken 15923 times.
✓ Branch 1 taken 15923 times.
31846 for (size_t row = 0; row < size_row; ++row) {
173
2/2
✓ Branch 0 taken 1272221 times.
✓ Branch 1 taken 15923 times.
1288144 for (size_t col = 0; col < size_col; ++col) {
174
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1272221 times.
1272221 const FpData value = read_2d<FpData>(fp_data, width, start_row + row, start_col + col);
175
2/4
✓ Branch 0 taken 1272221 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1272221 times.
✗ Branch 3 not taken.
2544442 const QData qvalue = quantize_symmetric<FpData, QData, QDATA_ROUND_MODE>(value, inv_qscale_value);
176
2/4
✓ Branch 0 taken 1272221 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1272221 times.
✗ Branch 3 not taken.
1272221 write_2d<QData>(qdata, width, start_row + row, start_col + col, qvalue);
177 1272221 }
178 15923 }
179 15923 }
180 15923 }
181
182 200 return {std::move(qdata), std::move(qscale), Buffer()};
183 200 }
184
185 } // namespace
186
187 200 DynamicQuantizeLinearFn make_dynamic_asymmetric_quantize_linear(
188 DataType fp_dtype, DataType qdata_dtype, DataType qscale_dtype, DataType qzp_dtype, RoundMode qdata_round_mode,
189 RoundMode qzp_round_mode) {
190 200 const auto params =
191 200 std::make_tuple(fp_dtype, qdata_dtype, qscale_dtype, qzp_dtype, qdata_round_mode, qzp_round_mode);
192
193
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 if (params ==
194 400 std::make_tuple(
195 200 DataType::FP32, DataType::I8, DataType::FP32, DataType::I32, RoundMode::TIE_AWAY, RoundMode::CURRENT)) {
196 200 return dynamic_asymmetric_quantize_linear<
197 float, int8_t, float, int32_t, RoundMode::TIE_AWAY, RoundMode::CURRENT>;
198 }
199
200 KAI_TEST_ERROR("Not implemented.");
201 200 }
202
203 200 DynamicQuantizeLinearFn make_dynamic_symmetric_quantize_linear(
204 DataType fp_dtype, DataType qdata_dtype, DataType qscale_dtype, RoundMode qdata_round_mode) {
205 200 const auto params = std::make_tuple(fp_dtype, qdata_dtype, qscale_dtype, qdata_round_mode);
206
207
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 if (params == std::make_tuple(DataType::FP32, DataType::U4, DataType::FP32, RoundMode::CURRENT)) {
208 200 return dynamic_symmetric_quantize_linear<float, UInt4, float, RoundMode::CURRENT>;
209 }
210
211 KAI_TEST_ERROR("Not implemented.");
212 200 }
213
214 } // namespace kai::test
215