KleidiAI Coverage Report


Directory: ./
File: test/common/data_format.hpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 100.0% 4 0 4
Functions: 100.0% 2 0 2
Branches: -% 0 0 0

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #pragma once
8
9 #include <cstddef>
10 #include <cstdint>
11 #include <functional>
12
13 #include "test/common/data_type.hpp"
14
15 namespace kai::test {
16
17 /// Data format.
18 class DataFormat {
19 public:
20 /// Packing format.
21 enum class PackFormat : uint32_t {
22 NONE, ///< No quantization information is included.
23 BIAS_PER_ROW, ///< Per-row bias.
24 QUANTIZE_PER_ROW, ///< Per-row quantization.
25 };
26
27 /// Creates a new data format.
28 ///
29 /// @param[in] data_type Data type of data value.
30 /// @param[in] block_height Block height.
31 /// @param[in] block_width Block width.
32 /// @param[in] pack_format Packing format.
33 /// @param[in] zero_point_dt Data type of zero point value.
34 /// @param[in] scale_dt Data type of scale value.
35 /// @param[in] subblock_height Sub-block height.
36 /// @param[in] subblock_width Sub-block width.
37 DataFormat(
38 DataType data_type = DataType::UNKNOWN, size_t block_height = 0, size_t block_width = 0,
39 PackFormat pack_format = PackFormat::NONE, DataType zero_point_dt = DataType::UNKNOWN,
40 DataType scale_dt = DataType::UNKNOWN, size_t subblock_height = 0, size_t subblock_width = 0) noexcept;
41
42 /// Equality operator.
43 [[nodiscard]] bool operator==(const DataFormat& rhs) const;
44
45 /// Unequality operator.
46 [[nodiscard]] bool operator!=(const DataFormat& rhs) const;
47
48 /// Gets the packing format.
49 [[nodiscard]] PackFormat pack_format() const;
50
51 /// Gets the data type of data value.
52 [[nodiscard]] DataType data_type() const;
53
54 /// Gets the data type of scale value.
55 [[nodiscard]] DataType scale_data_type() const;
56
57 /// Gets the data type of zero point value.
58 [[nodiscard]] DataType zero_point_data_type() const;
59
60 /// Gets a value indicating whether this format has no blocking or packing information.
61 [[nodiscard]] bool is_raw() const;
62
63 /// Gets the block height.
64 [[nodiscard]] size_t block_height() const;
65
66 /// Gets the block width.
67 [[nodiscard]] size_t block_width() const;
68
69 /// Gets the sub-block height.
70 [[nodiscard]] size_t subblock_height() const;
71
72 /// Gets the sub-block width.
73 [[nodiscard]] size_t subblock_width() const;
74
75 /// Gets the block height given the full height of the matrix.
76 ///
77 /// @param[in] full_height Height of the full matrix.
78 ///
79 /// @return The block height.
80 [[nodiscard]] size_t actual_block_height(size_t full_height) const;
81
82 /// Gets the block width given the full width of the matrix.
83 ///
84 /// @param[in] full_width Width of the full matrix.
85 ///
86 /// @return The block width.
87 [[nodiscard]] size_t actual_block_width(size_t full_width) const;
88
89 /// Gets the sub-block height given the full height of the matrix.
90 ///
91 /// @param[in] full_height Height of the full matrix.
92 ///
93 /// @return The sub-block height.
94 [[nodiscard]] size_t actual_subblock_height(size_t full_height) const;
95
96 /// Gets the sub-block width given the full width of the matrix.
97 ///
98 /// @param[in] full_width Width of the full matrix.
99 ///
100 /// @return The sub-block width.
101 [[nodiscard]] size_t actual_subblock_width(size_t full_width) const;
102
103 /// Gets the scheduling block height.
104 ///
105 /// @param[in] full_height Height of the full matrix.
106 ///
107 /// @return The block height for scheduling purpose.
108 [[nodiscard]] size_t scheduler_block_height(size_t full_height) const;
109
110 /// Gets the scheduling block width.
111 ///
112 /// @param[in] full_width Width of the full matrix.
113 ///
114 /// @return The block width for scheduling purpose.
115 [[nodiscard]] size_t scheduler_block_width(size_t full_width) const;
116
117 /// Gets the row stride in bytes given the data is stored continuously without any gap in the memory.
118 ///
119 /// In case of per-row bias or quantization, the row stride is the number of bytes from one row group
120 /// to the next. One row group consists of `block_height` rows.
121 ///
122 /// @param[in] width Width of the full matrix.
123 ///
124 /// @return The default row stride in bytes of the matrix.
125 [[nodiscard]] uintptr_t default_row_stride(size_t width) const;
126
127 /// Gets the offsets in bytes in the data buffer given the data is stored continuously
128 /// without any gap in the memory.
129 ///
130 /// @param[in] row Row coordinate.
131 /// @param[in] col Colum coordinate.
132 /// @param[in] width Width of the full matrix.
133 ///
134 /// @return The default offset in bytes.
135 [[nodiscard]] uintptr_t default_offset_in_bytes(size_t row, size_t col, size_t width) const;
136
137 /// Gets the size in bytes of the matrix given the data is stored continuously without any gap in the memory.
138 ///
139 /// @param[in] height Height of the full matrix.
140 /// @param[in] width Width of the full matrix.
141 ///
142 /// @return The size in bytes of the matrix.
143 [[nodiscard]] size_t default_size_in_bytes(size_t height, size_t width) const;
144
145 /// Hash functor
146 struct Hash {
147 size_t operator()(const DataFormat& format) const;
148 };
149
150 private:
151 DataType _data_type;
152 PackFormat _pack_format;
153 DataType _scale_dt;
154 DataType _zero_point_dt;
155 size_t _block_height;
156 size_t _block_width;
157 size_t _subblock_height;
158 size_t _subblock_width;
159 };
160
161 } // namespace kai::test
162
163 template <>
164 struct std::hash<kai::test::DataFormat> {
165 3729 size_t operator()(const kai::test::DataFormat& df) const {
166 3729 return kai::test::DataFormat::Hash{}(df);
167 }
168 };
169
170 template <>
171 struct std::hash<kai::test::DataFormat::PackFormat> {
172 120417 size_t operator()(const kai::test::DataFormat::PackFormat& pf) const {
173 using PF = std::underlying_type_t<kai::test::DataFormat::PackFormat>;
174 120417 return std::hash<PF>{}(static_cast<PF>(pf));
175 }
176 };
177