KleidiAI Coverage Report


Directory: ./
File: test/reference/pack.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 91.4% 159 21 195
Functions: 85.7% 6 0 7
Branches: 58.3% 74 78 205

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/pack.hpp"
8
9 #include <arm_neon.h>
10
11 #include <algorithm>
12 #include <cstddef>
13 #include <cstdint>
14 #include <cstring>
15
16 #include "kai/kai_common.h"
17 #include "test/common/bfloat16.hpp"
18 #include "test/common/buffer.hpp"
19 #include "test/common/data_format.hpp"
20 #include "test/common/data_type.hpp"
21 #include "test/common/float16.hpp"
22 #include "test/common/memory.hpp"
23 #include "test/common/round.hpp"
24
25 namespace kai::test {
26
27 namespace {
28
29 918733 BFloat16<> convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) {
30 KAI_ASSUME((src_dtype == DataType::FP32 || src_dtype == DataType::FP16) && dst_dtype == DataType::BF16);
31
32
2/3
✗ Branch 0 not taken.
✓ Branch 1 taken 536371 times.
✓ Branch 2 taken 382362 times.
918733 switch (src_dtype) {
33 case DataType::FP32:
34 536371 return BFloat16<>(*reinterpret_cast<const float*>(src_ptr_elm));
35 case DataType::FP16:
36 382362 return BFloat16<>(static_cast<float>(*reinterpret_cast<const Float16*>(src_ptr_elm)));
37 default:
38 KAI_ERROR("Unsupported Data Type");
39 }
40 918733 }
41
42 116 Buffer pack_block(
43 const void* src, DataType src_dtype, DataType dst_dtype, size_t src_esize, size_t dst_esize, size_t full_height,
44 size_t full_width, size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) {
45 232 const auto dst_bytes =
46 116 round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * dst_esize;
47
48 116 Buffer dst(dst_bytes, 0);
49
50 116 const auto* src_ptr = reinterpret_cast<const uint8_t*>(src);
51 116 auto* dst_ptr = dst.data();
52
53
2/2
✓ Branch 0 taken 1050 times.
✓ Branch 1 taken 116 times.
1166 for (size_t y_block = 0; y_block < full_height; y_block += block_height) {
54
2/2
✓ Branch 0 taken 31220 times.
✓ Branch 1 taken 1050 times.
32270 for (size_t x_block = 0; x_block < full_width; x_block += block_width) {
55
2/2
✓ Branch 0 taken 31220 times.
✓ Branch 1 taken 31220 times.
62440 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
56
2/2
✓ Branch 0 taken 31220 times.
✓ Branch 1 taken 31220 times.
62440 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
57
2/2
✓ Branch 0 taken 31220 times.
✓ Branch 1 taken 338932 times.
370152 for (size_t y_element = 0; y_element < subblock_height; ++y_element) {
58
2/2
✓ Branch 0 taken 26560 times.
✓ Branch 1 taken 312372 times.
338932 if (src_dtype == dst_dtype) {
59 26560 const size_t esize = dst_esize;
60
61
2/2
✓ Branch 0 taken 8586 times.
✓ Branch 1 taken 17974 times.
26560 if (y_block + y_subblock + y_element < full_height) {
62 17974 const size_t y_offset = (y_block + y_subblock + y_element) * full_width;
63 17974 const size_t x_offset = x_block + x_subblock;
64 17974 const size_t offset = y_offset + x_offset;
65
1/2
✓ Branch 0 taken 17974 times.
✗ Branch 1 not taken.
17974 const auto len = std::min(subblock_width, full_width - x_offset);
66
67 17974 memcpy(dst_ptr, src_ptr + offset * esize, len * esize);
68 17974 }
69
70 26560 dst_ptr += subblock_width * esize;
71
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 312372 times.
338932 } else if (dst_esize == 2 /* 16 bits */) {
72
2/2
✓ Branch 0 taken 1027408 times.
✓ Branch 1 taken 312372 times.
1339780 for (size_t x_element = 0; x_element < subblock_width; ++x_element) {
73
2/2
✓ Branch 0 taken 89248 times.
✓ Branch 1 taken 938160 times.
1027408 if (y_block + y_subblock + y_element < full_height) {
74
2/2
✓ Branch 0 taken 19427 times.
✓ Branch 1 taken 918733 times.
938160 if (x_block + x_subblock + x_element < full_width) {
75 1837466 const uint8_t* src_ptr_elm = src_ptr +
76 1837466 ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock +
77 1837466 x_element) *
78 918733 src_esize;
79
80
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 918733 times.
918733 const BFloat16 src_value = convert(src_ptr_elm, src_dtype, dst_dtype);
81 918733 memcpy(dst_ptr, &src_value, dst_esize);
82 918733 }
83 938160 }
84
85 1027408 dst_ptr += dst_esize;
86 1027408 }
87 312372 }
88 338932 }
89 31220 }
90 31220 }
91 31220 }
92 1050 }
93
94 KAI_ASSERT(reinterpret_cast<uintptr_t>(dst_ptr) - reinterpret_cast<uintptr_t>(dst.data()) == dst_bytes);
95
96 116 return dst;
97 116 }
98
99 /// Packs the matrix from raw to per-row bias format.
100 125 Buffer pack_bias_per_row(
101 DataType src_dtype, DataType bias_dtype, DataType dst_dtype, size_t src_esize, size_t bias_esize, size_t dst_esize,
102 const void* src, const void* bias, size_t height, size_t width, size_t block_height, size_t block_width,
103 size_t subblock_height, size_t subblock_width) {
104 KAI_ASSUME(src_dtype == bias_dtype);
105
106 125 const auto num_groups = (height + block_height - 1) / block_height;
107 125 const auto group_num_blocks = (width + block_width - 1) / block_width;
108 125 const auto group_bias_bytes = block_height * bias_esize;
109 125 const auto block_data_bytes = block_height * block_width * dst_esize;
110 125 const auto group_bytes = group_bias_bytes + group_num_blocks * block_data_bytes;
111 125 const auto dst_bytes = num_groups * group_bytes;
112
113 125 Buffer dst(dst_bytes, 0);
114
115 125 const auto* src_ptr = reinterpret_cast<const uint8_t*>(src);
116 125 const auto* bias_ptr = reinterpret_cast<const uint8_t*>(bias);
117 125 auto* dst_ptr = dst.data();
118
119
2/2
✓ Branch 0 taken 548 times.
✓ Branch 1 taken 125 times.
673 for (size_t y_block = 0; y_block < height; y_block += block_height) {
120 // Packs the bias.
121
1/2
✓ Branch 0 taken 548 times.
✗ Branch 1 not taken.
548 const auto bias_len = std::min(block_height, height - y_block);
122 548 memcpy(dst_ptr, bias_ptr, bias_len * bias_esize);
123 548 bias_ptr += block_height * bias_esize;
124 548 dst_ptr += block_height * bias_esize;
125
126
2/2
✓ Branch 0 taken 16528 times.
✓ Branch 1 taken 548 times.
17076 for (size_t x_block = 0; x_block < width; x_block += block_width) {
127
2/2
✓ Branch 0 taken 16528 times.
✓ Branch 1 taken 16528 times.
33056 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
128
2/2
✓ Branch 0 taken 16528 times.
✓ Branch 1 taken 20112 times.
36640 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
129
2/2
✓ Branch 0 taken 20112 times.
✓ Branch 1 taken 806504 times.
826616 for (size_t y_element = 0; y_element < subblock_height; ++y_element) {
130
1/2
✓ Branch 0 taken 806504 times.
✗ Branch 1 not taken.
806504 if (src_dtype == dst_dtype) {
131 806504 const size_t esize = dst_esize;
132
2/2
✓ Branch 0 taken 110098 times.
✓ Branch 1 taken 696406 times.
806504 if (y_block + y_subblock + y_element < height) {
133
1/2
✓ Branch 0 taken 696406 times.
✗ Branch 1 not taken.
696406 const auto len = std::min(subblock_width, width - x_block - x_subblock);
134
135 696406 memcpy(
136 696406 dst_ptr,
137 1392812 src_ptr +
138 696406 ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * esize,
139 696406 len * esize);
140 696406 }
141
142 806504 dst_ptr += subblock_width * esize;
143
0/2
✗ Branch 0 not taken.
✗ Branch 1 not taken.
806504 } else if (dst_esize == 2 /* 16 bits */) {
144 for (size_t x_element = 0; x_element < subblock_width; ++x_element) {
145 if (y_block + y_subblock + y_element < height) {
146 if (x_block + x_subblock + x_element < width) {
147 const uint8_t* src_ptr_elm = src_ptr +
148 ((y_block + y_subblock + y_element) * width + x_block + x_subblock +
149 x_element) *
150 src_esize;
151
152 const BFloat16 dst_value = convert(src_ptr_elm, src_dtype, dst_dtype);
153 memcpy(dst_ptr, &dst_value, dst_esize);
154 }
155 }
156
157 dst_ptr += dst_esize;
158 }
159 }
160 806504 }
161 20112 }
162 16528 }
163 16528 }
164 548 }
165
166 KAI_ASSERT(reinterpret_cast<uintptr_t>(dst_ptr) - reinterpret_cast<uintptr_t>(dst.data()) == dst_bytes);
167
168 125 return dst;
169 125 }
170
171 } // namespace
172
173 241 Buffer pack(
174 const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* bias,
175 const DataFormat& src_format, size_t height, size_t width) {
176 241 const auto dst_dt = dst_format.data_type();
177 241 const auto dst_qf = dst_format.pack_format();
178 241 const auto src_dt = src_format.data_type();
179 241 const auto src_qf = src_format.pack_format();
180
181 241 const auto block_height = dst_format.actual_block_height(height);
182 241 const auto block_width = dst_format.actual_block_width(width);
183 241 const auto subblock_height = dst_format.actual_subblock_height(height);
184 241 const auto subblock_width = dst_format.actual_subblock_width(width);
185
186
3/4
✓ Branch 0 taken 241 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 116 times.
✓ Branch 3 taken 125 times.
241 if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) {
187 KAI_ASSUME(
188 (src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16) ||
189 (src_dt == DataType::FP16 && dst_dt == DataType::BF16));
190
191 125 const auto src_esize = data_type_size_in_bits(src_dt);
192 125 const auto dst_esize = data_type_size_in_bits(dst_dt);
193 125 const auto bias_esize = data_type_size_in_bits(dst_format.zero_point_data_type());
194 125 const auto bias_dt = dst_format.zero_point_data_type();
195
196 KAI_ASSUME(dst_esize % 8 == 0 && bias_esize % 8 == 0 && src_esize % 8 == 0);
197
198 125 return pack_bias_per_row(
199 125 src_dt, bias_dt, dst_dt, src_esize / 8, bias_esize / 8, dst_esize / 8, src, bias, height, width,
200 125 block_height, block_width, subblock_height, subblock_width);
201 125 }
202
203
1/2
✓ Branch 0 taken 116 times.
✗ Branch 1 not taken.
116 if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) {
204 KAI_ASSUME(
205 (src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16) ||
206 (src_dt == DataType::FP16 && dst_dt == DataType::BF16));
207
208 116 const auto dst_esize = data_type_size_in_bits(dst_dt);
209 116 const auto src_esize = data_type_size_in_bits(src_dt);
210
211 KAI_ASSUME(src_esize % 8 == 0 && dst_esize % 8 == 0);
212
213 116 return pack_block(
214 116 src, src_dt, dst_dt, src_esize / 8, dst_esize / 8, height, width, block_height, block_width,
215 116 subblock_height, subblock_width);
216 116 }
217
218 KAI_ERROR("Unsupported operation!");
219 241 }
220
221 template <typename Data, typename Scale>
222 Buffer pack_data_scales(const void* data, const void* scales, size_t height, size_t width, size_t quant_width) {
223 KAI_ASSUME_IF(size_in_bits<Data> < 8, quant_width % (8 / size_in_bits<Data>) == 0);
224 KAI_ASSUME_IF(size_in_bits<Data> < 8, width % (8 / size_in_bits<Data>) == 0);
225
226 const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width;
227
228 const auto data_bytes = height * width * size_in_bits<Data> / 8;
229 const auto scales_bytes = height * num_quant_packets_x * sizeof(Scale);
230
231 Buffer dst(data_bytes + scales_bytes);
232
233 const auto* scales_ptr = reinterpret_cast<const Scale*>(scales);
234 auto* dst_ptr = dst.data();
235
236 for (size_t y = 0; y < height; ++y) {
237 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
238 write_array(dst_ptr, 0, *scales_ptr);
239 dst_ptr += sizeof(Scale);
240 ++scales_ptr;
241
242 const auto len = std::min(x_quant + quant_width, width) - x_quant;
243
244 for (size_t x_element = 0; x_element < len; ++x_element) {
245 const auto x = x_quant + x_element;
246 write_array(dst_ptr, x_element, read_array<Data>(data, y * width + x));
247 }
248
249 dst_ptr += len * size_in_bits<Data> / 8;
250 }
251 }
252
253 KAI_ASSERT(dst_ptr == dst.data() + dst.size());
254
255 return dst;
256 }
257
258 template <typename ZeroPoint, typename Data, typename Scale>
259 927 Buffer pack_zero_points_data_scales_per_block(
260 const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points,
261 size_t block_num_data, size_t block_num_scales) {
262 // Only data is allowed to be sub-byte.
263 KAI_ASSUME(size_in_bits<ZeroPoint> % 8 == 0);
264 KAI_ASSUME(size_in_bits<Scale> % 8 == 0);
265
266 // Checks for memory alignment.
267 KAI_ASSUME(size_in_bits<ZeroPoint> % size_in_bits<Data> == 0);
268 KAI_ASSUME(
269 (block_num_zero_points * size_in_bits<ZeroPoint> + block_num_data * size_in_bits<Data>) % size_in_bits<Scale> ==
270 0);
271 KAI_ASSUME(
272 (block_num_data * size_in_bits<Data> + block_num_scales * size_in_bits<Scale>) % size_in_bits<ZeroPoint> == 0);
273
274 1854 Buffer dst(round_up_division(
275 1854 num_blocks *
276 1854 (block_num_zero_points * size_in_bits<ZeroPoint> + block_num_data * size_in_bits<Data> +
277 927 block_num_scales * size_in_bits<Scale>),
278 8));
279
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 auto* dst_ptr = dst.data();
280
281
2/2
✓ Branch 0 taken 927 times.
✓ Branch 1 taken 3651 times.
4578 for (size_t block_no = 0; block_no < num_blocks; ++block_no) {
282
2/2
✓ Branch 0 taken 116832 times.
✓ Branch 1 taken 3651 times.
120483 for (size_t i = 0; i < block_num_zero_points; ++i) {
283
1/2
✓ Branch 0 taken 116832 times.
✗ Branch 1 not taken.
116832 write_array<ZeroPoint>(
284
1/2
✓ Branch 0 taken 116832 times.
✗ Branch 1 not taken.
116832 dst_ptr, i, read_array<ZeroPoint>(zero_points, block_no * block_num_zero_points + i));
285 116832 }
286 3651 dst_ptr += block_num_zero_points * sizeof(ZeroPoint);
287
288
2/2
✓ Branch 0 taken 19128576 times.
✓ Branch 1 taken 3651 times.
19132227 for (size_t i = 0; i < block_num_data; ++i) {
289
2/4
✓ Branch 0 taken 19128576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 19128576 times.
✗ Branch 3 not taken.
19128576 write_array<Data>(dst_ptr, i, read_array<Data>(data, block_no * block_num_data + i));
290 19128576 }
291
1/2
✓ Branch 0 taken 3651 times.
✗ Branch 1 not taken.
3651 dst_ptr += round_up_division(block_num_data * size_in_bits<Data>, 8);
292
293
2/2
✓ Branch 0 taken 3651 times.
✓ Branch 1 taken 116832 times.
120483 for (size_t i = 0; i < block_num_scales; ++i) {
294
2/4
✓ Branch 0 taken 116832 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 116832 times.
✗ Branch 3 not taken.
116832 write_array<Scale>(dst_ptr, i, read_array<Scale>(scales, block_no * block_num_scales + i));
295 116832 }
296 3651 dst_ptr += block_num_scales * sizeof(Scale);
297 3651 }
298
299 KAI_ASSERT(dst_ptr == dst.data() + dst.size());
300
301 927 return dst;
302 927 }
303
304 template Buffer pack_zero_points_data_scales_per_block<int32_t, int8_t, float>(
305 const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points,
306 size_t block_num_data, size_t block_num_scales);
307
308 template <typename Data, typename Scale>
309 124 Buffer pack_data_scales_interleave_block(
310 const void* data, const void* scales, size_t height, size_t width, size_t quant_width) {
311 KAI_ASSUME_IF(size_in_bits<Data> < 8, quant_width % (8 / size_in_bits<Data>) == 0);
312 KAI_ASSUME_IF(size_in_bits<Data> < 8, width % (8 / size_in_bits<Data>) == 0);
313 KAI_ASSUME(width % quant_width == 0);
314 KAI_ASSUME(quant_width % 2 == 0);
315
316 124 const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width;
317
318 124 const auto data_bytes = height * width * size_in_bits<Data> / 8;
319
1/4
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
124 const auto scales_bytes = scales != nullptr ? height * num_quant_packets_x * sizeof(Scale) : 0;
320
321 124 Buffer dst(data_bytes + scales_bytes);
322
323 124 const auto* scales_ptr = reinterpret_cast<const Scale*>(scales);
324
1/4
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
124 auto* dst_ptr = dst.data();
325
326
2/4
✓ Branch 0 taken 124 times.
✓ Branch 1 taken 5228 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
5352 for (size_t y = 0; y < height; ++y) {
327
2/4
✓ Branch 0 taken 9768 times.
✓ Branch 1 taken 5228 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
14996 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
328
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 9768 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
9768 if (scales_ptr != nullptr) {
329
1/4
✓ Branch 0 taken 9768 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
9768 write_array(dst_ptr, 0, *scales_ptr);
330 9768 dst_ptr += sizeof(Scale);
331 9768 ++scales_ptr;
332 9768 }
333
334
2/4
✓ Branch 0 taken 9768 times.
✓ Branch 1 taken 312576 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
322344 for (size_t x_element = 0; x_element < quant_width; ++x_element) {
335
2/4
✓ Branch 0 taken 156288 times.
✓ Branch 1 taken 156288 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
312576 const auto x = x_quant + x_element / 2 + (x_element % 2 != 0 ? quant_width / 2 : 0);
336
2/8
✓ Branch 0 taken 312576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 312576 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
312576 write_array(dst_ptr, x_element, read_array<Data>(data, y * width + x));
337 312576 }
338
339 9768 dst_ptr += quant_width * size_in_bits<Data> / 8;
340 9768 }
341 5228 }
342
343 KAI_ASSERT(dst_ptr == dst.data() + dst.size());
344
345 124 return dst;
346 124 }
347
348 template Buffer pack_data_scales_interleave_block<UInt4, Float16>(
349 const void* data, const void* scales, size_t height, size_t width, size_t quant_width);
350 template Buffer pack_data_scales_interleave_block<UInt4, std::nullptr_t>(
351 const void* data, const void* scales, size_t height, size_t width, size_t quant_width);
352
353 template <typename Data, typename ZeroPoint, typename Scale, typename Bias>
354 Buffer pack_block_data_zero_points_scale_bias(
355 const void* data, const void* zero_points, const void* scales, const void* biases, size_t height, size_t width,
356 size_t quant_height, size_t quant_width, size_t block_height, size_t block_width, size_t interleave_x_blocks) {
357 if (quant_width == width) {
358 quant_width = round_up_multiple(quant_width, block_width);
359 }
360
361 KAI_ASSERT(quant_height == block_height);
362 KAI_ASSERT(quant_width % block_width == 0);
363
364 if (interleave_x_blocks == 0) {
365 interleave_x_blocks = quant_width / block_width;
366 }
367
368 const auto has_zero_points = zero_points != nullptr;
369 const auto has_biases = biases != nullptr;
370
371 const auto num_quant_packets_y = round_up_division(height, quant_height);
372 const auto num_quant_packets_x = round_up_division(width, quant_width);
373
374 const auto quant_packet_data_bytes = quant_height * quant_width * size_in_bits<Data> / 8;
375 const auto quant_packet_zero_points_bytes = has_zero_points ? quant_height * sizeof(ZeroPoint) : 0;
376 const auto quant_packet_scales_bytes = quant_height * sizeof(Scale);
377 const auto quant_packet_bytes =
378 quant_packet_zero_points_bytes + quant_packet_data_bytes + quant_packet_scales_bytes;
379
380 const auto num_quant_packets_per_row = round_up_division(width, quant_width);
381 const auto biases_bytes = has_biases ? height * sizeof(Bias) : 0;
382
383 const auto dst_bytes = num_quant_packets_y * num_quant_packets_x * quant_packet_bytes + biases_bytes;
384 Buffer dst(dst_bytes);
385
386 const auto* zero_points_ptr = reinterpret_cast<const ZeroPoint*>(zero_points);
387 const auto* scales_ptr = reinterpret_cast<const Scale*>(scales);
388 const auto* biases_ptr = reinterpret_cast<const Bias*>(biases);
389 auto* dst_ptr = dst.data();
390
391 for (size_t y_quant = 0; y_quant < height; y_quant += quant_height) {
392 for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) {
393 size_t dst_index = 0;
394
395 // Packs the data.
396 for (size_t y_pack = 0; y_pack < quant_height; y_pack += block_height) {
397 for (size_t x_pack = 0; x_pack < block_width * interleave_x_blocks; x_pack += block_width) {
398 for (size_t y_element = 0; y_element < block_height; ++y_element) {
399 for (size_t x_element = 0; x_element < block_width; ++x_element) {
400 for (size_t x_interleave = 0; x_interleave < quant_width;
401 x_interleave += block_width * interleave_x_blocks) {
402 const auto y = y_quant + y_pack + y_element;
403 const auto x = x_quant + x_pack + x_element + x_interleave;
404
405 if (y < height && x < width) {
406 write_array(dst_ptr, dst_index, read_array<Data>(data, y * width + x));
407 }
408
409 ++dst_index;
410 }
411 }
412 }
413 }
414 }
415
416 dst_ptr += dst_index * size_in_bits<Data> / 8;
417
418 // Packs the zero points.
419 if (has_zero_points) {
420 for (size_t y_element = 0; y_element < quant_height; ++y_element) {
421 const auto y = y_quant + y_element;
422 const auto x = x_quant / quant_width;
423 memcpy(dst_ptr, &zero_points_ptr[y * num_quant_packets_per_row + x], sizeof(ZeroPoint));
424 dst_ptr += sizeof(ZeroPoint);
425 }
426 }
427
428 // Packs the scales.
429 for (size_t y_element = 0; y_element < quant_height; ++y_element) {
430 const auto y = y_quant + y_element;
431 const auto x = x_quant / quant_width;
432 memcpy(dst_ptr, &scales_ptr[y * num_quant_packets_per_row + x], sizeof(Scale));
433 dst_ptr += sizeof(Scale);
434 }
435 }
436
437 // Packs the biases.
438 if (has_biases) {
439 for (size_t y_element = 0; y_element < quant_height; ++y_element) {
440 const auto y = y_quant + y_element;
441 memcpy(dst_ptr, &biases_ptr[y], sizeof(Bias));
442 dst_ptr += sizeof(Bias);
443 }
444 }
445 }
446
447 KAI_ASSERT(dst_ptr == dst.data() + dst.size());
448
449 return dst;
450 }
451
452 } // namespace kai::test
453