KleidiAI Coverage Report


Directory: ./
File: test/reference/quantize.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.0% 145 2 153
Functions: 85.7% 24 0 28
Branches: 54.4% 137 10 262

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/quantize.hpp"
8
9 #include <algorithm>
10 #include <cmath>
11 #include <cstddef>
12 #include <cstdint>
13 #include <tuple>
14
15 #include "test/common/bfloat16.hpp"
16 #include "test/common/buffer.hpp"
17 #include "test/common/int4.hpp"
18 #include "test/common/memory.hpp"
19 #include "test/common/numeric_limits.hpp"
20 #include "test/common/round.hpp"
21 #include "test/common/type_traits.hpp"
22 #include "test/reference/cast.hpp"
23 #include "test/reference/transpose.hpp"
24
25 namespace kai::test {
26
27 namespace {
28
29 template <typename FloatData, typename IntData, typename ZeroPoint>
30 645006 std::tuple<FloatData, ZeroPoint> get_scale_zero_point_from_range(FloatData min_value, FloatData max_value) {
31 645006 const FloatData q_min = numeric_lowest<IntData>;
32 645006 const FloatData q_max = numeric_highest<IntData>;
33
34
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 119838 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 525168 times.
645006 if (min_value > 0) {
35 645006 min_value = 0;
36 645006 }
37
38
2/4
✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 525168 times.
✗ Branch 3 not taken.
645006 if (max_value < 0) {
39 max_value = 0;
40 }
41
42 // The reason for computing the inverted scale first is to make it bit-perfect with quantized packing
43 // micro-kernels. If those micro-kernels don't do it this way anymore, it makes more sense to calculate
44 // the scale directly.
45
2/4
✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 525168 times.
✗ Branch 3 not taken.
645006 const FloatData inv_scale = max_value != min_value ? (q_max - q_min) / (max_value - min_value) : 1.0F;
46 645006 const FloatData scale = 1.0F / inv_scale;
47
48 645006 const FloatData scaled_min = min_value / scale;
49 645006 const FloatData scaled_max = max_value / scale;
50
51
2/4
✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 525168 times.
✗ Branch 3 not taken.
645006 const FloatData zero_point_f = -(scaled_min + q_min) < scaled_max + q_max ? scaled_min - q_min : scaled_max - q_max;
52 645006 const ZeroPoint zero_point = -round_to_nearest_even<ZeroPoint>(zero_point_f);
53
54 645006 return {scale, zero_point};
55 645006 }
56
57 } // namespace
58
59 template <typename IntType>
60 55999055 IntType quantize_symmetric(float value, float scale) {
61
3/6
✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25768996 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 30124681 times.
✗ Branch 5 not taken.
55999055 const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F;
62 55999055 auto qsi32 = round_to_nearest_even_i32(value * inv_scale);
63
64 if (is_unsigned<IntType>) {
65 qsi32 += 1 << (size_in_bits<IntType> - 1);
66 }
67
68 86229114 return static_cast<IntType>(std::clamp<int32_t>(qsi32, numeric_lowest<IntType>, numeric_highest<IntType>));
69 55999055 }
70
71 template <typename FloatType, typename IntType, typename ZeroPointType>
72 36253141 IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero_point) {
73
2/4
✓ Branch 0 taken 14734805 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 21518336 times.
✗ Branch 3 not taken.
36253141 const auto inv_scale = scale != 0 ? 1.0F / scale : 0.0F;
74 36253141 auto quantized_value = round_to_nearest_even<ZeroPointType>(value * inv_scale) + zero_point;
75 50987946 return static_cast<IntType>(
76 36253141 std::clamp<ZeroPointType>(quantized_value, numeric_lowest<IntType>, numeric_highest<IntType>));
77 36253141 }
78
79 template int8_t quantize_asymmetric(float value, float scale, int32_t zero_point);
80
81 template <typename SrcType, typename DstType, typename ScaleType>
82 20815 Buffer compute_symmetric_per_block_quantization_info(const void* src, size_t height, size_t width, size_t quant_width) {
83 static_assert(is_floating_point<SrcType>);
84 static_assert(is_integral<DstType>);
85 static_assert(is_floating_point<ScaleType>);
86
87 KAI_ASSUME(quant_width != 0);
88
89 20815 const auto num_quant_packets_x = round_up_division(width, quant_width);
90
91 20815 const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType);
92 20815 Buffer scales(scales_bytes);
93
94 20815 const auto* src_ptr = reinterpret_cast<const SrcType*>(src);
95
96
4/6
✓ Branch 0 taken 275340 times.
✓ Branch 1 taken 4452 times.
✓ Branch 2 taken 379590 times.
✓ Branch 3 taken 16363 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
675745 for (size_t y = 0; y < height; ++y) {
97
4/6
✓ Branch 0 taken 577814 times.
✓ Branch 1 taken 275340 times.
✓ Branch 2 taken 379590 times.
✓ Branch 3 taken 457562 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
1690306 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
98 // Computes the quantization scale.
99 1035376 SrcType max_abs = 0;
100
101
4/6
✓ Branch 0 taken 25768996 times.
✓ Branch 1 taken 577814 times.
✓ Branch 2 taken 30124681 times.
✓ Branch 3 taken 457562 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
56929053 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
102 55893677 const auto x = x_quant + x_element;
103
104
2/6
✗ Branch 0 not taken.
✓ Branch 1 taken 25768996 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 30124681 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
55893677 if (x < width) {
105
2/6
✗ Branch 0 not taken.
✓ Branch 1 taken 25768996 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 30124681 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
55893677 max_abs = std::max<SrcType>(max_abs, std::abs(src_ptr[y * width + x]));
106 55893677 }
107 55893677 }
108
109 2070752 const auto scale =
110 1035376 max_abs / static_cast<SrcType>((static_cast<uint64_t>(1) << (size_in_bits<DstType> - 1)) - 1);
111
112 // Stores the scales.
113
1/2
✓ Branch 0 taken 577814 times.
✗ Branch 1 not taken.
1035376 write_array<ScaleType>(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale);
114 1035376 }
115 654930 }
116
117 20815 return scales;
118 20815 }
119
120 template <typename SrcType, typename DstType, typename ScaleType>
121 21742 Buffer quantize_symmetric_per_block(
122 const void* src, const void* scales, size_t height, size_t width, size_t quant_width) {
123 static_assert(is_floating_point<SrcType>);
124 static_assert(is_integral<DstType>);
125 static_assert(is_floating_point<ScaleType>);
126
127 21742 const auto num_quant_packets_x = round_up_division(width, quant_width);
128
129 21742 const auto data_bytes = round_up_division(height * width * size_in_bits<DstType>, 8);
130 21742 Buffer data(data_bytes);
131
132 21742 const auto* src_ptr = reinterpret_cast<const SrcType*>(src);
133
134
6/6
✓ Branch 0 taken 105378 times.
✓ Branch 1 taken 927 times.
✓ Branch 2 taken 275340 times.
✓ Branch 3 taken 4452 times.
✓ Branch 4 taken 379590 times.
✓ Branch 5 taken 16363 times.
782050 for (size_t y = 0; y < height; ++y) {
135
6/6
✓ Branch 0 taken 105378 times.
✓ Branch 1 taken 105378 times.
✓ Branch 2 taken 275340 times.
✓ Branch 3 taken 577814 times.
✓ Branch 4 taken 379590 times.
✓ Branch 5 taken 457562 times.
1901062 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
136
1/2
✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
1140754 const auto scale = read_array<ScaleType>(scales, y * num_quant_packets_x + x_quant / quant_width);
137
138 // Quantizes and stores the data.
139
6/6
✓ Branch 0 taken 105378 times.
✓ Branch 1 taken 105378 times.
✓ Branch 2 taken 577814 times.
✓ Branch 3 taken 25768996 times.
✓ Branch 4 taken 457562 times.
✓ Branch 5 taken 30124681 times.
57139809 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
140 55999055 const auto x = x_quant + x_element;
141
142
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 105378 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 25768996 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 30124681 times.
55999055 if (x < width) {
143
3/6
✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25768996 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 30124681 times.
✗ Branch 5 not taken.
55999055 const auto quantized = quantize_symmetric<DstType>(src_ptr[y * width + x], scale);
144
4/8
✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 105378 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25768996 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 30124681 times.
✗ Branch 7 not taken.
55999055 write_array(data.data(), y * width + x, quantized);
145 55999055 }
146 55999055 }
147 1140754 }
148 760308 }
149 21742 return data;
150 21742 }
151
152 template Buffer quantize_symmetric_per_block<float, int32_t, float>(
153 const void* src, const void* scales, size_t height, size_t width, size_t quant_width);
154
155 template <typename SrcType, typename DstType, typename ScaleType>
156 20815 std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic(
157 const void* src, size_t height, size_t width, size_t quant_width) {
158 20815 auto scales_src_type =
159 20815 compute_symmetric_per_block_quantization_info<SrcType, DstType, SrcType>(src, height, width, quant_width);
160
5/14
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2924 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1404 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 124 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16239 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
20815 auto data = quantize_symmetric_per_block<SrcType, DstType, SrcType>(
161
5/14
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2924 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1404 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 124 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16239 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
20815 src, scales_src_type.data(), height, width, quant_width);
162
163 if constexpr (std::is_same_v<ScaleType, SrcType>) {
164 19163 return {std::move(data), std::move(scales_src_type)};
165 } else {
166 1652 auto scales =
167
9/24
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1404 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1404 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1404 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 124 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 124 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 124 times.
✗ Branch 23 not taken.
1652 cast<ScaleType, SrcType>(scales_src_type.data(), scales_src_type.size() * 8 / size_in_bits<SrcType>);
168
169 1652 return {std::move(data), std::move(scales)};
170 1652 }
171 20815 }
172
173 template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, Int4, Float16>(
174 const void* src, size_t height, size_t width, size_t quant_width);
175 template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, Int4, float>(
176 const void* src, size_t height, size_t width, size_t quant_width);
177 template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, Int4, BFloat16<true>>(
178 const void* src, size_t height, size_t width, size_t quant_width);
179 template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, Int4, BFloat16<false>>(
180 const void* src, size_t height, size_t width, size_t quant_width);
181 template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, int8_t, Float16>(
182 const void* src, size_t height, size_t width, size_t quant_width);
183 template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, int8_t, float>(
184 const void* src, size_t height, size_t width, size_t quant_width);
185 template std::tuple<Buffer, Buffer> quantize_symmetric_per_block_dynamic<float, int32_t, float>(
186 const void* src, size_t height, size_t width, size_t quant_width);
187
188 template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType>
189 13206 std::tuple<Buffer, Buffer> compute_asymmetric_per_block_quantization_info(
190 const void* src, size_t height, size_t width, size_t quant_width) {
191 static_assert(is_floating_point<SrcType>);
192 static_assert(is_integral<DstType>);
193 static_assert(is_floating_point<ScaleType>);
194 static_assert(is_integral<ZeroPointType>);
195
196 KAI_ASSUME(quant_width != 0);
197
198 13206 const auto num_quant_packets_x = round_up_division(width, quant_width);
199
200 13206 const auto scales_bytes = height * num_quant_packets_x * sizeof(ScaleType);
201 13206 Buffer scales(scales_bytes);
202
203 13206 const auto zero_points_bytes = height * num_quant_packets_x * sizeof(ZeroPointType);
204
2/4
✓ Branch 0 taken 7974 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5232 times.
✗ Branch 3 not taken.
13206 Buffer zero_points(zero_points_bytes);
205
206
4/4
✓ Branch 0 taken 119838 times.
✓ Branch 1 taken 7974 times.
✓ Branch 2 taken 304160 times.
✓ Branch 3 taken 5232 times.
437204 for (size_t y = 0; y < height; ++y) {
207
4/4
✓ Branch 0 taken 119838 times.
✓ Branch 1 taken 119838 times.
✓ Branch 2 taken 304160 times.
✓ Branch 3 taken 525168 times.
1069004 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
208 // Computes the quantization scale and zero point.
209 645006 auto min_value = numeric_highest<SrcType>;
210 645006 auto max_value = numeric_lowest<SrcType>;
211
212
4/4
✓ Branch 0 taken 14732951 times.
✓ Branch 1 taken 119838 times.
✓ Branch 2 taken 21518336 times.
✓ Branch 3 taken 525168 times.
36896293 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
213 36251287 const auto x = x_quant + x_element;
214
215
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14732951 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 21518336 times.
36251287 if (x < width) {
216 36251287 const auto value = read_array<SrcType>(src, y * width + x);
217
218
2/4
✓ Branch 0 taken 14732951 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 21518336 times.
✗ Branch 3 not taken.
36251287 min_value = std::min(min_value, value);
219
2/4
✓ Branch 0 taken 14732951 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 21518336 times.
✗ Branch 3 not taken.
36251287 max_value = std::max(max_value, value);
220 36251287 }
221 36251287 }
222
223 645006 const auto [scale, zero_point] =
224
2/4
✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 525168 times.
✗ Branch 3 not taken.
645006 get_scale_zero_point_from_range<SrcType, DstType, ZeroPointType>(min_value, max_value);
225
226 // Stores the scale and zero point.
227 1290012 write_array<ScaleType>(scales.data(), y * num_quant_packets_x + x_quant / quant_width, scale);
228 1290012 write_array<ZeroPointType>(zero_points.data(), y * num_quant_packets_x + x_quant / quant_width, zero_point);
229 645006 }
230 423998 }
231
232 13206 return {std::move(scales), std::move(zero_points)};
233 13206 }
234
235 template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType>
236 13206 Buffer quantize_asymmetric_per_block(
237 const void* src, const void* scales, const void* zero_points, size_t height, size_t width, size_t quant_width) {
238 static_assert(is_floating_point<SrcType>);
239 static_assert(is_integral<DstType>);
240 static_assert(is_floating_point<ScaleType>);
241 static_assert(is_integral<ZeroPointType>);
242
243 13206 const auto num_quant_packets_x = round_up_division(width, quant_width);
244
245 13206 const auto data_bytes = round_up_division(height * width * size_in_bits<DstType>, 8);
246 13206 Buffer data(data_bytes);
247
248
4/4
✓ Branch 0 taken 119838 times.
✓ Branch 1 taken 7974 times.
✓ Branch 2 taken 304160 times.
✓ Branch 3 taken 5232 times.
437204 for (size_t y = 0; y < height; ++y) {
249
4/4
✓ Branch 0 taken 119838 times.
✓ Branch 1 taken 119838 times.
✓ Branch 2 taken 304160 times.
✓ Branch 3 taken 525168 times.
1069004 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
250 645006 const auto scale = read_array<ScaleType>(scales, y * num_quant_packets_x + x_quant / quant_width);
251 1170174 const auto zero_point =
252
1/2
✓ Branch 0 taken 119838 times.
✗ Branch 1 not taken.
645006 read_array<ZeroPointType>(zero_points, y * num_quant_packets_x + x_quant / quant_width);
253
254 // Quantizes and stores the data.
255
4/4
✓ Branch 0 taken 14732951 times.
✓ Branch 1 taken 119838 times.
✓ Branch 2 taken 525168 times.
✓ Branch 3 taken 21518336 times.
36896293 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
256 36251287 const auto x = x_quant + x_element;
257
258
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14732951 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 21518336 times.
36251287 if (x < width) {
259 36251287 const auto value_f = read_array<SrcType>(src, y * width + x);
260 36251287 const auto value_q =
261
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14732951 times.
✓ Branch 2 taken 21518336 times.
✗ Branch 3 not taken.
36251287 quantize_asymmetric<SrcType, DstType, ZeroPointType>(value_f, scale, zero_point);
262
263
1/2
✓ Branch 0 taken 21518336 times.
✗ Branch 1 not taken.
36251287 write_array<DstType>(data.data(), y * width + x, value_q);
264 36251287 }
265 36251287 }
266 645006 }
267 423998 }
268
269 13206 return data;
270 13206 }
271
272 template <typename SrcType, typename DstType, typename ScaleType, typename ZeroPointType>
273 12279 std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic(
274 const void* src, size_t height, size_t width, size_t quant_width) {
275 /* Calculate the asymmetric quantization information, one scaling per row */
276 49116 auto [scales_src_type, zero_points] =
277 12279 compute_asymmetric_per_block_quantization_info<SrcType, DstType, SrcType, ZeroPointType>(
278 12279 src, height, width, quant_width);
279
280 /* Do the actual quantization */
281
2/6
✓ Branch 0 taken 7047 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 5232 times.
✗ Branch 5 not taken.
24558 auto data = quantize_asymmetric_per_block<SrcType, DstType, SrcType, ZeroPointType>(
282
6/18
✓ Branch 0 taken 7047 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 7047 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7047 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 5232 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 5232 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5232 times.
✗ Branch 17 not taken.
24558 src, scales_src_type.data(), zero_points.data(), height, width, quant_width);
283
284 if constexpr (std::is_same_v<ScaleType, SrcType>) {
285 12279 return {std::move(data), std::move(scales_src_type), std::move(zero_points)};
286 } else {
287 auto scales =
288 cast<ScaleType, SrcType>(scales_src_type.data(), scales_src_type.size() * 8 / size_in_bits<SrcType>);
289
290 return {std::move(data), std::move(scales), std::move(zero_points)};
291 }
292 12279 }
293
294 template std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(
295 const void* src, size_t height, size_t width, size_t quant_width);
296 template std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic<float, int8_t, BFloat16<>, int32_t>(
297 const void* src, size_t height, size_t width, size_t quant_width);
298 template std::tuple<Buffer, Buffer, Buffer> quantize_asymmetric_per_block_dynamic<float, Int4, float, int32_t>(
299 const void* src, size_t height, size_t width, size_t quant_width);
300
301 // Reference quantization and packing => Int4 per-block.
302 // * Generates signed values for reference matmul
303 // * Generates reference scales from input RHS matrix
304 template <typename SrcData, typename ScaleType>
305 1404 inline std::tuple<Buffer, Buffer> quantize_rhs_qsi4c32p(
306 size_t N, size_t K, size_t bl, const Buffer& rhs, bool transposed) {
307 4914 auto [rhs_values_qsi4, rhs_scales] =
308 1404 quantize_symmetric_per_block_dynamic<SrcData, Int4, ScaleType>(rhs.data(), N, K, bl);
309
310
2/2
✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
1404 const size_t width = transposed ? K : N;
311
2/2
✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
1404 const size_t height = transposed ? N : K;
312
313
1/2
✓ Branch 0 taken 1404 times.
✗ Branch 1 not taken.
1404 const size_t qsi4_stride = round_up_multiple(width, 2);
314
1/2
✓ Branch 0 taken 1404 times.
✗ Branch 1 not taken.
1404 const size_t qsi4_size_bytes = round_up_division(height * qsi4_stride, 2);
315
316
2/2
✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
1404 if (!transposed) {
317
3/6
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 702 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 702 times.
✗ Branch 5 not taken.
1404 rhs_values_qsi4 = transpose_with_padding<Int4>(rhs_values_qsi4.data(), N, K, K, qsi4_stride, qsi4_size_bytes);
318 702 }
319
320 1404 return {std::move(rhs_values_qsi4), std::move(rhs_scales)};
321 1404 }
322
323 template std::tuple<Buffer, Buffer> quantize_rhs_qsi4c32p<float, BFloat16<false>>(
324 size_t N, size_t K, size_t bl, const Buffer& rhs, bool transposed);
325
326 } // namespace kai::test
327