KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.1% 124 3 132
Functions: 100.0% 13 0 13
Branches: 41.6% 187 18 468

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <functional>
13 #include <limits>
14 #include <map>
15 #include <string_view>
16 #include <tuple>
17 #include <utility>
18 #include <vector>
19
20 #include "kai/kai_common.h"
21 #include "test/common/buffer.hpp"
22 #include "test/common/compare.hpp"
23 #include "test/common/cpu_info.hpp"
24 #include "test/common/data_format.hpp"
25 #include "test/common/data_type.hpp"
26 #include "test/common/matmul_test_common.hpp"
27 #include "test/common/matrix_portion.hpp"
28 #include "test/common/printer.hpp"
29 #include "test/reference/fill.hpp"
30 #include "test/reference/matmul.hpp"
31 #include "test/reference/pack.hpp"
32
33 // matmul_clamp_f16_bf16p_bf16p
34 #include "kai/ukernels/matmul/matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h"
35 #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h"
36 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.h"
37 namespace kai::test {
38
39 /// List of supported matrix multiplication methods.
40 namespace {
41
42 1 static const std::array<MatMulMethod, 2>& get_matmul_methods() {
43
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
1 static std::array<MatMulMethod, 2> matmul_methods{};
44
45 matmul_methods[0].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla";
46 matmul_methods[0].m0 = 8;
47 matmul_methods[0].n0 = 12;
48 matmul_methods[0].k0 = 4;
49 matmul_methods[0].dst_format = DataFormat(DataType::FP16);
50 matmul_methods[0].lhs_format = DataFormat(DataType::FP16);
51 matmul_methods[0].packed_lhs_format =
52 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4);
53 matmul_methods[0].rhs_format = DataFormat(DataType::FP16);
54 matmul_methods[0].packed_rhs_format = DataFormat(
55 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4);
56 matmul_methods[0].bias_format = DataFormat(DataType::FP16);
57 matmul_methods[0].fn_is_supported = cpu_has_bf16;
58 matmul_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
59 matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
60 matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
61 matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
62 matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
63 matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
64 matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
65 matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon;
66 matmul_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon;
67 matmul_methods[0].fn_get_packed_lhs_offset =
68 kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
69 matmul_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon;
70 matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
71 matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
72 matmul_methods[0].fn_get_main_packed_rhs_offset =
73 kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
74 matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
75 matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
76 matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
77 matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
78 matmul_methods[0].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
79
80 matmul_methods[1].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla_opt_bias";
81 matmul_methods[1].m0 = 8;
82 matmul_methods[1].n0 = 12;
83 matmul_methods[1].k0 = 4;
84 matmul_methods[1].dst_format = DataFormat(DataType::FP16);
85 matmul_methods[1].lhs_format = DataFormat(DataType::FP16);
86 matmul_methods[1].packed_lhs_format =
87 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4);
88 matmul_methods[1].rhs_format = DataFormat(DataType::FP16);
89 matmul_methods[1].packed_rhs_format = DataFormat(
90 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4);
91 matmul_methods[1].bias_format = DataFormat(DataType::UNKNOWN);
92 matmul_methods[1].fn_is_supported = cpu_has_bf16;
93 matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
94 matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
95 matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
96 matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
97 matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
98 matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
99 matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
100 matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon;
101 matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon;
102 matmul_methods[1].fn_get_packed_lhs_offset =
103 kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
104 matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon;
105 matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
106 matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
107 matmul_methods[1].fn_get_main_packed_rhs_offset =
108 kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
109 matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
110 matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
111 matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
112 matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
113 matmul_methods[1].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
114
115 return matmul_methods;
116 }
117
118 } // namespace
119
120 /// Matrix multiplication test fixture.
121 class MatMulTestBf16OutFp16 : public testing::TestWithParam<MatMulTestParams> {
122 private:
123 /// Unique ID: m, n, k
124 using TestDataId = std::tuple<size_t, size_t, size_t, std::string_view>;
125
126 protected:
127 /// Cached test data that is shared between multiple test case.
128 18 struct TestData {
129 36 Buffer lhs{}; ///< LHS operand.
130 36 Buffer ref_packed_lhs{}; ///< Reference packed LHS.
131 36 Buffer rhs{}; ///< RHS operand.
132 36 Buffer rhs_scales{}; ///< RHS per-row quantization scales.
133 36 Buffer bias{}; ///< Bias.
134 36 Buffer ref_packed_rhs{}; ///< Reference packed RHS.
135 36 Buffer ref_dst{}; ///< Reference output.
136 };
137
138 /// Gets the test data for the current test case.
139 90 static const TestData& test_data() {
140 657 const auto& [method, info, portion] = GetParam();
141 450 const TestDataId data_id{info.m, info.n, info.k, method.name};
142
143 // If the test data is already available, returns it.
144 90 const auto data_it = _data.find(data_id);
145
146
2/2
✓ Branch 0 taken 72 times.
✓ Branch 1 taken 18 times.
90 if (data_it != _data.end()) {
147 72 return data_it->second;
148 }
149
150 // Generates the test data.
151 36 const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN;
152 36 const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN;
153 36 const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN;
154
155 36 const auto lhs_h = info.m;
156 36 const auto lhs_w = info.k;
157 36 auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0);
158 18 Buffer ref_packed_lhs;
159
160
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
18 if (has_lhs_pack) {
161 18 ref_packed_lhs =
162
3/6
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
36 pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w);
163 18 }
164
165 36 const auto rhs_h = info.k;
166 36 const auto rhs_w = info.n;
167
2/4
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
36 auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1);
168
169 18 Buffer rhs_scales;
170
3/8
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
18 if (data_type_is_quantized(method.rhs_format.data_type()) &&
171 method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) {
172 rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2);
173 }
174
175 18 const auto bias_h = 1;
176 36 const auto bias_w = info.n;
177 18 Buffer bias;
178
179
2/2
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
18 if (has_bias) {
180
2/4
✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
18 bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3);
181 9 }
182
183
3/6
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
36 Buffer packed_rhs(method.fn_get_packed_rhs_size(rhs_w, rhs_h));
184
185
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
18 if (has_rhs_pack) {
186
2/4
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
36 const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w);
187
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 method.pack_rhs(
188
5/8
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✓ Branch 5 taken 9 times.
✓ Branch 6 taken 9 times.
✗ Branch 7 not taken.
36 info.n, info.k, rhs.data(), ref_rhs_row_stride, has_bias ? bias.data() : nullptr, nullptr,
189
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 packed_rhs.data());
190 18 }
191
192 KAI_ASSUME(method.lhs_format.is_raw());
193 KAI_ASSUME(method.rhs_format.is_raw());
194 KAI_ASSUME(method.dst_format.is_raw());
195
196
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
36 auto ref_dst = matmul(
197
2/4
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
18 lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), //
198
3/6
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
18 rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), //
199
5/8
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
✗ Branch 7 not taken.
18 has_bias ? bias.data() : nullptr, nullptr, nullptr, method.bias_format.data_type(), //
200
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 method.dst_format.data_type(), //
201 54 info.m, info.n, info.k, false /* lhs_transposed */, false /* rhs_transposed */);
202
203
8/16
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 18 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 18 times.
✗ Branch 15 not taken.
144 auto& data = _data[data_id] = {};
204 18 data.lhs = std::move(lhs);
205 18 data.ref_packed_lhs = std::move(ref_packed_lhs);
206 18 data.rhs = std::move(rhs);
207 18 data.rhs_scales = std::move(rhs_scales);
208 18 data.bias = std::move(bias);
209 18 data.ref_packed_rhs = std::move(packed_rhs);
210 18 data.ref_dst = std::move(ref_dst);
211
212 18 return data;
213 90 }
214
215 private:
216 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
217 static std::map<TestDataId, TestData> _data;
218 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
219 };
220
221 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
222 1 std::map<MatMulTestBf16OutFp16::TestDataId, MatMulTestBf16OutFp16::TestData> MatMulTestBf16OutFp16::_data;
223 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
224
225 /// Tests the output.
226
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
272 TEST_P(MatMulTestBf16OutFp16, Output) {
227 3108 const auto& [method, info, portion] = GetParam();
228
229
2/4
✓ Branch 0 taken 90 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 90 times.
✗ Branch 3 not taken.
90 if (method.fn_is_supported && !method.fn_is_supported()) {
230 GTEST_SKIP() << "Unsupported CPU feature";
231 }
232
233
1/2
✓ Branch 0 taken 90 times.
✗ Branch 1 not taken.
90 if (!method.has_main_kernel()) {
234 GTEST_SKIP() << "No main kernel available";
235 }
236
237 90 const auto& data = test_data();
238
239 180 const auto m_step = method.fn_get_main_m_step();
240
4/16
✓ Branch 0 taken 90 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 90 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 90 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 90 times.
180 ASSERT_EQ(m_step, method.m0);
241
242 180 const auto n_step = method.fn_get_main_n_step();
243
4/16
✓ Branch 0 taken 90 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 90 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 90 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 90 times.
180 ASSERT_EQ(n_step, method.n0);
244
245 450 const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0);
246
247
4/4
✓ Branch 0 taken 86 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 84 times.
90 if (rect.height() == 0 || rect.width() == 0) {
248
9/18
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
6 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
249 }
250
251 168 const size_t lhs_w = info.k;
252 84 const size_t rhs_w = rect.width();
253 168 const size_t bias_w = info.n;
254 168 const size_t dst_w = info.n;
255 84 const bool has_bias = (data.bias.size() > 0);
256
257 84 const auto lhs_start_row = rect.start_row();
258 168 const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w);
259
260 504 const size_t lhs_packed_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, method.k0, 1 /* sr */);
261 84 Buffer lhs_data(lhs_packed_size);
262
263
2/4
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
168 uintptr_t lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride);
264
3/6
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
252 uintptr_t lhs_packed_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k);
265
266 KAI_UNUSED(lhs_offset);
267
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
168 method.fn_pack_lhs(
268
4/8
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
84 rect.height(), info.k, method.m0, method.k0, 1 /* sr */, 0 /* m_idx_start */, data.lhs.data() + lhs_offset,
269
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
84 lhs_stride, lhs_data.data() + lhs_packed_offset);
270
271
3/6
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
252 const auto rhs_stride = method.rhs_format.default_row_stride(info.n);
272
273
4/8
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
336 const size_t rhs_packed_size = method.fn_get_packed_rhs_size(info.n, info.k);
274
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
84 Buffer rhs_data(rhs_packed_size);
275
276
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
84 const auto packed_rhs_start_row = rect.start_col();
277 84 const auto packed_rhs_start_col = 0;
278
279
3/6
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
168 uintptr_t rhs_offset = method.fn_get_rhs_offset(rect.start_col());
280
3/6
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
252 uintptr_t rhs_packed_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k);
281 84 const auto ref_rhs_packed_offset =
282
2/4
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
168 method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k);
283
284
4/16
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
84 ASSERT_EQ(rhs_packed_offset, ref_rhs_packed_offset);
285
286
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
84 uintptr_t bias_offset = sizeof(uint16_t) * rect.start_col();
287
288
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
168 method.fn_pack_rhs(
289 1, // num_groups
290 336 rhs_w, info.k, method.n0, method.k0,
291 1, // sr
292
4/6
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✓ Branch 3 taken 42 times.
✓ Branch 4 taken 42 times.
✗ Branch 5 not taken.
84 rhs_stride, data.rhs.data() + rhs_offset, has_bias ? data.bias.data() + bias_offset : nullptr,
293 NULL, // Scale
294
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
84 rhs_data.data() + rhs_packed_offset, 0, NULL);
295
296
2/2
✓ Branch 0 taken 42 times.
✓ Branch 1 taken 42 times.
84 if (has_bias) {
297
3/6
✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 42 times.
✗ Branch 5 not taken.
84 const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_col(), bias_w);
298
4/16
✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 42 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 42 times.
42 ASSERT_EQ(ref_bias_offset, bias_offset);
299 42 }
300
301
2/4
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
168 const auto dst_stride = method.dst_format.default_row_stride(dst_w);
302
4/8
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
168 const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
303
4/8
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
168 const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w);
304
4/16
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
84 ASSERT_EQ(dst_offset, ref_dst_offset);
305
306
4/8
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
336 const auto dst_size = method.fn_get_dst_size(info.m, info.n);
307
4/8
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
336 const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n);
308
4/16
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
84 ASSERT_EQ(dst_size, ref_dst_size);
309
310
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
84 Buffer dst(dst_size);
311
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
84 method.main_kernel(
312
4/8
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
84 rect.height(), rect.width(), info.k, lhs_data.data() + lhs_packed_offset, rhs_data.data() + rhs_packed_offset,
313
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
84 NULL, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits<float>::infinity(),
314 84 std::numeric_limits<float>::infinity());
315
316
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
84 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
317
5/10
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 84 times.
✗ Branch 9 not taken.
84 const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler);
318
4/16
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
84 ASSERT_TRUE(success);
319 90 }
320
321
15/48
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 90 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
92 INSTANTIATE_TEST_SUITE_P(
322 MatMul, MatMulTestBf16OutFp16,
323 testing::Combine(
324 testing::ValuesIn(get_matmul_methods()),
325 testing::Values(
326 MatMulShape{3, 7, 3}, // Smaller than block size
327 MatMulShape{12, 8, 4}, // Same block size
328 MatMulShape{1, 1, 73}, // Long K
329 MatMulShape{73, 1, 5}, // Long M
330 MatMulShape{2, 73, 6}, // Long N
331 MatMulShape{13, 33, 23}, //
332 MatMulShape{73, 57, 69}, //
333 MatMulShape{70, 70, 70}, // Square
334 MatMulShape{59, 67, 73} // Prime numbers
335 ),
336 testing::Values(
337 MatrixPortion(0, 0, 1, 1), // Full matrix.
338 MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner.
339 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
340 MatrixPortion(0.75, 0, 1, 1), // Partial rows
341 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
342 )),
343 testing::PrintToStringParamName());
344 } // namespace kai::test
345