test/nextgen/format/block2d_row_format.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/nextgen/format/block2d_row_format.hpp" | ||
| 8 | |||
| 9 | #include <algorithm> | ||
| 10 | #include <array> | ||
| 11 | #include <cstddef> | ||
| 12 | #include <ostream> | ||
| 13 | #include <vector> | ||
| 14 | |||
| 15 | #include "test/common/assert.hpp" | ||
| 16 | #include "test/common/buffer.hpp" | ||
| 17 | #include "test/common/compare.hpp" | ||
| 18 | #include "test/common/data_type.hpp" | ||
| 19 | #include "test/common/round.hpp" | ||
| 20 | #include "test/common/span.hpp" | ||
| 21 | #include "test/nextgen/common/random.hpp" | ||
| 22 | #include "test/nextgen/format/format.hpp" | ||
| 23 | #include "test/nextgen/reference/compare.hpp" | ||
| 24 | #include "test/nextgen/reference/pack.hpp" | ||
| 25 | #include "test/nextgen/reference/print.hpp" | ||
| 26 | |||
| 27 | namespace kai::test { | ||
| 28 | |||
| 29 | 4400 | size_t Block2dRowFormat::compute_offset(Span<const size_t> shape, Span<const size_t> indices) const { | |
| 30 |
1/4✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
4400 | KAI_TEST_ASSERT(shape.size() == 2); |
| 31 |
1/4✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
4400 | KAI_TEST_ASSERT(shape.size() == indices.size()); |
| 32 | |||
| 33 | 4400 | const size_t height = shape.at(0); | |
| 34 | 4400 | const size_t width = shape.at(1); | |
| 35 | |||
| 36 | 4400 | const size_t row = indices.at(0); | |
| 37 | 4400 | const size_t col = indices.at(1); | |
| 38 | |||
| 39 |
1/4✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
4400 | KAI_TEST_ASSERT(row < height); |
| 40 |
1/4✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
4400 | KAI_TEST_ASSERT(col < width); |
| 41 | |||
| 42 |
1/4✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
4400 | KAI_TEST_ASSERT(row % m_block_height == 0); |
| 43 |
1/4✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
4400 | KAI_TEST_ASSERT(col % m_block_width == 0); |
| 44 | |||
| 45 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4400 times.
|
4400 | const bool has_per_row_component = !m_pre_dtypes.empty() || !m_post_dtypes.empty(); |
| 46 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4400 times.
|
4400 | if (has_per_row_component) { |
| 47 |
1/4✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
4400 | KAI_TEST_ASSERT(col == 0); |
| 48 | 4400 | } | |
| 49 | |||
| 50 | 4400 | const size_t block_row = row / m_block_height; | |
| 51 | 4400 | const size_t block_col = col / m_block_width; | |
| 52 | |||
| 53 | 4400 | const size_t block_size = m_block_height * m_block_width * data_type_size_in_bits(m_dtype) / 8; | |
| 54 | 4400 | const size_t num_blocks_per_row = round_up_multiple(width, m_width_align) / m_block_width; | |
| 55 | |||
| 56 |
1/2✓ Branch 0 taken 4400 times.
✗ Branch 1 not taken.
|
4400 | if (has_per_row_component) { |
| 57 | 4400 | size_t block_row_size = block_size * num_blocks_per_row; | |
| 58 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4400 times.
|
4400 | for (const DataType dtype : m_pre_dtypes) { |
| 59 | ✗ | block_row_size += m_block_height * data_type_size_in_bits(dtype) / 8; | |
| 60 | ✗ | } | |
| 61 |
2/2✓ Branch 0 taken 11000 times.
✓ Branch 1 taken 4400 times.
|
15400 | for (const DataType dtype : m_post_dtypes) { |
| 62 | 11000 | block_row_size += m_block_height * data_type_size_in_bits(dtype) / 8; | |
| 63 | 11000 | } | |
| 64 | |||
| 65 | 4400 | return block_row * block_row_size; | |
| 66 | 4400 | } else { | |
| 67 | ✗ | return (block_row * num_blocks_per_row + block_col) * block_size; | |
| 68 | } | ||
| 69 | 4400 | } | |
| 70 | |||
| 71 | 2000 | size_t Block2dRowFormat::compute_size(Span<const size_t> shape) const { | |
| 72 |
1/4✓ Branch 0 taken 2000 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
2000 | KAI_TEST_ASSERT(shape.size() == 2); |
| 73 | |||
| 74 | 2000 | const size_t height = shape.at(0); | |
| 75 | 2000 | const size_t width = shape.at(1); | |
| 76 | |||
| 77 | 2000 | const size_t padded_height = round_up_multiple(height, m_block_height); | |
| 78 | |||
| 79 | 2000 | const size_t size = compute_offset({padded_height + m_block_height, width}, {padded_height, 0}); | |
| 80 | 4000 | return size; | |
| 81 | 2000 | } | |
| 82 | |||
| 83 | ✗ | Buffer Block2dRowFormat::generate_random([[maybe_unused]] Span<const size_t> shape, [[maybe_unused]] Rng& rng) const { | |
| 84 | ✗ | KAI_TEST_ERROR("Not supported!"); | |
| 85 | ✗ | } | |
| 86 | |||
| 87 | 400 | Buffer Block2dRowFormat::pack(Span<const size_t> shape, Span<const Span<const std::byte>> buffers) const { | |
| 88 |
1/4✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
400 | KAI_TEST_ASSERT(shape.size() == 2); |
| 89 | |||
| 90 | 400 | const size_t height = shape.at(0); | |
| 91 | 400 | const size_t width = shape.at(1); | |
| 92 | 400 | const size_t num_block_rows = round_up_division(height, m_block_height); | |
| 93 | |||
| 94 | 400 | const size_t packed_size = compute_size(shape); | |
| 95 | 400 | Buffer packed_buffer(packed_size, 0); | |
| 96 |
1/2✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
|
400 | Span<std::byte> packed_data(packed_buffer); |
| 97 | |||
| 98 |
1/2✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
|
400 | const PackBlock2dFn pack_data_fn = make_pack_block2d(m_dtype); |
| 99 | |||
| 100 | 400 | const size_t num_pres = m_pre_dtypes.size(); | |
| 101 | 400 | std::vector<Span<const std::byte>> pre_buffers; | |
| 102 |
1/2✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
|
400 | pre_buffers.reserve(num_pres); |
| 103 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 400 times.
|
400 | for (size_t i = 0; i < num_pres; ++i) { |
| 104 | ✗ | pre_buffers.emplace_back(buffers.at(i)); | |
| 105 | ✗ | } | |
| 106 | |||
| 107 |
1/2✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
|
400 | Span<const std::byte> data_buffer = buffers.at(num_pres); |
| 108 | |||
| 109 | 400 | const size_t num_posts = m_post_dtypes.size(); | |
| 110 | 400 | std::vector<Span<const std::byte>> post_buffers; | |
| 111 |
1/2✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
|
400 | post_buffers.reserve(num_posts); |
| 112 |
2/2✓ Branch 0 taken 1000 times.
✓ Branch 1 taken 400 times.
|
1400 | for (size_t i = 0; i < num_posts; ++i) { |
| 113 |
2/4✓ Branch 0 taken 1000 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1000 times.
✗ Branch 3 not taken.
|
1000 | post_buffers.emplace_back(buffers.at(num_pres + 1 + i)); |
| 114 | 1000 | } | |
| 115 | |||
| 116 |
2/2✓ Branch 0 taken 9162 times.
✓ Branch 1 taken 400 times.
|
9562 | for (size_t block_row = 0; block_row < num_block_rows; ++block_row) { |
| 117 |
1/2✓ Branch 0 taken 9162 times.
✗ Branch 1 not taken.
|
9162 | const size_t remaining_height = std::min(m_block_height, height - block_row * m_block_height); |
| 118 | |||
| 119 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 9162 times.
|
9162 | for (size_t i = 0; i < num_pres; ++i) { |
| 120 | ✗ | const size_t copy_size = remaining_height * data_type_size_in_bits(m_pre_dtypes.at(i)) / 8; | |
| 121 | ✗ | Span<const std::byte>& data = pre_buffers.at(i); | |
| 122 | |||
| 123 | ✗ | std::copy_n(data.begin(), copy_size, packed_data.begin()); | |
| 124 | |||
| 125 | ✗ | data = data.subspan(copy_size); | |
| 126 | ✗ | packed_data = packed_data.subspan(m_block_height * data_type_size_in_bits(m_pre_dtypes.at(i)) / 8); | |
| 127 | ✗ | } | |
| 128 | |||
| 129 | { | ||
| 130 |
4/8✓ Branch 0 taken 9162 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9162 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9162 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9162 times.
✗ Branch 7 not taken.
|
36648 | const size_t size = pack_data_fn( |
| 131 | 18324 | m_block_height, m_block_width, m_width_align, m_pad_right_same, remaining_height, width, packed_data, | |
| 132 | 9162 | data_buffer); | |
| 133 | 9162 | data_buffer = | |
| 134 |
3/6✓ Branch 0 taken 9162 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9162 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9162 times.
✗ Branch 5 not taken.
|
9162 | data_buffer.subspan(remaining_height * round_up_division(width * data_type_size_in_bits(m_dtype), 8)); |
| 135 |
1/2✓ Branch 0 taken 9162 times.
✗ Branch 1 not taken.
|
9162 | packed_data = packed_data.subspan(size); |
| 136 | 9162 | } | |
| 137 | |||
| 138 |
2/2✓ Branch 0 taken 9162 times.
✓ Branch 1 taken 18685 times.
|
27847 | for (size_t i = 0; i < num_posts; ++i) { |
| 139 |
2/4✓ Branch 0 taken 18685 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18685 times.
✗ Branch 3 not taken.
|
18685 | const size_t copy_size = remaining_height * data_type_size_in_bits(m_post_dtypes.at(i)) / 8; |
| 140 |
1/2✓ Branch 0 taken 18685 times.
✗ Branch 1 not taken.
|
18685 | Span<const std::byte>& data = post_buffers.at(i); |
| 141 | |||
| 142 |
1/2✓ Branch 0 taken 18685 times.
✗ Branch 1 not taken.
|
18685 | std::copy_n(data.begin(), copy_size, packed_data.begin()); |
| 143 | |||
| 144 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 18685 times.
|
18685 | data = data.subspan(copy_size); |
| 145 |
3/6✓ Branch 0 taken 18685 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18685 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18685 times.
✗ Branch 5 not taken.
|
18685 | packed_data = packed_data.subspan(m_block_height * data_type_size_in_bits(m_post_dtypes.at(i)) / 8); |
| 146 | 18685 | } | |
| 147 | 9162 | } | |
| 148 | |||
| 149 |
1/6✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
400 | KAI_TEST_ASSERT(data_buffer.empty()); |
| 150 |
1/4✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
400 | KAI_TEST_ASSERT(packed_data.empty()); |
| 151 | |||
| 152 | 400 | return packed_buffer; | |
| 153 | 400 | } | |
| 154 | |||
| 155 | 1200 | bool Block2dRowFormat::compare( | |
| 156 | Span<const size_t> shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape, | ||
| 157 | Span<const std::byte> imp_buffer, Span<const std::byte> ref_buffer, MismatchHandler& handler) const { | ||
| 158 |
1/4✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1200 | KAI_TEST_ASSERT(shape.size() == 2); |
| 159 |
1/4✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1200 | KAI_TEST_ASSERT(shape.size() == tile_coords.size()); |
| 160 |
1/4✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1200 | KAI_TEST_ASSERT(shape.size() == tile_shape.size()); |
| 161 | |||
| 162 | 1200 | const size_t height = shape.at(0); | |
| 163 | 1200 | const size_t width = shape.at(1); | |
| 164 | |||
| 165 | 1200 | const size_t tile_row = tile_coords.at(0); | |
| 166 | 1200 | const size_t tile_col = tile_coords.at(1); | |
| 167 | |||
| 168 | 1200 | const size_t tile_height = tile_shape.at(0); | |
| 169 | 1200 | size_t tile_width = tile_shape.at(1); | |
| 170 | |||
| 171 |
1/4✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1200 | KAI_TEST_ASSERT(tile_row % m_block_height == 0); |
| 172 |
1/4✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1200 | KAI_TEST_ASSERT(tile_col % m_block_width == 0); |
| 173 |
3/6✓ Branch 0 taken 314 times.
✓ Branch 1 taken 886 times.
✓ Branch 2 taken 314 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
1200 | KAI_TEST_ASSERT(tile_row + tile_height == height || (tile_row + tile_height) % m_block_height == 0); |
| 174 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1200 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
1200 | KAI_TEST_ASSERT(tile_col + tile_width == width || (tile_col + tile_width) % m_block_width == 0); |
| 175 | |||
| 176 |
2/2✓ Branch 0 taken 600 times.
✓ Branch 1 taken 600 times.
|
1200 | if (m_pad_right_same) { |
| 177 | // If the tile includes the last block column, extends the tile to cover the right padding blocks. | ||
| 178 | // In SAME padding mode, these blocks contain data even though they are outside the tile of interests. | ||
| 179 | // If we don't extend the tile, there will be mismatched because these data points are outside the tile | ||
| 180 | // and the data is not 0. | ||
| 181 | 600 | tile_width = round_up_multiple(tile_col + tile_width, m_width_align) - tile_col; | |
| 182 | 600 | } | |
| 183 | |||
| 184 | 1200 | const size_t num_pre_rows = m_pre_dtypes.size(); | |
| 185 | 1200 | std::vector<CompareFn> pre_compares; | |
| 186 |
1/2✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
|
1200 | pre_compares.reserve(num_pre_rows); |
| 187 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1200 times.
|
1200 | for (const DataType dtype : m_pre_dtypes) { |
| 188 | ✗ | pre_compares.emplace_back(make_compare_plain_2d(dtype)); | |
| 189 | ✗ | } | |
| 190 | |||
| 191 |
1/2✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
|
1200 | const CompareFn data_compare = make_compare_plain_2d(m_dtype); |
| 192 | |||
| 193 | 1200 | const size_t num_post_rows = m_post_dtypes.size(); | |
| 194 | 1200 | std::vector<CompareFn> post_compares; | |
| 195 |
1/2✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
|
1200 | post_compares.reserve(num_post_rows); |
| 196 |
2/2✓ Branch 0 taken 3000 times.
✓ Branch 1 taken 1200 times.
|
4200 | for (const DataType dtype : m_post_dtypes) { |
| 197 |
2/4✓ Branch 0 taken 3000 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3000 times.
✗ Branch 3 not taken.
|
3000 | post_compares.emplace_back(make_compare_plain_2d(dtype)); |
| 198 | 3000 | } | |
| 199 | |||
| 200 |
1/2✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
|
1200 | const size_t num_block_rows = round_up_division(height, m_block_height); |
| 201 |
1/2✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
|
1200 | const size_t num_block_cols_padded = round_up_multiple(width, m_width_align) / m_block_width; |
| 202 |
2/4✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1200 times.
✗ Branch 3 not taken.
|
1200 | const size_t block_size = round_up_division(m_block_height * m_block_width * data_type_size_in_bits(m_dtype), 8); |
| 203 | |||
| 204 | 1200 | const size_t tile_block_col = tile_col / m_block_width; | |
| 205 |
1/2✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
|
1200 | const size_t tile_num_block_cols = round_up_division(tile_width, m_block_width); |
| 206 | |||
| 207 | 1200 | size_t num_checks = 0; | |
| 208 | |||
| 209 |
2/2✓ Branch 0 taken 27486 times.
✓ Branch 1 taken 1200 times.
|
28686 | for (size_t block_row = 0; block_row < num_block_rows; ++block_row) { |
| 210 | 48400 | const bool block_row_in_tile = | |
| 211 |
2/2✓ Branch 0 taken 6572 times.
✓ Branch 1 taken 20914 times.
|
27486 | tile_row <= block_row * m_block_height && tile_row + tile_height > block_row * m_block_height; |
| 212 | |||
| 213 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 27486 times.
|
27486 | for (size_t i = 0; i < num_pre_rows; ++i) { |
| 214 | ✗ | num_checks += pre_compares.at(i)( | |
| 215 | ✗ | {1, m_block_height}, {0, 0}, {1, block_row_in_tile ? m_block_height : 0}, imp_buffer, ref_buffer, | |
| 216 | ✗ | [&](std::ostream& os, Span<const size_t> coords) { | |
| 217 | ✗ | os << "Mismatched at block row " << block_row << ", prefix per-row component " << i << ", element " | |
| 218 | ✗ | << coords.at(1); | |
| 219 | ✗ | }, | |
| 220 | ✗ | handler); | |
| 221 | |||
| 222 | ✗ | imp_buffer = imp_buffer.subspan(m_block_height * data_type_size_in_bits(m_pre_dtypes.at(i)) / 8); | |
| 223 | ✗ | ref_buffer = ref_buffer.subspan(m_block_height * data_type_size_in_bits(m_pre_dtypes.at(i)) / 8); | |
| 224 | ✗ | } | |
| 225 | |||
| 226 | { | ||
| 227 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 27486 times.
|
137430 | num_checks += data_compare( |
| 228 | 54972 | {num_block_cols_padded, m_block_height * m_block_width}, {tile_block_col, 0}, | |
| 229 |
2/2✓ Branch 0 taken 14237 times.
✓ Branch 1 taken 13249 times.
|
27486 | {tile_num_block_cols, block_row_in_tile ? m_block_height * m_block_width : 0}, imp_buffer, ref_buffer, |
| 230 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 27486 times.
|
27486 | [&](std::ostream& os, Span<const size_t> coords) { |
| 231 | ✗ | os << "Mismatched at block row " << block_row << ", blocked data, block column " << coords.at(0) | |
| 232 | ✗ | << ", element " << coords.at(1); | |
| 233 | ✗ | }, | |
| 234 | 27486 | handler); | |
| 235 | |||
| 236 |
1/2✓ Branch 0 taken 27486 times.
✗ Branch 1 not taken.
|
27486 | imp_buffer = imp_buffer.subspan(num_block_cols_padded * block_size); |
| 237 |
1/2✓ Branch 0 taken 27486 times.
✗ Branch 1 not taken.
|
27486 | ref_buffer = ref_buffer.subspan(num_block_cols_padded * block_size); |
| 238 | } | ||
| 239 | |||
| 240 |
2/2✓ Branch 0 taken 56055 times.
✓ Branch 1 taken 27486 times.
|
83541 | for (size_t i = 0; i < num_post_rows; ++i) { |
| 241 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 56055 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 56055 times.
|
168165 | num_checks += post_compares.at(i)( |
| 242 |
6/6✓ Branch 0 taken 29300 times.
✓ Branch 1 taken 26755 times.
✓ Branch 2 taken 29300 times.
✓ Branch 3 taken 26755 times.
✓ Branch 4 taken 29300 times.
✓ Branch 5 taken 26755 times.
|
168165 | {1, m_block_height}, {0, 0}, {1, block_row_in_tile ? m_block_height : 0}, imp_buffer, ref_buffer, |
| 243 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 56055 times.
|
56055 | [&](std::ostream& os, Span<const size_t> coords) { |
| 244 | ✗ | os << "Mismatched at block row " << block_row << ", postfix per-row component " << i << ", element " | |
| 245 | ✗ | << coords.at(1); | |
| 246 | ✗ | }, | |
| 247 | 56055 | handler); | |
| 248 | |||
| 249 |
3/6✓ Branch 0 taken 56055 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 56055 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 56055 times.
✗ Branch 5 not taken.
|
56055 | imp_buffer = imp_buffer.subspan(m_block_height * data_type_size_in_bits(m_post_dtypes.at(i)) / 8); |
| 250 |
3/6✓ Branch 0 taken 56055 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 56055 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 56055 times.
✗ Branch 5 not taken.
|
56055 | ref_buffer = ref_buffer.subspan(m_block_height * data_type_size_in_bits(m_post_dtypes.at(i)) / 8); |
| 251 | 56055 | } | |
| 252 | 27486 | } | |
| 253 | |||
| 254 |
1/6✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
1200 | KAI_TEST_ASSERT(imp_buffer.empty()); |
| 255 |
1/4✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1200 | KAI_TEST_ASSERT(ref_buffer.empty()); |
| 256 | |||
| 257 |
1/2✓ Branch 0 taken 1200 times.
✗ Branch 1 not taken.
|
1200 | return handler.success(num_checks); |
| 258 | 1200 | } | |
| 259 | |||
| 260 | ✗ | void Block2dRowFormat::print(std::ostream& os, Span<const size_t> shape, Span<const std::byte> data) const { | |
| 261 | ✗ | if (shape.empty()) { | |
| 262 | ✗ | os << "None"; | |
| 263 | ✗ | } else { | |
| 264 | ✗ | KAI_TEST_ASSERT(shape.size() == 2); | |
| 265 | |||
| 266 | ✗ | const size_t height = shape.at(0); | |
| 267 | ✗ | const size_t width = shape.at(1); | |
| 268 | |||
| 269 | ✗ | const PrintFn data_printer = make_print_array(m_dtype); | |
| 270 | |||
| 271 | ✗ | std::vector<PrintFn> pre_row_printers; | |
| 272 | ✗ | pre_row_printers.reserve(m_pre_dtypes.size()); | |
| 273 | |||
| 274 | ✗ | for (const DataType dtype : m_pre_dtypes) { | |
| 275 | ✗ | pre_row_printers.emplace_back(make_print_array(dtype)); | |
| 276 | ✗ | } | |
| 277 | |||
| 278 | ✗ | std::vector<PrintFn> post_row_printers; | |
| 279 | ✗ | post_row_printers.reserve(m_post_dtypes.size()); | |
| 280 | |||
| 281 | ✗ | for (const DataType dtype : m_post_dtypes) { | |
| 282 | ✗ | post_row_printers.emplace_back(make_print_array(dtype)); | |
| 283 | ✗ | } | |
| 284 | |||
| 285 | ✗ | const bool has_per_row_component = !m_pre_dtypes.empty() || !m_post_dtypes.empty(); | |
| 286 | |||
| 287 | ✗ | const size_t num_block_rows = round_up_division(height, m_block_height); | |
| 288 | ✗ | const size_t num_block_cols_padded = round_up_multiple(width, m_width_align) / m_block_width; | |
| 289 | ✗ | const size_t block_size = | |
| 290 | ✗ | round_up_division(m_block_height * m_block_width * data_type_size_in_bits(m_dtype), 8); | |
| 291 | |||
| 292 | ✗ | os << "[\n"; | |
| 293 | |||
| 294 | ✗ | for (size_t block_row = 0; block_row < num_block_rows; ++block_row) { | |
| 295 | ✗ | if (has_per_row_component) { | |
| 296 | ✗ | os << " {\n"; | |
| 297 | |||
| 298 | ✗ | for (size_t i = 0; i < m_pre_dtypes.size(); ++i) { | |
| 299 | ✗ | os << " \"row_data_" << i << "\": "; | |
| 300 | ✗ | pre_row_printers.at(i)(os, std::array{m_block_height}, data, 0); | |
| 301 | ✗ | data = | |
| 302 | ✗ | data.subspan(round_up_division(m_block_height * data_type_size_in_bits(m_pre_dtypes.at(i)), 8)); | |
| 303 | ✗ | os << ",\n"; | |
| 304 | ✗ | } | |
| 305 | |||
| 306 | ✗ | os << " \"data\": [\n"; | |
| 307 | |||
| 308 | ✗ | for (size_t i = 0; i < num_block_cols_padded; ++i) { | |
| 309 | ✗ | data_printer(os, std::array{m_block_height * m_block_width}, data, 3); | |
| 310 | ✗ | data = data.subspan(block_size); | |
| 311 | ✗ | os << ",\n"; | |
| 312 | ✗ | } | |
| 313 | |||
| 314 | ✗ | os << " ],\n"; | |
| 315 | |||
| 316 | ✗ | for (size_t i = 0; i < m_post_dtypes.size(); ++i) { | |
| 317 | ✗ | os << " \"row_data_" << i + m_pre_dtypes.size() << "\": "; | |
| 318 | ✗ | post_row_printers.at(i)(os, std::array{m_block_height}, data, 0); | |
| 319 | ✗ | data = data.subspan( | |
| 320 | ✗ | round_up_division(m_block_height * data_type_size_in_bits(m_post_dtypes.at(i)), 8)); | |
| 321 | ✗ | os << ",\n"; | |
| 322 | ✗ | } | |
| 323 | |||
| 324 | ✗ | os << " },\n"; | |
| 325 | ✗ | } else { | |
| 326 | ✗ | for (size_t i = 0; i < num_block_cols_padded; ++i) { | |
| 327 | ✗ | data_printer(os, std::array{m_block_height * m_block_width}, data, 1); | |
| 328 | ✗ | data = data.subspan(block_size); | |
| 329 | ✗ | os << ",\n"; | |
| 330 | ✗ | } | |
| 331 | } | ||
| 332 | ✗ | } | |
| 333 | |||
| 334 | ✗ | KAI_TEST_ASSERT(data.empty()); | |
| 335 | |||
| 336 | ✗ | os << "]"; | |
| 337 | ✗ | } | |
| 338 | ✗ | } | |
| 339 | |||
| 340 | ✗ | bool Block2dRowFormat::operator==(const Format& other) const { | |
| 341 | ✗ | const auto* rhs = dynamic_cast<const Block2dRowFormat*>(&other); | |
| 342 | ✗ | return rhs != nullptr && m_dtype == rhs->m_dtype; | |
| 343 | ✗ | } | |
| 344 | |||
| 345 | } // namespace kai::test | ||
| 346 |