test/nextgen/format/block2d_row_format.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | #include <vector> | ||
| 11 | |||
| 12 | #include "test/common/assert.hpp" | ||
| 13 | #include "test/common/data_type.hpp" | ||
| 14 | #include "test/common/span.hpp" | ||
| 15 | #include "test/nextgen/common/random.hpp" | ||
| 16 | #include "test/nextgen/format/format.hpp" | ||
| 17 | |||
| 18 | namespace kai::test { | ||
| 19 | |||
| 20 | /// 2D blocked data with optional per-row values. | ||
| 21 | /// | ||
| 22 | /// Example: | ||
| 23 | /// Shape: (5, 8) | ||
| 24 | /// Block size: (2, 3) | ||
| 25 | /// Prefix per-row data 0: | ||
| 26 | /// a0 a1 a2 a3 a4 | ||
| 27 | /// Prefix per-row data 1: | ||
| 28 | /// b0 b1 b2 b3 b4 | ||
| 29 | /// Data: | ||
| 30 | /// v00 v01 v02 v03 v04 v05 v06 v07 | ||
| 31 | /// v10 v11 v12 v13 v14 v15 v16 v17 | ||
| 32 | /// v20 v21 v22 v23 v24 v25 v26 v27 | ||
| 33 | /// v30 v31 v32 v33 v34 v35 v36 v37 | ||
| 34 | /// v40 v41 v42 v43 v44 v45 v46 v47 | ||
| 35 | /// Postfix per-row data 0: | ||
| 36 | /// c0 c1 c2 c3 c4 | ||
| 37 | /// Postfix per-row data 1: | ||
| 38 | /// d0 d1 d2 d3 d4 | ||
| 39 | /// | ||
| 40 | /// Combined blocked data with per-row data: | ||
| 41 | /// +----+----+-------------+--------------+------------+----+----+ | ||
| 42 | /// | a0 | b0 | v00 v01 v02 | v03 v04 v05 | v06 v07 ___ | c0 | d0 | | ||
| 43 | /// | a1 | b1 | v10 v11 v12 | v13 v14 v15 | v16 v17 ___ | c1 | d1 | | ||
| 44 | /// +----+----+-------------+--------------+------------+----+----+ | ||
| 45 | /// | a2 | b2 | v20 v21 v22 | v23 v24 v25 | v26 v27 ___ | c2 | d2 | | ||
| 46 | /// | a3 | b3 | v30 v31 v32 | v33 v34 v35 | v36 v37 ___ | c3 | d3 | | ||
| 47 | /// +----+----+-------------+--------------+------------+----+----+ | ||
| 48 | /// | a4 | b4 | v40 v41 v42 | v43 v44 v45 | v46 v47 ___ | c4 | d4 | | ||
| 49 | /// | __ | __ | ___ ___ ___ | ___ ___ ___ | ___ ___ ___ | __ | __ | | ||
| 50 | /// +----+----+-------------+--------------+------------+----+----+ | ||
| 51 | /// | ||
| 52 | /// Packed data stream: | ||
| 53 | /// +-------+-------+-------------------------+-------------------------+-------------------------+-------+-------+ | ||
| 54 | /// | a0 a1 | b0 b1 | v00 v01 v02 v10 v11 v12 | v03 v04 v05 v13 v14 v15 | v06 v07 0 v16 v17 0 | c0 c1 | d0 d1 | | ||
| 55 | /// +-------+-------+-------------------------+-------------------------+-------------------------+-------+-------+ | ||
| 56 | /// | a2 a3 | b2 b3 | v20 v21 v22 v30 v31 v32 | v23 v24 v25 v33 v34 v35 | v26 v27 0 v36 v37 0 | c2 c3 | d2 d3 | | ||
| 57 | /// +-------+-------+-------------------------+-------------------------+-------------------------+-------+-------+ | ||
| 58 | /// | a4 0 | b4 0 | v40 v41 v42 0 0 0 | v43 v44 v45 0 0 0 | v46 v47 0 0 0 0 | c4 0 | d4 0 | | ||
| 59 | /// +-------+-------+-------------------------+-------------------------+-------------------------+-------+-------+ | ||
| 60 |
1/2✓ Branch 0 taken 818 times.
✗ Branch 1 not taken.
|
818 | class Block2dRowFormat : public Format { |
| 61 | public: | ||
| 62 | /// Creates a 2D blocked data with optional per-row values. | ||
| 63 | /// | ||
| 64 | /// @param[in] block_height The block height. | ||
| 65 | /// @param[in] block_width The block width. | ||
| 66 | /// @param[in] width_align The input data is padded so that the width is multiple of this value | ||
| 67 | /// before the data is packed. This value must be divisible by block width. | ||
| 68 | /// @param[in] pad_right_same Right padding with the last element instead of 0. | ||
| 69 | /// @param[in] dtype The data type. | ||
| 70 | /// @param[in] pre_dtypes The data type of each prefix per-row component. | ||
| 71 | /// @param[in] post_dtypes The data type of each postfix per-row component. | ||
| 72 | 32 | Block2dRowFormat( | |
| 73 | size_t block_height, size_t block_width, size_t width_align, bool pad_right_same, DataType dtype, | ||
| 74 | Span<const DataType> pre_dtypes, Span<const DataType> post_dtypes) : | ||
| 75 | 24 | m_block_height(block_height), | |
| 76 | 24 | m_block_width(block_width), | |
| 77 | 24 | m_width_align(width_align), | |
| 78 | 24 | m_pad_right_same(pad_right_same), | |
| 79 | 24 | m_dtype(dtype), | |
| 80 |
1/2✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
|
24 | m_pre_dtypes(pre_dtypes.begin(), pre_dtypes.end()), |
| 81 |
1/2✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
|
32 | m_post_dtypes(post_dtypes.begin(), post_dtypes.end()) { |
| 82 |
1/4✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
24 | KAI_TEST_ASSERT(width_align % block_width == 0); |
| 83 |
2/6✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
24 | KAI_TEST_ASSERT(block_height * block_width * data_type_size_in_bits(dtype) % 8 == 0); |
| 84 | |||
| 85 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
|
24 | for (const DataType pre_dtype : pre_dtypes) { |
| 86 | ✗ | KAI_TEST_ASSERT(data_type_size_in_bits(pre_dtype) % 8 == 0); | |
| 87 | ✗ | } | |
| 88 | |||
| 89 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 60 times.
|
84 | for (const DataType post_dtype : post_dtypes) { |
| 90 |
2/6✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
60 | KAI_TEST_ASSERT(data_type_size_in_bits(post_dtype) % 8 == 0); |
| 91 | 60 | } | |
| 92 | 32 | } | |
| 93 | |||
| 94 | [[nodiscard]] size_t compute_offset(Span<const size_t> shape, Span<const size_t> indices) const override; | ||
| 95 | [[nodiscard]] size_t compute_size(Span<const size_t> shape) const override; | ||
| 96 | [[nodiscard]] Buffer generate_random(Span<const size_t> shape, Rng& rng) const override; | ||
| 97 | [[nodiscard]] Buffer pack(Span<const size_t> shape, Span<const Span<const std::byte>> buffers) const override; | ||
| 98 | [[nodiscard]] bool compare( | ||
| 99 | Span<const size_t> shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape, | ||
| 100 | Span<const std::byte> imp_buffer, Span<const std::byte> ref_buffer, MismatchHandler& handler) const override; | ||
| 101 | void print(std::ostream& os, Span<const size_t> shape, Span<const std::byte> data) const override; | ||
| 102 | [[nodiscard]] bool operator==(const Format& other) const override; | ||
| 103 | |||
| 104 | private: | ||
| 105 | size_t m_block_height; | ||
| 106 | size_t m_block_width; | ||
| 107 | size_t m_width_align; | ||
| 108 | bool m_pad_right_same; | ||
| 109 | DataType m_dtype; | ||
| 110 | std::vector<DataType> m_pre_dtypes; | ||
| 111 | std::vector<DataType> m_post_dtypes; | ||
| 112 | }; | ||
| 113 | |||
| 114 | } // namespace kai::test | ||
| 115 |