test/reference/pack.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/reference/pack.hpp" | ||
| 8 | |||
| 9 | #include <arm_neon.h> | ||
| 10 | |||
| 11 | #include <algorithm> | ||
| 12 | #include <cstddef> | ||
| 13 | #include <cstdint> | ||
| 14 | #include <cstring> | ||
| 15 | |||
| 16 | #include "kai/kai_common.h" | ||
| 17 | #include "test/common/bfloat16.hpp" | ||
| 18 | #include "test/common/buffer.hpp" | ||
| 19 | #include "test/common/data_format.hpp" | ||
| 20 | #include "test/common/data_type.hpp" | ||
| 21 | #include "test/common/float16.hpp" | ||
| 22 | #include "test/common/memory.hpp" | ||
| 23 | #include "test/common/round.hpp" | ||
| 24 | |||
| 25 | namespace kai::test { | ||
| 26 | |||
| 27 | namespace { | ||
| 28 | |||
| 29 | 4984005 | BFloat16<> convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) { | |
| 30 | − | KAI_ASSUME_ALWAYS((src_dtype == DataType::FP32 || src_dtype == DataType::FP16) && dst_dtype == DataType::BF16); | |
| 31 | |||
| 32 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 2689833 times.
✓ Branch 2 taken 2294172 times.
|
4984005 | switch (src_dtype) { |
| 33 | case DataType::FP32: | ||
| 34 | 2689833 | return BFloat16<>(*reinterpret_cast<const float*>(src_ptr_elm)); | |
| 35 | case DataType::FP16: | ||
| 36 | 2294172 | return BFloat16<>(static_cast<float>(*reinterpret_cast<const Float16*>(src_ptr_elm))); | |
| 37 | default: | ||
| 38 | − | KAI_ERROR("Unsupported Data Type"); | |
| 39 | ✗ | } | |
| 40 | 4984005 | } | |
| 41 | |||
| 42 | 594 | Buffer pack_block( | |
| 43 | const void* src, DataType src_dtype, DataType dst_dtype, size_t src_esize, size_t dst_esize, size_t full_height, | ||
| 44 | size_t full_width, size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) { | ||
| 45 | 1188 | const auto dst_bytes = | |
| 46 | 594 | round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * dst_esize; | |
| 47 | |||
| 48 | 594 | Buffer dst(dst_bytes, 0); | |
| 49 | |||
| 50 | 594 | const auto* src_ptr = reinterpret_cast<const uint8_t*>(src); | |
| 51 | 594 | auto* dst_ptr = dst.data(); | |
| 52 | |||
| 53 |
2/2✓ Branch 0 taken 6030 times.
✓ Branch 1 taken 594 times.
|
6624 | for (size_t y_block = 0; y_block < full_height; y_block += block_height) { |
| 54 |
2/2✓ Branch 0 taken 96540 times.
✓ Branch 1 taken 83910 times.
|
180450 | for (size_t x_block = 0; x_block < full_width; x_block += block_width) { |
| 55 |
2/2✓ Branch 0 taken 174420 times.
✓ Branch 1 taken 174420 times.
|
348840 | for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { |
| 56 |
2/2✓ Branch 0 taken 174420 times.
✓ Branch 1 taken 174420 times.
|
348840 | for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { |
| 57 |
2/2✓ Branch 0 taken 697656 times.
✓ Branch 1 taken 1097556 times.
|
1795212 | for (size_t y_element = 0; y_element < subblock_height; ++y_element) { |
| 58 |
2/2✓ Branch 0 taken 79680 times.
✓ Branch 1 taken 1541112 times.
|
1620792 | if (src_dtype == dst_dtype) { |
| 59 | 79680 | const size_t esize = dst_esize; | |
| 60 | |||
| 61 |
2/2✓ Branch 0 taken 25758 times.
✓ Branch 1 taken 53922 times.
|
79680 | if (y_block + y_subblock + y_element < full_height) { |
| 62 | 53922 | const size_t y_offset = (y_block + y_subblock + y_element) * full_width; | |
| 63 | 53922 | const size_t x_offset = x_block + x_subblock; | |
| 64 | 53922 | const size_t offset = y_offset + x_offset; | |
| 65 |
1/2✓ Branch 0 taken 53922 times.
✗ Branch 1 not taken.
|
53922 | const auto len = std::min(subblock_width, full_width - x_offset); |
| 66 | |||
| 67 | 53922 | memcpy(dst_ptr, src_ptr + offset * esize, len * esize); | |
| 68 | 53922 | } | |
| 69 | |||
| 70 | 79680 | dst_ptr += subblock_width * esize; | |
| 71 |
2/4✗ Branch 0 not taken.
✗ Branch 0 not taken.
✓ Branch 1 taken 603996 times.
✓ Branch 1 taken 937116 times.
|
1620792 | } else if (dst_esize == 2 /* 16 bits */) { |
| 72 |
2/2✓ Branch 0 taken 5498208 times.
✓ Branch 1 taken 1541112 times.
|
7039320 | for (size_t x_element = 0; x_element < subblock_width; ++x_element) { |
| 73 |
2/2✓ Branch 0 taken 401784 times.
✓ Branch 1 taken 5096424 times.
|
5498208 | if (y_block + y_subblock + y_element < full_height) { |
| 74 |
2/2✓ Branch 0 taken 112419 times.
✓ Branch 1 taken 4984005 times.
|
5096424 | if (x_block + x_subblock + x_element < full_width) { |
| 75 | 9968010 | const uint8_t* src_ptr_elm = src_ptr + | |
| 76 | 9968010 | ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock + | |
| 77 | 9968010 | x_element) * | |
| 78 | 4984005 | src_esize; | |
| 79 | |||
| 80 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4984005 times.
|
4984005 | const BFloat16 src_value = convert(src_ptr_elm, src_dtype, dst_dtype); |
| 81 | 4984005 | memcpy(dst_ptr, &src_value, dst_esize); | |
| 82 | 4984005 | } | |
| 83 | 5096424 | } | |
| 84 | |||
| 85 | 5498208 | dst_ptr += dst_esize; | |
| 86 | 5498208 | } | |
| 87 | 1541112 | } | |
| 88 | 1620792 | } | |
| 89 | 174420 | } | |
| 90 | 174420 | } | |
| 91 | 174420 | } | |
| 92 | 6030 | } | |
| 93 | |||
| 94 | − | KAI_ASSERT_ALWAYS(reinterpret_cast<uintptr_t>(dst_ptr) - reinterpret_cast<uintptr_t>(dst.data()) == dst_bytes); | |
| 95 | |||
| 96 | 594 | return dst; | |
| 97 | 594 | } | |
| 98 | |||
| 99 | /// Packs the matrix from raw to per-row bias format. | ||
| 100 | 663 | Buffer pack_bias_per_row( | |
| 101 | DataType src_dtype, DataType bias_dtype, DataType dst_dtype, size_t src_esize, size_t bias_esize, size_t dst_esize, | ||
| 102 | const void* src, const void* bias, size_t height, size_t width, size_t block_height, size_t block_width, | ||
| 103 | size_t subblock_height, size_t subblock_width) { | ||
| 104 | − | KAI_ASSUME_ALWAYS(src_dtype == bias_dtype); | |
| 105 | |||
| 106 | 663 | const auto num_groups = (height + block_height - 1) / block_height; | |
| 107 | 663 | const auto group_num_blocks = (width + block_width - 1) / block_width; | |
| 108 | 663 | const auto group_bias_bytes = block_height * bias_esize; | |
| 109 | 663 | const auto block_data_bytes = block_height * block_width * dst_esize; | |
| 110 | 663 | const auto group_bytes = group_bias_bytes + group_num_blocks * block_data_bytes; | |
| 111 | 663 | const auto dst_bytes = num_groups * group_bytes; | |
| 112 | |||
| 113 | 663 | Buffer dst(dst_bytes, 0); | |
| 114 | |||
| 115 | 663 | const auto* src_ptr = reinterpret_cast<const uint8_t*>(src); | |
| 116 | 663 | const auto* bias_ptr = reinterpret_cast<const uint8_t*>(bias); | |
| 117 | 663 | auto* dst_ptr = dst.data(); | |
| 118 | |||
| 119 |
2/2✓ Branch 0 taken 2247 times.
✓ Branch 1 taken 663 times.
|
2910 | for (size_t y_block = 0; y_block < height; y_block += block_height) { |
| 120 | // Packs the bias. | ||
| 121 |
1/2✓ Branch 0 taken 2247 times.
✗ Branch 1 not taken.
|
2247 | const auto bias_len = std::min(block_height, height - y_block); |
| 122 | 2247 | memcpy(dst_ptr, bias_ptr, bias_len * bias_esize); | |
| 123 | 2247 | bias_ptr += block_height * bias_esize; | |
| 124 | 2247 | dst_ptr += block_height * bias_esize; | |
| 125 | |||
| 126 |
2/2✓ Branch 0 taken 50187 times.
✓ Branch 1 taken 2247 times.
|
52434 | for (size_t x_block = 0; x_block < width; x_block += block_width) { |
| 127 |
2/2✓ Branch 0 taken 50187 times.
✓ Branch 1 taken 50187 times.
|
100374 | for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { |
| 128 |
2/2✓ Branch 0 taken 50187 times.
✓ Branch 1 taken 81696 times.
|
131883 | for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { |
| 129 |
2/2✓ Branch 0 taken 81696 times.
✓ Branch 1 taken 2838432 times.
|
2920128 | for (size_t y_element = 0; y_element < subblock_height; ++y_element) { |
| 130 |
1/2✓ Branch 0 taken 2838432 times.
✗ Branch 1 not taken.
|
2838432 | if (src_dtype == dst_dtype) { |
| 131 | 2838432 | const size_t esize = dst_esize; | |
| 132 |
2/2✓ Branch 0 taken 398862 times.
✓ Branch 1 taken 2439570 times.
|
2838432 | if (y_block + y_subblock + y_element < height) { |
| 133 |
1/2✓ Branch 0 taken 2439570 times.
✗ Branch 1 not taken.
|
2439570 | const auto len = std::min(subblock_width, width - x_block - x_subblock); |
| 134 | |||
| 135 | 2439570 | memcpy( | |
| 136 | 2439570 | dst_ptr, | |
| 137 | 4879140 | src_ptr + | |
| 138 | 2439570 | ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * esize, | |
| 139 | 2439570 | len * esize); | |
| 140 | 2439570 | } | |
| 141 | |||
| 142 | 2838432 | dst_ptr += subblock_width * esize; | |
| 143 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
2838432 | } else if (dst_esize == 2 /* 16 bits */) { |
| 144 | ✗ | for (size_t x_element = 0; x_element < subblock_width; ++x_element) { | |
| 145 | ✗ | if (y_block + y_subblock + y_element < height) { | |
| 146 | ✗ | if (x_block + x_subblock + x_element < width) { | |
| 147 | ✗ | const uint8_t* src_ptr_elm = src_ptr + | |
| 148 | ✗ | ((y_block + y_subblock + y_element) * width + x_block + x_subblock + | |
| 149 | ✗ | x_element) * | |
| 150 | ✗ | src_esize; | |
| 151 | |||
| 152 | ✗ | const BFloat16 dst_value = convert(src_ptr_elm, src_dtype, dst_dtype); | |
| 153 | ✗ | memcpy(dst_ptr, &dst_value, dst_esize); | |
| 154 | ✗ | } | |
| 155 | ✗ | } | |
| 156 | |||
| 157 | ✗ | dst_ptr += dst_esize; | |
| 158 | ✗ | } | |
| 159 | ✗ | } | |
| 160 | 2838432 | } | |
| 161 | 81696 | } | |
| 162 | 50187 | } | |
| 163 | 50187 | } | |
| 164 | 2247 | } | |
| 165 | |||
| 166 | − | KAI_ASSERT_ALWAYS(reinterpret_cast<uintptr_t>(dst_ptr) - reinterpret_cast<uintptr_t>(dst.data()) == dst_bytes); | |
| 167 | |||
| 168 | 663 | return dst; | |
| 169 | 663 | } | |
| 170 | |||
| 171 | } // namespace | ||
| 172 | |||
| 173 | 1257 | Buffer pack( | |
| 174 | const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* bias, | ||
| 175 | const DataFormat& src_format, size_t height, size_t width) { | ||
| 176 | 1257 | const auto dst_dt = dst_format.data_type(); | |
| 177 | 1257 | const auto dst_qf = dst_format.pack_format(); | |
| 178 | 1257 | const auto src_dt = src_format.data_type(); | |
| 179 | 1257 | const auto src_qf = src_format.pack_format(); | |
| 180 | |||
| 181 | 1257 | const auto block_height = dst_format.actual_block_height(height); | |
| 182 | 1257 | const auto block_width = dst_format.actual_block_width(width); | |
| 183 | 1257 | const auto subblock_height = dst_format.actual_subblock_height(height); | |
| 184 | 1257 | const auto subblock_width = dst_format.actual_subblock_width(width); | |
| 185 | |||
| 186 |
3/4✓ Branch 0 taken 1257 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✓ Branch 3 taken 663 times.
|
1257 | if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) { |
| 187 | − | KAI_ASSUME_ALWAYS( | |
| 188 | (src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16) || | ||
| 189 | (src_dt == DataType::FP16 && dst_dt == DataType::BF16)); | ||
| 190 | |||
| 191 | 663 | const auto src_esize = data_type_size_in_bits(src_dt); | |
| 192 | 663 | const auto dst_esize = data_type_size_in_bits(dst_dt); | |
| 193 | 663 | const auto bias_esize = data_type_size_in_bits(dst_format.zero_point_data_type()); | |
| 194 | 663 | const auto bias_dt = dst_format.zero_point_data_type(); | |
| 195 | |||
| 196 | − | KAI_ASSUME_ALWAYS(dst_esize % 8 == 0 && bias_esize % 8 == 0 && src_esize % 8 == 0); | |
| 197 | |||
| 198 | 663 | return pack_bias_per_row( | |
| 199 | 663 | src_dt, bias_dt, dst_dt, src_esize / 8, bias_esize / 8, dst_esize / 8, src, bias, height, width, | |
| 200 | 663 | block_height, block_width, subblock_height, subblock_width); | |
| 201 | 663 | } | |
| 202 | |||
| 203 |
1/2✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
|
594 | if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) { |
| 204 | − | KAI_ASSUME_ALWAYS( | |
| 205 | (src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16) || | ||
| 206 | (src_dt == DataType::FP16 && dst_dt == DataType::BF16)); | ||
| 207 | |||
| 208 | 594 | const auto dst_esize = data_type_size_in_bits(dst_dt); | |
| 209 | 594 | const auto src_esize = data_type_size_in_bits(src_dt); | |
| 210 | |||
| 211 | − | KAI_ASSUME_ALWAYS(src_esize % 8 == 0 && dst_esize % 8 == 0); | |
| 212 | |||
| 213 | 594 | return pack_block( | |
| 214 | 594 | src, src_dt, dst_dt, src_esize / 8, dst_esize / 8, height, width, block_height, block_width, | |
| 215 | 594 | subblock_height, subblock_width); | |
| 216 | 594 | } | |
| 217 | |||
| 218 | − | KAI_ERROR("Unsupported operation!"); | |
| 219 | 1257 | } | |
| 220 | |||
| 221 | template <typename Data, typename Scale> | ||
| 222 | Buffer pack_data_scales(const void* data, const void* scales, size_t height, size_t width, size_t quant_width) { | ||
| 223 | KAI_ASSUME_ALWAYS_IF(size_in_bits<Data> < 8, quant_width % (8 / size_in_bits<Data>) == 0); | ||
| 224 | KAI_ASSUME_ALWAYS_IF(size_in_bits<Data> < 8, width % (8 / size_in_bits<Data>) == 0); | ||
| 225 | |||
| 226 | const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width; | ||
| 227 | |||
| 228 | const auto data_bytes = height * width * size_in_bits<Data> / 8; | ||
| 229 | const auto scales_bytes = height * num_quant_packets_x * sizeof(Scale); | ||
| 230 | |||
| 231 | Buffer dst(data_bytes + scales_bytes); | ||
| 232 | |||
| 233 | const auto* scales_ptr = reinterpret_cast<const Scale*>(scales); | ||
| 234 | auto* dst_ptr = dst.data(); | ||
| 235 | |||
| 236 | for (size_t y = 0; y < height; ++y) { | ||
| 237 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { | ||
| 238 | write_array(dst_ptr, 0, *scales_ptr); | ||
| 239 | dst_ptr += sizeof(Scale); | ||
| 240 | ++scales_ptr; | ||
| 241 | |||
| 242 | const auto len = std::min(x_quant + quant_width, width) - x_quant; | ||
| 243 | |||
| 244 | for (size_t x_element = 0; x_element < len; ++x_element) { | ||
| 245 | const auto x = x_quant + x_element; | ||
| 246 | write_array(dst_ptr, x_element, read_array<Data>(data, y * width + x)); | ||
| 247 | } | ||
| 248 | |||
| 249 | dst_ptr += len * size_in_bits<Data> / 8; | ||
| 250 | } | ||
| 251 | } | ||
| 252 | |||
| 253 | KAI_ASSERT_ALWAYS(dst_ptr == dst.data() + dst.size()); | ||
| 254 | |||
| 255 | return dst; | ||
| 256 | } | ||
| 257 | |||
| 258 | template <typename ZeroPoint, typename Data, typename Scale> | ||
| 259 | 521 | Buffer pack_zero_points_data_scales_per_block( | |
| 260 | const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, | ||
| 261 | size_t block_num_data, size_t block_num_scales) { | ||
| 262 | // Only data is allowed to be sub-byte. | ||
| 263 | − | KAI_ASSUME_ALWAYS(size_in_bits<ZeroPoint> % 8 == 0); | |
| 264 | − | KAI_ASSUME_ALWAYS(size_in_bits<Scale> % 8 == 0); | |
| 265 | |||
| 266 | // Checks for memory alignment. | ||
| 267 | − | KAI_ASSUME_ALWAYS(size_in_bits<ZeroPoint> % size_in_bits<Data> == 0); | |
| 268 | − | KAI_ASSUME_ALWAYS( | |
| 269 | (block_num_zero_points * size_in_bits<ZeroPoint> + block_num_data * size_in_bits<Data>) % size_in_bits<Scale> == | ||
| 270 | 0); | ||
| 271 | − | KAI_ASSUME_ALWAYS( | |
| 272 | (block_num_data * size_in_bits<Data> + block_num_scales * size_in_bits<Scale>) % size_in_bits<ZeroPoint> == 0); | ||
| 273 | |||
| 274 | 1042 | Buffer dst(round_up_division( | |
| 275 | 1042 | num_blocks * | |
| 276 | 1042 | (block_num_zero_points * size_in_bits<ZeroPoint> + block_num_data * size_in_bits<Data> + | |
| 277 | 521 | block_num_scales * size_in_bits<Scale>), | |
| 278 | 8)); | ||
| 279 |
1/2✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
|
521 | auto* dst_ptr = dst.data(); |
| 280 | |||
| 281 |
2/2✓ Branch 0 taken 521 times.
✓ Branch 1 taken 2101 times.
|
2622 | for (size_t block_no = 0; block_no < num_blocks; ++block_no) { |
| 282 |
2/2✓ Branch 0 taken 67232 times.
✓ Branch 1 taken 2101 times.
|
69333 | for (size_t i = 0; i < block_num_zero_points; ++i) { |
| 283 |
1/2✓ Branch 0 taken 67232 times.
✗ Branch 1 not taken.
|
67232 | write_array<ZeroPoint>( |
| 284 |
1/2✓ Branch 0 taken 67232 times.
✗ Branch 1 not taken.
|
67232 | dst_ptr, i, read_array<ZeroPoint>(zero_points, block_no * block_num_zero_points + i)); |
| 285 | 67232 | } | |
| 286 | 2101 | dst_ptr += block_num_zero_points * sizeof(ZeroPoint); | |
| 287 | |||
| 288 |
2/2✓ Branch 0 taken 7001728 times.
✓ Branch 1 taken 2101 times.
|
7003829 | for (size_t i = 0; i < block_num_data; ++i) { |
| 289 |
2/4✓ Branch 0 taken 7001728 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 7001728 times.
✗ Branch 3 not taken.
|
7001728 | write_array<Data>(dst_ptr, i, read_array<Data>(data, block_no * block_num_data + i)); |
| 290 | 7001728 | } | |
| 291 |
1/2✓ Branch 0 taken 2101 times.
✗ Branch 1 not taken.
|
2101 | dst_ptr += round_up_division(block_num_data * size_in_bits<Data>, 8); |
| 292 | |||
| 293 |
2/2✓ Branch 0 taken 2101 times.
✓ Branch 1 taken 67232 times.
|
69333 | for (size_t i = 0; i < block_num_scales; ++i) { |
| 294 |
2/4✓ Branch 0 taken 67232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 67232 times.
✗ Branch 3 not taken.
|
67232 | write_array<Scale>(dst_ptr, i, read_array<Scale>(scales, block_no * block_num_scales + i)); |
| 295 | 67232 | } | |
| 296 | 2101 | dst_ptr += block_num_scales * sizeof(Scale); | |
| 297 | 2101 | } | |
| 298 | |||
| 299 | − | KAI_ASSERT_ALWAYS(dst_ptr == dst.data() + dst.size()); | |
| 300 | |||
| 301 | 521 | return dst; | |
| 302 | 521 | } | |
| 303 | |||
| 304 | template Buffer pack_zero_points_data_scales_per_block<int32_t, int8_t, float>( | ||
| 305 | const void* zero_points, const void* data, const void* scales, size_t num_blocks, size_t block_num_zero_points, | ||
| 306 | size_t block_num_data, size_t block_num_scales); | ||
| 307 | |||
| 308 | template <typename Data, typename Scale> | ||
| 309 | 920 | Buffer pack_data_scales_interleave_block( | |
| 310 | const void* data, const void* scales, size_t height, size_t width, size_t quant_width) { | ||
| 311 | − | KAI_ASSUME_ALWAYS_IF(size_in_bits<Data> < 8, quant_width % (8 / size_in_bits<Data>) == 0); | |
| 312 | − | KAI_ASSUME_ALWAYS_IF(size_in_bits<Data> < 8, width % (8 / size_in_bits<Data>) == 0); | |
| 313 | − | KAI_ASSUME_ALWAYS(width % quant_width == 0); | |
| 314 | − | KAI_ASSUME_ALWAYS(quant_width % 2 == 0); | |
| 315 | |||
| 316 | 920 | const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width; | |
| 317 | |||
| 318 | 920 | const auto data_bytes = height * width * size_in_bits<Data> / 8; | |
| 319 |
1/4✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
920 | const auto scales_bytes = scales != nullptr ? height * num_quant_packets_x * sizeof(Scale) : 0; |
| 320 | |||
| 321 | 920 | Buffer dst(data_bytes + scales_bytes); | |
| 322 | |||
| 323 | 920 | const auto* scales_ptr = reinterpret_cast<const Scale*>(scales); | |
| 324 |
1/4✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
920 | auto* dst_ptr = dst.data(); |
| 325 | |||
| 326 |
2/4✓ Branch 0 taken 920 times.
✓ Branch 1 taken 36460 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
37380 | for (size_t y = 0; y < height; ++y) { |
| 327 |
2/4✓ Branch 0 taken 61920 times.
✓ Branch 1 taken 36460 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
98380 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { |
| 328 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 61920 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
61920 | if (scales_ptr != nullptr) { |
| 329 |
1/4✓ Branch 0 taken 61920 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
61920 | write_array(dst_ptr, 0, *scales_ptr); |
| 330 | 61920 | dst_ptr += sizeof(Scale); | |
| 331 | 61920 | ++scales_ptr; | |
| 332 | 61920 | } | |
| 333 | |||
| 334 |
2/4✓ Branch 0 taken 61920 times.
✓ Branch 1 taken 1981440 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
2043360 | for (size_t x_element = 0; x_element < quant_width; ++x_element) { |
| 335 |
2/4✓ Branch 0 taken 990720 times.
✓ Branch 1 taken 990720 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1981440 | const auto x = x_quant + x_element / 2 + (x_element % 2 != 0 ? quant_width / 2 : 0); |
| 336 |
2/8✓ Branch 0 taken 1981440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1981440 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
1981440 | write_array(dst_ptr, x_element, read_array<Data>(data, y * width + x)); |
| 337 | 1981440 | } | |
| 338 | |||
| 339 | 61920 | dst_ptr += quant_width * size_in_bits<Data> / 8; | |
| 340 | 61920 | } | |
| 341 | 36460 | } | |
| 342 | |||
| 343 | − | KAI_ASSERT_ALWAYS(dst_ptr == dst.data() + dst.size()); | |
| 344 | |||
| 345 | 920 | return dst; | |
| 346 | 920 | } | |
| 347 | |||
| 348 | template Buffer pack_data_scales_interleave_block<UInt4, Float16>( | ||
| 349 | const void* data, const void* scales, size_t height, size_t width, size_t quant_width); | ||
| 350 | template Buffer pack_data_scales_interleave_block<UInt4, std::nullptr_t>( | ||
| 351 | const void* data, const void* scales, size_t height, size_t width, size_t quant_width); | ||
| 352 | |||
| 353 | template <typename Data, typename ZeroPoint, typename Scale, typename Bias> | ||
| 354 | Buffer pack_block_data_zero_points_scale_bias( | ||
| 355 | const void* data, const void* zero_points, const void* scales, const void* biases, size_t height, size_t width, | ||
| 356 | size_t quant_height, size_t quant_width, size_t block_height, size_t block_width, size_t interleave_x_blocks) { | ||
| 357 | if (quant_width == width) { | ||
| 358 | quant_width = round_up_multiple(quant_width, block_width); | ||
| 359 | } | ||
| 360 | |||
| 361 | KAI_ASSERT_ALWAYS(quant_height == block_height); | ||
| 362 | KAI_ASSERT_ALWAYS(quant_width % block_width == 0); | ||
| 363 | |||
| 364 | if (interleave_x_blocks == 0) { | ||
| 365 | interleave_x_blocks = quant_width / block_width; | ||
| 366 | } | ||
| 367 | |||
| 368 | const auto has_zero_points = zero_points != nullptr; | ||
| 369 | const auto has_biases = biases != nullptr; | ||
| 370 | |||
| 371 | const auto num_quant_packets_y = round_up_division(height, quant_height); | ||
| 372 | const auto num_quant_packets_x = round_up_division(width, quant_width); | ||
| 373 | |||
| 374 | const auto quant_packet_data_bytes = quant_height * quant_width * size_in_bits<Data> / 8; | ||
| 375 | const auto quant_packet_zero_points_bytes = has_zero_points ? quant_height * sizeof(ZeroPoint) : 0; | ||
| 376 | const auto quant_packet_scales_bytes = quant_height * sizeof(Scale); | ||
| 377 | const auto quant_packet_bytes = | ||
| 378 | quant_packet_zero_points_bytes + quant_packet_data_bytes + quant_packet_scales_bytes; | ||
| 379 | |||
| 380 | const auto num_quant_packets_per_row = round_up_division(width, quant_width); | ||
| 381 | const auto biases_bytes = has_biases ? height * sizeof(Bias) : 0; | ||
| 382 | |||
| 383 | const auto dst_bytes = num_quant_packets_y * num_quant_packets_x * quant_packet_bytes + biases_bytes; | ||
| 384 | Buffer dst(dst_bytes); | ||
| 385 | |||
| 386 | const auto* zero_points_ptr = reinterpret_cast<const ZeroPoint*>(zero_points); | ||
| 387 | const auto* scales_ptr = reinterpret_cast<const Scale*>(scales); | ||
| 388 | const auto* biases_ptr = reinterpret_cast<const Bias*>(biases); | ||
| 389 | auto* dst_ptr = dst.data(); | ||
| 390 | |||
| 391 | for (size_t y_quant = 0; y_quant < height; y_quant += quant_height) { | ||
| 392 | for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { | ||
| 393 | size_t dst_index = 0; | ||
| 394 | |||
| 395 | // Packs the data. | ||
| 396 | for (size_t y_pack = 0; y_pack < quant_height; y_pack += block_height) { | ||
| 397 | for (size_t x_pack = 0; x_pack < block_width * interleave_x_blocks; x_pack += block_width) { | ||
| 398 | for (size_t y_element = 0; y_element < block_height; ++y_element) { | ||
| 399 | for (size_t x_element = 0; x_element < block_width; ++x_element) { | ||
| 400 | for (size_t x_interleave = 0; x_interleave < quant_width; | ||
| 401 | x_interleave += block_width * interleave_x_blocks) { | ||
| 402 | const auto y = y_quant + y_pack + y_element; | ||
| 403 | const auto x = x_quant + x_pack + x_element + x_interleave; | ||
| 404 | |||
| 405 | if (y < height && x < width) { | ||
| 406 | write_array(dst_ptr, dst_index, read_array<Data>(data, y * width + x)); | ||
| 407 | } | ||
| 408 | |||
| 409 | ++dst_index; | ||
| 410 | } | ||
| 411 | } | ||
| 412 | } | ||
| 413 | } | ||
| 414 | } | ||
| 415 | |||
| 416 | dst_ptr += dst_index * size_in_bits<Data> / 8; | ||
| 417 | |||
| 418 | // Packs the zero points. | ||
| 419 | if (has_zero_points) { | ||
| 420 | for (size_t y_element = 0; y_element < quant_height; ++y_element) { | ||
| 421 | const auto y = y_quant + y_element; | ||
| 422 | const auto x = x_quant / quant_width; | ||
| 423 | memcpy(dst_ptr, &zero_points_ptr[y * num_quant_packets_per_row + x], sizeof(ZeroPoint)); | ||
| 424 | dst_ptr += sizeof(ZeroPoint); | ||
| 425 | } | ||
| 426 | } | ||
| 427 | |||
| 428 | // Packs the scales. | ||
| 429 | for (size_t y_element = 0; y_element < quant_height; ++y_element) { | ||
| 430 | const auto y = y_quant + y_element; | ||
| 431 | const auto x = x_quant / quant_width; | ||
| 432 | memcpy(dst_ptr, &scales_ptr[y * num_quant_packets_per_row + x], sizeof(Scale)); | ||
| 433 | dst_ptr += sizeof(Scale); | ||
| 434 | } | ||
| 435 | } | ||
| 436 | |||
| 437 | // Packs the biases. | ||
| 438 | if (has_biases) { | ||
| 439 | for (size_t y_element = 0; y_element < quant_height; ++y_element) { | ||
| 440 | const auto y = y_quant + y_element; | ||
| 441 | memcpy(dst_ptr, &biases_ptr[y], sizeof(Bias)); | ||
| 442 | dst_ptr += sizeof(Bias); | ||
| 443 | } | ||
| 444 | } | ||
| 445 | } | ||
| 446 | |||
| 447 | KAI_ASSERT_ALWAYS(dst_ptr == dst.data() + dst.size()); | ||
| 448 | |||
| 449 | return dst; | ||
| 450 | } | ||
| 451 | |||
| 452 | } // namespace kai::test | ||
| 453 |