test/common/data_format.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/common/data_format.hpp" | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | #include <cstdint> | ||
| 11 | #include <functional> | ||
| 12 | |||
| 13 | #include "kai/kai_common.h" | ||
| 14 | #include "test/common/data_type.hpp" | ||
| 15 | #include "test/common/round.hpp" | ||
| 16 | |||
| 17 | namespace kai::test { | ||
| 18 | |||
| 19 | 173820 | DataFormat::DataFormat( | |
| 20 | DataType data_type, size_t block_height, size_t block_width, PackFormat pack_format, DataType zero_point_dt, | ||
| 21 | DataType scale_dt, size_t subblock_height, size_t subblock_width) noexcept : | ||
| 22 | 108204 | _data_type(data_type), | |
| 23 | 108204 | _pack_format(pack_format), | |
| 24 | 108204 | _scale_dt(scale_dt), | |
| 25 | 108204 | _zero_point_dt(zero_point_dt), | |
| 26 | 108204 | _block_height(block_height), | |
| 27 | 108204 | _block_width(block_width), | |
| 28 | 108204 | _subblock_height(subblock_height), | |
| 29 | 173820 | _subblock_width(subblock_width) { | |
| 30 | 173820 | } | |
| 31 | |||
| 32 | 85348 | bool DataFormat::operator==(const DataFormat& rhs) const { | |
| 33 |
4/6✓ Branch 0 taken 85294 times.
✓ Branch 1 taken 54 times.
✓ Branch 2 taken 85294 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 85294 times.
✗ Branch 5 not taken.
|
170642 | return _data_type == rhs._data_type && _pack_format == rhs._pack_format && _scale_dt == rhs._scale_dt && |
| 34 |
2/4✓ Branch 0 taken 85294 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 85294 times.
|
85294 | _zero_point_dt == rhs._zero_point_dt && _block_height == rhs._block_height && _block_width == rhs._block_width; |
| 35 | } | ||
| 36 | |||
| 37 | ✗ | bool DataFormat::operator!=(const DataFormat& rhs) const { | |
| 38 | ✗ | return !(*this == rhs); | |
| 39 | } | ||
| 40 | |||
| 41 | 116165 | DataType DataFormat::data_type() const { | |
| 42 | 116165 | return _data_type; | |
| 43 | } | ||
| 44 | |||
| 45 | 72240 | DataFormat::PackFormat DataFormat::pack_format() const { | |
| 46 | 72240 | return _pack_format; | |
| 47 | } | ||
| 48 | |||
| 49 | 62670 | DataType DataFormat::scale_data_type() const { | |
| 50 | 62670 | return _scale_dt; | |
| 51 | } | ||
| 52 | |||
| 53 | 61596 | DataType DataFormat::zero_point_data_type() const { | |
| 54 | 61596 | return _zero_point_dt; | |
| 55 | } | ||
| 56 | |||
| 57 | 4218 | bool DataFormat::is_raw() const { | |
| 58 |
1/2✓ Branch 0 taken 4218 times.
✗ Branch 1 not taken.
|
8436 | return _pack_format == PackFormat::NONE && // |
| 59 |
3/6✓ Branch 0 taken 4218 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4218 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 4218 times.
|
4218 | _block_height == 0 && _block_width == 0 && _subblock_height == 0 && _subblock_width == 0; |
| 60 | } | ||
| 61 | |||
| 62 | ✗ | size_t DataFormat::block_height() const { | |
| 63 | ✗ | return _block_height; | |
| 64 | } | ||
| 65 | |||
| 66 | ✗ | size_t DataFormat::block_width() const { | |
| 67 | ✗ | return _block_width; | |
| 68 | } | ||
| 69 | |||
| 70 | ✗ | size_t DataFormat::subblock_height() const { | |
| 71 | ✗ | return _subblock_height; | |
| 72 | } | ||
| 73 | |||
| 74 | ✗ | size_t DataFormat::subblock_width() const { | |
| 75 | ✗ | return _subblock_width; | |
| 76 | } | ||
| 77 | |||
| 78 | 119469 | size_t DataFormat::actual_block_height(size_t full_height) const { | |
| 79 |
2/2✓ Branch 0 taken 4161 times.
✓ Branch 1 taken 115308 times.
|
234777 | return _block_height > 0 ? _block_height |
| 80 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 115308 times.
|
115308 | : round_up_multiple(full_height, _subblock_height > 0 ? _subblock_height : 1); |
| 81 | } | ||
| 82 | |||
| 83 | 165567 | size_t DataFormat::actual_block_width(size_t full_width) const { | |
| 84 |
4/4✓ Branch 0 taken 8319 times.
✓ Branch 1 taken 157248 times.
✓ Branch 2 taken 5940 times.
✓ Branch 3 taken 151308 times.
|
165567 | return _block_width > 0 ? _block_width : round_up_multiple(full_width, _subblock_width > 0 ? _subblock_width : 1); |
| 85 | } | ||
| 86 | |||
| 87 | 61527 | size_t DataFormat::actual_subblock_height(size_t full_height) const { | |
| 88 |
2/2✓ Branch 0 taken 3585 times.
✓ Branch 1 taken 57942 times.
|
61527 | return _subblock_height > 0 ? _subblock_height : actual_block_height(full_height); |
| 89 | } | ||
| 90 | |||
| 91 | 61527 | size_t DataFormat::actual_subblock_width(size_t full_width) const { | |
| 92 |
2/2✓ Branch 0 taken 3585 times.
✓ Branch 1 taken 57942 times.
|
61527 | return _subblock_width > 0 ? _subblock_width : actual_block_width(full_width); |
| 93 | } | ||
| 94 | |||
| 95 | 5232 | size_t DataFormat::scheduler_block_height([[maybe_unused]] size_t full_height) const { | |
| 96 |
2/2✓ Branch 0 taken 5016 times.
✓ Branch 1 taken 216 times.
|
5232 | const auto padded_block_height = round_up_multiple(_block_height, _subblock_height > 0 ? _subblock_height : 1); |
| 97 | |||
| 98 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 5016 times.
✓ Branch 2 taken 216 times.
|
5232 | switch (_pack_format) { |
| 99 | case PackFormat::NONE: | ||
| 100 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | return _block_height > 0 ? padded_block_height : 1; |
| 101 | |||
| 102 | case PackFormat::BIAS_PER_ROW: | ||
| 103 | case PackFormat::QUANTIZE_PER_ROW: | ||
| 104 | − | KAI_ASSUME_ALWAYS(_block_height > 0); | |
| 105 | 5016 | return padded_block_height; | |
| 106 | |||
| 107 | default: | ||
| 108 | − | KAI_ERROR("Unsupported packing format!"); | |
| 109 | ✗ | } | |
| 110 | 5232 | } | |
| 111 | |||
| 112 | 26544 | size_t DataFormat::scheduler_block_width(size_t full_width) const { | |
| 113 |
2/2✓ Branch 0 taken 11850 times.
✓ Branch 1 taken 14694 times.
|
26544 | const auto padded_block_width = round_up_multiple(_block_width, _subblock_width > 0 ? _subblock_width : 1); |
| 114 | |||
| 115 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 11850 times.
✓ Branch 2 taken 14694 times.
|
26544 | switch (_pack_format) { |
| 116 | case PackFormat::NONE: | ||
| 117 |
2/2✓ Branch 0 taken 432 times.
✓ Branch 1 taken 14262 times.
|
14694 | return _block_width > 0 ? padded_block_width : 1; |
| 118 | |||
| 119 | case PackFormat::BIAS_PER_ROW: | ||
| 120 | case PackFormat::QUANTIZE_PER_ROW: | ||
| 121 | 11850 | return full_width; | |
| 122 | |||
| 123 | default: | ||
| 124 | − | KAI_ERROR("Unsupported packing format!"); | |
| 125 | ✗ | } | |
| 126 | 26544 | } | |
| 127 | |||
| 128 | 46098 | uintptr_t DataFormat::default_row_stride(size_t width) const { | |
| 129 | 46098 | const auto padded_width = round_up_multiple(width, actual_block_width(width)); | |
| 130 | |||
| 131 |
3/4✗ Branch 0 not taken.
✓ Branch 1 taken 27087 times.
✓ Branch 2 taken 6348 times.
✓ Branch 3 taken 12663 times.
|
46098 | switch (_pack_format) { |
| 132 | case PackFormat::NONE: | ||
| 133 |
2/2✓ Branch 0 taken 648 times.
✓ Branch 1 taken 36000 times.
|
36648 | return (_block_height > 0 ? _block_height : 1) * padded_width * data_type_size_in_bits(_data_type) / 8; |
| 134 | |||
| 135 | case PackFormat::BIAS_PER_ROW: | ||
| 136 | − | KAI_ASSUME_ALWAYS(_block_height > 0); | |
| 137 | 18900 | return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + // | |
| 138 | 9450 | _block_height * padded_width * data_type_size_in_bits(_data_type) / 8; | |
| 139 | |||
| 140 | case PackFormat::QUANTIZE_PER_ROW: | ||
| 141 | − | KAI_ASSUME_ALWAYS(_block_height > 0); | |
| 142 | ✗ | return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + // | |
| 143 | ✗ | _block_height * padded_width * data_type_size_in_bits(_data_type) / 8 + // | |
| 144 | ✗ | _block_height * data_type_size_in_bits(_scale_dt) / 8; | |
| 145 | |||
| 146 | default: | ||
| 147 | − | KAI_ERROR("Unsupported packing format!"); | |
| 148 | ✗ | } | |
| 149 | 46098 | } | |
| 150 | |||
| 151 | 21744 | uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t width) const { | |
| 152 | 21744 | const auto row_stride = default_row_stride(width); | |
| 153 | 21744 | const auto block_width = scheduler_block_width(width); | |
| 154 | |||
| 155 | − | KAI_ASSERT_ALWAYS(col % block_width == 0); | |
| 156 | |||
| 157 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 7050 times.
✓ Branch 2 taken 14694 times.
|
21744 | switch (_pack_format) { |
| 158 | case PackFormat::NONE: | ||
| 159 |
2/2✓ Branch 0 taken 432 times.
✓ Branch 1 taken 14262 times.
|
14694 | return row * row_stride / (_block_height > 0 ? _block_height : 1) + |
| 160 | 14694 | col * data_type_size_in_bits(_data_type) / 8; | |
| 161 | |||
| 162 | case PackFormat::BIAS_PER_ROW: | ||
| 163 | case PackFormat::QUANTIZE_PER_ROW: | ||
| 164 | − | KAI_ASSUME_ALWAYS(row % _block_height == 0); | |
| 165 | − | KAI_ASSUME_ALWAYS(col == 0); | |
| 166 | 7050 | return (row / _block_height) * row_stride; | |
| 167 | |||
| 168 | default: | ||
| 169 | − | KAI_ERROR("Unsupported packing format!"); | |
| 170 | ✗ | } | |
| 171 | 21744 | } | |
| 172 | |||
| 173 | 7266 | size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const { | |
| 174 |
2/2✓ Branch 0 taken 2616 times.
✓ Branch 1 taken 4650 times.
|
7266 | const auto num_rows = _block_height > 0 ? (height + _block_height - 1) / _block_height : height; |
| 175 | 7266 | const auto block_stride = default_row_stride(width); | |
| 176 | 14532 | return num_rows * block_stride; | |
| 177 | 7266 | } | |
| 178 | |||
| 179 | 110101 | size_t DataFormat::Hash::operator()(const DataFormat& format) const { | |
| 180 | 110101 | return // | |
| 181 | 220202 | (std::hash<DataType>{}(format._data_type) << 0) ^ // | |
| 182 | 220202 | (std::hash<DataFormat::PackFormat>{}(format._pack_format) << 1) ^ // | |
| 183 | 220202 | (std::hash<DataType>{}(format._scale_dt) << 2) ^ // | |
| 184 | 220202 | (std::hash<DataType>{}(format._zero_point_dt) << 3) ^ // | |
| 185 | 220202 | (std::hash<size_t>{}(format._block_height) << 4) ^ // | |
| 186 | 220202 | (std::hash<size_t>{}(format._block_width) << 5) ^ // | |
| 187 | 220202 | (std::hash<size_t>{}(format._subblock_height) << 6) ^ // | |
| 188 | 110101 | (std::hash<size_t>{}(format._subblock_width) << 7); // | |
| 189 | } | ||
| 190 | |||
| 191 | } // namespace kai::test | ||
| 192 |