Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/pack.hpp" | ||
8 | |||
9 | #include <arm_neon.h> | ||
10 | |||
11 | #include <algorithm> | ||
12 | #include <cstddef> | ||
13 | #include <cstdint> | ||
14 | #include <cstring> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | #include "test/common/bfloat16.hpp" | ||
18 | #include "test/common/buffer.hpp" | ||
19 | #include "test/common/data_format.hpp" | ||
20 | #include "test/common/data_type.hpp" | ||
21 | #include "test/common/float16.hpp" | ||
22 | #include "test/common/memory.hpp" | ||
23 | #include "test/common/round.hpp" | ||
24 | |||
25 | namespace kai::test { | ||
26 | |||
27 | namespace { | ||
28 | |||
29 | 918733 | BFloat16<> convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) { | |
30 | − | KAI_ASSUME((src_dtype == DataType::FP32 || src_dtype == DataType::FP16) && dst_dtype == DataType::BF16); | |
31 | |||
32 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 536371 times.
✓ Branch 2 taken 382362 times.
|
918733 | switch (src_dtype) { |
33 | case DataType::FP32: | ||
34 | 536371 | return BFloat16<>(*reinterpret_cast<const float*>(src_ptr_elm)); | |
35 | case DataType::FP16: | ||
36 | 382362 | return BFloat16<>(static_cast<float>(*reinterpret_cast<const Float16*>(src_ptr_elm))); | |
37 | default: | ||
38 | − | KAI_ERROR("Unsupported Data Type"); | |
39 | ✗ | } | |
40 | 918733 | } | |
41 | |||
42 | 116 | Buffer pack_block( | |
43 | const void* src, DataType src_dtype, DataType dst_dtype, size_t src_esize, size_t dst_esize, size_t full_height, | ||
44 | size_t full_width, size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) { | ||
45 | 232 | const auto dst_bytes = | |
46 | 116 | round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * dst_esize; | |
47 | |||
48 | 116 | Buffer dst(dst_bytes, 0); | |
49 | |||
50 | 116 | const auto* src_ptr = reinterpret_cast<const uint8_t*>(src); | |
51 | 116 | auto* dst_ptr = dst.data(); | |
52 | |||
53 |
2/2✓ Branch 0 taken 1050 times.
✓ Branch 1 taken 116 times.
|
1166 | for (size_t y_block = 0; y_block < full_height; y_block += block_height) { |
54 |
2/2✓ Branch 0 taken 31220 times.
✓ Branch 1 taken 1050 times.
|
32270 | for (size_t x_block = 0; x_block < full_width; x_block += block_width) { |
55 |
2/2✓ Branch 0 taken 31220 times.
✓ Branch 1 taken 31220 times.
|
62440 | for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { |
56 |
2/2✓ Branch 0 taken 31220 times.
✓ Branch 1 taken 31220 times.
|
62440 | for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { |
57 |
2/2✓ Branch 0 taken 31220 times.
✓ Branch 1 taken 338932 times.
|
370152 | for (size_t y_element = 0; y_element < subblock_height; ++y_element) { |
58 |
2/2✓ Branch 0 taken 26560 times.
✓ Branch 1 taken 312372 times.
|
338932 | if (src_dtype == dst_dtype) { |
59 | 26560 | const size_t esize = dst_esize; | |
60 | |||
61 |
2/2✓ Branch 0 taken 8586 times.
✓ Branch 1 taken 17974 times.
|
26560 | if (y_block + y_subblock + y_element < full_height) { |
62 | 17974 | const size_t y_offset = (y_block + y_subblock + y_element) * full_width; | |
63 | 17974 | const size_t x_offset = x_block + x_subblock; | |
64 | 17974 | const size_t offset = y_offset + x_offset; | |
65 |
1/2✓ Branch 0 taken 17974 times.
✗ Branch 1 not taken.
|
17974 | const auto len = std::min(subblock_width, full_width - x_offset); |
66 | |||
67 | 17974 | memcpy(dst_ptr, src_ptr + offset * esize, len * esize); | |
68 | 17974 | } | |
69 | |||
70 | 26560 | dst_ptr += subblock_width * esize; | |
71 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 312372 times.
|
338932 | } else if (dst_esize == 2 /* 16 bits */) { |
72 |
2/2✓ Branch 0 taken 1027408 times.
✓ Branch 1 taken 312372 times.
|
1339780 | for (size_t x_element = 0; x_element < subblock_width; ++x_element) { |
73 |
2/2✓ Branch 0 taken 89248 times.
✓ Branch 1 taken 938160 times.
|
1027408 | if (y_block + y_subblock + y_element < full_height) { |
74 |
2/2✓ Branch 0 taken 19427 times.
✓ Branch 1 taken 918733 times.
|
938160 | if (x_block + x_subblock + x_element < full_width) { |
75 | 1837466 | const uint8_t* src_ptr_elm = src_ptr + | |
76 | 1837466 | ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock + | |
77 | 1837466 | x_element) * | |
78 | 918733 | src_esize; | |
79 | |||
80 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 918733 times.
|
918733 | const BFloat16 src_value = convert(src_ptr_elm, src_dtype, dst_dtype); |
81 | 918733 | memcpy(dst_ptr, &src_value, dst_esize); | |
82 | 918733 | } | |
83 | 938160 | } | |
84 | |||
85 | 1027408 | dst_ptr += dst_esize; | |
86 | 1027408 | } | |
87 | 312372 | } | |
88 | 338932 | } | |
89 | 31220 | } | |
90 | 31220 | } | |
91 | 31220 | } | |
92 | 1050 | } | |
93 | |||
94 | − | KAI_ASSERT(reinterpret_cast<uintptr_t>(dst_ptr) - reinterpret_cast<uintptr_t>(dst.data()) == dst_bytes); | |
95 | |||
96 | 116 | return dst; | |
97 | 116 | } | |
98 | |||
99 | /// Packs the matrix from raw to per-row bias format. | ||
100 | 125 | Buffer pack_bias_per_row( | |
101 | DataType src_dtype, DataType bias_dtype, DataType dst_dtype, size_t src_esize, size_t bias_esize, size_t dst_esize, | ||
102 | const void* src, const void* bias, size_t height, size_t width, size_t block_height, size_t block_width, | ||
103 | size_t subblock_height, size_t subblock_width) { | ||
104 | − | KAI_ASSUME(src_dtype == bias_dtype); | |
105 | |||
106 | 125 | const auto num_groups = (height + block_height - 1) / block_height; | |
107 | 125 | const auto group_num_blocks = (width + block_width - 1) / block_width; | |
108 | 125 | const auto group_bias_bytes = block_height * bias_esize; | |
109 | 125 | const auto block_data_bytes = block_height * block_width * dst_esize; | |
110 | 125 | const auto group_bytes = group_bias_bytes + group_num_blocks * block_data_bytes; | |
111 | 125 | const auto dst_bytes = num_groups * group_bytes; | |
112 | |||
113 | 125 | Buffer dst(dst_bytes, 0); | |
114 | |||
115 | 125 | const auto* src_ptr = reinterpret_cast<const uint8_t*>(src); | |
116 | 125 | const auto* bias_ptr = reinterpret_cast<const uint8_t*>(bias); | |
117 | 125 | auto* dst_ptr = dst.data(); | |
118 | |||
119 |
2/2✓ Branch 0 taken 548 times.
✓ Branch 1 taken 125 times.
|
673 | for (size_t y_block = 0; y_block < height; y_block += block_height) { |
120 | // Packs the bias. | ||
121 |
1/2✓ Branch 0 taken 548 times.
✗ Branch 1 not taken.
|
548 | const auto bias_len = std::min(block_height, height - y_block); |
122 | 548 | memcpy(dst_ptr, bias_ptr, bias_len * bias_esize); | |
123 | 548 | bias_ptr += block_height * bias_esize; | |
124 | 548 | dst_ptr += block_height * bias_esize; | |
125 | |||
126 |
2/2✓ Branch 0 taken 16528 times.
✓ Branch 1 taken 548 times.
|
17076 | for (size_t x_block = 0; x_block < width; x_block += block_width) { |
127 |
2/2✓ Branch 0 taken 16528 times.
✓ Branch 1 taken 16528 times.
|
33056 | for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { |
128 |
2/2✓ Branch 0 taken 16528 times.
✓ Branch 1 taken 20112 times.
|
36640 | for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { |
129 |
2/2✓ Branch 0 taken 20112 times.
✓ Branch 1 taken 806504 times.
|
826616 | for (size_t y_element = 0; y_element < subblock_height; ++y_element) { |
130 |
1/2✓ Branch 0 taken 806504 times.
✗ Branch 1 not taken.
|
806504 | if (src_dtype == dst_dtype) { |
131 | 806504 | const size_t esize = dst_esize; | |
132 |
2/2✓ Branch 0 taken 110098 times.
✓ Branch 1 taken 696406 times.
|
806504 | if (y_block + y_subblock + y_element < height) { |
133 |
1/2✓ Branch 0 taken 696406 times.
✗ Branch 1 not taken.
|
696406 | const auto len = std::min(subblock_width, width - x_block - x_subblock); |
134 | |||
135 | 696406 | memcpy( | |
136 | 696406 | dst_ptr, | |
137 | 1392812 | src_ptr + | |
138 | 696406 | ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * esize, | |
139 | 696406 | len * esize); | |
140 | 696406 | } | |
141 | |||
142 | 806504 | dst_ptr += subblock_width * esize; | |
143 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
806504 | } else if (dst_esize == 2 /* 16 bits */) { |
144 | ✗ | for (size_t x_element = 0; x_element < subblock_width; ++x_element) { | |
145 | ✗ | if (y_block + y_subblock + y_element < height) { | |
146 | ✗ | if (x_block + x_subblock + x_element < width) { | |
147 | ✗ | const uint8_t* src_ptr_elm = src_ptr + | |
148 | ✗ | ((y_block + y_subblock + y_element) * width + x_block + x_subblock + | |
149 | ✗ | x_element) * | |
150 | ✗ | src_esize; | |
151 | |||
152 | ✗ | const BFloat16 dst_value = convert(src_ptr_elm, src_dtype, dst_dtype); | |
153 | ✗ | memcpy(dst_ptr, &dst_value, dst_esize); | |
154 | ✗ | } | |
155 | ✗ | } | |
156 | |||
157 | ✗ | dst_ptr += dst_esize; | |
158 | ✗ | } | |
159 | ✗ | } | |
160 | 806504 | } | |
161 | 20112 | } | |
162 | 16528 | } | |
163 | 16528 | } | |
164 | 548 | } | |
165 | |||
166 | − | KAI_ASSERT(reinterpret_cast<uintptr_t>(dst_ptr) - reinterpret_cast<uintptr_t>(dst.data()) == dst_bytes); | |
167 | |||
168 | 125 | return dst; | |
169 | 125 | } | |
170 | |||
171 | } // namespace | ||
172 | |||
173 | 241 | Buffer pack( | |
174 | const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* bias, | ||
175 | const DataFormat& src_format, size_t height, size_t width) { | ||
176 | 241 | const auto dst_dt = dst_format.data_type(); | |
177 | 241 | const auto dst_qf = dst_format.pack_format(); | |
178 | 241 | const auto src_dt = src_format.data_type(); | |
179 | 241 | const auto src_qf = src_format.pack_format(); | |
180 | |||
181 | 241 | const auto block_height = dst_format.actual_block_height(height); | |
182 | 241 | const auto block_width = dst_format.actual_block_width(width); | |
183 | 241 | const auto subblock_height = dst_format.actual_subblock_height(height); | |
184 | 241 | const auto subblock_width = dst_format.actual_subblock_width(width); | |
185 | |||
186 |
3/4✓ Branch 0 taken 241 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 116 times.
✓ Branch 3 taken 125 times.
|
241 | if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) { |
187 | − | KAI_ASSUME( | |
188 | (src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16) || | ||
189 | (src_dt == DataType::FP16 && dst_dt == DataType::BF16)); | ||
190 | |||
191 | 125 | const auto src_esize = data_type_size_in_bits(src_dt); | |
192 | 125 | const auto dst_esize = data_type_size_in_bits(dst_dt); | |
193 | 125 | const auto bias_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); | |
194 | 125 | const auto bias_dt = dst_format.zero_point_data_type(); | |
195 | |||
196 | − | KAI_ASSUME(dst_esize % 8 == 0 && bias_esize % 8 == 0 && src_esize % 8 == 0); | |
197 | |||
198 | 125 | return pack_bias_per_row( | |
199 | 125 | src_dt, bias_dt, dst_dt, src_esize / 8, bias_esize / 8, dst_esize / 8, src, bias, height, width, | |
200 | 125 | block_height, block_width, subblock_height, subblock_width); | |
201 | 125 | } | |
202 | |||
203 |
1/2✓ Branch 0 taken 116 times.
✗ Branch 1 not taken.
|
116 | if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) { |
204 | − | KAI_ASSUME( | |
205 | (src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16) || | ||
206 | (src_dt == DataType::FP16 && dst_dt == DataType::BF16)); | ||
207 | |||
208 | 116 | const auto dst_esize = data_type_size_in_bits(dst_dt); | |
209 | 116 | const auto src_esize = data_type_size_in_bits(src_dt); | |
210 | |||
211 | − | KAI_ASSUME(src_esize % 8 == 0 && dst_esize % 8 == 0); | |
212 | |||
213 | 116 | return pack_block( | |
214 | 116 | src, src_dt, dst_dt, src_esize / 8, dst_esize / 8, height, width, block_height, block_width, | |
215 | 116 | subblock_height, subblock_width); | |
216 | 116 | } | |
217 | |||
218 | − | KAI_ERROR("Unsupported operation!"); | |
219 | 241 | } | |
220 | |||
221 | template <typename Data, typename Scale> | ||
222 | Buffer pack_data_scales(const void* data, const void* scales, size_t height, size_t width, size_t quant_width) { | ||
223 | KAI_ASSUME_IF(size_in_bits<Data> < 8, quant_width % (8 / size_in_bits<Data>) == 0); | ||
224 | KAI_ASSUME_IF(size_in_bits<Data> < 8, width % (8 / size_in_bits<Data>) == 0); | ||
225 | |||
226 | const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width; | ||
227 | |||
228 | const auto data_bytes = height * width * size_in_bits<Data> / 8; | ||
229 | const auto scales_bytes = height * num_quant_packets_x * sizeof(Scale); | ||
230 | |||
231 | Buffer dst(data_bytes + scales_bytes); | ||
232 | |||
233 | const auto* scales_ptr = reinterpret_cast<const Scale*>(scales); | ||
234 | auto* dst_ptr = dst.data(); | ||
235 | |||
236 | for (size_t y = 0; y < height; ++y) { | ||
237 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { | ||
238 | write_array(dst_ptr, 0, *scales_ptr); | ||
239 | dst_ptr += sizeof(Scale); | ||
240 | ++scales_ptr; | ||
241 | |||
242 | const auto len = std::min(x_quant + quant_width, width) - x_quant; | ||
243 | |||
244 | for (size_t x_element = 0; x_element < len; ++x_element) { | ||
245 | const auto x = x_quant + x_element; | ||
246 | write_array(dst_ptr, x_element, read_array<Data>(data, y * width + x)); | ||
247 | } | ||
248 | |||
249 | dst_ptr += len * size_in_bits<Data> / 8; | ||
250 | } | ||
251 | } | ||
252 | |||
253 | KAI_ASSERT(dst_ptr == dst.data() + dst.size()); | ||
254 | |||
255 | return dst; | ||
256 | } | ||
257 | |||
258 | template <typename ZeroPoint, typename Data, typename Scale> | ||
259 | 927 | Buffer pack_zero_points_data_scales_per_block( | |
260 | const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, | ||
261 | size_t block_num_data, size_t block_num_scales) { | ||
262 | // Only data is allowed to be sub-byte. | ||
263 | − | KAI_ASSUME(size_in_bits<ZeroPoint> % 8 == 0); | |
264 | − | KAI_ASSUME(size_in_bits<Scale> % 8 == 0); | |
265 | |||
266 | // Checks for memory alignment. | ||
267 | − | KAI_ASSUME(size_in_bits<ZeroPoint> % size_in_bits<Data> == 0); | |
268 | − | KAI_ASSUME( | |
269 | (block_num_zero_points * size_in_bits<ZeroPoint> + block_num_data * size_in_bits<Data>) % size_in_bits<Scale> == | ||
270 | 0); | ||
271 | − | KAI_ASSUME( | |
272 | (block_num_data * size_in_bits<Data> + block_num_scales * size_in_bits<Scale>) % size_in_bits<ZeroPoint> == 0); | ||
273 | |||
274 | 1854 | Buffer dst(round_up_division( | |
275 | 1854 | num_blocks * | |
276 | 1854 | (block_num_zero_points * size_in_bits<ZeroPoint> + block_num_data * size_in_bits<Data> + | |
277 | 927 | block_num_scales * size_in_bits<Scale>), | |
278 | 8)); | ||
279 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | auto* dst_ptr = dst.data(); |
280 | |||
281 |
2/2✓ Branch 0 taken 927 times.
✓ Branch 1 taken 3651 times.
|
4578 | for (size_t block_no = 0; block_no < num_blocks; ++block_no) { |
282 |
2/2✓ Branch 0 taken 116832 times.
✓ Branch 1 taken 3651 times.
|
120483 | for (size_t i = 0; i < block_num_zero_points; ++i) { |
283 |
1/2✓ Branch 0 taken 116832 times.
✗ Branch 1 not taken.
|
116832 | write_array<ZeroPoint>( |
284 |
1/2✓ Branch 0 taken 116832 times.
✗ Branch 1 not taken.
|
116832 | dst_ptr, i, read_array<ZeroPoint>(zero_points, block_no * block_num_zero_points + i)); |
285 | 116832 | } | |
286 | 3651 | dst_ptr += block_num_zero_points * sizeof(ZeroPoint); | |
287 | |||
288 |
2/2✓ Branch 0 taken 19128576 times.
✓ Branch 1 taken 3651 times.
|
19132227 | for (size_t i = 0; i < block_num_data; ++i) { |
289 |
2/4✓ Branch 0 taken 19128576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 19128576 times.
✗ Branch 3 not taken.
|
19128576 | write_array<Data>(dst_ptr, i, read_array<Data>(data, block_no * block_num_data + i)); |
290 | 19128576 | } | |
291 |
1/2✓ Branch 0 taken 3651 times.
✗ Branch 1 not taken.
|
3651 | dst_ptr += round_up_division(block_num_data * size_in_bits<Data>, 8); |
292 | |||
293 |
2/2✓ Branch 0 taken 3651 times.
✓ Branch 1 taken 116832 times.
|
120483 | for (size_t i = 0; i < block_num_scales; ++i) { |
294 |
2/4✓ Branch 0 taken 116832 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 116832 times.
✗ Branch 3 not taken.
|
116832 | write_array<Scale>(dst_ptr, i, read_array<Scale>(scales, block_no * block_num_scales + i)); |
295 | 116832 | } | |
296 | 3651 | dst_ptr += block_num_scales * sizeof(Scale); | |
297 | 3651 | } | |
298 | |||
299 | − | KAI_ASSERT(dst_ptr == dst.data() + dst.size()); | |
300 | |||
301 | 927 | return dst; | |
302 | 927 | } | |
303 | |||
304 | template Buffer pack_zero_points_data_scales_per_block<int32_t, int8_t, float>( | ||
305 | const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, | ||
306 | size_t block_num_data, size_t block_num_scales); | ||
307 | |||
308 | template <typename Data, typename Scale> | ||
309 | 124 | Buffer pack_data_scales_interleave_block( | |
310 | const void* data, const void* scales, size_t height, size_t width, size_t quant_width) { | ||
311 | − | KAI_ASSUME_IF(size_in_bits<Data> < 8, quant_width % (8 / size_in_bits<Data>) == 0); | |
312 | − | KAI_ASSUME_IF(size_in_bits<Data> < 8, width % (8 / size_in_bits<Data>) == 0); | |
313 | − | KAI_ASSUME(width % quant_width == 0); | |
314 | − | KAI_ASSUME(quant_width % 2 == 0); | |
315 | |||
316 | 124 | const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width; | |
317 | |||
318 | 124 | const auto data_bytes = height * width * size_in_bits<Data> / 8; | |
319 |
1/4✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
124 | const auto scales_bytes = scales != nullptr ? height * num_quant_packets_x * sizeof(Scale) : 0; |
320 | |||
321 | 124 | Buffer dst(data_bytes + scales_bytes); | |
322 | |||
323 | 124 | const auto* scales_ptr = reinterpret_cast<const Scale*>(scales); | |
324 |
1/4✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
124 | auto* dst_ptr = dst.data(); |
325 | |||
326 |
2/4✓ Branch 0 taken 124 times.
✓ Branch 1 taken 5228 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
5352 | for (size_t y = 0; y < height; ++y) { |
327 |
2/4✓ Branch 0 taken 9768 times.
✓ Branch 1 taken 5228 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
14996 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
328 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 9768 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
9768 | if (scales_ptr != nullptr) { |
329 |
1/4✓ Branch 0 taken 9768 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
9768 | write_array(dst_ptr, 0, *scales_ptr); |
330 | 9768 | dst_ptr += sizeof(Scale); | |
331 | 9768 | ++scales_ptr; | |
332 | 9768 | } | |
333 | |||
334 |
2/4✓ Branch 0 taken 9768 times.
✓ Branch 1 taken 312576 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
322344 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
335 |
2/4✓ Branch 0 taken 156288 times.
✓ Branch 1 taken 156288 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
312576 | const auto x = x_quant + x_element / 2 + (x_element % 2 != 0 ? quant_width / 2 : 0); |
336 |
2/8✓ Branch 0 taken 312576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 312576 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
312576 | write_array(dst_ptr, x_element, read_array<Data>(data, y * width + x)); |
337 | 312576 | } | |
338 | |||
339 | 9768 | dst_ptr += quant_width * size_in_bits<Data> / 8; | |
340 | 9768 | } | |
341 | 5228 | } | |
342 | |||
343 | − | KAI_ASSERT(dst_ptr == dst.data() + dst.size()); | |
344 | |||
345 | 124 | return dst; | |
346 | 124 | } | |
347 | |||
348 | template Buffer pack_data_scales_interleave_block<UInt4, Float16>( | ||
349 | const void* data, const void* scales, size_t height, size_t width, size_t quant_width); | ||
350 | template Buffer pack_data_scales_interleave_block<UInt4, std::nullptr_t>( | ||
351 | const void* data, const void* scales, size_t height, size_t width, size_t quant_width); | ||
352 | |||
353 | template <typename Data, typename ZeroPoint, typename Scale, typename Bias> | ||
354 | Buffer pack_block_data_zero_points_scale_bias( | ||
355 | const void* data, const void* zero_points, const void* scales, const void* biases, size_t height, size_t width, | ||
356 | size_t quant_height, size_t quant_width, size_t block_height, size_t block_width, size_t interleave_x_blocks) { | ||
357 | if (quant_width == width) { | ||
358 | quant_width = round_up_multiple(quant_width, block_width); | ||
359 | } | ||
360 | |||
361 | KAI_ASSERT(quant_height == block_height); | ||
362 | KAI_ASSERT(quant_width % block_width == 0); | ||
363 | |||
364 | if (interleave_x_blocks == 0) { | ||
365 | interleave_x_blocks = quant_width / block_width; | ||
366 | } | ||
367 | |||
368 | const auto has_zero_points = zero_points != nullptr; | ||
369 | const auto has_biases = biases != nullptr; | ||
370 | |||
371 | const auto num_quant_packets_y = round_up_division(height, quant_height); | ||
372 | const auto num_quant_packets_x = round_up_division(width, quant_width); | ||
373 | |||
374 | const auto quant_packet_data_bytes = quant_height * quant_width * size_in_bits<Data> / 8; | ||
375 | const auto quant_packet_zero_points_bytes = has_zero_points ? quant_height * sizeof(ZeroPoint) : 0; | ||
376 | const auto quant_packet_scales_bytes = quant_height * sizeof(Scale); | ||
377 | const auto quant_packet_bytes = | ||
378 | quant_packet_zero_points_bytes + quant_packet_data_bytes + quant_packet_scales_bytes; | ||
379 | |||
380 | const auto num_quant_packets_per_row = round_up_division(width, quant_width); | ||
381 | const auto biases_bytes = has_biases ? height * sizeof(Bias) : 0; | ||
382 | |||
383 | const auto dst_bytes = num_quant_packets_y * num_quant_packets_x * quant_packet_bytes + biases_bytes; | ||
384 | Buffer dst(dst_bytes); | ||
385 | |||
386 | const auto* zero_points_ptr = reinterpret_cast<const ZeroPoint*>(zero_points); | ||
387 | const auto* scales_ptr = reinterpret_cast<const Scale*>(scales); | ||
388 | const auto* biases_ptr = reinterpret_cast<const Bias*>(biases); | ||
389 | auto* dst_ptr = dst.data(); | ||
390 | |||
391 | for (size_t y_quant = 0; y_quant < height; y_quant += quant_height) { | ||
392 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { | ||
393 | size_t dst_index = 0; | ||
394 | |||
395 | // Packs the data. | ||
396 | for (size_t y_pack = 0; y_pack < quant_height; y_pack += block_height) { | ||
397 | for (size_t x_pack = 0; x_pack < block_width * interleave_x_blocks; x_pack += block_width) { | ||
398 | for (size_t y_element = 0; y_element < block_height; ++y_element) { | ||
399 | for (size_t x_element = 0; x_element < block_width; ++x_element) { | ||
400 | for (size_t x_interleave = 0; x_interleave < quant_width; | ||
401 | x_interleave += block_width * interleave_x_blocks) { | ||
402 | const auto y = y_quant + y_pack + y_element; | ||
403 | const auto x = x_quant + x_pack + x_element + x_interleave; | ||
404 | |||
405 | if (y < height && x < width) { | ||
406 | write_array(dst_ptr, dst_index, read_array<Data>(data, y * width + x)); | ||
407 | } | ||
408 | |||
409 | ++dst_index; | ||
410 | } | ||
411 | } | ||
412 | } | ||
413 | } | ||
414 | } | ||
415 | |||
416 | dst_ptr += dst_index * size_in_bits<Data> / 8; | ||
417 | |||
418 | // Packs the zero points. | ||
419 | if (has_zero_points) { | ||
420 | for (size_t y_element = 0; y_element < quant_height; ++y_element) { | ||
421 | const auto y = y_quant + y_element; | ||
422 | const auto x = x_quant / quant_width; | ||
423 | memcpy(dst_ptr, &zero_points_ptr[y * num_quant_packets_per_row + x], sizeof(ZeroPoint)); | ||
424 | dst_ptr += sizeof(ZeroPoint); | ||
425 | } | ||
426 | } | ||
427 | |||
428 | // Packs the scales. | ||
429 | for (size_t y_element = 0; y_element < quant_height; ++y_element) { | ||
430 | const auto y = y_quant + y_element; | ||
431 | const auto x = x_quant / quant_width; | ||
432 | memcpy(dst_ptr, &scales_ptr[y * num_quant_packets_per_row + x], sizeof(Scale)); | ||
433 | dst_ptr += sizeof(Scale); | ||
434 | } | ||
435 | } | ||
436 | |||
437 | // Packs the biases. | ||
438 | if (has_biases) { | ||
439 | for (size_t y_element = 0; y_element < quant_height; ++y_element) { | ||
440 | const auto y = y_quant + y_element; | ||
441 | memcpy(dst_ptr, &biases_ptr[y], sizeof(Bias)); | ||
442 | dst_ptr += sizeof(Bias); | ||
443 | } | ||
444 | } | ||
445 | } | ||
446 | |||
447 | KAI_ASSERT(dst_ptr == dst.data() + dst.size()); | ||
448 | |||
449 | return dst; | ||
450 | } | ||
451 | |||
452 | } // namespace kai::test | ||
453 |