Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/common/data_format.hpp" | ||
8 | |||
9 | #include <cstddef> | ||
10 | #include <cstdint> | ||
11 | #include <functional> | ||
12 | |||
13 | #include "kai/kai_common.h" | ||
14 | #include "test/common/data_type.hpp" | ||
15 | #include "test/common/round.hpp" | ||
16 | |||
17 | namespace kai::test { | ||
18 | |||
19 | 72180 | DataFormat::DataFormat( | |
20 | DataType data_type, size_t block_height, size_t block_width, PackFormat pack_format, DataType zero_point_dt, | ||
21 | DataType scale_dt, size_t subblock_height, size_t subblock_width) noexcept : | ||
22 | 36090 | _data_type(data_type), | |
23 | 36090 | _pack_format(pack_format), | |
24 | 36090 | _scale_dt(scale_dt), | |
25 | 36090 | _zero_point_dt(zero_point_dt), | |
26 | 36090 | _block_height(block_height), | |
27 | 36090 | _block_width(block_width), | |
28 | 36090 | _subblock_height(subblock_height), | |
29 | 72180 | _subblock_width(subblock_width) { | |
30 | 72180 | } | |
31 | |||
32 | 102399 | bool DataFormat::operator==(const DataFormat& rhs) const { | |
33 |
3/6✓ Branch 0 taken 102399 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 102399 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 102399 times.
✗ Branch 5 not taken.
|
204798 | return _data_type == rhs._data_type && _pack_format == rhs._pack_format && _scale_dt == rhs._scale_dt && |
34 |
2/4✓ Branch 0 taken 102399 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 102399 times.
|
102399 | _zero_point_dt == rhs._zero_point_dt && _block_height == rhs._block_height && _block_width == rhs._block_width; |
35 | } | ||
36 | |||
37 | ✗ | bool DataFormat::operator!=(const DataFormat& rhs) const { | |
38 | ✗ | return !(*this == rhs); | |
39 | } | ||
40 | |||
41 | 121184 | DataType DataFormat::data_type() const { | |
42 | 121184 | return _data_type; | |
43 | } | ||
44 | |||
45 | 42413 | DataFormat::PackFormat DataFormat::pack_format() const { | |
46 | 42413 | return _pack_format; | |
47 | } | ||
48 | |||
49 | 34126 | DataType DataFormat::scale_data_type() const { | |
50 | 34126 | return _scale_dt; | |
51 | } | ||
52 | |||
53 | 33908 | DataType DataFormat::zero_point_data_type() const { | |
54 | 33908 | return _zero_point_dt; | |
55 | } | ||
56 | |||
57 | 776 | bool DataFormat::is_raw() const { | |
58 |
1/2✓ Branch 0 taken 776 times.
✗ Branch 1 not taken.
|
1552 | return _pack_format == PackFormat::NONE && // |
59 |
3/6✓ Branch 0 taken 776 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 776 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 776 times.
|
776 | _block_height == 0 && _block_width == 0 && _subblock_height == 0 && _subblock_width == 0; |
60 | } | ||
61 | |||
62 | ✗ | size_t DataFormat::block_height() const { | |
63 | ✗ | return _block_height; | |
64 | } | ||
65 | |||
66 | ✗ | size_t DataFormat::block_width() const { | |
67 | ✗ | return _block_width; | |
68 | } | ||
69 | |||
70 | ✗ | size_t DataFormat::subblock_height() const { | |
71 | ✗ | return _subblock_height; | |
72 | } | ||
73 | |||
74 | ✗ | size_t DataFormat::subblock_width() const { | |
75 | ✗ | return _subblock_width; | |
76 | } | ||
77 | |||
78 | 67113 | size_t DataFormat::actual_block_height(size_t full_height) const { | |
79 |
2/2✓ Branch 0 taken 869 times.
✓ Branch 1 taken 66244 times.
|
133357 | return _block_height > 0 ? _block_height |
80 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 66244 times.
|
66244 | : round_up_multiple(full_height, _subblock_height > 0 ? _subblock_height : 1); |
81 | } | ||
82 | |||
83 | 75599 | size_t DataFormat::actual_block_width(size_t full_width) const { | |
84 |
4/4✓ Branch 0 taken 2115 times.
✓ Branch 1 taken 73484 times.
✓ Branch 2 taken 686 times.
✓ Branch 3 taken 72798 times.
|
75599 | return _block_width > 0 ? _block_width : round_up_multiple(full_width, _subblock_width > 0 ? _subblock_width : 1); |
85 | } | ||
86 | |||
87 | 33899 | size_t DataFormat::actual_subblock_height(size_t full_height) const { | |
88 |
2/2✓ Branch 0 taken 685 times.
✓ Branch 1 taken 33214 times.
|
33899 | return _subblock_height > 0 ? _subblock_height : actual_block_height(full_height); |
89 | } | ||
90 | |||
91 | 33899 | size_t DataFormat::actual_subblock_width(size_t full_width) const { | |
92 |
2/2✓ Branch 0 taken 685 times.
✓ Branch 1 taken 33214 times.
|
33899 | return _subblock_width > 0 ? _subblock_width : actual_block_width(full_width); |
93 | } | ||
94 | |||
95 | 1124 | size_t DataFormat::scheduler_block_height([[maybe_unused]] size_t full_height) const { | |
96 |
2/2✓ Branch 0 taken 1052 times.
✓ Branch 1 taken 72 times.
|
1124 | const auto padded_block_height = round_up_multiple(_block_height, _subblock_height > 0 ? _subblock_height : 1); |
97 | |||
98 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 1052 times.
✓ Branch 2 taken 72 times.
|
1124 | switch (_pack_format) { |
99 | case PackFormat::NONE: | ||
100 |
1/2✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
|
72 | return _block_height > 0 ? padded_block_height : 1; |
101 | |||
102 | case PackFormat::BIAS_PER_ROW: | ||
103 | case PackFormat::QUANTIZE_PER_ROW: | ||
104 | − | KAI_ASSUME(_block_height > 0); | |
105 | 1052 | return padded_block_height; | |
106 | |||
107 | default: | ||
108 | − | KAI_ERROR("Unsupported packing format!"); | |
109 | ✗ | } | |
110 | 1124 | } | |
111 | |||
112 | 5118 | size_t DataFormat::scheduler_block_width(size_t full_width) const { | |
113 |
2/2✓ Branch 0 taken 2244 times.
✓ Branch 1 taken 2874 times.
|
5118 | const auto padded_block_width = round_up_multiple(_block_width, _subblock_width > 0 ? _subblock_width : 1); |
114 | |||
115 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 2244 times.
✓ Branch 2 taken 2874 times.
|
5118 | switch (_pack_format) { |
116 | case PackFormat::NONE: | ||
117 |
2/2✓ Branch 0 taken 132 times.
✓ Branch 1 taken 2742 times.
|
2874 | return _block_width > 0 ? padded_block_width : 1; |
118 | |||
119 | case PackFormat::BIAS_PER_ROW: | ||
120 | case PackFormat::QUANTIZE_PER_ROW: | ||
121 | 2244 | return full_width; | |
122 | |||
123 | default: | ||
124 | − | KAI_ERROR("Unsupported packing format!"); | |
125 | ✗ | } | |
126 | 5118 | } | |
127 | |||
128 | 8486 | uintptr_t DataFormat::default_row_stride(size_t width) const { | |
129 | 8486 | const auto padded_width = round_up_multiple(width, actual_block_width(width)); | |
130 | |||
131 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6754 times.
✓ Branch 2 taken 1732 times.
✗ Branch 3 not taken.
|
8486 | switch (_pack_format) { |
132 | case PackFormat::NONE: | ||
133 |
2/2✓ Branch 0 taken 200 times.
✓ Branch 1 taken 6554 times.
|
6754 | return (_block_height > 0 ? _block_height : 1) * padded_width * data_type_size_in_bits(_data_type) / 8; |
134 | |||
135 | case PackFormat::BIAS_PER_ROW: | ||
136 | − | KAI_ASSUME(_block_height > 0); | |
137 | 3464 | return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + // | |
138 | 1732 | _block_height * padded_width * data_type_size_in_bits(_data_type) / 8; | |
139 | |||
140 | case PackFormat::QUANTIZE_PER_ROW: | ||
141 | − | KAI_ASSUME(_block_height > 0); | |
142 | ✗ | return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + // | |
143 | ✗ | _block_height * padded_width * data_type_size_in_bits(_data_type) / 8 + // | |
144 | ✗ | _block_height * data_type_size_in_bits(_scale_dt) / 8; | |
145 | |||
146 | default: | ||
147 | − | KAI_ERROR("Unsupported packing format!"); | |
148 | ✗ | } | |
149 | 8486 | } | |
150 | |||
151 | 4138 | uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t width) const { | |
152 | 4138 | const auto row_stride = default_row_stride(width); | |
153 | 4138 | const auto block_width = scheduler_block_width(width); | |
154 | |||
155 | − | KAI_ASSERT(col % block_width == 0); | |
156 | |||
157 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 1264 times.
✓ Branch 2 taken 2874 times.
|
4138 | switch (_pack_format) { |
158 | case PackFormat::NONE: | ||
159 |
2/2✓ Branch 0 taken 132 times.
✓ Branch 1 taken 2742 times.
|
2874 | return row * row_stride / (_block_height > 0 ? _block_height : 1) + |
160 | 2874 | col * data_type_size_in_bits(_data_type) / 8; | |
161 | |||
162 | case PackFormat::BIAS_PER_ROW: | ||
163 | case PackFormat::QUANTIZE_PER_ROW: | ||
164 | − | KAI_ASSUME(row % _block_height == 0); | |
165 | − | KAI_ASSUME(col == 0); | |
166 | 1264 | return (row / _block_height) * row_stride; | |
167 | |||
168 | default: | ||
169 | − | KAI_ERROR("Unsupported packing format!"); | |
170 | ✗ | } | |
171 | 4138 | } | |
172 | |||
173 | 1332 | size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const { | |
174 |
2/2✓ Branch 0 taken 536 times.
✓ Branch 1 taken 796 times.
|
1332 | const auto num_rows = _block_height > 0 ? (height + _block_height - 1) / _block_height : height; |
175 | 1332 | const auto block_stride = default_row_stride(width); | |
176 | 2664 | return num_rows * block_stride; | |
177 | 1332 | } | |
178 | |||
179 | 120417 | size_t DataFormat::Hash::operator()(const DataFormat& format) const { | |
180 | 120417 | return // | |
181 | 240834 | (std::hash<DataType>{}(format._data_type) << 0) ^ // | |
182 | 240834 | (std::hash<DataFormat::PackFormat>{}(format._pack_format) << 1) ^ // | |
183 | 240834 | (std::hash<DataType>{}(format._scale_dt) << 2) ^ // | |
184 | 240834 | (std::hash<DataType>{}(format._zero_point_dt) << 3) ^ // | |
185 | 240834 | (std::hash<size_t>{}(format._block_height) << 4) ^ // | |
186 | 240834 | (std::hash<size_t>{}(format._block_width) << 5) ^ // | |
187 | 240834 | (std::hash<size_t>{}(format._subblock_height) << 6) ^ // | |
188 | 120417 | (std::hash<size_t>{}(format._subblock_width) << 7); // | |
189 | } | ||
190 | |||
191 | } // namespace kai::test | ||
192 |