Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/quantize.hpp" | ||
8 | |||
9 | #include <algorithm> | ||
10 | #include <cmath> | ||
11 | #include <cstddef> | ||
12 | #include <cstdint> | ||
13 | #include <tuple> | ||
14 | |||
15 | #include "test/common/bfloat16.hpp" | ||
16 | #include "test/common/buffer.hpp" | ||
17 | #include "test/common/int4.hpp" | ||
18 | #include "test/common/memory.hpp" | ||
19 | #include "test/common/numeric_limits.hpp" | ||
20 | #include "test/common/round.hpp" | ||
21 | #include "test/common/type_traits.hpp" | ||
22 | #include "test/reference/cast.hpp" | ||
23 | #include "test/reference/transpose.hpp" | ||
24 | |||
25 | namespace kai::test { | ||
26 | |||
27 | namespace { | ||
28 | |||
29 | template <typename FloatData, typename IntData, typename ZeroPoint> | ||
30 | 645006 | std::tuple<FloatData, ZeroPoint> get_scale_zero_point_from_range(FloatData min_value, FloatData max_value) { | |
31 | 645006 | const FloatData q_min = numeric_lowest<IntData>; | |
32 | 645006 | const FloatData q_max = numeric_highest<IntData>; | |
33 | |||
34 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 119838 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 525168 times.
|
645006 | if (min_value > 0) { |
35 | 645006 | min_value = 0; | |
36 | 645006 | } | |
37 | |||
38 |
2/4✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 525168 times.
✗ Branch 3 not taken.
|
645006 | if (max_value < 0) { |
39 | ✗ | max_value = 0; | |
40 | ✗ | } | |
41 | |||
42 | // The reason for computing the inverted scale first is to make it bit-perfect with quantized packing | ||
43 | // micro-kernels. If those micro-kernels don't do it this way anymore, it makes more sense to calculate | ||
44 | // the scale directly. | ||
45 |
2/4✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 525168 times.
✗ Branch 3 not taken.
|
645006 | const FloatData inv_scale = max_value != min_value ? (q_max - q_min) / (max_value - min_value) : 1.0F; |
46 | 645006 | const FloatData scale = 1.0F / inv_scale; | |
47 | |||
48 | 645006 | const FloatData scaled_min = min_value / scale; | |
49 | 645006 | const FloatData scaled_max = max_value / scale; | |
50 | |||
51 |
2/4✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 525168 times.
✗ Branch 3 not taken.
|
645006 | const FloatData zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max; |
52 | 645006 | const ZeroPoint zero_point = -round_to_nearest_even<ZeroPoint>(zero_point_f); | |
53 | |||
54 | 645006 | return {scale, zero_point}; | |
55 | 645006 | } | |
56 | |||
57 | } // namespace | ||
58 | |||
59 | template <typename IntType> | ||
60 | 55999055 | IntType quantize_symmetric(float value, float scale) { | |
61 |
3/6✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25768996 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 30124681 times.
✗ Branch 5 not taken.
|
55999055 | const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F; |
62 | 55999055 | auto qsi32 = round_to_nearest_even_i32(value * inv_scale); | |
63 | |||
64 | if (is_unsigned<IntType>) { | ||
65 | qsi32 += 1 << (size_in_bits<IntType> - 1); | ||
66 | } | ||
67 | |||
68 | 86229114 | return static_cast<IntType>(std::clamp<int32_t>(qsi32, numeric_lowest<IntType>, numeric_highest<IntType>)); | |
69 | 55999055 | } | |
70 | |||
71 | template <typename FloatType, typename IntType, typename ZeroPointType> | ||
72 | 36253141 | IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero_point) { | |
73 |
2/4✓ Branch 0 taken 14734805 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 21518336 times.
✗ Branch 3 not taken.
|
36253141 | const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F; |
74 | 36253141 | auto quantized_value = round_to_nearest_even<ZeroPointType>(value * inv_scale) + zero_point; | |
75 | 50987946 | return static_cast<IntType>( | |
76 | 36253141 | std::clamp<ZeroPointType>(quantized_value, numeric_lowest<IntType>, numeric_highest<IntType>)); | |
77 | 36253141 | } | |
78 | |||
79 | template int8_t quantize_asymmetric(float value, float scale, int32_t zero_point); | ||
80 | |||
81 | template <typename SrcType, typename DstType, typename ScaleType> | ||
82 | 20815 | Buffer compute_symmetric_per_block_quantization_info(const void* src, size_t height, size_t width, size_t quant_width) { | |
83 | static_assert(is_floating_point<SrcType>); | ||
84 | static_assert(is_integral<DstType>); | ||
85 | static_assert(is_floating_point<ScaleType>); | ||
86 | |||
87 | − | KAI_ASSUME(quant_width != 0); | |
88 | |||
89 | 20815 | const auto num_quant_packets_x = round_up_division(width, quant_width); | |
90 | |||
91 | 20815 | const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType); | |
92 | 20815 | Buffer scales(scales_bytes); | |
93 | |||
94 | 20815 | const auto* src_ptr = reinterpret_cast<const SrcType*>(src); | |
95 | |||
96 |
4/6✓ Branch 0 taken 275340 times.
✓ Branch 1 taken 4452 times.
✓ Branch 2 taken 379590 times.
✓ Branch 3 taken 16363 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
675745 | for (size_t y = 0; y < height; ++y) { |
97 |
4/6✓ Branch 0 taken 577814 times.
✓ Branch 1 taken 275340 times.
✓ Branch 2 taken 379590 times.
✓ Branch 3 taken 457562 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
1690306 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
98 | // Computes the quantization scale. | ||
99 | 1035376 | SrcType max_abs = 0; | |
100 | |||
101 |
4/6✓ Branch 0 taken 25768996 times.
✓ Branch 1 taken 577814 times.
✓ Branch 2 taken 30124681 times.
✓ Branch 3 taken 457562 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
56929053 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
102 | 55893677 | const auto x = x_quant + x_element; | |
103 | |||
104 |
2/6✗ Branch 0 not taken.
✓ Branch 1 taken 25768996 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 30124681 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
55893677 | if (x < width) { |
105 |
2/6✗ Branch 0 not taken.
✓ Branch 1 taken 25768996 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 30124681 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
55893677 | max_abs = std::max<SrcType>(max_abs, std::abs(src_ptr[y * width + x])); |
106 | 55893677 | } | |
107 | 55893677 | } | |
108 | |||
109 | 2070752 | const auto scale = | |
110 | 1035376 | max_abs / static_cast<SrcType>((static_cast<uint64_t>(1) << (size_in_bits<DstType> - 1)) - 1); | |
111 | |||
112 | // Stores the scales. | ||
113 |
1/2✓ Branch 0 taken 577814 times.
✗ Branch 1 not taken.
|
1035376 | write_array<ScaleType>(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale); |
114 | 1035376 | } | |
115 | 654930 | } | |
116 | |||
117 | 20815 | return scales; | |
118 | 20815 | } | |
119 | |||
120 | template <typename SrcType, typename DstType, typename ScaleType> | ||
121 | 21742 | Buffer quantize_symmetric_per_block( | |
122 | const void* src, const void* scales, size_t height, size_t width, size_t quant_width) { | ||
123 | static_assert(is_floating_point<SrcType>); | ||
124 | static_assert(is_integral<DstType>); | ||
125 | static_assert(is_floating_point<ScaleType>); | ||
126 | |||
127 | 21742 | const auto num_quant_packets_x = round_up_division(width, quant_width); | |
128 | |||
129 | 21742 | const auto data_bytes = round_up_division(height * width * size_in_bits<DstType>, 8); | |
130 | 21742 | Buffer data(data_bytes); | |
131 | |||
132 | 21742 | const auto* src_ptr = reinterpret_cast<const SrcType*>(src); | |
133 | |||
134 |
6/6✓ Branch 0 taken 105378 times.
✓ Branch 1 taken 927 times.
✓ Branch 2 taken 275340 times.
✓ Branch 3 taken 4452 times.
✓ Branch 4 taken 379590 times.
✓ Branch 5 taken 16363 times.
|
782050 | for (size_t y = 0; y < height; ++y) { |
135 |
6/6✓ Branch 0 taken 105378 times.
✓ Branch 1 taken 105378 times.
✓ Branch 2 taken 275340 times.
✓ Branch 3 taken 577814 times.
✓ Branch 4 taken 379590 times.
✓ Branch 5 taken 457562 times.
|
1901062 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
136 |
1/2✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
|
1140754 | const auto scale = read_array<ScaleType>(scales, y * num_quant_packets_x + x_quant / quant_width); |
137 | |||
138 | // Quantizes and stores the data. | ||
139 |
6/6✓ Branch 0 taken 105378 times.
✓ Branch 1 taken 105378 times.
✓ Branch 2 taken 577814 times.
✓ Branch 3 taken 25768996 times.
✓ Branch 4 taken 457562 times.
✓ Branch 5 taken 30124681 times.
|
57139809 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
140 | 55999055 | const auto x = x_quant + x_element; | |
141 | |||
142 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 105378 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 25768996 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 30124681 times.
|
55999055 | if (x < width) { |
143 |
3/6✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25768996 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 30124681 times.
✗ Branch 5 not taken.
|
55999055 | const auto quantized = quantize_symmetric<DstType>(src_ptr[y * width + x], scale); |
144 |
4/8✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 105378 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25768996 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 30124681 times.
✗ Branch 7 not taken.
|
55999055 | write_array(data.data(), y * width + x, quantized); |
145 | 55999055 | } | |
146 | 55999055 | } | |
147 | 1140754 | } | |
148 | 760308 | } | |
149 | 21742 | return data; | |
150 | 21742 | } | |
151 | |||
152 | template Buffer quantize_symmetric_per_block<float, int32_t, float>( | ||
153 | const void* src, const void* scales, size_t height, size_t width, size_t quant_width); | ||
154 | |||
155 | template <typename SrcType, typename DstType, typename ScaleType> | ||
156 | 20815 | std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic( | |
157 | const void* src, size_t height, size_t width, size_t quant_width) { | ||
158 | 20815 | auto scales_src_type = | |
159 | 20815 | compute_symmetric_per_block_quantization_info<SrcType, DstType, SrcType>(src, height, width, quant_width); | |
160 |
5/14✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2924 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1404 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 124 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16239 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
|
20815 | auto data = quantize_symmetric_per_block<SrcType, DstType, SrcType>( |
161 |
5/14✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2924 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1404 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 124 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16239 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
|
20815 | src, scales_src_type.data(), height, width, quant_width); |
162 | |||
163 | if constexpr (std::is_same_v<ScaleType, SrcType>) { | ||
164 | 19163 | return {std::move(data), std::move(scales_src_type)}; | |
165 | } else { | ||
166 | 1652 | auto scales = | |
167 |
9/24✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1404 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1404 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1404 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 124 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 124 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 124 times.
✗ Branch 23 not taken.
|
1652 | cast<ScaleType, SrcType>(scales_src_type.data(), scales_src_type.size() * 8 / size_in_bits<SrcType>); |
168 | |||
169 | 1652 | return {std::move(data), std::move(scales)}; | |
170 | 1652 | } | |
171 | 20815 | } | |
172 | |||
173 | template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, Int4, Float16>( | ||
174 | const void* src, size_t height, size_t width, size_t quant_width); | ||
175 | template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, Int4, float>( | ||
176 | const void* src, size_t height, size_t width, size_t quant_width); | ||
177 | template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, Int4, BFloat16<true>>( | ||
178 | const void* src, size_t height, size_t width, size_t quant_width); | ||
179 | template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, Int4, BFloat16<false>>( | ||
180 | const void* src, size_t height, size_t width, size_t quant_width); | ||
181 | template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, int8_t, Float16>( | ||
182 | const void* src, size_t height, size_t width, size_t quant_width); | ||
183 | template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, int8_t, float>( | ||
184 | const void* src, size_t height, size_t width, size_t quant_width); | ||
185 | template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, int32_t, float>( | ||
186 | const void* src, size_t height, size_t width, size_t quant_width); | ||
187 | |||
188 | template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType> | ||
189 | 13206 | std::tuple<Buffer, Buffer> compute_asymmetric_per_block_quantization_info( | |
190 | const void* src, size_t height, size_t width, size_t quant_width) { | ||
191 | static_assert(is_floating_point<SrcType>); | ||
192 | static_assert(is_integral<DstType>); | ||
193 | static_assert(is_floating_point<ScaleType>); | ||
194 | static_assert(is_integral<ZeroPointType>); | ||
195 | |||
196 | − | KAI_ASSUME(quant_width != 0); | |
197 | |||
198 | 13206 | const auto num_quant_packets_x = round_up_division(width, quant_width); | |
199 | |||
200 | 13206 | const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType); | |
201 | 13206 | Buffer scales(scales_bytes); | |
202 | |||
203 | 13206 | const auto zero_points_bytes = height * num_quant_packets_x * sizeof(ZeroPointType); | |
204 |
2/4✓ Branch 0 taken 7974 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5232 times.
✗ Branch 3 not taken.
|
13206 | Buffer zero_points(zero_points_bytes); |
205 | |||
206 |
4/4✓ Branch 0 taken 119838 times.
✓ Branch 1 taken 7974 times.
✓ Branch 2 taken 304160 times.
✓ Branch 3 taken 5232 times.
|
437204 | for (size_t y = 0; y < height; ++y) { |
207 |
4/4✓ Branch 0 taken 119838 times.
✓ Branch 1 taken 119838 times.
✓ Branch 2 taken 304160 times.
✓ Branch 3 taken 525168 times.
|
1069004 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
208 | // Computes the quantization scale and zero point. | ||
209 | 645006 | auto min_value = numeric_highest<SrcType>; | |
210 | 645006 | auto max_value = numeric_lowest<SrcType>; | |
211 | |||
212 |
4/4✓ Branch 0 taken 14732951 times.
✓ Branch 1 taken 119838 times.
✓ Branch 2 taken 21518336 times.
✓ Branch 3 taken 525168 times.
|
36896293 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
213 | 36251287 | const auto x = x_quant + x_element; | |
214 | |||
215 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14732951 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 21518336 times.
|
36251287 | if (x < width) { |
216 | 36251287 | const auto value = read_array<SrcType>(src, y * width + x); | |
217 | |||
218 |
2/4✓ Branch 0 taken 14732951 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 21518336 times.
✗ Branch 3 not taken.
|
36251287 | min_value = std::min(min_value, value); |
219 |
2/4✓ Branch 0 taken 14732951 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 21518336 times.
✗ Branch 3 not taken.
|
36251287 | max_value = std::max(max_value, value); |
220 | 36251287 | } | |
221 | 36251287 | } | |
222 | |||
223 | 645006 | const auto [scale, zero_point] = | |
224 |
2/4✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 525168 times.
✗ Branch 3 not taken.
|
645006 | get_scale_zero_point_from_range<SrcType, DstType, ZeroPointType>(min_value, max_value); |
225 | |||
226 | // Stores the scale and zero point. | ||
227 | 1290012 | write_array<ScaleType>(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale); | |
228 | 1290012 | write_array<ZeroPointType>(zero_points.data(), y * num_quant_packets_x + x_quant / quant_width, zero_point); | |
229 | 645006 | } | |
230 | 423998 | } | |
231 | |||
232 | 13206 | return {std::move(scales), std::move(zero_points)}; | |
233 | 13206 | } | |
234 | |||
235 | template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType> | ||
236 | 13206 | Buffer quantize_asymmetric_per_block( | |
237 | const void* src, const void* scales, const void* zero_points, size_t height, size_t width, size_t quant_width) { | ||
238 | static_assert(is_floating_point<SrcType>); | ||
239 | static_assert(is_integral<DstType>); | ||
240 | static_assert(is_floating_point<ScaleType>); | ||
241 | static_assert(is_integral<ZeroPointType>); | ||
242 | |||
243 | 13206 | const auto num_quant_packets_x = round_up_division(width, quant_width); | |
244 | |||
245 | 13206 | const auto data_bytes = round_up_division(height * width * size_in_bits<DstType>, 8); | |
246 | 13206 | Buffer data(data_bytes); | |
247 | |||
248 |
4/4✓ Branch 0 taken 119838 times.
✓ Branch 1 taken 7974 times.
✓ Branch 2 taken 304160 times.
✓ Branch 3 taken 5232 times.
|
437204 | for (size_t y = 0; y < height; ++y) { |
249 |
4/4✓ Branch 0 taken 119838 times.
✓ Branch 1 taken 119838 times.
✓ Branch 2 taken 304160 times.
✓ Branch 3 taken 525168 times.
|
1069004 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
250 | 645006 | const auto scale = read_array<ScaleType>(scales, y * num_quant_packets_x + x_quant / quant_width); | |
251 | 1170174 | const auto zero_point = | |
252 |
1/2✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
|
645006 | read_array<ZeroPointType>(zero_points, y * num_quant_packets_x + x_quant / quant_width); |
253 | |||
254 | // Quantizes and stores the data. | ||
255 |
4/4✓ Branch 0 taken 14732951 times.
✓ Branch 1 taken 119838 times.
✓ Branch 2 taken 525168 times.
✓ Branch 3 taken 21518336 times.
|
36896293 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
256 | 36251287 | const auto x = x_quant + x_element; | |
257 | |||
258 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14732951 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 21518336 times.
|
36251287 | if (x < width) { |
259 | 36251287 | const auto value_f = read_array<SrcType>(src, y * width + x); | |
260 | 36251287 | const auto value_q = | |
261 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14732951 times.
✓ Branch 2 taken 21518336 times.
✗ Branch 3 not taken.
|
36251287 | quantize_asymmetric<SrcType, DstType, ZeroPointType>(value_f, scale, zero_point); |
262 | |||
263 |
1/2✓ Branch 0 taken 21518336 times.
✗ Branch 1 not taken.
|
36251287 | write_array<DstType>(data.data(), y * width + x, value_q); |
264 | 36251287 | } | |
265 | 36251287 | } | |
266 | 645006 | } | |
267 | 423998 | } | |
268 | |||
269 | 13206 | return data; | |
270 | 13206 | } | |
271 | |||
272 | template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType> | ||
273 | 12279 | std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic( | |
274 | const void* src, size_t height, size_t width, size_t quant_width) { | ||
275 | /* Calculate the asymmetric quantization information, one scaling per row */ | ||
276 | 49116 | auto [scales_src_type, zero_points] = | |
277 | 12279 | compute_asymmetric_per_block_quantization_info<SrcType, DstType, SrcType, ZeroPointType>( | |
278 | 12279 | src, height, width, quant_width); | |
279 | |||
280 | /* Do the actual quantization */ | ||
281 |
2/6✓ Branch 0 taken 7047 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 5232 times.
✗ Branch 5 not taken.
|
24558 | auto data = quantize_asymmetric_per_block<SrcType, DstType, SrcType, ZeroPointType>( |
282 |
6/18✓ Branch 0 taken 7047 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 7047 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7047 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 5232 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 5232 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5232 times.
✗ Branch 17 not taken.
|
24558 | src, scales_src_type.data(), zero_points.data(), height, width, quant_width); |
283 | |||
284 | if constexpr (std::is_same_v<ScaleType, SrcType>) { | ||
285 | 12279 | return {std::move(data), std::move(scales_src_type), std::move(zero_points)}; | |
286 | } else { | ||
287 | ✗ | auto scales = | |
288 | ✗ | cast<ScaleType, SrcType>(scales_src_type.data(), scales_src_type.size() * 8 / size_in_bits<SrcType>); | |
289 | |||
290 | ✗ | return {std::move(data), std::move(scales), std::move(zero_points)}; | |
291 | ✗ | } | |
292 | 12279 | } | |
293 | |||
294 | template std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>( | ||
295 | const void* src, size_t height, size_t width, size_t quant_width); | ||
296 | template std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic<float, int8_t, BFloat16<>, int32_t>( | ||
297 | const void* src, size_t height, size_t width, size_t quant_width); | ||
298 | template std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic<float, Int4, float, int32_t>( | ||
299 | const void* src, size_t height, size_t width, size_t quant_width); | ||
300 | |||
301 | // Reference quantization and packing => Int4 per-block. | ||
302 | // * Generates signed values for reference matmul | ||
303 | // * Generates reference scales from input RHS matrix | ||
304 | template <typename SrcData, typename ScaleType> | ||
305 | 1404 | inline std::tuple<Buffer, Buffer> quantize_rhs_qsi4c32p( | |
306 | size_t N, size_t K, size_t bl, const Buffer& rhs, bool transposed) { | ||
307 | 4914 | auto [rhs_values_qsi4, rhs_scales] = | |
308 | 1404 | quantize_symmetric_per_block_dynamic<SrcData, Int4, ScaleType>(rhs.data(), N, K, bl); | |
309 | |||
310 |
2/2✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
|
1404 | const size_t width = transposed ? K : N; |
311 |
2/2✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
|
1404 | const size_t height = transposed ? N : K; |
312 | |||
313 |
1/2✓ Branch 0 taken 1404 times.
✗ Branch 1 not taken.
|
1404 | const size_t qsi4_stride = round_up_multiple(width, 2); |
314 |
1/2✓ Branch 0 taken 1404 times.
✗ Branch 1 not taken.
|
1404 | const size_t qsi4_size_bytes = round_up_division(height * qsi4_stride, 2); |
315 | |||
316 |
2/2✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
|
1404 | if (!transposed) { |
317 |
3/6✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 702 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 702 times.
✗ Branch 5 not taken.
|
1404 | rhs_values_qsi4 = transpose_with_padding<Int4>(rhs_values_qsi4.data(), N, K, K, qsi4_stride, qsi4_size_bytes); |
318 | 702 | } | |
319 | |||
320 | 1404 | return {std::move(rhs_values_qsi4), std::move(rhs_scales)}; | |
321 | 1404 | } | |
322 | |||
323 | template std::tuple<Buffer, Buffer> quantize_rhs_qsi4c32p<float, BFloat16<false>>( | ||
324 | size_t N, size_t K, size_t bl, const Buffer& rhs, bool transposed); | ||
325 | |||
326 | } // namespace kai::test | ||
327 |