KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 3 / 0 / 3
Functions: 100.0% 3 / 0 / 3
Branches: -% 0 / 0 / 0

test/reference/quantize.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #pragma once
8
9 #include <cstddef>
10 #include <cstdint>
11 #include <tuple>
12
13 #include "test/common/buffer.hpp"
14 #include "test/common/data_format.hpp"
15 #include "test/common/data_type.hpp"
16
17 namespace kai::test {
18
19 /// Quantization info.
20 struct QuantizationInfo {
21 size_t quant_width{0}; ///< Number of columns in each quantization block.
22 DataType dst_type{DataType::UNKNOWN}; ///< Data type of the output matrix.
23 DataType scale_type{DataType::UNKNOWN}; ///< Data type of the quantization scales.
24 DataType zero_point_type{
25 DataType::UNKNOWN}; ///< Data type of the quantization zero points (only for asymmetric quantization).
26 };
27
28 /// Quantization result buffers.
29 78170 struct QuantizationOutputs {
30 78170 Buffer scales{}; ///< Quantization scales.
31 78170 Buffer zero_points{}; ///< Quantization zero points.
32 };
33
34 template <typename FloatType, typename IntType, typename ZeroPointType>
35 IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero_point);
36
37 /// Quantizes each block of the matrix using symmetric quantization method.
38 ///
39 /// The input matrix is divided into quantization blocks of the same size.
40 ///
41 /// The height of the block does not affect the behavior of this function hence it is omitted
42 /// from the function arguments and the figures below.
43 ///
44 /// The quantization scale matrix can be calculated using
45 /// @ref compute_symmetric_per_block_quantization_info function.
46 ///
47 /// The input matrix and the quantization scale matrix:
48 ///
49 /// ```
50 /// Floating-point data Scale
51 ///
52 /// Quantization blocks -------+
53 /// | |
54 /// | |
55 /// v v
56 /// +-----------------+-----------------+----- ... +-----+-----+-- ...
57 /// | f00 f01 f02 f03 | f04 f05 f06 f07 | ........ | s00 | s01 | .....
58 /// | f10 f11 f12 f13 | f14 f15 f16 f17 | ........ | s10 | s11 | .....
59 /// | f20 f21 f22 f23 | f24 f25 f26 f27 | ........ | s20 | s21 | .....
60 /// | f30 f31 f32 f33 | f34 f35 f36 f37 | ........ | s30 | s31 | .....
61 /// | ............... | ............... | ........ | ... | ... | .....
62 /// : ............... : ............... : ........ : ... : ... : .....
63 /// ```
64 ///
65 /// Each row of the quantization block is quantized individually.
66 ///
67 /// ```
68 /// Floating-point data Scale Quantized data
69 /// +-----------------+ +-----+ +-----------------+
70 /// | f00 f01 f02 f03 | | s00 | -------> | q00 q01 q02 q03 |
71 /// | f10 f11 f12 f13 | | s10 | -------> | q10 q11 q12 q13 |
72 /// | f20 f21 f22 f23 | | s20 | -------> | q20 q21 q22 q23 |
73 /// | f30 f31 f32 f33 | | s30 | -------> | q30 q31 q32 q33 |
74 /// | ............... | | ... | -------> | ............... |
75 /// : ............... : : ... : : ............... :
76 /// ```
77 ///
78 /// The computed quantized data matrix:
79 ///
80 /// ```
81 /// +-----------------+-----------------+----- ...
82 /// | q00 q01 q02 q03 | q04 q05 q06 q07 | ........
83 /// | q10 q11 q12 q13 | q14 q15 q16 q17 | ........
84 /// | q20 q21 q22 q23 | q24 q25 q26 q27 | ........
85 /// | q30 q31 q32 q33 | q34 q35 q36 q37 | ........
86 /// | ............... | ............... | ........
87 /// : ............... : ............... : ........
88 /// ```
89 ///
90 /// @tparam SrcType The data type of the input data (must be floating-point).
91 /// @tparam DstType The data type of the output data (must be integer).
92 /// @tparam ScaleType The data type of the quantization scales (must be floating-point).
93 ///
94 /// @param[in] src The input matrix.
95 /// @param[in] scales The quantization scale matrix.
96 /// @param[in] height The number of rows.
97 /// @param[in] width The number of columns.
98 /// @param[in] quant_width The number of columns of the quantization block.
99 ///
100 /// @return The quantized data matrix.
101 template <typename SrcType, typename DstType, typename ScaleType>
102 Buffer quantize_symmetric_per_block(
103 const void* src, const void* scales, size_t height, size_t width, size_t quant_width);
104
105 /// Computes the quantization information using asymmetric per-block quantization method.
106 ///
107 /// The input matrix is divided into quantization blocks of the same size.
108 ///
109 /// The height of the block does not affect the behavior of this function hence it is omitted
110 /// from the function arguments and the figures below.
111 ///
112 /// ```
113 /// Quantization blocks -------+
114 /// | |
115 /// | |
116 /// v v
117 /// +-----------------+-----------------+----- ...
118 /// | f00 f01 f02 f03 | f04 f05 f06 f07 | ........
119 /// | f10 f11 f12 f13 | f14 f15 f16 f17 | ........
120 /// | f20 f21 f22 f23 | f24 f25 f26 f27 | ........
121 /// | f30 f31 f32 f33 | f34 f35 f36 f37 | ........
122 /// | ............... | ............... | ........
123 /// : ............... : ............... : ........
124 /// ```
125 ///
126 /// Each row of the quantization block is quantized individually.
127 ///
128 /// ```
129 /// Floating-point data Scale Zero point
130 /// +-----------------+ +-----+ +-----+
131 /// | f00 f01 f02 f03 | -------> | s00 | | z00 |
132 /// | f10 f11 f12 f13 | -------> | s10 | | z10 |
133 /// | f20 f21 f22 f23 | -------> | s20 | | z20 |
134 /// | f30 f31 f32 f33 | -------> | s30 | | z30 |
135 /// | ............... | | ... | | ... |
136 /// : ............... : : ... : : ... :
137 /// ```
138 ///
139 /// The computed quantization scales and zero points matrices:
140 ///
141 /// ```
142 /// Quantization scale matrix:
143 ///
144 /// +-----+-----+-- ...
145 /// | s00 | s01 | .....
146 /// | s10 | s11 | .....
147 /// | s20 | s21 | .....
148 /// | s30 | s31 | .....
149 /// | ... | ... | .....
150 /// : ... : ... : .....
151 /// ```
152 ///
153 /// Quantization zero point matrix:
154 ///
155 /// +-----+-----+-- ...
156 /// | z00 | z01 | .....
157 /// | z10 | z11 | .....
158 /// | z20 | z21 | .....
159 /// | z30 | z31 | .....
160 /// | ... | ... | .....
161 /// : ... : ... : .....
162 /// ```
163 ///
164 /// @tparam SrcType The data type of the input data (must be floating-point).
165 /// @tparam DstType The data type of the output data (must be integer).
166 /// @tparam ScaleType The data type of the quantization scales (must be floating-point).
167 /// @tparam ZeroPointType The data type of the quantization zero points (must be integer).
168 ///
169 /// @param[in] src The input matrix.
170 /// @param[in] height The number of rows.
171 /// @param[in] width The number of columns.
172 /// @param[in] quant_width The number of columns of the quantization block.
173 ///
174 /// @return The quantization scale matrix and the quantization zero point matrix.
175 template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType>
176 std::tuple<Buffer, Buffer> compute_asymmetric_per_block_quantization_info(
177 const void* src, size_t height, size_t width, size_t quant_width);
178
179 /// Quantizes each block of the matrix using asymmetric quantization method.
180 ///
181 /// The input matrix is divided into quantization blocks of the same size.
182 ///
183 /// The height of the block does not affect the behavior of this function hence it is omitted
184 /// from the function arguments and the figures below.
185 ///
186 /// The quantization scale and zero point matrix can be calculated using
187 /// @ref compute_asymmetric_per_block_quantization_info function.
188 ///
189 /// The input matrix, quantization scale matrix and quantization zero matrix:
190 ///
191 /// ```
192 /// Floating-point data Scale Zero point
193 ///
194 /// Quantization blocks -------+
195 /// | |
196 /// | |
197 /// v v
198 /// +-----------------+-----------------+----- ... +-----+-----+-- ... +-----+-----+-- ...
199 /// | f00 f01 f02 f03 | f04 f05 f06 f07 | ........ | s00 | s01 | ..... | z00 | z01 | .....
200 /// | f10 f11 f12 f13 | f14 f15 f16 f17 | ........ | s10 | s11 | ..... | z10 | z11 | .....
201 /// | f20 f21 f22 f23 | f24 f25 f26 f27 | ........ | s20 | s21 | ..... | z20 | z21 | .....
202 /// | f30 f31 f32 f33 | f34 f35 f36 f37 | ........ | s30 | s31 | ..... | z30 | z31 | .....
203 /// | ............... | ............... | ........ | ... | ... | ..... | ... | ... | .....
204 /// : ............... : ............... : ........ : ... : ... : ..... | ... | ... | .....
205 /// ```
206 ///
207 /// Each row of the quantization block is quantized individually.
208 ///
209 /// ```
210 /// Floating-point data Scale Zero point Quantized data
211 /// +-----------------+ +-----+ +-----+ +-----------------+
212 /// | f00 f01 f02 f03 | | s00 | | z00 | -------> | q00 q01 q02 q03 |
213 /// | f10 f11 f12 f13 | | s10 | | z10 | -------> | q10 q11 q12 q13 |
214 /// | f20 f21 f22 f23 | | s20 | | z20 | -------> | q20 q21 q22 q23 |
215 /// | f30 f31 f32 f33 | | s30 | | z30 | -------> | q30 q31 q32 q33 |
216 /// | ............... | | ... | | ... | -------> | ............... |
217 /// : ............... : : ... : : ... : : ............... :
218 /// ```
219 ///
220 /// The computed quantized data matrix:
221 ///
222 /// ```
223 /// +-----------------+-----------------+----- ...
224 /// | q00 q01 q02 q03 | q04 q05 q06 q07 | ........
225 /// | q10 q11 q12 q13 | q14 q15 q16 q17 | ........
226 /// | q20 q21 q22 q23 | q24 q25 q26 q27 | ........
227 /// | q30 q31 q32 q33 | q34 q35 q36 q37 | ........
228 /// | ............... | ............... | ........
229 /// : ............... : ............... : ........
230 /// ```
231 ///
232 /// @tparam SrcType The data type of the input data (must be floating-point).
233 /// @tparam DstType The data type of the output data (must be integer).
234 /// @tparam ScaleType The data type of the quantization scales (must be floating-point).
235 /// @tparam ZeroPointType The data type of the quantization zero points (must be integer).
236 ///
237 /// @param[in] src The input matrix.
238 /// @param[in] scales The quantization scale matrix.
239 /// @param[in] zero_points The quantization zero point matrix.
240 /// @param[in] height The number of rows.
241 /// @param[in] width The number of columns.
242 /// @param[in] quant_width The number of columns of the quantization block.
243 ///
244 /// @return The quantized data matrix.
245 template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType>
246 Buffer quantize_asymmetric_per_block(
247 const void* src, const void* scales, const void* zero_points, size_t height, size_t width, size_t quant_width);
248
249 /// Quantizes the input matrix using the options specified in the quantization info.
250 ///
251 /// @param[in] src The input matrix.
252 /// @param[in] src_type The data type of the input data (must be floating-point).
253 /// @param[in] height The number of rows.
254 /// @param[in] width The number of columns.
255 /// @param[in] qinfo The quantization information.
256 ///
257 /// @return Quantized values and QuantizationOutputs containing scales and (optionally) zero_point data.
258 std::tuple<Buffer, QuantizationOutputs> quantize_dynamic(
259 const void* src, DataType src_type, size_t height, size_t width, const QuantizationInfo& qinfo);
260 } // namespace kai::test
261