KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 88.9% 16 / 0 / 18
Functions: 100.0% 3 / 0 / 3
Branches: 34.4% 11 / 0 / 32

test/nextgen/format/block2d_row_format.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #pragma once
8
9 #include <cstddef>
10 #include <vector>
11
12 #include "test/common/assert.hpp"
13 #include "test/common/data_type.hpp"
14 #include "test/common/span.hpp"
15 #include "test/nextgen/common/random.hpp"
16 #include "test/nextgen/format/format.hpp"
17
18 namespace kai::test {
19
20 /// 2D blocked data with optional per-row values.
21 ///
22 /// Example:
23 /// Shape: (5, 8)
24 /// Block size: (2, 3)
25 /// Prefix per-row data 0:
26 /// a0 a1 a2 a3 a4
27 /// Prefix per-row data 1:
28 /// b0 b1 b2 b3 b4
29 /// Data:
30 /// v00 v01 v02 v03 v04 v05 v06 v07
31 /// v10 v11 v12 v13 v14 v15 v16 v17
32 /// v20 v21 v22 v23 v24 v25 v26 v27
33 /// v30 v31 v32 v33 v34 v35 v36 v37
34 /// v40 v41 v42 v43 v44 v45 v46 v47
35 /// Postfix per-row data 0:
36 /// c0 c1 c2 c3 c4
37 /// Postfix per-row data 1:
38 /// d0 d1 d2 d3 d4
39 ///
40 /// Combined blocked data with per-row data:
41 /// +----+----+-------------+--------------+------------+----+----+
42 /// | a0 | b0 | v00 v01 v02 | v03 v04 v05 | v06 v07 ___ | c0 | d0 |
43 /// | a1 | b1 | v10 v11 v12 | v13 v14 v15 | v16 v17 ___ | c1 | d1 |
44 /// +----+----+-------------+--------------+------------+----+----+
45 /// | a2 | b2 | v20 v21 v22 | v23 v24 v25 | v26 v27 ___ | c2 | d2 |
46 /// | a3 | b3 | v30 v31 v32 | v33 v34 v35 | v36 v37 ___ | c3 | d3 |
47 /// +----+----+-------------+--------------+------------+----+----+
48 /// | a4 | b4 | v40 v41 v42 | v43 v44 v45 | v46 v47 ___ | c4 | d4 |
49 /// | __ | __ | ___ ___ ___ | ___ ___ ___ | ___ ___ ___ | __ | __ |
50 /// +----+----+-------------+--------------+------------+----+----+
51 ///
52 /// Packed data stream:
53 /// +-------+-------+-------------------------+-------------------------+-------------------------+-------+-------+
54 /// | a0 a1 | b0 b1 | v00 v01 v02 v10 v11 v12 | v03 v04 v05 v13 v14 v15 | v06 v07 0 v16 v17 0 | c0 c1 | d0 d1 |
55 /// +-------+-------+-------------------------+-------------------------+-------------------------+-------+-------+
56 /// | a2 a3 | b2 b3 | v20 v21 v22 v30 v31 v32 | v23 v24 v25 v33 v34 v35 | v26 v27 0 v36 v37 0 | c2 c3 | d2 d3 |
57 /// +-------+-------+-------------------------+-------------------------+-------------------------+-------+-------+
58 /// | a4 0 | b4 0 | v40 v41 v42 0 0 0 | v43 v44 v45 0 0 0 | v46 v47 0 0 0 0 | c4 0 | d4 0 |
59 /// +-------+-------+-------------------------+-------------------------+-------------------------+-------+-------+
60
1/2
✓ Branch 0 taken 818 times.
✗ Branch 1 not taken.
818 class Block2dRowFormat : public Format {
61 public:
62 /// Creates a 2D blocked data with optional per-row values.
63 ///
64 /// @param[in] block_height The block height.
65 /// @param[in] block_width The block width.
66 /// @param[in] width_align The input data is padded so that the width is multiple of this value
67 /// before the data is packed. This value must be divisible by block width.
68 /// @param[in] pad_right_same Right padding with the last element instead of 0.
69 /// @param[in] dtype The data type.
70 /// @param[in] pre_dtypes The data type of each prefix per-row component.
71 /// @param[in] post_dtypes The data type of each postfix per-row component.
72 32 Block2dRowFormat(
73 size_t block_height, size_t block_width, size_t width_align, bool pad_right_same, DataType dtype,
74 Span<const DataType> pre_dtypes, Span<const DataType> post_dtypes) :
75 24 m_block_height(block_height),
76 24 m_block_width(block_width),
77 24 m_width_align(width_align),
78 24 m_pad_right_same(pad_right_same),
79 24 m_dtype(dtype),
80
1/2
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
24 m_pre_dtypes(pre_dtypes.begin(), pre_dtypes.end()),
81
1/2
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
32 m_post_dtypes(post_dtypes.begin(), post_dtypes.end()) {
82
1/4
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
24 KAI_TEST_ASSERT(width_align % block_width == 0);
83
2/6
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
24 KAI_TEST_ASSERT(block_height * block_width * data_type_size_in_bits(dtype) % 8 == 0);
84
85
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
24 for (const DataType pre_dtype : pre_dtypes) {
86 KAI_TEST_ASSERT(data_type_size_in_bits(pre_dtype) % 8 == 0);
87 }
88
89
2/2
✓ Branch 0 taken 24 times.
✓ Branch 1 taken 60 times.
84 for (const DataType post_dtype : post_dtypes) {
90
2/6
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
60 KAI_TEST_ASSERT(data_type_size_in_bits(post_dtype) % 8 == 0);
91 60 }
92 32 }
93
94 [[nodiscard]] size_t compute_offset(Span<const size_t> shape, Span<const size_t> indices) const override;
95 [[nodiscard]] size_t compute_size(Span<const size_t> shape) const override;
96 [[nodiscard]] Buffer generate_random(Span<const size_t> shape, Rng& rng) const override;
97 [[nodiscard]] Buffer pack(Span<const size_t> shape, Span<const Span<const std::byte>> buffers) const override;
98 [[nodiscard]] bool compare(
99 Span<const size_t> shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape,
100 Span<const std::byte> imp_buffer, Span<const std::byte> ref_buffer, MismatchHandler& handler) const override;
101 void print(std::ostream& os, Span<const size_t> shape, Span<const std::byte> data) const override;
102 [[nodiscard]] bool operator==(const Format& other) const override;
103
104 private:
105 size_t m_block_height;
106 size_t m_block_width;
107 size_t m_width_align;
108 bool m_pad_right_same;
109 DataType m_dtype;
110 std::vector<DataType> m_pre_dtypes;
111 std::vector<DataType> m_post_dtypes;
112 };
113
114 } // namespace kai::test
115