KleidiAI Coverage Report


Directory: ./
File: test/common/data_format.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 81.5% 75 10 102
Functions: 77.3% 17 0 22
Branches: 71.9% 41 12 69

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/common/data_format.hpp"
8
9 #include <cstddef>
10 #include <cstdint>
11 #include <functional>
12
13 #include "kai/kai_common.h"
14 #include "test/common/data_type.hpp"
15 #include "test/common/round.hpp"
16
17 namespace kai::test {
18
19 72180 DataFormat::DataFormat(
20 DataType data_type, size_t block_height, size_t block_width, PackFormat pack_format, DataType zero_point_dt,
21 DataType scale_dt, size_t subblock_height, size_t subblock_width) noexcept :
22 36090 _data_type(data_type),
23 36090 _pack_format(pack_format),
24 36090 _scale_dt(scale_dt),
25 36090 _zero_point_dt(zero_point_dt),
26 36090 _block_height(block_height),
27 36090 _block_width(block_width),
28 36090 _subblock_height(subblock_height),
29 72180 _subblock_width(subblock_width) {
30 72180 }
31
32 102399 bool DataFormat::operator==(const DataFormat& rhs) const {
33
3/6
✓ Branch 0 taken 102399 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 102399 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 102399 times.
✗ Branch 5 not taken.
204798 return _data_type == rhs._data_type && _pack_format == rhs._pack_format && _scale_dt == rhs._scale_dt &&
34
2/4
✓ Branch 0 taken 102399 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 102399 times.
102399 _zero_point_dt == rhs._zero_point_dt && _block_height == rhs._block_height && _block_width == rhs._block_width;
35 }
36
37 bool DataFormat::operator!=(const DataFormat& rhs) const {
38 return !(*this == rhs);
39 }
40
41 121184 DataType DataFormat::data_type() const {
42 121184 return _data_type;
43 }
44
45 42413 DataFormat::PackFormat DataFormat::pack_format() const {
46 42413 return _pack_format;
47 }
48
49 34126 DataType DataFormat::scale_data_type() const {
50 34126 return _scale_dt;
51 }
52
53 33908 DataType DataFormat::zero_point_data_type() const {
54 33908 return _zero_point_dt;
55 }
56
57 776 bool DataFormat::is_raw() const {
58
1/2
✓ Branch 0 taken 776 times.
✗ Branch 1 not taken.
1552 return _pack_format == PackFormat::NONE && //
59
3/6
✓ Branch 0 taken 776 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 776 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 776 times.
776 _block_height == 0 && _block_width == 0 && _subblock_height == 0 && _subblock_width == 0;
60 }
61
62 size_t DataFormat::block_height() const {
63 return _block_height;
64 }
65
66 size_t DataFormat::block_width() const {
67 return _block_width;
68 }
69
70 size_t DataFormat::subblock_height() const {
71 return _subblock_height;
72 }
73
74 size_t DataFormat::subblock_width() const {
75 return _subblock_width;
76 }
77
78 67113 size_t DataFormat::actual_block_height(size_t full_height) const {
79
2/2
✓ Branch 0 taken 869 times.
✓ Branch 1 taken 66244 times.
133357 return _block_height > 0 ? _block_height
80
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 66244 times.
66244 : round_up_multiple(full_height, _subblock_height > 0 ? _subblock_height : 1);
81 }
82
83 75599 size_t DataFormat::actual_block_width(size_t full_width) const {
84
4/4
✓ Branch 0 taken 2115 times.
✓ Branch 1 taken 73484 times.
✓ Branch 2 taken 686 times.
✓ Branch 3 taken 72798 times.
75599 return _block_width > 0 ? _block_width : round_up_multiple(full_width, _subblock_width > 0 ? _subblock_width : 1);
85 }
86
87 33899 size_t DataFormat::actual_subblock_height(size_t full_height) const {
88
2/2
✓ Branch 0 taken 685 times.
✓ Branch 1 taken 33214 times.
33899 return _subblock_height > 0 ? _subblock_height : actual_block_height(full_height);
89 }
90
91 33899 size_t DataFormat::actual_subblock_width(size_t full_width) const {
92
2/2
✓ Branch 0 taken 685 times.
✓ Branch 1 taken 33214 times.
33899 return _subblock_width > 0 ? _subblock_width : actual_block_width(full_width);
93 }
94
95 1124 size_t DataFormat::scheduler_block_height([[maybe_unused]] size_t full_height) const {
96
2/2
✓ Branch 0 taken 1052 times.
✓ Branch 1 taken 72 times.
1124 const auto padded_block_height = round_up_multiple(_block_height, _subblock_height > 0 ? _subblock_height : 1);
97
98
2/3
✗ Branch 0 not taken.
✓ Branch 1 taken 1052 times.
✓ Branch 2 taken 72 times.
1124 switch (_pack_format) {
99 case PackFormat::NONE:
100
1/2
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
72 return _block_height > 0 ? padded_block_height : 1;
101
102 case PackFormat::BIAS_PER_ROW:
103 case PackFormat::QUANTIZE_PER_ROW:
104 KAI_ASSUME(_block_height > 0);
105 1052 return padded_block_height;
106
107 default:
108 KAI_ERROR("Unsupported packing format!");
109 }
110 1124 }
111
112 5118 size_t DataFormat::scheduler_block_width(size_t full_width) const {
113
2/2
✓ Branch 0 taken 2244 times.
✓ Branch 1 taken 2874 times.
5118 const auto padded_block_width = round_up_multiple(_block_width, _subblock_width > 0 ? _subblock_width : 1);
114
115
2/3
✗ Branch 0 not taken.
✓ Branch 1 taken 2244 times.
✓ Branch 2 taken 2874 times.
5118 switch (_pack_format) {
116 case PackFormat::NONE:
117
2/2
✓ Branch 0 taken 132 times.
✓ Branch 1 taken 2742 times.
2874 return _block_width > 0 ? padded_block_width : 1;
118
119 case PackFormat::BIAS_PER_ROW:
120 case PackFormat::QUANTIZE_PER_ROW:
121 2244 return full_width;
122
123 default:
124 KAI_ERROR("Unsupported packing format!");
125 }
126 5118 }
127
128 8486 uintptr_t DataFormat::default_row_stride(size_t width) const {
129 8486 const auto padded_width = round_up_multiple(width, actual_block_width(width));
130
131
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6754 times.
✓ Branch 2 taken 1732 times.
✗ Branch 3 not taken.
8486 switch (_pack_format) {
132 case PackFormat::NONE:
133
2/2
✓ Branch 0 taken 200 times.
✓ Branch 1 taken 6554 times.
6754 return (_block_height > 0 ? _block_height : 1) * padded_width * data_type_size_in_bits(_data_type) / 8;
134
135 case PackFormat::BIAS_PER_ROW:
136 KAI_ASSUME(_block_height > 0);
137 3464 return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + //
138 1732 _block_height * padded_width * data_type_size_in_bits(_data_type) / 8;
139
140 case PackFormat::QUANTIZE_PER_ROW:
141 KAI_ASSUME(_block_height > 0);
142 return _block_height * data_type_size_in_bits(_zero_point_dt) / 8 + //
143 _block_height * padded_width * data_type_size_in_bits(_data_type) / 8 + //
144 _block_height * data_type_size_in_bits(_scale_dt) / 8;
145
146 default:
147 KAI_ERROR("Unsupported packing format!");
148 }
149 8486 }
150
151 4138 uintptr_t DataFormat::default_offset_in_bytes(size_t row, size_t col, size_t width) const {
152 4138 const auto row_stride = default_row_stride(width);
153 4138 const auto block_width = scheduler_block_width(width);
154
155 KAI_ASSERT(col % block_width == 0);
156
157
2/3
✗ Branch 0 not taken.
✓ Branch 1 taken 1264 times.
✓ Branch 2 taken 2874 times.
4138 switch (_pack_format) {
158 case PackFormat::NONE:
159
2/2
✓ Branch 0 taken 132 times.
✓ Branch 1 taken 2742 times.
2874 return row * row_stride / (_block_height > 0 ? _block_height : 1) +
160 2874 col * data_type_size_in_bits(_data_type) / 8;
161
162 case PackFormat::BIAS_PER_ROW:
163 case PackFormat::QUANTIZE_PER_ROW:
164 KAI_ASSUME(row % _block_height == 0);
165 KAI_ASSUME(col == 0);
166 1264 return (row / _block_height) * row_stride;
167
168 default:
169 KAI_ERROR("Unsupported packing format!");
170 }
171 4138 }
172
173 1332 size_t DataFormat::default_size_in_bytes(size_t height, size_t width) const {
174
2/2
✓ Branch 0 taken 536 times.
✓ Branch 1 taken 796 times.
1332 const auto num_rows = _block_height > 0 ? (height + _block_height - 1) / _block_height : height;
175 1332 const auto block_stride = default_row_stride(width);
176 2664 return num_rows * block_stride;
177 1332 }
178
179 120417 size_t DataFormat::Hash::operator()(const DataFormat& format) const {
180 120417 return //
181 240834 (std::hash<DataType>{}(format._data_type) << 0) ^ //
182 240834 (std::hash<DataFormat::PackFormat>{}(format._pack_format) << 1) ^ //
183 240834 (std::hash<DataType>{}(format._scale_dt) << 2) ^ //
184 240834 (std::hash<DataType>{}(format._zero_point_dt) << 3) ^ //
185 240834 (std::hash<size_t>{}(format._block_height) << 4) ^ //
186 240834 (std::hash<size_t>{}(format._block_width) << 5) ^ //
187 240834 (std::hash<size_t>{}(format._subblock_height) << 6) ^ //
188 120417 (std::hash<size_t>{}(format._subblock_width) << 7); //
189 }
190
191 } // namespace kai::test
192