KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 91.4% 159 / 21 / 195
Functions: 85.7% 6 / 0 / 7
Branches: 58.1% 75 / 68 / 197

test/reference/pack.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/pack.hpp"
8
9 #include <arm_neon.h>
10
11 #include <algorithm>
12 #include <cstddef>
13 #include <cstdint>
14 #include <cstring>
15
16 #include "kai/kai_common.h"
17 #include "test/common/bfloat16.hpp"
18 #include "test/common/buffer.hpp"
19 #include "test/common/data_format.hpp"
20 #include "test/common/data_type.hpp"
21 #include "test/common/float16.hpp"
22 #include "test/common/memory.hpp"
23 #include "test/common/round.hpp"
24
25 namespace kai::test {
26
27 namespace {
28
29 4984005 BFloat16<> convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) {
30 KAI_ASSUME_ALWAYS((src_dtype == DataType::FP32 || src_dtype == DataType::FP16) && dst_dtype == DataType::BF16);
31
32
2/3
✗ Branch 0 not taken.
✓ Branch 1 taken 2689833 times.
✓ Branch 2 taken 2294172 times.
4984005 switch (src_dtype) {
33 case DataType::FP32:
34 2689833 return BFloat16<>(*reinterpret_cast<const float*>(src_ptr_elm));
35 case DataType::FP16:
36 2294172 return BFloat16<>(static_cast<float>(*reinterpret_cast<const Float16*>(src_ptr_elm)));
37 default:
38 KAI_ERROR("Unsupported Data Type");
39 }
40 4984005 }
41
42 594 Buffer pack_block(
43 const void* src, DataType src_dtype, DataType dst_dtype, size_t src_esize, size_t dst_esize, size_t full_height,
44 size_t full_width, size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) {
45 1188 const auto dst_bytes =
46 594 round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * dst_esize;
47
48 594 Buffer dst(dst_bytes, 0);
49
50 594 const auto* src_ptr = reinterpret_cast<const uint8_t*>(src);
51 594 auto* dst_ptr = dst.data();
52
53
2/2
✓ Branch 0 taken 6030 times.
✓ Branch 1 taken 594 times.
6624 for (size_t y_block = 0; y_block < full_height; y_block += block_height) {
54
2/2
✓ Branch 0 taken 96540 times.
✓ Branch 1 taken 83910 times.
180450 for (size_t x_block = 0; x_block < full_width; x_block += block_width) {
55
2/2
✓ Branch 0 taken 174420 times.
✓ Branch 1 taken 174420 times.
348840 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
56
2/2
✓ Branch 0 taken 174420 times.
✓ Branch 1 taken 174420 times.
348840 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
57
2/2
✓ Branch 0 taken 697656 times.
✓ Branch 1 taken 1097556 times.
1795212 for (size_t y_element = 0; y_element < subblock_height; ++y_element) {
58
2/2
✓ Branch 0 taken 79680 times.
✓ Branch 1 taken 1541112 times.
1620792 if (src_dtype == dst_dtype) {
59 79680 const size_t esize = dst_esize;
60
61
2/2
✓ Branch 0 taken 25758 times.
✓ Branch 1 taken 53922 times.
79680 if (y_block + y_subblock + y_element < full_height) {
62 53922 const size_t y_offset = (y_block + y_subblock + y_element) * full_width;
63 53922 const size_t x_offset = x_block + x_subblock;
64 53922 const size_t offset = y_offset + x_offset;
65
1/2
✓ Branch 0 taken 53922 times.
✗ Branch 1 not taken.
53922 const auto len = std::min(subblock_width, full_width - x_offset);
66
67 53922 memcpy(dst_ptr, src_ptr + offset * esize, len * esize);
68 53922 }
69
70 79680 dst_ptr += subblock_width * esize;
71
2/4
✗ Branch 0 not taken.
✗ Branch 0 not taken.
✓ Branch 1 taken 603996 times.
✓ Branch 1 taken 937116 times.
1620792 } else if (dst_esize == 2 /* 16 bits */) {
72
2/2
✓ Branch 0 taken 5498208 times.
✓ Branch 1 taken 1541112 times.
7039320 for (size_t x_element = 0; x_element < subblock_width; ++x_element) {
73
2/2
✓ Branch 0 taken 401784 times.
✓ Branch 1 taken 5096424 times.
5498208 if (y_block + y_subblock + y_element < full_height) {
74
2/2
✓ Branch 0 taken 112419 times.
✓ Branch 1 taken 4984005 times.
5096424 if (x_block + x_subblock + x_element < full_width) {
75 9968010 const uint8_t* src_ptr_elm = src_ptr +
76 9968010 ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock +
77 9968010 x_element) *
78 4984005 src_esize;
79
80
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4984005 times.
4984005 const BFloat16 src_value = convert(src_ptr_elm, src_dtype, dst_dtype);
81 4984005 memcpy(dst_ptr, &src_value, dst_esize);
82 4984005 }
83 5096424 }
84
85 5498208 dst_ptr += dst_esize;
86 5498208 }
87 1541112 }
88 1620792 }
89 174420 }
90 174420 }
91 174420 }
92 6030 }
93
94 KAI_ASSERT_ALWAYS(reinterpret_cast<uintptr_t>(dst_ptr) - reinterpret_cast<uintptr_t>(dst.data()) == dst_bytes);
95
96 594 return dst;
97 594 }
98
99 /// Packs the matrix from raw to per-row bias format.
100 663 Buffer pack_bias_per_row(
101 DataType src_dtype, DataType bias_dtype, DataType dst_dtype, size_t src_esize, size_t bias_esize, size_t dst_esize,
102 const void* src, const void* bias, size_t height, size_t width, size_t block_height, size_t block_width,
103 size_t subblock_height, size_t subblock_width) {
104 KAI_ASSUME_ALWAYS(src_dtype == bias_dtype);
105
106 663 const auto num_groups = (height + block_height - 1) / block_height;
107 663 const auto group_num_blocks = (width + block_width - 1) / block_width;
108 663 const auto group_bias_bytes = block_height * bias_esize;
109 663 const auto block_data_bytes = block_height * block_width * dst_esize;
110 663 const auto group_bytes = group_bias_bytes + group_num_blocks * block_data_bytes;
111 663 const auto dst_bytes = num_groups * group_bytes;
112
113 663 Buffer dst(dst_bytes, 0);
114
115 663 const auto* src_ptr = reinterpret_cast<const uint8_t*>(src);
116 663 const auto* bias_ptr = reinterpret_cast<const uint8_t*>(bias);
117 663 auto* dst_ptr = dst.data();
118
119
2/2
✓ Branch 0 taken 2247 times.
✓ Branch 1 taken 663 times.
2910 for (size_t y_block = 0; y_block < height; y_block += block_height) {
120 // Packs the bias.
121
1/2
✓ Branch 0 taken 2247 times.
✗ Branch 1 not taken.
2247 const auto bias_len = std::min(block_height, height - y_block);
122 2247 memcpy(dst_ptr, bias_ptr, bias_len * bias_esize);
123 2247 bias_ptr += block_height * bias_esize;
124 2247 dst_ptr += block_height * bias_esize;
125
126
2/2
✓ Branch 0 taken 50187 times.
✓ Branch 1 taken 2247 times.
52434 for (size_t x_block = 0; x_block < width; x_block += block_width) {
127
2/2
✓ Branch 0 taken 50187 times.
✓ Branch 1 taken 50187 times.
100374 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
128
2/2
✓ Branch 0 taken 50187 times.
✓ Branch 1 taken 81696 times.
131883 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
129
2/2
✓ Branch 0 taken 81696 times.
✓ Branch 1 taken 2838432 times.
2920128 for (size_t y_element = 0; y_element < subblock_height; ++y_element) {
130
1/2
✓ Branch 0 taken 2838432 times.
✗ Branch 1 not taken.
2838432 if (src_dtype == dst_dtype) {
131 2838432 const size_t esize = dst_esize;
132
2/2
✓ Branch 0 taken 398862 times.
✓ Branch 1 taken 2439570 times.
2838432 if (y_block + y_subblock + y_element < height) {
133
1/2
✓ Branch 0 taken 2439570 times.
✗ Branch 1 not taken.
2439570 const auto len = std::min(subblock_width, width - x_block - x_subblock);
134
135 2439570 memcpy(
136 2439570 dst_ptr,
137 4879140 src_ptr +
138 2439570 ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * esize,
139 2439570 len * esize);
140 2439570 }
141
142 2838432 dst_ptr += subblock_width * esize;
143
0/2
✗ Branch 0 not taken.
✗ Branch 1 not taken.
2838432 } else if (dst_esize == 2 /* 16 bits */) {
144 for (size_t x_element = 0; x_element < subblock_width; ++x_element) {
145 if (y_block + y_subblock + y_element < height) {
146 if (x_block + x_subblock + x_element < width) {
147 const uint8_t* src_ptr_elm = src_ptr +
148 ((y_block + y_subblock + y_element) * width + x_block + x_subblock +
149 x_element) *
150 src_esize;
151
152 const BFloat16 dst_value = convert(src_ptr_elm, src_dtype, dst_dtype);
153 memcpy(dst_ptr, &dst_value, dst_esize);
154 }
155 }
156
157 dst_ptr += dst_esize;
158 }
159 }
160 2838432 }
161 81696 }
162 50187 }
163 50187 }
164 2247 }
165
166 KAI_ASSERT_ALWAYS(reinterpret_cast<uintptr_t>(dst_ptr) - reinterpret_cast<uintptr_t>(dst.data()) == dst_bytes);
167
168 663 return dst;
169 663 }
170
171 } // namespace
172
173 1257 Buffer pack(
174 const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* bias,
175 const DataFormat& src_format, size_t height, size_t width) {
176 1257 const auto dst_dt = dst_format.data_type();
177 1257 const auto dst_qf = dst_format.pack_format();
178 1257 const auto src_dt = src_format.data_type();
179 1257 const auto src_qf = src_format.pack_format();
180
181 1257 const auto block_height = dst_format.actual_block_height(height);
182 1257 const auto block_width = dst_format.actual_block_width(width);
183 1257 const auto subblock_height = dst_format.actual_subblock_height(height);
184 1257 const auto subblock_width = dst_format.actual_subblock_width(width);
185
186
3/4
✓ Branch 0 taken 1257 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✓ Branch 3 taken 663 times.
1257 if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) {
187 KAI_ASSUME_ALWAYS(
188 (src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16) ||
189 (src_dt == DataType::FP16 && dst_dt == DataType::BF16));
190
191 663 const auto src_esize = data_type_size_in_bits(src_dt);
192 663 const auto dst_esize = data_type_size_in_bits(dst_dt);
193 663 const auto bias_esize = data_type_size_in_bits(dst_format.zero_point_data_type());
194 663 const auto bias_dt = dst_format.zero_point_data_type();
195
196 KAI_ASSUME_ALWAYS(dst_esize % 8 == 0 && bias_esize % 8 == 0 && src_esize % 8 == 0);
197
198 663 return pack_bias_per_row(
199 663 src_dt, bias_dt, dst_dt, src_esize / 8, bias_esize / 8, dst_esize / 8, src, bias, height, width,
200 663 block_height, block_width, subblock_height, subblock_width);
201 663 }
202
203
1/2
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
594 if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) {
204 KAI_ASSUME_ALWAYS(
205 (src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16) ||
206 (src_dt == DataType::FP16 && dst_dt == DataType::BF16));
207
208 594 const auto dst_esize = data_type_size_in_bits(dst_dt);
209 594 const auto src_esize = data_type_size_in_bits(src_dt);
210
211 KAI_ASSUME_ALWAYS(src_esize % 8 == 0 && dst_esize % 8 == 0);
212
213 594 return pack_block(
214 594 src, src_dt, dst_dt, src_esize / 8, dst_esize / 8, height, width, block_height, block_width,
215 594 subblock_height, subblock_width);
216 594 }
217
218 KAI_ERROR("Unsupported operation!");
219 1257 }
220
221 template <typename Data, typename Scale>
222 Buffer pack_data_scales(const void* data, const void* scales, size_t height, size_t width, size_t quant_width) {
223 KAI_ASSUME_ALWAYS_IF(size_in_bits<Data> < 8, quant_width % (8 / size_in_bits<Data>) == 0);
224 KAI_ASSUME_ALWAYS_IF(size_in_bits<Data> < 8, width % (8 / size_in_bits<Data>) == 0);
225
226 const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width;
227
228 const auto data_bytes = height * width * size_in_bits<Data> / 8;
229 const auto scales_bytes = height * num_quant_packets_x * sizeof(Scale);
230
231 Buffer dst(data_bytes + scales_bytes);
232
233 const auto* scales_ptr = reinterpret_cast<const Scale*>(scales);
234 auto* dst_ptr = dst.data();
235
236 for (size_t y = 0; y < height; ++y) {
237 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
238 write_array(dst_ptr, 0, *scales_ptr);
239 dst_ptr += sizeof(Scale);
240 ++scales_ptr;
241
242 const auto len = std::min(x_quant + quant_width, width) - x_quant;
243
244 for (size_t x_element = 0; x_element < len; ++x_element) {
245 const auto x = x_quant + x_element;
246 write_array(dst_ptr, x_element, read_array<Data>(data, y * width + x));
247 }
248
249 dst_ptr += len * size_in_bits<Data> / 8;
250 }
251 }
252
253 KAI_ASSERT_ALWAYS(dst_ptr == dst.data() + dst.size());
254
255 return dst;
256 }
257
258 template <typename ZeroPoint, typename Data, typename Scale>
259 521 Buffer pack_zero_points_data_scales_per_block(
260 const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points,
261 size_t block_num_data, size_t block_num_scales) {
262 // Only data is allowed to be sub-byte.
263 KAI_ASSUME_ALWAYS(size_in_bits<ZeroPoint> % 8 == 0);
264 KAI_ASSUME_ALWAYS(size_in_bits<Scale> % 8 == 0);
265
266 // Checks for memory alignment.
267 KAI_ASSUME_ALWAYS(size_in_bits<ZeroPoint> % size_in_bits<Data> == 0);
268 KAI_ASSUME_ALWAYS(
269 (block_num_zero_points * size_in_bits<ZeroPoint> + block_num_data * size_in_bits<Data>) % size_in_bits<Scale> ==
270 0);
271 KAI_ASSUME_ALWAYS(
272 (block_num_data * size_in_bits<Data> + block_num_scales * size_in_bits<Scale>) % size_in_bits<ZeroPoint> == 0);
273
274 1042 Buffer dst(round_up_division(
275 1042 num_blocks *
276 1042 (block_num_zero_points * size_in_bits<ZeroPoint> + block_num_data * size_in_bits<Data> +
277 521 block_num_scales * size_in_bits<Scale>),
278 8));
279
1/2
✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
521 auto* dst_ptr = dst.data();
280
281
2/2
✓ Branch 0 taken 521 times.
✓ Branch 1 taken 2101 times.
2622 for (size_t block_no = 0; block_no < num_blocks; ++block_no) {
282
2/2
✓ Branch 0 taken 67232 times.
✓ Branch 1 taken 2101 times.
69333 for (size_t i = 0; i < block_num_zero_points; ++i) {
283
1/2
✓ Branch 0 taken 67232 times.
✗ Branch 1 not taken.
67232 write_array<ZeroPoint>(
284
1/2
✓ Branch 0 taken 67232 times.
✗ Branch 1 not taken.
67232 dst_ptr, i, read_array<ZeroPoint>(zero_points, block_no * block_num_zero_points + i));
285 67232 }
286 2101 dst_ptr += block_num_zero_points * sizeof(ZeroPoint);
287
288
2/2
✓ Branch 0 taken 7001728 times.
✓ Branch 1 taken 2101 times.
7003829 for (size_t i = 0; i < block_num_data; ++i) {
289
2/4
✓ Branch 0 taken 7001728 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 7001728 times.
✗ Branch 3 not taken.
7001728 write_array<Data>(dst_ptr, i, read_array<Data>(data, block_no * block_num_data + i));
290 7001728 }
291
1/2
✓ Branch 0 taken 2101 times.
✗ Branch 1 not taken.
2101 dst_ptr += round_up_division(block_num_data * size_in_bits<Data>, 8);
292
293
2/2
✓ Branch 0 taken 2101 times.
✓ Branch 1 taken 67232 times.
69333 for (size_t i = 0; i < block_num_scales; ++i) {
294
2/4
✓ Branch 0 taken 67232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 67232 times.
✗ Branch 3 not taken.
67232 write_array<Scale>(dst_ptr, i, read_array<Scale>(scales, block_no * block_num_scales + i));
295 67232 }
296 2101 dst_ptr += block_num_scales * sizeof(Scale);
297 2101 }
298
299 KAI_ASSERT_ALWAYS(dst_ptr == dst.data() + dst.size());
300
301 521 return dst;
302 521 }
303
304 template Buffer pack_zero_points_data_scales_per_block<int32_t, int8_t, float>(
305 const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points,
306 size_t block_num_data, size_t block_num_scales);
307
308 template <typename Data, typename Scale>
309 920 Buffer pack_data_scales_interleave_block(
310 const void* data, const void* scales, size_t height, size_t width, size_t quant_width) {
311 KAI_ASSUME_ALWAYS_IF(size_in_bits<Data> < 8, quant_width % (8 / size_in_bits<Data>) == 0);
312 KAI_ASSUME_ALWAYS_IF(size_in_bits<Data> < 8, width % (8 / size_in_bits<Data>) == 0);
313 KAI_ASSUME_ALWAYS(width % quant_width == 0);
314 KAI_ASSUME_ALWAYS(quant_width % 2 == 0);
315
316 920 const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width;
317
318 920 const auto data_bytes = height * width * size_in_bits<Data> / 8;
319
1/4
✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
920 const auto scales_bytes = scales != nullptr ? height * num_quant_packets_x * sizeof(Scale) : 0;
320
321 920 Buffer dst(data_bytes + scales_bytes);
322
323 920 const auto* scales_ptr = reinterpret_cast<const Scale*>(scales);
324
1/4
✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
920 auto* dst_ptr = dst.data();
325
326
2/4
✓ Branch 0 taken 920 times.
✓ Branch 1 taken 36460 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
37380 for (size_t y = 0; y < height; ++y) {
327
2/4
✓ Branch 0 taken 61920 times.
✓ Branch 1 taken 36460 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
98380 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
328
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 61920 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
61920 if (scales_ptr != nullptr) {
329
1/4
✓ Branch 0 taken 61920 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
61920 write_array(dst_ptr, 0, *scales_ptr);
330 61920 dst_ptr += sizeof(Scale);
331 61920 ++scales_ptr;
332 61920 }
333
334
2/4
✓ Branch 0 taken 61920 times.
✓ Branch 1 taken 1981440 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
2043360 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
335
2/4
✓ Branch 0 taken 990720 times.
✓ Branch 1 taken 990720 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1981440 const auto x = x_quant + x_element / 2 + (x_element % 2 != 0 ? quant_width / 2 : 0);
336
2/8
✓ Branch 0 taken 1981440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1981440 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
1981440 write_array(dst_ptr, x_element, read_array<Data>(data, y * width + x));
337 1981440 }
338
339 61920 dst_ptr += quant_width * size_in_bits<Data> / 8;
340 61920 }
341 36460 }
342
343 KAI_ASSERT_ALWAYS(dst_ptr == dst.data() + dst.size());
344
345 920 return dst;
346 920 }
347
348 template Buffer pack_data_scales_interleave_block<UInt4, Float16>(
349 const void* data, const void* scales, size_t height, size_t width, size_t quant_width);
350 template Buffer pack_data_scales_interleave_block<UInt4, std::nullptr_t>(
351 const void* data, const void* scales, size_t height, size_t width, size_t quant_width);
352
353 template <typename Data, typename ZeroPoint, typename Scale, typename Bias>
354 Buffer pack_block_data_zero_points_scale_bias(
355 const void* data, const void* zero_points, const void* scales, const void* biases, size_t height, size_t width,
356 size_t quant_height, size_t quant_width, size_t block_height, size_t block_width, size_t interleave_x_blocks) {
357 if (quant_width == width) {
358 quant_width = round_up_multiple(quant_width, block_width);
359 }
360
361 KAI_ASSERT_ALWAYS(quant_height == block_height);
362 KAI_ASSERT_ALWAYS(quant_width % block_width == 0);
363
364 if (interleave_x_blocks == 0) {
365 interleave_x_blocks = quant_width / block_width;
366 }
367
368 const auto has_zero_points = zero_points != nullptr;
369 const auto has_biases = biases != nullptr;
370
371 const auto num_quant_packets_y = round_up_division(height, quant_height);
372 const auto num_quant_packets_x = round_up_division(width, quant_width);
373
374 const auto quant_packet_data_bytes = quant_height * quant_width * size_in_bits<Data> / 8;
375 const auto quant_packet_zero_points_bytes = has_zero_points ? quant_height * sizeof(ZeroPoint) : 0;
376 const auto quant_packet_scales_bytes = quant_height * sizeof(Scale);
377 const auto quant_packet_bytes =
378 quant_packet_zero_points_bytes + quant_packet_data_bytes + quant_packet_scales_bytes;
379
380 const auto num_quant_packets_per_row = round_up_division(width, quant_width);
381 const auto biases_bytes = has_biases ? height * sizeof(Bias) : 0;
382
383 const auto dst_bytes = num_quant_packets_y * num_quant_packets_x * quant_packet_bytes + biases_bytes;
384 Buffer dst(dst_bytes);
385
386 const auto* zero_points_ptr = reinterpret_cast<const ZeroPoint*>(zero_points);
387 const auto* scales_ptr = reinterpret_cast<const Scale*>(scales);
388 const auto* biases_ptr = reinterpret_cast<const Bias*>(biases);
389 auto* dst_ptr = dst.data();
390
391 for (size_t y_quant = 0; y_quant < height; y_quant += quant_height) {
392 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
393 size_t dst_index = 0;
394
395 // Packs the data.
396 for (size_t y_pack = 0; y_pack < quant_height; y_pack += block_height) {
397 for (size_t x_pack = 0; x_pack < block_width * interleave_x_blocks; x_pack += block_width) {
398 for (size_t y_element = 0; y_element < block_height; ++y_element) {
399 for (size_t x_element = 0; x_element < block_width; ++x_element) {
400 for (size_t x_interleave = 0; x_interleave < quant_width;
401 x_interleave += block_width * interleave_x_blocks) {
402 const auto y = y_quant + y_pack + y_element;
403 const auto x = x_quant + x_pack + x_element + x_interleave;
404
405 if (y < height && x < width) {
406 write_array(dst_ptr, dst_index, read_array<Data>(data, y * width + x));
407 }
408
409 ++dst_index;
410 }
411 }
412 }
413 }
414 }
415
416 dst_ptr += dst_index * size_in_bits<Data> / 8;
417
418 // Packs the zero points.
419 if (has_zero_points) {
420 for (size_t y_element = 0; y_element < quant_height; ++y_element) {
421 const auto y = y_quant + y_element;
422 const auto x = x_quant / quant_width;
423 memcpy(dst_ptr, &zero_points_ptr[y * num_quant_packets_per_row + x], sizeof(ZeroPoint));
424 dst_ptr += sizeof(ZeroPoint);
425 }
426 }
427
428 // Packs the scales.
429 for (size_t y_element = 0; y_element < quant_height; ++y_element) {
430 const auto y = y_quant + y_element;
431 const auto x = x_quant / quant_width;
432 memcpy(dst_ptr, &scales_ptr[y * num_quant_packets_per_row + x], sizeof(Scale));
433 dst_ptr += sizeof(Scale);
434 }
435 }
436
437 // Packs the biases.
438 if (has_biases) {
439 for (size_t y_element = 0; y_element < quant_height; ++y_element) {
440 const auto y = y_quant + y_element;
441 memcpy(dst_ptr, &biases_ptr[y], sizeof(Bias));
442 dst_ptr += sizeof(Bias);
443 }
444 }
445 }
446
447 KAI_ASSERT_ALWAYS(dst_ptr == dst.data() + dst.size());
448
449 return dst;
450 }
451
452 } // namespace kai::test
453