test/reference/quantize.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/reference/quantize.hpp" | ||
| 8 | |||
| 9 | #include <algorithm> | ||
| 10 | #include <cmath> | ||
| 11 | #include <cstddef> | ||
| 12 | #include <cstdint> | ||
| 13 | #include <tuple> | ||
| 14 | |||
| 15 | #include "test/common/bfloat16.hpp" | ||
| 16 | #include "test/common/buffer.hpp" | ||
| 17 | #include "test/common/int4.hpp" | ||
| 18 | #include "test/common/memory.hpp" | ||
| 19 | #include "test/common/numeric_limits.hpp" | ||
| 20 | #include "test/common/round.hpp" | ||
| 21 | #include "test/common/type_traits.hpp" | ||
| 22 | #include "test/reference/cast.hpp" | ||
| 23 | |||
| 24 | namespace kai::test { | ||
| 25 | |||
| 26 | namespace { | ||
| 27 | |||
| 28 | template <typename FloatData, typename IntData, typename ZeroPoint> | ||
| 29 | 3166992 | std::tuple<FloatData, ZeroPoint> get_scale_zero_point_from_range(FloatData min_value, FloatData max_value) { | |
| 30 | 3166992 | const FloatData q_min = numeric_lowest<IntData>; | |
| 31 | 3166992 | const FloatData q_max = numeric_highest<IntData>; | |
| 32 | |||
| 33 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 561816 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2605176 times.
|
3166992 | if (min_value > 0) { |
| 34 | 3166992 | min_value = 0; | |
| 35 | 3166992 | } | |
| 36 | |||
| 37 |
2/4✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2605176 times.
✗ Branch 3 not taken.
|
3166992 | if (max_value < 0) { |
| 38 | ✗ | max_value = 0; | |
| 39 | ✗ | } | |
| 40 | |||
| 41 | // The reason for computing the inverted scale first is to make it bit-perfect with quantized packing | ||
| 42 | // micro-kernels. If those micro-kernels don't do it this way anymore, it makes more sense to calculate | ||
| 43 | // the scale directly. | ||
| 44 |
2/4✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2605176 times.
✗ Branch 3 not taken.
|
3166992 | const FloatData inv_scale = max_value != min_value ? (q_max - q_min) / (max_value - min_value) : 1.0F; |
| 45 | 3166992 | const FloatData scale = 1.0F / inv_scale; | |
| 46 | |||
| 47 | 3166992 | const FloatData scaled_min = min_value / scale; | |
| 48 | 3166992 | const FloatData scaled_max = max_value / scale; | |
| 49 | |||
| 50 |
2/4✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2605176 times.
✗ Branch 3 not taken.
|
3166992 | const FloatData zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max; |
| 51 | 3166992 | const ZeroPoint zero_point = -round_to_nearest_even<ZeroPoint>(zero_point_f); | |
| 52 | |||
| 53 | 3166992 | return {scale, zero_point}; | |
| 54 | 3166992 | } | |
| 55 | |||
| 56 | /// Quantized a float value to an integer datatype using a provided scale. | ||
| 57 | /// | ||
| 58 | /// @tparam IntType Quantized integer datatype. | ||
| 59 | /// | ||
| 60 | /// @param[in] float The value to quantize | ||
| 61 | /// @param[in] scale The scale used to quantize the provided float value. | ||
| 62 | /// | ||
| 63 | /// @return The quantized data matrix, the quantization scale matrix and the quantization zero point matrix. | ||
| 64 | template <typename IntType> | ||
| 65 | 99249221 | IntType quantize_symmetric(float value, float scale) { | |
| 66 |
3/6✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 76177560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 23010898 times.
✗ Branch 5 not taken.
|
99249221 | const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F; |
| 67 | 99249221 | auto qsi32 = round_to_nearest_even_i32(value * inv_scale); | |
| 68 | |||
| 69 | if (is_unsigned<IntType>) { | ||
| 70 | qsi32 += 1 << (size_in_bits<IntType> - 1); | ||
| 71 | } | ||
| 72 | |||
| 73 | 122320882 | return static_cast<IntType>(std::clamp<int32_t>(qsi32, numeric_lowest<IntType>, numeric_highest<IntType>)); | |
| 74 | 99249221 | } | |
| 75 | |||
| 76 | /// Computes the quantization information using symmetric per-block quantization method. | ||
| 77 | /// | ||
| 78 | /// The input matrix is divided into quantization blocks of the same size. | ||
| 79 | /// | ||
| 80 | /// The height of the block does not affect the behavior of this function hence it is omitted | ||
| 81 | /// from the function arguments and the figures below. | ||
| 82 | /// | ||
| 83 | /// ``` | ||
| 84 | /// Quantization blocks -------+ | ||
| 85 | /// | | | ||
| 86 | /// | | | ||
| 87 | /// v v | ||
| 88 | /// +-----------------+-----------------+----- ... | ||
| 89 | /// | f00 f01 f02 f03 | f04 f05 f06 f07 | ........ | ||
| 90 | /// | f10 f11 f12 f13 | f14 f15 f16 f17 | ........ | ||
| 91 | /// | f20 f21 f22 f23 | f24 f25 f26 f27 | ........ | ||
| 92 | /// | f30 f31 f32 f33 | f34 f35 f36 f37 | ........ | ||
| 93 | /// | ............... | ............... | ........ | ||
| 94 | /// : ............... : ............... : ........ | ||
| 95 | /// ``` | ||
| 96 | /// | ||
| 97 | /// Each row of the quantization block is quantized individually. | ||
| 98 | /// | ||
| 99 | /// ``` | ||
| 100 | /// Floating-point data Scale | ||
| 101 | /// +-----------------+ +-----+ | ||
| 102 | /// | f00 f01 f02 f03 | -------> | s00 | | ||
| 103 | /// | f10 f11 f12 f13 | -------> | s10 | | ||
| 104 | /// | f20 f21 f22 f23 | -------> | s20 | | ||
| 105 | /// | f30 f31 f32 f33 | -------> | s30 | | ||
| 106 | /// | ............... | | ... | | ||
| 107 | /// : ............... : : ... : | ||
| 108 | /// ``` | ||
| 109 | /// | ||
| 110 | /// The computed quantization scale matrix: | ||
| 111 | /// | ||
| 112 | /// ``` | ||
| 113 | /// +-----+-----+-- ... | ||
| 114 | /// | s00 | s01 | ..... | ||
| 115 | /// | s10 | s11 | ..... | ||
| 116 | /// | s20 | s21 | ..... | ||
| 117 | /// | s30 | s31 | ..... | ||
| 118 | /// | ... | ... | ..... | ||
| 119 | /// : ... : ... : ..... | ||
| 120 | /// ``` | ||
| 121 | /// | ||
| 122 | /// @tparam SrcType The data type of the input data (must be floating-point). | ||
| 123 | /// @tparam DstType The data type of the output data (must be integer). | ||
| 124 | /// @tparam ScaleType The data type of the quantization scales (must be floating-point). | ||
| 125 | /// | ||
| 126 | /// @param[in] src The input matrix. | ||
| 127 | /// @param[in] height The number of rows. | ||
| 128 | /// @param[in] width The number of columns. | ||
| 129 | /// @param[in] quant_width The number of columns of the quantization block. | ||
| 130 | /// | ||
| 131 | /// @return The quantization scale matrix. | ||
| 132 | template <typename SrcType, typename DstType, typename ScaleType> | ||
| 133 | 38937 | Buffer compute_symmetric_per_block_quantization_info(const void* src, size_t height, size_t width, size_t quant_width) { | |
| 134 | static_assert(is_floating_point<SrcType>); | ||
| 135 | static_assert(is_integral<DstType>); | ||
| 136 | static_assert(is_floating_point<ScaleType>); | ||
| 137 | |||
| 138 | − | KAI_ASSUME_ALWAYS(quant_width != 0); | |
| 139 | |||
| 140 | 38937 | const auto num_quant_packets_x = round_up_division(width, quant_width); | |
| 141 | |||
| 142 | 38937 | const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType); | |
| 143 | 38937 | Buffer scales(scales_bytes); | |
| 144 | |||
| 145 | 38937 | const auto* src_ptr = reinterpret_cast<const SrcType*>(src); | |
| 146 | |||
| 147 |
4/6✓ Branch 0 taken 968873 times.
✓ Branch 1 taken 11534 times.
✓ Branch 2 taken 347657 times.
✓ Branch 3 taken 27403 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
1355467 | for (size_t y = 0; y < height; ++y) { |
| 148 |
4/6✓ Branch 0 taken 968873 times.
✓ Branch 1 taken 1850246 times.
✓ Branch 2 taken 395173 times.
✓ Branch 3 taken 432469 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
3646761 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
| 149 | // Computes the quantization scale. | ||
| 150 | 2330231 | SrcType max_abs = 0; | |
| 151 | |||
| 152 |
4/6✓ Branch 0 taken 76177560 times.
✓ Branch 1 taken 1850246 times.
✓ Branch 2 taken 23010898 times.
✓ Branch 3 taken 479985 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
101518689 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
| 153 | 99188458 | const auto x = x_quant + x_element; | |
| 154 | |||
| 155 |
2/6✗ Branch 0 not taken.
✓ Branch 1 taken 76177560 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 23010898 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
99188458 | if (x < width) { |
| 156 |
2/6✗ Branch 0 not taken.
✓ Branch 1 taken 76177560 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 17005565 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
99188458 | max_abs = std::max<SrcType>(max_abs, std::abs(src_ptr[y * width + x])); |
| 157 | 99188458 | } | |
| 158 | 99188458 | } | |
| 159 | |||
| 160 | 4660462 | const auto scale = | |
| 161 | 2330231 | max_abs / static_cast<SrcType>((static_cast<uint64_t>(1) << (size_in_bits<DstType> - 1)) - 1); | |
| 162 | |||
| 163 | // Stores the scales. | ||
| 164 | 2330231 | write_array<ScaleType>(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale); | |
| 165 | 2330231 | } | |
| 166 | 1316530 | } | |
| 167 | |||
| 168 | 38937 | return scales; | |
| 169 | 38937 | } | |
| 170 | |||
| 171 | /// Dynamically quantizes each block of the matrix using symmetric quantization method. | ||
| 172 | /// | ||
| 173 | /// The quantization information is calculated using | ||
| 174 | /// @ref compute_symmetric_per_block_quantization_info function. | ||
| 175 | /// The floating-point data is then quantized using | ||
| 176 | /// @ref quantize_symmetric_per_block function. | ||
| 177 | /// | ||
| 178 | /// To retain highest quantization accuracy, the data is quantized using the quantization scale | ||
| 179 | /// with the same data type as the input data. | ||
| 180 | /// After that the quantization scale can be stored in the buffer using `ScaleType` data type | ||
| 181 | /// which might have lowest precision than the input data type. | ||
| 182 | /// | ||
| 183 | /// @tparam SrcType The data type of the input data (must be floating-point). | ||
| 184 | /// @tparam DstType The data type of the output data (must be integer). | ||
| 185 | /// @tparam ScaleType The data type of the quantization scales (must be floating-point). | ||
| 186 | /// | ||
| 187 | /// @param[in] src The input matrix. | ||
| 188 | /// @param[in] height The number of rows. | ||
| 189 | /// @param[in] width The number of columns. | ||
| 190 | /// @param[in] quant_width The number of columns of the quantization block. | ||
| 191 | /// | ||
| 192 | /// @return The quantized data matrix and the quantization scale matrix. | ||
| 193 | template <typename SrcType, typename DstType, typename ScaleType> | ||
| 194 | 38937 | std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic( | |
| 195 | const void* src, size_t height, size_t width, size_t quant_width) { | ||
| 196 | 38937 | auto scales_src_type = | |
| 197 | 38937 | compute_symmetric_per_block_quantization_info<SrcType, DstType, SrcType>(src, height, width, quant_width); | |
| 198 |
10/24✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9480 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9480 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1134 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1134 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 920 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 920 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 26483 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 26483 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
|
77874 | auto data = quantize_symmetric_per_block<SrcType, DstType, SrcType>( |
| 199 | 38937 | src, scales_src_type.data(), height, width, quant_width); | |
| 200 | |||
| 201 | if constexpr (std::is_same_v<ScaleType, SrcType>) { | ||
| 202 | 35963 | return {std::move(data), std::move(scales_src_type)}; | |
| 203 | } else { | ||
| 204 | 2974 | auto scales = | |
| 205 |
3/6✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1134 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 920 times.
✗ Branch 5 not taken.
|
2974 | cast<ScaleType, SrcType>(scales_src_type.data(), scales_src_type.size() * 8 / size_in_bits<SrcType>); |
| 206 | |||
| 207 | 2974 | return {std::move(data), std::move(scales)}; | |
| 208 | 2974 | } | |
| 209 | 38937 | } | |
| 210 | |||
| 211 | /// Dynamically quantizes each block of the matrix using symmetric quantization method. | ||
| 212 | /// | ||
| 213 | /// @param[in] src The input matrix. | ||
| 214 | /// @param[in] src_type The data type of the input data (must be FP32). | ||
| 215 | /// @param[in] height The number of rows. | ||
| 216 | /// @param[in] width The number of columns. | ||
| 217 | /// @param[in] qinfo The quantization information. | ||
| 218 | /// | ||
| 219 | /// @return The quantized data matrix and the quantization scale matrix. | ||
| 220 | 38937 | std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic( | |
| 221 | const void* src, DataType src_type, size_t height, size_t width, const QuantizationInfo& qinfo) { | ||
| 222 | // Fail fast for datatypes that must be fixed. | ||
| 223 | − | KAI_ASSUME_ALWAYS(src_type == DataType::FP32); | |
| 224 | |||
| 225 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 11534 times.
✓ Branch 2 taken 27403 times.
✗ Branch 3 not taken.
|
38937 | switch (qinfo.dst_type) { |
| 226 | case DataType::QSI4: | ||
| 227 |
3/4✓ Branch 0 taken 9480 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✓ Branch 3 taken 1134 times.
|
11534 | switch (qinfo.scale_type) { |
| 228 | case DataType::FP16: | ||
| 229 | 920 | return quantize_symmetric_per_block_dynamic<float, Int4, Float16>( | |
| 230 | 920 | src, height, width, qinfo.quant_width); | |
| 231 | case DataType::FP32: | ||
| 232 | 9480 | return quantize_symmetric_per_block_dynamic<float, Int4, float>( | |
| 233 | 9480 | src, height, width, qinfo.quant_width); | |
| 234 | case DataType::BF16: | ||
| 235 | 1134 | return quantize_symmetric_per_block_dynamic<float, Int4, BFloat16<>>( | |
| 236 | 1134 | src, height, width, qinfo.quant_width); | |
| 237 | default: | ||
| 238 | ✗ | break; | |
| 239 | } | ||
| 240 | ✗ | break; | |
| 241 | case DataType::QSI8: | ||
| 242 |
2/3✓ Branch 0 taken 26483 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
|
27403 | switch (qinfo.scale_type) { |
| 243 | case DataType::FP16: | ||
| 244 | 920 | return quantize_symmetric_per_block_dynamic<float, int8_t, Float16>( | |
| 245 | 920 | src, height, width, qinfo.quant_width); | |
| 246 | case DataType::FP32: | ||
| 247 | 26483 | return quantize_symmetric_per_block_dynamic<float, int8_t, float>( | |
| 248 | 26483 | src, height, width, qinfo.quant_width); | |
| 249 | default: | ||
| 250 | ✗ | break; | |
| 251 | } | ||
| 252 | ✗ | break; | |
| 253 | case DataType::I32: | ||
| 254 | ✗ | if (qinfo.scale_type == DataType::FP32) { | |
| 255 | ✗ | return quantize_symmetric_per_block_dynamic<float, int32_t, float>( | |
| 256 | ✗ | src, height, width, qinfo.quant_width); | |
| 257 | } | ||
| 258 | ✗ | break; | |
| 259 | default: | ||
| 260 | ✗ | break; | |
| 261 | } | ||
| 262 | − | KAI_ERROR("Unsupported combination of data types for symmetric quantization."); | |
| 263 | 38937 | } | |
| 264 | |||
| 265 | /// Dynamically quantizes each block of the matrix using asymmetric quantization method. | ||
| 266 | /// | ||
| 267 | /// The quantization information is calculated using | ||
| 268 | /// @ref compute_asymmetric_per_block_quantization_info function. | ||
| 269 | /// The floating-point data is then quantized using | ||
| 270 | /// @ref quantize_asymmetric_per_block function. | ||
| 271 | /// | ||
| 272 | /// To retain highest quantization accuracy, the data is quantized using the quantization scale | ||
| 273 | /// with the same data type as the input data. | ||
| 274 | /// After that the quantization scale can be stored in the buffer using `ScaleType` data type | ||
| 275 | /// which might have lowest precision than the input data type. | ||
| 276 | /// | ||
| 277 | /// @tparam SrcType The data type of the input data (must be floating-point). | ||
| 278 | /// @tparam DstType The data type of the output data (must be integer). | ||
| 279 | /// @tparam ScaleType The data type of the quantization scales (must be floating-point). | ||
| 280 | /// @tparam ZeroPointType The data type of the quantization zero points (must be integer). | ||
| 281 | /// | ||
| 282 | /// @param[in] src The input matrix. | ||
| 283 | /// @param[in] height The number of rows. | ||
| 284 | /// @param[in] width The number of columns. | ||
| 285 | /// @param[in] quant_width The number of columns of the quantization block. | ||
| 286 | /// | ||
| 287 | /// @return The quantized data matrix, the quantization scale matrix and the quantization zero point matrix. | ||
| 288 | template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType> | ||
| 289 | 39233 | std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic( | |
| 290 | const void* src, size_t height, size_t width, size_t quant_width) { | ||
| 291 | /* Calculate the asymmetric quantization information, one scaling per row */ | ||
| 292 | 156932 | auto [scales_src_type, zero_points] = | |
| 293 | 39233 | compute_asymmetric_per_block_quantization_info<SrcType, DstType, SrcType, ZeroPointType>( | |
| 294 | 39233 | src, height, width, quant_width); | |
| 295 | |||
| 296 | /* Do the actual quantization */ | ||
| 297 |
4/12✓ Branch 0 taken 13529 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13529 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 25704 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 25704 times.
✗ Branch 11 not taken.
|
78466 | auto data = quantize_asymmetric_per_block<SrcType, DstType, SrcType, ZeroPointType>( |
| 298 | 117699 | src, scales_src_type.data(), zero_points.data(), height, width, quant_width); | |
| 299 | |||
| 300 | if constexpr (std::is_same_v<ScaleType, SrcType>) { | ||
| 301 | 39233 | return {std::move(data), std::move(scales_src_type), std::move(zero_points)}; | |
| 302 | } else { | ||
| 303 | ✗ | auto scales = | |
| 304 | ✗ | cast<ScaleType, SrcType>(scales_src_type.data(), scales_src_type.size() * 8 / size_in_bits<SrcType>); | |
| 305 | |||
| 306 | ✗ | return {std::move(data), std::move(scales), std::move(zero_points)}; | |
| 307 | ✗ | } | |
| 308 | 39233 | } | |
| 309 | |||
| 310 | /// Dynamically quantizes each block of the matrix using asymmetric quantization method. | ||
| 311 | /// | ||
| 312 | /// @param[in] src The input matrix. | ||
| 313 | /// @param[in] src_type The data type of the input data (must be FP32). | ||
| 314 | /// @param[in] height The number of rows. | ||
| 315 | /// @param[in] width The number of columns. | ||
| 316 | /// @param[in] qinfo The quantization information. | ||
| 317 | /// | ||
| 318 | /// @return The quantized data matrix, the quantization scale matrix and the quantization zero point matrix. | ||
| 319 | 39233 | std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic( | |
| 320 | const void* src, DataType src_type, size_t height, size_t width, const QuantizationInfo& qinfo) { | ||
| 321 | // Fail fast for datatypes that must be fixed. | ||
| 322 | − | KAI_ASSUME_ALWAYS(src_type == DataType::FP32); | |
| 323 | − | KAI_ASSUME_ALWAYS(qinfo.zero_point_type == DataType::I32); | |
| 324 | |||
| 325 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 13529 times.
✓ Branch 2 taken 25704 times.
|
39233 | switch (qinfo.dst_type) { |
| 326 | case DataType::QAI8: | ||
| 327 |
1/3✗ Branch 0 not taken.
✓ Branch 1 taken 13529 times.
✗ Branch 2 not taken.
|
13529 | switch (qinfo.scale_type) { |
| 328 | case DataType::FP32: | ||
| 329 | 13529 | return quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>( | |
| 330 | 13529 | src, height, width, qinfo.quant_width); | |
| 331 | case DataType::BF16: | ||
| 332 | ✗ | return quantize_asymmetric_per_block_dynamic<float, int8_t, BFloat16<>, int32_t>( | |
| 333 | ✗ | src, height, width, qinfo.quant_width); | |
| 334 | default: | ||
| 335 | ✗ | break; | |
| 336 | } | ||
| 337 | ✗ | break; | |
| 338 | case DataType::QAI4: | ||
| 339 |
1/2✓ Branch 0 taken 25704 times.
✗ Branch 1 not taken.
|
25704 | switch (qinfo.scale_type) { |
| 340 | case DataType::FP32: | ||
| 341 | 25704 | return quantize_asymmetric_per_block_dynamic<float, Int4, float, int32_t>( | |
| 342 | 25704 | src, height, width, qinfo.quant_width); | |
| 343 | default: | ||
| 344 | ✗ | break; | |
| 345 | } | ||
| 346 | ✗ | break; | |
| 347 | default: | ||
| 348 | ✗ | break; | |
| 349 | } | ||
| 350 | − | KAI_ERROR("Unsupported combination of destination/scale types for asymmetric quantization."); | |
| 351 | 39233 | } | |
| 352 | |||
| 353 | } // namespace | ||
| 354 | |||
| 355 | template <typename FloatType, typename IntType, typename ZeroPointType> | ||
| 356 | 210858462 | IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero_point) { | |
| 357 |
2/4✓ Branch 0 taken 104220126 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 106638336 times.
✗ Branch 3 not taken.
|
210858462 | const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F; |
| 358 | 210858462 | auto quantized_value = round_to_nearest_even<ZeroPointType>(value * inv_scale) + zero_point; | |
| 359 | 315078588 | return static_cast<IntType>( | |
| 360 | 210858462 | std::clamp<ZeroPointType>(quantized_value, numeric_lowest<IntType>, numeric_highest<IntType>)); | |
| 361 | 210858462 | } | |
| 362 | |||
| 363 | template int8_t quantize_asymmetric(float value, float scale, int32_t zero_point); | ||
| 364 | |||
| 365 | template <typename SrcType, typename DstType, typename ScaleType> | ||
| 366 | 39458 | Buffer quantize_symmetric_per_block( | |
| 367 | const void* src, const void* scales, size_t height, size_t width, size_t quant_width) { | ||
| 368 | static_assert(is_floating_point<SrcType>); | ||
| 369 | static_assert(is_integral<DstType>); | ||
| 370 | static_assert(is_floating_point<ScaleType>); | ||
| 371 | |||
| 372 | 39458 | const auto num_quant_packets_x = round_up_division(width, quant_width); | |
| 373 | |||
| 374 | 39458 | const auto data_bytes = round_up_division(height * width * size_in_bits<DstType>, 8); | |
| 375 | 39458 | Buffer data(data_bytes); | |
| 376 | |||
| 377 | 39458 | const auto* src_ptr = reinterpret_cast<const SrcType*>(src); | |
| 378 | |||
| 379 |
6/6✓ Branch 0 taken 60763 times.
✓ Branch 1 taken 521 times.
✓ Branch 2 taken 968873 times.
✓ Branch 3 taken 11534 times.
✓ Branch 4 taken 347657 times.
✓ Branch 5 taken 27403 times.
|
1416751 | for (size_t y = 0; y < height; ++y) { |
| 380 |
6/6✓ Branch 0 taken 60763 times.
✓ Branch 1 taken 60763 times.
✓ Branch 2 taken 968873 times.
✓ Branch 3 taken 1850246 times.
✓ Branch 4 taken 479985 times.
✓ Branch 5 taken 347657 times.
|
3768287 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
| 381 |
3/6✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1850246 times.
✓ Branch 4 taken 479985 times.
✗ Branch 5 not taken.
|
2390994 | const auto scale = read_array<ScaleType>(scales, y * num_quant_packets_x + x_quant / quant_width); |
| 382 | |||
| 383 | // Quantizes and stores the data. | ||
| 384 |
6/6✓ Branch 0 taken 60763 times.
✓ Branch 1 taken 60763 times.
✓ Branch 2 taken 1850246 times.
✓ Branch 3 taken 76177560 times.
✓ Branch 4 taken 23010898 times.
✓ Branch 5 taken 479985 times.
|
101640215 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
| 385 | 99249221 | const auto x = x_quant + x_element; | |
| 386 | |||
| 387 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 60763 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 76177560 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 23010898 times.
|
99249221 | if (x < width) { |
| 388 |
3/6✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 76177560 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 23010898 times.
|
99249221 | const auto quantized = quantize_symmetric<DstType>(src_ptr[y * width + x], scale); |
| 389 |
3/6✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60763 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 76177560 times.
✗ Branch 5 not taken.
|
99249221 | write_array(data.data(), y * width + x, quantized); |
| 390 | 99249221 | } | |
| 391 | 99249221 | } | |
| 392 | 2390994 | } | |
| 393 | 1377293 | } | |
| 394 | 39458 | return data; | |
| 395 | 39458 | } | |
| 396 | |||
| 397 | template Buffer quantize_symmetric_per_block<float, int32_t, float>( | ||
| 398 | const void* src, const void* scales, size_t height, size_t width, size_t quant_width); | ||
| 399 | |||
| 400 | template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType> | ||
| 401 | 39754 | std::tuple<Buffer, Buffer> compute_asymmetric_per_block_quantization_info( | |
| 402 | const void* src, size_t height, size_t width, size_t quant_width) { | ||
| 403 | static_assert(is_floating_point<SrcType>); | ||
| 404 | static_assert(is_integral<DstType>); | ||
| 405 | static_assert(is_floating_point<ScaleType>); | ||
| 406 | static_assert(is_integral<ZeroPointType>); | ||
| 407 | |||
| 408 | − | KAI_ASSUME_ALWAYS(quant_width != 0); | |
| 409 | |||
| 410 | 39754 | const auto num_quant_packets_x = round_up_division(width, quant_width); | |
| 411 | |||
| 412 | 39754 | const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType); | |
| 413 | 39754 | Buffer scales(scales_bytes); | |
| 414 | |||
| 415 | 39754 | const auto zero_points_bytes = height * num_quant_packets_x * sizeof(ZeroPointType); | |
| 416 |
2/4✓ Branch 0 taken 14050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25704 times.
✗ Branch 3 not taken.
|
39754 | Buffer zero_points(zero_points_bytes); |
| 417 | |||
| 418 |
4/4✓ Branch 0 taken 14050 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 514080 times.
✓ Branch 3 taken 1002456 times.
|
2092402 | for (size_t y = 0; y < height; ++y) { |
| 419 |
4/4✓ Branch 0 taken 561816 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 2605176 times.
✓ Branch 3 taken 1490832 times.
|
5219640 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
| 420 | // Computes the quantization scale and zero point. | ||
| 421 | 3166992 | auto min_value = numeric_highest<SrcType>; | |
| 422 | 3166992 | auto max_value = numeric_lowest<SrcType>; | |
| 423 | |||
| 424 |
4/4✓ Branch 0 taken 104219084 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 106638336 times.
✓ Branch 3 taken 2605176 times.
|
214024412 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
| 425 | 210857420 | const auto x = x_quant + x_element; | |
| 426 | |||
| 427 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 104219084 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 106638336 times.
|
210857420 | if (x < width) { |
| 428 |
3/4✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 71092224 times.
✓ Branch 3 taken 35546112 times.
|
210857420 | const auto value = read_array<SrcType>(src, y * width + x); |
| 429 | |||
| 430 |
2/4✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 71092224 times.
✗ Branch 3 not taken.
|
210857420 | min_value = std::min(min_value, value); |
| 431 |
2/4✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 71092224 times.
✗ Branch 3 not taken.
|
210857420 | max_value = std::max(max_value, value); |
| 432 | 210857420 | } | |
| 433 | 210857420 | } | |
| 434 | |||
| 435 | 3166992 | const auto [scale, zero_point] = | |
| 436 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 561816 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2605176 times.
|
3166992 | get_scale_zero_point_from_range<SrcType, DstType, ZeroPointType>(min_value, max_value); |
| 437 | |||
| 438 | // Stores the scale and zero point. | ||
| 439 |
2/4✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 561816 times.
✗ Branch 3 not taken.
|
6333984 | write_array<ScaleType>(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale); |
| 440 |
4/8✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 561816 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2605176 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2605176 times.
|
6333984 | write_array<ZeroPointType>(zero_points.data(), y * num_quant_packets_x + x_quant / quant_width, zero_point); |
| 441 | 3166992 | } | |
| 442 | 2052648 | } | |
| 443 | |||
| 444 | 39754 | return {std::move(scales), std::move(zero_points)}; | |
| 445 | 39754 | } | |
| 446 | |||
| 447 | template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType> | ||
| 448 | 39754 | Buffer quantize_asymmetric_per_block( | |
| 449 | const void* src, const void* scales, const void* zero_points, size_t height, size_t width, size_t quant_width) { | ||
| 450 | static_assert(is_floating_point<SrcType>); | ||
| 451 | static_assert(is_integral<DstType>); | ||
| 452 | static_assert(is_floating_point<ScaleType>); | ||
| 453 | static_assert(is_integral<ZeroPointType>); | ||
| 454 | |||
| 455 | 39754 | const auto num_quant_packets_x = round_up_division(width, quant_width); | |
| 456 | |||
| 457 | 39754 | const auto data_bytes = round_up_division(height * width * size_in_bits<DstType>, 8); | |
| 458 | 39754 | Buffer data(data_bytes); | |
| 459 | |||
| 460 |
4/4✓ Branch 0 taken 14050 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 25704 times.
✓ Branch 3 taken 1490832 times.
|
2092402 | for (size_t y = 0; y < height; ++y) { |
| 461 |
4/4✓ Branch 0 taken 561816 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 2605176 times.
✓ Branch 3 taken 1490832 times.
|
5219640 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
| 462 |
2/4✓ Branch 0 taken 561816 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2605176 times.
✗ Branch 3 not taken.
|
3166992 | const auto scale = read_array<ScaleType>(scales, y * num_quant_packets_x + x_quant / quant_width); |
| 463 | 5772168 | const auto zero_point = | |
| 464 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 561816 times.
|
3166992 | read_array<ZeroPointType>(zero_points, y * num_quant_packets_x + x_quant / quant_width); |
| 465 | |||
| 466 | // Quantizes and stores the data. | ||
| 467 |
4/4✓ Branch 0 taken 104219084 times.
✓ Branch 1 taken 561816 times.
✓ Branch 2 taken 106638336 times.
✓ Branch 3 taken 2605176 times.
|
214024412 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
| 468 | 210857420 | const auto x = x_quant + x_element; | |
| 469 | |||
| 470 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 104219084 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 106638336 times.
|
210857420 | if (x < width) { |
| 471 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 104219084 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 106638336 times.
|
210857420 | const auto value_f = read_array<SrcType>(src, y * width + x); |
| 472 | 210857420 | const auto value_q = | |
| 473 |
2/4✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 106638336 times.
✗ Branch 3 not taken.
|
210857420 | quantize_asymmetric<SrcType, DstType, ZeroPointType>(value_f, scale, zero_point); |
| 474 | |||
| 475 |
2/4✓ Branch 0 taken 104219084 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 106638336 times.
✗ Branch 3 not taken.
|
210857420 | write_array<DstType>(data.data(), y * width + x, value_q); |
| 476 | 210857420 | } | |
| 477 | 210857420 | } | |
| 478 | 3166992 | } | |
| 479 | 2052648 | } | |
| 480 | |||
| 481 | 39754 | return data; | |
| 482 | 39754 | } | |
| 483 | |||
| 484 | 78170 | std::tuple<Buffer, QuantizationOutputs> quantize_dynamic( | |
| 485 | const void* src, DataType src_type, size_t height, size_t width, const QuantizationInfo& qinfo) { | ||
| 486 | − | KAI_ASSUME_ALWAYS(data_type_is_quantized(qinfo.dst_type)); | |
| 487 | 78170 | Buffer data; | |
| 488 | 78170 | QuantizationOutputs qoutputs; | |
| 489 |
3/4✓ Branch 0 taken 78170 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 39233 times.
✓ Branch 3 taken 38937 times.
|
78170 | if (data_type_is_quantized_asymm(qinfo.dst_type)) { |
| 490 | − | KAI_ASSUME_ALWAYS(qinfo.zero_point_type != DataType::UNKNOWN); | |
| 491 | 39233 | std::tie(data, qoutputs.scales, qoutputs.zero_points) = | |
| 492 |
1/2✓ Branch 0 taken 39233 times.
✗ Branch 1 not taken.
|
39233 | quantize_asymmetric_per_block_dynamic(src, src_type, height, width, qinfo); |
| 493 | 39233 | } else { | |
| 494 |
1/2✓ Branch 0 taken 38937 times.
✗ Branch 1 not taken.
|
38937 | std::tie(data, qoutputs.scales) = quantize_symmetric_per_block_dynamic(src, src_type, height, width, qinfo); |
| 495 | } | ||
| 496 | 78170 | return {std::move(data), std::move(qoutputs)}; | |
| 497 | 78170 | } | |
| 498 | } // namespace kai::test | ||
| 499 |