KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 58.6% 136 / 0 / 232
Functions: 26.7% 4 / 0 / 15
Branches: 26.4% 104 / 0 / 394

test/nextgen/format/block2d_row_format.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/nextgen/format/block2d_row_format.hpp"
8
9 #include <algorithm>
10 #include <array>
11 #include <cstddef>
12 #include <ostream>
13 #include <vector>
14
15 #include "test/common/assert.hpp"
16 #include "test/common/buffer.hpp"
17 #include "test/common/compare.hpp"
18 #include "test/common/data_type.hpp"
19 #include "test/common/round.hpp"
20 #include "test/common/span.hpp"
21 #include "test/nextgen/common/random.hpp"
22 #include "test/nextgen/format/format.hpp"
23 #include "test/nextgen/reference/compare.hpp"
24 #include "test/nextgen/reference/pack.hpp"
25 #include "test/nextgen/reference/print.hpp"
26
27 namespace kai::test {
28
29 4400 size_t Block2dRowFormat::compute_offset(Span<const size_t> shape, Span<const size_t> indices) const {
30
1/4
✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
4400 KAI_TEST_ASSERT(shape.size() == 2);
31
1/4
✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
4400 KAI_TEST_ASSERT(shape.size() == indices.size());
32
33 4400 const size_t height = shape.at(0);
34 4400 const size_t width = shape.at(1);
35
36 4400 const size_t row = indices.at(0);
37 4400 const size_t col = indices.at(1);
38
39
1/4
✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
4400 KAI_TEST_ASSERT(row < height);
40
1/4
✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
4400 KAI_TEST_ASSERT(col < width);
41
42
1/4
✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
4400 KAI_TEST_ASSERT(row % m_block_height == 0);
43
1/4
✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
4400 KAI_TEST_ASSERT(col % m_block_width == 0);
44
45
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4400 times.
4400 const bool has_per_row_component = !m_pre_dtypes.empty() || !m_post_dtypes.empty();
46
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4400 times.
4400 if (has_per_row_component) {
47
1/4
✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
4400 KAI_TEST_ASSERT(col == 0);
48 4400 }
49
50 4400 const size_t block_row = row / m_block_height;
51 4400 const size_t block_col = col / m_block_width;
52
53 4400 const size_t block_size = m_block_height * m_block_width * data_type_size_in_bits(m_dtype) / 8;
54 4400 const size_t num_blocks_per_row = round_up_multiple(width, m_width_align) / m_block_width;
55
56
1/2
✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
4400 if (has_per_row_component) {
57 4400 size_t block_row_size = block_size * num_blocks_per_row;
58
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4400 times.
4400 for (const DataType dtype : m_pre_dtypes) {
59 block_row_size += m_block_height * data_type_size_in_bits(dtype) / 8;
60 }
61
2/2
✓ Branch 0 taken 11000 times.
✓ Branch 1 taken 4400 times.
15400 for (const DataType dtype : m_post_dtypes) {
62 11000 block_row_size += m_block_height * data_type_size_in_bits(dtype) / 8;
63 11000 }
64
65 4400 return block_row * block_row_size;
66 4400 } else {
67 return (block_row * num_blocks_per_row + block_col) * block_size;
68 }
69 4400 }
70
71 2000 size_t Block2dRowFormat::compute_size(Span<const size_t> shape) const {
72
1/4
✓ Branch 0 taken 2000 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
2000 KAI_TEST_ASSERT(shape.size() == 2);
73
74 2000 const size_t height = shape.at(0);
75 2000 const size_t width = shape.at(1);
76
77 2000 const size_t padded_height = round_up_multiple(height, m_block_height);
78
79 2000 const size_t size = compute_offset({padded_height + m_block_height, width}, {padded_height, 0});
80 4000 return size;
81 2000 }
82
83 Buffer Block2dRowFormat::generate_random([[maybe_unused]] Span<const size_t> shape, [[maybe_unused]] Rng& rng) const {
84 KAI_TEST_ERROR("Not supported!");
85 }
86
87 400 Buffer Block2dRowFormat::pack(Span<const size_t> shape, Span<const Span<const std::byte>> buffers) const {
88
1/4
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
400 KAI_TEST_ASSERT(shape.size() == 2);
89
90 400 const size_t height = shape.at(0);
91 400 const size_t width = shape.at(1);
92 400 const size_t num_block_rows = round_up_division(height, m_block_height);
93
94 400 const size_t packed_size = compute_size(shape);
95 400 Buffer packed_buffer(packed_size, 0);
96
1/2
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
400 Span<std::byte> packed_data(packed_buffer);
97
98
1/2
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
400 const PackBlock2dFn pack_data_fn = make_pack_block2d(m_dtype);
99
100 400 const size_t num_pres = m_pre_dtypes.size();
101 400 std::vector<Span<const std::byte>> pre_buffers;
102
1/2
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
400 pre_buffers.reserve(num_pres);
103
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 400 times.
400 for (size_t i = 0; i < num_pres; ++i) {
104 pre_buffers.emplace_back(buffers.at(i));
105 }
106
107
1/2
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
400 Span<const std::byte> data_buffer = buffers.at(num_pres);
108
109 400 const size_t num_posts = m_post_dtypes.size();
110 400 std::vector<Span<const std::byte>> post_buffers;
111
1/2
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
400 post_buffers.reserve(num_posts);
112
2/2
✓ Branch 0 taken 1000 times.
✓ Branch 1 taken 400 times.
1400 for (size_t i = 0; i < num_posts; ++i) {
113
2/4
✓ Branch 0 taken 1000 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1000 times.
✗ Branch 3 not taken.
1000 post_buffers.emplace_back(buffers.at(num_pres + 1 + i));
114 1000 }
115
116
2/2
✓ Branch 0 taken 9162 times.
✓ Branch 1 taken 400 times.
9562 for (size_t block_row = 0; block_row < num_block_rows; ++block_row) {
117
1/2
✓ Branch 0 taken 9162 times.
✗ Branch 1 not taken.
9162 const size_t remaining_height = std::min(m_block_height, height - block_row * m_block_height);
118
119
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 9162 times.
9162 for (size_t i = 0; i < num_pres; ++i) {
120 const size_t copy_size = remaining_height * data_type_size_in_bits(m_pre_dtypes.at(i)) / 8;
121 Span<const std::byte>& data = pre_buffers.at(i);
122
123 std::copy_n(data.begin(), copy_size, packed_data.begin());
124
125 data = data.subspan(copy_size);
126 packed_data = packed_data.subspan(m_block_height * data_type_size_in_bits(m_pre_dtypes.at(i)) / 8);
127 }
128
129 {
130
4/8
✓ Branch 0 taken 9162 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9162 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9162 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9162 times.
✗ Branch 7 not taken.
36648 const size_t size = pack_data_fn(
131 18324 m_block_height, m_block_width, m_width_align, m_pad_right_same, remaining_height, width, packed_data,
132 9162 data_buffer);
133 9162 data_buffer =
134
3/6
✓ Branch 0 taken 9162 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9162 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9162 times.
✗ Branch 5 not taken.
9162 data_buffer.subspan(remaining_height * round_up_division(width * data_type_size_in_bits(m_dtype), 8));
135
1/2
✓ Branch 0 taken 9162 times.
✗ Branch 1 not taken.
9162 packed_data = packed_data.subspan(size);
136 9162 }
137
138
2/2
✓ Branch 0 taken 9162 times.
✓ Branch 1 taken 18685 times.
27847 for (size_t i = 0; i < num_posts; ++i) {
139
2/4
✓ Branch 0 taken 18685 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18685 times.
✗ Branch 3 not taken.
18685 const size_t copy_size = remaining_height * data_type_size_in_bits(m_post_dtypes.at(i)) / 8;
140
1/2
✓ Branch 0 taken 18685 times.
✗ Branch 1 not taken.
18685 Span<const std::byte>& data = post_buffers.at(i);
141
142
1/2
✓ Branch 0 taken 18685 times.
✗ Branch 1 not taken.
18685 std::copy_n(data.begin(), copy_size, packed_data.begin());
143
144
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 18685 times.
18685 data = data.subspan(copy_size);
145
3/6
✓ Branch 0 taken 18685 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18685 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18685 times.
✗ Branch 5 not taken.
18685 packed_data = packed_data.subspan(m_block_height * data_type_size_in_bits(m_post_dtypes.at(i)) / 8);
146 18685 }
147 9162 }
148
149
1/6
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
400 KAI_TEST_ASSERT(data_buffer.empty());
150
1/4
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
400 KAI_TEST_ASSERT(packed_data.empty());
151
152 400 return packed_buffer;
153 400 }
154
155 1200 bool Block2dRowFormat::compare(
156 Span<const size_t> shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape,
157 Span<const std::byte> imp_buffer, Span<const std::byte> ref_buffer, MismatchHandler& handler) const {
158
1/4
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1200 KAI_TEST_ASSERT(shape.size() == 2);
159
1/4
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1200 KAI_TEST_ASSERT(shape.size() == tile_coords.size());
160
1/4
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1200 KAI_TEST_ASSERT(shape.size() == tile_shape.size());
161
162 1200 const size_t height = shape.at(0);
163 1200 const size_t width = shape.at(1);
164
165 1200 const size_t tile_row = tile_coords.at(0);
166 1200 const size_t tile_col = tile_coords.at(1);
167
168 1200 const size_t tile_height = tile_shape.at(0);
169 1200 size_t tile_width = tile_shape.at(1);
170
171
1/4
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1200 KAI_TEST_ASSERT(tile_row % m_block_height == 0);
172
1/4
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1200 KAI_TEST_ASSERT(tile_col % m_block_width == 0);
173
3/6
✓ Branch 0 taken 314 times.
✓ Branch 1 taken 886 times.
✓ Branch 2 taken 314 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
1200 KAI_TEST_ASSERT(tile_row + tile_height == height || (tile_row + tile_height) % m_block_height == 0);
174
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1200 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
1200 KAI_TEST_ASSERT(tile_col + tile_width == width || (tile_col + tile_width) % m_block_width == 0);
175
176
2/2
✓ Branch 0 taken 600 times.
✓ Branch 1 taken 600 times.
1200 if (m_pad_right_same) {
177 // If the tile includes the last block column, extends the tile to cover the right padding blocks.
178 // In SAME padding mode, these blocks contain data even though they are outside the tile of interests.
179 // If we don't extend the tile, there will be mismatched because these data points are outside the tile
180 // and the data is not 0.
181 600 tile_width = round_up_multiple(tile_col + tile_width, m_width_align) - tile_col;
182 600 }
183
184 1200 const size_t num_pre_rows = m_pre_dtypes.size();
185 1200 std::vector<CompareFn> pre_compares;
186
1/2
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
1200 pre_compares.reserve(num_pre_rows);
187
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1200 times.
1200 for (const DataType dtype : m_pre_dtypes) {
188 pre_compares.emplace_back(make_compare_plain_2d(dtype));
189 }
190
191
1/2
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
1200 const CompareFn data_compare = make_compare_plain_2d(m_dtype);
192
193 1200 const size_t num_post_rows = m_post_dtypes.size();
194 1200 std::vector<CompareFn> post_compares;
195
1/2
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
1200 post_compares.reserve(num_post_rows);
196
2/2
✓ Branch 0 taken 3000 times.
✓ Branch 1 taken 1200 times.
4200 for (const DataType dtype : m_post_dtypes) {
197
2/4
✓ Branch 0 taken 3000 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3000 times.
✗ Branch 3 not taken.
3000 post_compares.emplace_back(make_compare_plain_2d(dtype));
198 3000 }
199
200
1/2
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
1200 const size_t num_block_rows = round_up_division(height, m_block_height);
201
1/2
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
1200 const size_t num_block_cols_padded = round_up_multiple(width, m_width_align) / m_block_width;
202
2/4
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1200 times.
✗ Branch 3 not taken.
1200 const size_t block_size = round_up_division(m_block_height * m_block_width * data_type_size_in_bits(m_dtype), 8);
203
204 1200 const size_t tile_block_col = tile_col / m_block_width;
205
1/2
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
1200 const size_t tile_num_block_cols = round_up_division(tile_width, m_block_width);
206
207 1200 size_t num_checks = 0;
208
209
2/2
✓ Branch 0 taken 27486 times.
✓ Branch 1 taken 1200 times.
28686 for (size_t block_row = 0; block_row < num_block_rows; ++block_row) {
210 48400 const bool block_row_in_tile =
211
2/2
✓ Branch 0 taken 6572 times.
✓ Branch 1 taken 20914 times.
27486 tile_row <= block_row * m_block_height && tile_row + tile_height > block_row * m_block_height;
212
213
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 27486 times.
27486 for (size_t i = 0; i < num_pre_rows; ++i) {
214 num_checks += pre_compares.at(i)(
215 {1, m_block_height}, {0, 0}, {1, block_row_in_tile ? m_block_height : 0}, imp_buffer, ref_buffer,
216 [&](std::ostream& os, Span<const size_t> coords) {
217 os << "Mismatched at block row " << block_row << ", prefix per-row component " << i << ", element "
218 << coords.at(1);
219 },
220 handler);
221
222 imp_buffer = imp_buffer.subspan(m_block_height * data_type_size_in_bits(m_pre_dtypes.at(i)) / 8);
223 ref_buffer = ref_buffer.subspan(m_block_height * data_type_size_in_bits(m_pre_dtypes.at(i)) / 8);
224 }
225
226 {
227
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 27486 times.
137430 num_checks += data_compare(
228 54972 {num_block_cols_padded, m_block_height * m_block_width}, {tile_block_col, 0},
229
2/2
✓ Branch 0 taken 14237 times.
✓ Branch 1 taken 13249 times.
27486 {tile_num_block_cols, block_row_in_tile ? m_block_height * m_block_width : 0}, imp_buffer, ref_buffer,
230
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 27486 times.
27486 [&](std::ostream& os, Span<const size_t> coords) {
231 os << "Mismatched at block row " << block_row << ", blocked data, block column " << coords.at(0)
232 << ", element " << coords.at(1);
233 },
234 27486 handler);
235
236
1/2
✓ Branch 0 taken 27486 times.
✗ Branch 1 not taken.
27486 imp_buffer = imp_buffer.subspan(num_block_cols_padded * block_size);
237
1/2
✓ Branch 0 taken 27486 times.
✗ Branch 1 not taken.
27486 ref_buffer = ref_buffer.subspan(num_block_cols_padded * block_size);
238 }
239
240
2/2
✓ Branch 0 taken 56055 times.
✓ Branch 1 taken 27486 times.
83541 for (size_t i = 0; i < num_post_rows; ++i) {
241
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 56055 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 56055 times.
168165 num_checks += post_compares.at(i)(
242
6/6
✓ Branch 0 taken 29300 times.
✓ Branch 1 taken 26755 times.
✓ Branch 2 taken 29300 times.
✓ Branch 3 taken 26755 times.
✓ Branch 4 taken 29300 times.
✓ Branch 5 taken 26755 times.
168165 {1, m_block_height}, {0, 0}, {1, block_row_in_tile ? m_block_height : 0}, imp_buffer, ref_buffer,
243
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 56055 times.
56055 [&](std::ostream& os, Span<const size_t> coords) {
244 os << "Mismatched at block row " << block_row << ", postfix per-row component " << i << ", element "
245 << coords.at(1);
246 },
247 56055 handler);
248
249
3/6
✓ Branch 0 taken 56055 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 56055 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 56055 times.
✗ Branch 5 not taken.
56055 imp_buffer = imp_buffer.subspan(m_block_height * data_type_size_in_bits(m_post_dtypes.at(i)) / 8);
250
3/6
✓ Branch 0 taken 56055 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 56055 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 56055 times.
✗ Branch 5 not taken.
56055 ref_buffer = ref_buffer.subspan(m_block_height * data_type_size_in_bits(m_post_dtypes.at(i)) / 8);
251 56055 }
252 27486 }
253
254
1/6
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
1200 KAI_TEST_ASSERT(imp_buffer.empty());
255
1/4
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1200 KAI_TEST_ASSERT(ref_buffer.empty());
256
257
1/2
✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
1200 return handler.success(num_checks);
258 1200 }
259
260 void Block2dRowFormat::print(std::ostream& os, Span<const size_t> shape, Span<const std::byte> data) const {
261 if (shape.empty()) {
262 os << "None";
263 } else {
264 KAI_TEST_ASSERT(shape.size() == 2);
265
266 const size_t height = shape.at(0);
267 const size_t width = shape.at(1);
268
269 const PrintFn data_printer = make_print_array(m_dtype);
270
271 std::vector<PrintFn> pre_row_printers;
272 pre_row_printers.reserve(m_pre_dtypes.size());
273
274 for (const DataType dtype : m_pre_dtypes) {
275 pre_row_printers.emplace_back(make_print_array(dtype));
276 }
277
278 std::vector<PrintFn> post_row_printers;
279 post_row_printers.reserve(m_post_dtypes.size());
280
281 for (const DataType dtype : m_post_dtypes) {
282 post_row_printers.emplace_back(make_print_array(dtype));
283 }
284
285 const bool has_per_row_component = !m_pre_dtypes.empty() || !m_post_dtypes.empty();
286
287 const size_t num_block_rows = round_up_division(height, m_block_height);
288 const size_t num_block_cols_padded = round_up_multiple(width, m_width_align) / m_block_width;
289 const size_t block_size =
290 round_up_division(m_block_height * m_block_width * data_type_size_in_bits(m_dtype), 8);
291
292 os << "[\n";
293
294 for (size_t block_row = 0; block_row < num_block_rows; ++block_row) {
295 if (has_per_row_component) {
296 os << " {\n";
297
298 for (size_t i = 0; i < m_pre_dtypes.size(); ++i) {
299 os << " \"row_data_" << i << "\": ";
300 pre_row_printers.at(i)(os, std::array{m_block_height}, data, 0);
301 data =
302 data.subspan(round_up_division(m_block_height * data_type_size_in_bits(m_pre_dtypes.at(i)), 8));
303 os << ",\n";
304 }
305
306 os << " \"data\": [\n";
307
308 for (size_t i = 0; i < num_block_cols_padded; ++i) {
309 data_printer(os, std::array{m_block_height * m_block_width}, data, 3);
310 data = data.subspan(block_size);
311 os << ",\n";
312 }
313
314 os << " ],\n";
315
316 for (size_t i = 0; i < m_post_dtypes.size(); ++i) {
317 os << " \"row_data_" << i + m_pre_dtypes.size() << "\": ";
318 post_row_printers.at(i)(os, std::array{m_block_height}, data, 0);
319 data = data.subspan(
320 round_up_division(m_block_height * data_type_size_in_bits(m_post_dtypes.at(i)), 8));
321 os << ",\n";
322 }
323
324 os << " },\n";
325 } else {
326 for (size_t i = 0; i < num_block_cols_padded; ++i) {
327 data_printer(os, std::array{m_block_height * m_block_width}, data, 1);
328 data = data.subspan(block_size);
329 os << ",\n";
330 }
331 }
332 }
333
334 KAI_TEST_ASSERT(data.empty());
335
336 os << "]";
337 }
338 }
339
340 bool Block2dRowFormat::operator==(const Format& other) const {
341 const auto* rhs = dynamic_cast<const Block2dRowFormat*>(&other);
342 return rhs != nullptr && m_dtype == rhs->m_dtype;
343 }
344
345 } // namespace kai::test
346