Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <functional> | ||
13 | #include <limits> | ||
14 | #include <map> | ||
15 | #include <string_view> | ||
16 | #include <tuple> | ||
17 | #include <utility> | ||
18 | #include <vector> | ||
19 | |||
20 | #include "kai/kai_common.h" | ||
21 | #include "test/common/buffer.hpp" | ||
22 | #include "test/common/compare.hpp" | ||
23 | #include "test/common/cpu_info.hpp" | ||
24 | #include "test/common/data_format.hpp" | ||
25 | #include "test/common/data_type.hpp" | ||
26 | #include "test/common/matmul_test_common.hpp" | ||
27 | #include "test/common/matrix_portion.hpp" | ||
28 | #include "test/common/printer.hpp" | ||
29 | #include "test/reference/fill.hpp" | ||
30 | #include "test/reference/matmul.hpp" | ||
31 | #include "test/reference/pack.hpp" | ||
32 | |||
33 | // matmul_clamp_f16_bf16p_bf16p | ||
34 | #include "kai/ukernels/matmul/matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" | ||
35 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h" | ||
36 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.h" | ||
37 | namespace kai::test { | ||
38 | |||
39 | /// List of supported matrix multiplication methods. | ||
40 | namespace { | ||
41 | |||
42 | 1 | static const std::array<MatMulMethod, 2>& get_matmul_methods() { | |
43 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
1 | static std::array<MatMulMethod, 2> matmul_methods{}; |
44 | |||
45 | matmul_methods[0].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla"; | ||
46 | matmul_methods[0].m0 = 8; | ||
47 | matmul_methods[0].n0 = 12; | ||
48 | matmul_methods[0].k0 = 4; | ||
49 | matmul_methods[0].dst_format = DataFormat(DataType::FP16); | ||
50 | matmul_methods[0].lhs_format = DataFormat(DataType::FP16); | ||
51 | matmul_methods[0].packed_lhs_format = | ||
52 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); | ||
53 | matmul_methods[0].rhs_format = DataFormat(DataType::FP16); | ||
54 | matmul_methods[0].packed_rhs_format = DataFormat( | ||
55 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); | ||
56 | matmul_methods[0].bias_format = DataFormat(DataType::FP16); | ||
57 | matmul_methods[0].fn_is_supported = cpu_has_bf16; | ||
58 | matmul_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
59 | matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
60 | matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
61 | matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
62 | matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
63 | matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
64 | matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
65 | matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; | ||
66 | matmul_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; | ||
67 | matmul_methods[0].fn_get_packed_lhs_offset = | ||
68 | kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
69 | matmul_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; | ||
70 | matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
71 | matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
72 | matmul_methods[0].fn_get_main_packed_rhs_offset = | ||
73 | kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
74 | matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
75 | matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
76 | matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
77 | matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
78 | matmul_methods[0].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
79 | |||
80 | matmul_methods[1].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla_opt_bias"; | ||
81 | matmul_methods[1].m0 = 8; | ||
82 | matmul_methods[1].n0 = 12; | ||
83 | matmul_methods[1].k0 = 4; | ||
84 | matmul_methods[1].dst_format = DataFormat(DataType::FP16); | ||
85 | matmul_methods[1].lhs_format = DataFormat(DataType::FP16); | ||
86 | matmul_methods[1].packed_lhs_format = | ||
87 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); | ||
88 | matmul_methods[1].rhs_format = DataFormat(DataType::FP16); | ||
89 | matmul_methods[1].packed_rhs_format = DataFormat( | ||
90 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4); | ||
91 | matmul_methods[1].bias_format = DataFormat(DataType::UNKNOWN); | ||
92 | matmul_methods[1].fn_is_supported = cpu_has_bf16; | ||
93 | matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
94 | matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
95 | matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
96 | matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
97 | matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
98 | matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
99 | matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
100 | matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; | ||
101 | matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; | ||
102 | matmul_methods[1].fn_get_packed_lhs_offset = | ||
103 | kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
104 | matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; | ||
105 | matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
106 | matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
107 | matmul_methods[1].fn_get_main_packed_rhs_offset = | ||
108 | kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
109 | matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
110 | matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon; | ||
111 | matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
112 | matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
113 | matmul_methods[1].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
114 | |||
115 | return matmul_methods; | ||
116 | } | ||
117 | |||
118 | } // namespace | ||
119 | |||
120 | /// Matrix multiplication test fixture. | ||
121 | class MatMulTestBf16OutFp16 : public testing::TestWithParam<MatMulTestParams> { | ||
122 | private: | ||
123 | /// Unique ID: m, n, k | ||
124 | using TestDataId = std::tuple<size_t, size_t, size_t, std::string_view>; | ||
125 | |||
126 | protected: | ||
127 | /// Cached test data that is shared between multiple test case. | ||
128 | 18 | struct TestData { | |
129 | 36 | Buffer lhs{}; ///< LHS operand. | |
130 | 36 | Buffer ref_packed_lhs{}; ///< Reference packed LHS. | |
131 | 36 | Buffer rhs{}; ///< RHS operand. | |
132 | 36 | Buffer rhs_scales{}; ///< RHS per-row quantization scales. | |
133 | 36 | Buffer bias{}; ///< Bias. | |
134 | 36 | Buffer ref_packed_rhs{}; ///< Reference packed RHS. | |
135 | 36 | Buffer ref_dst{}; ///< Reference output. | |
136 | }; | ||
137 | |||
138 | /// Gets the test data for the current test case. | ||
139 | 90 | static const TestData& test_data() { | |
140 | 657 | const auto& [method, info, portion] = GetParam(); | |
141 | 450 | const TestDataId data_id{info.m, info.n, info.k, method.name}; | |
142 | |||
143 | // If the test data is already available, returns it. | ||
144 | 90 | const auto data_it = _data.find(data_id); | |
145 | |||
146 |
2/2✓ Branch 0 taken 72 times.
✓ Branch 1 taken 18 times.
|
90 | if (data_it != _data.end()) { |
147 | 72 | return data_it->second; | |
148 | } | ||
149 | |||
150 | // Generates the test data. | ||
151 | 36 | const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; | |
152 | 36 | const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; | |
153 | 36 | const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; | |
154 | |||
155 | 36 | const auto lhs_h = info.m; | |
156 | 36 | const auto lhs_w = info.k; | |
157 | 36 | auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); | |
158 | 18 | Buffer ref_packed_lhs; | |
159 | |||
160 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
|
18 | if (has_lhs_pack) { |
161 | 18 | ref_packed_lhs = | |
162 |
3/6✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
|
36 | pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); |
163 | 18 | } | |
164 | |||
165 | 36 | const auto rhs_h = info.k; | |
166 | 36 | const auto rhs_w = info.n; | |
167 |
2/4✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
|
36 | auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); |
168 | |||
169 | 18 | Buffer rhs_scales; | |
170 |
3/8✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
18 | if (data_type_is_quantized(method.rhs_format.data_type()) && |
171 | ✗ | method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { | |
172 | ✗ | rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); | |
173 | ✗ | } | |
174 | |||
175 | 18 | const auto bias_h = 1; | |
176 | 36 | const auto bias_w = info.n; | |
177 | 18 | Buffer bias; | |
178 | |||
179 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | if (has_bias) { |
180 |
2/4✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
|
18 | bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3); |
181 | 9 | } | |
182 | |||
183 |
3/6✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
|
36 | Buffer packed_rhs(method.fn_get_packed_rhs_size(rhs_w, rhs_h)); |
184 | |||
185 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
|
18 | if (has_rhs_pack) { |
186 |
2/4✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
|
36 | const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); |
187 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | method.pack_rhs( |
188 |
5/8✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✓ Branch 5 taken 9 times.
✓ Branch 6 taken 9 times.
✗ Branch 7 not taken.
|
36 | info.n, info.k, rhs.data(), ref_rhs_row_stride, has_bias ? bias.data() : nullptr, nullptr, |
189 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | packed_rhs.data()); |
190 | 18 | } | |
191 | |||
192 | − | KAI_ASSUME(method.lhs_format.is_raw()); | |
193 | − | KAI_ASSUME(method.rhs_format.is_raw()); | |
194 | − | KAI_ASSUME(method.dst_format.is_raw()); | |
195 | |||
196 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
36 | auto ref_dst = matmul( |
197 |
2/4✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
|
18 | lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // |
198 |
3/6✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
|
18 | rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // |
199 |
5/8✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
✗ Branch 7 not taken.
|
18 | has_bias ? bias.data() : nullptr, nullptr, nullptr, method.bias_format.data_type(), // |
200 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | method.dst_format.data_type(), // |
201 | 54 | info.m, info.n, info.k, false /* lhs_transposed */, false /* rhs_transposed */); | |
202 | |||
203 |
8/16✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 18 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 18 times.
✗ Branch 15 not taken.
|
144 | auto& data = _data[data_id] = {}; |
204 | 18 | data.lhs = std::move(lhs); | |
205 | 18 | data.ref_packed_lhs = std::move(ref_packed_lhs); | |
206 | 18 | data.rhs = std::move(rhs); | |
207 | 18 | data.rhs_scales = std::move(rhs_scales); | |
208 | 18 | data.bias = std::move(bias); | |
209 | 18 | data.ref_packed_rhs = std::move(packed_rhs); | |
210 | 18 | data.ref_dst = std::move(ref_dst); | |
211 | |||
212 | 18 | return data; | |
213 | 90 | } | |
214 | |||
215 | private: | ||
216 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
217 | static std::map<TestDataId, TestData> _data; | ||
218 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
219 | }; | ||
220 | |||
221 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
222 | 1 | std::map<MatMulTestBf16OutFp16::TestDataId, MatMulTestBf16OutFp16::TestData> MatMulTestBf16OutFp16::_data; | |
223 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
224 | |||
225 | /// Tests the output. | ||
226 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
272 | TEST_P(MatMulTestBf16OutFp16, Output) { |
227 | 3108 | const auto& [method, info, portion] = GetParam(); | |
228 | |||
229 |
2/4✓ Branch 0 taken 90 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 90 times.
✗ Branch 3 not taken.
|
90 | if (method.fn_is_supported && !method.fn_is_supported()) { |
230 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
231 | } | ||
232 | |||
233 |
1/2✓ Branch 0 taken 90 times.
✗ Branch 1 not taken.
|
90 | if (!method.has_main_kernel()) { |
234 | ✗ | GTEST_SKIP() << "No main kernel available"; | |
235 | } | ||
236 | |||
237 | 90 | const auto& data = test_data(); | |
238 | |||
239 | 180 | const auto m_step = method.fn_get_main_m_step(); | |
240 |
4/16✓ Branch 0 taken 90 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 90 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 90 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 90 times.
|
180 | ASSERT_EQ(m_step, method.m0); |
241 | |||
242 | 180 | const auto n_step = method.fn_get_main_n_step(); | |
243 |
4/16✓ Branch 0 taken 90 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 90 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 90 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 90 times.
|
180 | ASSERT_EQ(n_step, method.n0); |
244 | |||
245 | 450 | const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); | |
246 | |||
247 |
4/4✓ Branch 0 taken 86 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 84 times.
|
90 | if (rect.height() == 0 || rect.width() == 0) { |
248 |
9/18✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
|
6 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
249 | } | ||
250 | |||
251 | 168 | const size_t lhs_w = info.k; | |
252 | 84 | const size_t rhs_w = rect.width(); | |
253 | 168 | const size_t bias_w = info.n; | |
254 | 168 | const size_t dst_w = info.n; | |
255 | 84 | const bool has_bias = (data.bias.size() > 0); | |
256 | |||
257 | 84 | const auto lhs_start_row = rect.start_row(); | |
258 | 168 | const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); | |
259 | |||
260 | 504 | const size_t lhs_packed_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, method.k0, 1 /* sr */); | |
261 | 84 | Buffer lhs_data(lhs_packed_size); | |
262 | |||
263 |
2/4✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
|
168 | uintptr_t lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); |
264 |
3/6✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
|
252 | uintptr_t lhs_packed_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); |
265 | |||
266 | KAI_UNUSED(lhs_offset); | ||
267 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
168 | method.fn_pack_lhs( |
268 |
4/8✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
|
84 | rect.height(), info.k, method.m0, method.k0, 1 /* sr */, 0 /* m_idx_start */, data.lhs.data() + lhs_offset, |
269 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | lhs_stride, lhs_data.data() + lhs_packed_offset); |
270 | |||
271 |
3/6✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
|
252 | const auto rhs_stride = method.rhs_format.default_row_stride(info.n); |
272 | |||
273 |
4/8✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
|
336 | const size_t rhs_packed_size = method.fn_get_packed_rhs_size(info.n, info.k); |
274 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | Buffer rhs_data(rhs_packed_size); |
275 | |||
276 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | const auto packed_rhs_start_row = rect.start_col(); |
277 | 84 | const auto packed_rhs_start_col = 0; | |
278 | |||
279 |
3/6✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
|
168 | uintptr_t rhs_offset = method.fn_get_rhs_offset(rect.start_col()); |
280 |
3/6✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
|
252 | uintptr_t rhs_packed_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); |
281 | 84 | const auto ref_rhs_packed_offset = | |
282 |
2/4✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
|
168 | method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); |
283 | |||
284 |
4/16✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
|
84 | ASSERT_EQ(rhs_packed_offset, ref_rhs_packed_offset); |
285 | |||
286 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | uintptr_t bias_offset = sizeof(uint16_t) * rect.start_col(); |
287 | |||
288 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
168 | method.fn_pack_rhs( |
289 | 1, // num_groups | ||
290 | 336 | rhs_w, info.k, method.n0, method.k0, | |
291 | 1, // sr | ||
292 |
4/6✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✓ Branch 3 taken 42 times.
✓ Branch 4 taken 42 times.
✗ Branch 5 not taken.
|
84 | rhs_stride, data.rhs.data() + rhs_offset, has_bias ? data.bias.data() + bias_offset : nullptr, |
293 | NULL, // Scale | ||
294 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | rhs_data.data() + rhs_packed_offset, 0, NULL); |
295 | |||
296 |
2/2✓ Branch 0 taken 42 times.
✓ Branch 1 taken 42 times.
|
84 | if (has_bias) { |
297 |
3/6✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 42 times.
✗ Branch 5 not taken.
|
84 | const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_col(), bias_w); |
298 |
4/16✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 42 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 42 times.
|
42 | ASSERT_EQ(ref_bias_offset, bias_offset); |
299 | 42 | } | |
300 | |||
301 |
2/4✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
|
168 | const auto dst_stride = method.dst_format.default_row_stride(dst_w); |
302 |
4/8✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
|
168 | const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
303 |
4/8✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
|
168 | const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); |
304 |
4/16✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
|
84 | ASSERT_EQ(dst_offset, ref_dst_offset); |
305 | |||
306 |
4/8✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
|
336 | const auto dst_size = method.fn_get_dst_size(info.m, info.n); |
307 |
4/8✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
|
336 | const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n); |
308 |
4/16✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
|
84 | ASSERT_EQ(dst_size, ref_dst_size); |
309 | |||
310 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | Buffer dst(dst_size); |
311 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | method.main_kernel( |
312 |
4/8✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
|
84 | rect.height(), rect.width(), info.k, lhs_data.data() + lhs_packed_offset, rhs_data.data() + rhs_packed_offset, |
313 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | NULL, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits<float>::infinity(), |
314 | 84 | std::numeric_limits<float>::infinity()); | |
315 | |||
316 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
317 |
5/10✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 84 times.
✗ Branch 9 not taken.
|
84 | const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); |
318 |
4/16✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
|
84 | ASSERT_TRUE(success); |
319 | 90 | } | |
320 | |||
321 |
15/48✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 90 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
|
92 | INSTANTIATE_TEST_SUITE_P( |
322 | MatMul, MatMulTestBf16OutFp16, | ||
323 | testing::Combine( | ||
324 | testing::ValuesIn(get_matmul_methods()), | ||
325 | testing::Values( | ||
326 | MatMulShape{3, 7, 3}, // Smaller than block size | ||
327 | MatMulShape{12, 8, 4}, // Same block size | ||
328 | MatMulShape{1, 1, 73}, // Long K | ||
329 | MatMulShape{73, 1, 5}, // Long M | ||
330 | MatMulShape{2, 73, 6}, // Long N | ||
331 | MatMulShape{13, 33, 23}, // | ||
332 | MatMulShape{73, 57, 69}, // | ||
333 | MatMulShape{70, 70, 70}, // Square | ||
334 | MatMulShape{59, 67, 73} // Prime numbers | ||
335 | ), | ||
336 | testing::Values( | ||
337 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
338 | MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. | ||
339 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
340 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
341 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
342 | )), | ||
343 | testing::PrintToStringParamName()); | ||
344 | } // namespace kai::test | ||
345 |