KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 95.6% 130 / 3 / 139
Functions: 100.0% 14 / 0 / 14
Branches: 39.2% 229 / 12 / 596

test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <functional>
13 #include <limits>
14 #include <map>
15 #include <string_view>
16 #include <tuple>
17 #include <utility>
18 #include <vector>
19
20 #include "kai/kai_common.h"
21 #include "test/common/buffer.hpp"
22 #include "test/common/compare.hpp"
23 #include "test/common/cpu_info.hpp"
24 #include "test/common/data_format.hpp"
25 #include "test/common/data_type.hpp"
26 #include "test/common/matmul_test_common.hpp"
27 #include "test/common/matrix_portion.hpp"
28 #include "test/common/printer.hpp"
29 #include "test/common/seed.hpp"
30 #include "test/reference/clamp.hpp"
31 #include "test/reference/fill.hpp"
32 #include "test/reference/matmul.hpp"
33 #include "test/reference/pack.hpp"
34
35 // matmul_clamp_f16_bf16p_bf16p
36 #include "kai/ukernels/matmul/matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h"
37 #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h"
38 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.h"
39
40 namespace kai::test {
41
42 /// List of supported matrix multiplication methods.
43 namespace {
44
45 3 static const std::array<MatMulMethod, 2>& get_matmul_methods() {
46
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
3 static std::array<MatMulMethod, 2> matmul_methods{};
47
48 matmul_methods[0].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla";
49 matmul_methods[0].m0 = 8;
50 matmul_methods[0].n0 = 12;
51 matmul_methods[0].k0 = 4;
52 matmul_methods[0].dst_format = DataFormat(DataType::FP16);
53 matmul_methods[0].lhs_format = DataFormat(DataType::FP16);
54 matmul_methods[0].packed_lhs_format =
55 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4);
56 matmul_methods[0].rhs_format = DataFormat(DataType::FP16);
57 matmul_methods[0].packed_rhs_format = DataFormat(
58 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4);
59 matmul_methods[0].bias_format = DataFormat(DataType::FP16);
60 matmul_methods[0].fn_is_supported = cpu_has_bf16;
61 matmul_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
62 matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
63 matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
64 matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
65 matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
66 matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
67 matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
68 matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon;
69 matmul_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon;
70 matmul_methods[0].fn_get_packed_lhs_offset =
71 kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
72 matmul_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon;
73 matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
74 matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
75 matmul_methods[0].fn_get_main_packed_rhs_offset =
76 kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
77 matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
78 matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
79 matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
80 matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
81 matmul_methods[0].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
82
83 matmul_methods[1].name = "matmul_nt_nt_f16_bf16p_bf16p_8x12_neon_mla_opt_bias";
84 matmul_methods[1].m0 = 8;
85 matmul_methods[1].n0 = 12;
86 matmul_methods[1].k0 = 4;
87 matmul_methods[1].dst_format = DataFormat(DataType::FP16);
88 matmul_methods[1].lhs_format = DataFormat(DataType::FP16);
89 matmul_methods[1].packed_lhs_format =
90 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4);
91 matmul_methods[1].rhs_format = DataFormat(DataType::FP16);
92 matmul_methods[1].packed_rhs_format = DataFormat(
93 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 12, 4);
94 matmul_methods[1].bias_format = DataFormat(DataType::UNKNOWN);
95 matmul_methods[1].fn_is_supported = cpu_has_bf16;
96 matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
97 matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
98 matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
99 matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
100 matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
101 matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
102 matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
103 matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon;
104 matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon;
105 matmul_methods[1].fn_get_packed_lhs_offset =
106 kai_get_lhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
107 matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon;
108 matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
109 matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
110 matmul_methods[1].fn_get_main_packed_rhs_offset =
111 kai_get_rhs_packed_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
112 matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
113 matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon;
114 matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
115 matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
116 matmul_methods[1].fn_matmul_f16_bf16p_bf16p = kai_run_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
117
118 return matmul_methods;
119 }
120
121 } // namespace
122
123 /// Matrix multiplication test fixture.
124 class MatMulTestBf16OutFp16 : public testing::TestWithParam<MatMulClampTestParams> {
125 private:
126 /// Unique ID: m, n, k
127 using TestDataId = std::tuple<size_t, size_t, size_t, float, std::string_view>;
128
129 protected:
130 /// Cached test data that is shared between multiple test case.
131 108 struct TestData {
132 216 Buffer lhs{}; ///< LHS operand.
133 216 Buffer ref_packed_lhs{}; ///< Reference packed LHS.
134 216 Buffer rhs{}; ///< RHS operand.
135 216 Buffer rhs_scales{}; ///< RHS per-row quantization scales.
136 216 Buffer bias{}; ///< Bias.
137 216 Buffer ref_packed_rhs{}; ///< Reference packed RHS.
138 216 Buffer ref_dst{}; ///< Reference output.
139 Range<float> clamp_range; ///< Clamp range
140 };
141
142 /// Gets the test data for the current test case.
143 540 static const TestData& test_data() {
144 5508 const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam();
145 2700 const TestDataId data_id{info.m, info.n, info.k, clamp_keep_ratio, method.name};
146
147 // If the test data is already available, returns it.
148 540 const auto data_it = _data.find(data_id);
149
150
2/2
✓ Branch 0 taken 432 times.
✓ Branch 1 taken 108 times.
540 if (data_it != _data.end()) {
151 432 return data_it->second;
152 }
153
154 // Generates the test data.
155 216 const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN;
156 216 const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN;
157 216 const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN;
158
159 // Seed the random generator.
160
12/24
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 108 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 108 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 108 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 108 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 108 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 108 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 108 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 54 times.
✗ Branch 23 not taken.
432 const auto key = std::string(method.name) + "_" + std::to_string(info.m) + "x" + std::to_string(info.n) + "x" +
161
7/14
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 108 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 108 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 108 times.
✗ Branch 13 not taken.
432 std::to_string(info.k) + "_" + (bias_mode == BiasMode::INTERNAL ? "internal" : "provided") + ":" +
162
2/4
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
216 std::to_string(clamp_keep_ratio);
163
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
108 auto& feed = seed_stream(key);
164
165 216 const auto lhs_h = info.m;
166 216 const auto lhs_w = info.k;
167
3/6
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
216 auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, feed());
168 108 Buffer ref_packed_lhs;
169
170
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 108 times.
108 if (has_lhs_pack) {
171 108 ref_packed_lhs =
172
3/6
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
216 pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w);
173 108 }
174
175 216 const auto rhs_h = info.k;
176 216 const auto rhs_w = info.n;
177
3/6
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
216 auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, feed());
178
179 108 Buffer rhs_scales;
180
3/8
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 108 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
108 if (data_type_is_quantized(method.rhs_format.data_type()) &&
181 method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) {
182 rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), feed());
183 }
184
185 108 const auto bias_h = 1;
186 216 const auto bias_w = info.n;
187 108 Buffer bias;
188
189
2/2
✓ Branch 0 taken 54 times.
✓ Branch 1 taken 54 times.
108 if (has_bias) {
190
3/6
✓ Branch 0 taken 54 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 54 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 54 times.
✗ Branch 5 not taken.
108 bias = fill_matrix_random(bias_h, bias_w, method.bias_format, feed());
191 54 }
192
193
3/6
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
216 Buffer packed_rhs(method.fn_get_packed_rhs_size(rhs_w, rhs_h));
194
195
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 108 times.
108 if (has_rhs_pack) {
196
2/4
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
216 const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w);
197
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
108 method.pack_rhs(
198
5/8
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 54 times.
✓ Branch 5 taken 54 times.
✓ Branch 6 taken 54 times.
✗ Branch 7 not taken.
216 info.n, info.k, rhs.data(), ref_rhs_row_stride, has_bias ? bias.data() : nullptr, nullptr,
199
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
108 packed_rhs.data());
200 108 }
201
202 KAI_ASSUME_ALWAYS(method.lhs_format.is_raw());
203 KAI_ASSUME_ALWAYS(method.rhs_format.is_raw());
204 KAI_ASSUME_ALWAYS(method.dst_format.is_raw());
205
206
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
216 auto ref_dst = matmul(
207
2/4
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
108 lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), //
208
3/6
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
108 rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), //
209
5/8
✓ Branch 0 taken 54 times.
✓ Branch 1 taken 54 times.
✓ Branch 2 taken 54 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
108 has_bias ? bias.data() : nullptr, nullptr, nullptr, method.bias_format.data_type(), //
210
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
108 method.dst_format.data_type(), //
211 324 info.m, info.n, info.k, false /* lhs_transposed */, false /* rhs_transposed */);
212
213 432 const auto [min, max] =
214
5/10
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 108 times.
✗ Branch 9 not taken.
108 find_clamp_range(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_keep_ratio);
215
216
5/10
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 108 times.
✗ Branch 9 not taken.
108 auto ref_clamped = clamp(DataType::FP16, ref_dst.data(), info.m * info.n, min, max);
217
218
8/16
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 108 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 108 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 108 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 108 times.
✗ Branch 15 not taken.
864 auto& data = _data[data_id] = {};
219 108 data.lhs = std::move(lhs);
220 108 data.ref_packed_lhs = std::move(ref_packed_lhs);
221 108 data.rhs = std::move(rhs);
222 108 data.rhs_scales = std::move(rhs_scales);
223 108 data.bias = std::move(bias);
224 108 data.ref_packed_rhs = std::move(packed_rhs);
225 108 data.ref_dst = std::move(ref_clamped);
226 324 data.clamp_range = {min, max};
227
228 108 return data;
229 540 }
230
231 private:
232 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
233 static std::map<TestDataId, TestData> _data;
234 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
235 };
236
237 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
238 3 std::map<MatMulTestBf16OutFp16::TestDataId, MatMulTestBf16OutFp16::TestData> MatMulTestBf16OutFp16::_data;
239 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
240
241 /// Tests the output.
242
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
1356 TEST_P(MatMulTestBf16OutFp16, Output) {
243 19710 const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam();
244
245
2/4
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
540 if (method.fn_is_supported && !method.fn_is_supported()) {
246 GTEST_SKIP() << "Unsupported CPU feature";
247 }
248
249
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 if (!method.has_main_kernel()) {
250 GTEST_SKIP() << "No main kernel available";
251 }
252
253 540 const auto& data = test_data();
254
255 1080 const auto m_step = method.fn_get_main_m_step();
256
4/16
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 540 times.
1080 ASSERT_EQ(m_step, method.m0);
257
258 1080 const auto n_step = method.fn_get_main_n_step();
259
4/16
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 540 times.
1080 ASSERT_EQ(n_step, method.n0);
260
261 2700 const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0);
262
263
2/4
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 540 times.
540 if (rect.height() == 0 || rect.width() == 0) {
264 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
265 }
266
267 1080 const size_t lhs_w = info.k;
268 540 const size_t rhs_w = rect.width();
269 1080 const size_t bias_w = info.n;
270 1080 const size_t dst_w = info.n;
271 540 const bool has_bias = (data.bias.size() > 0);
272
273 540 const auto lhs_start_row = rect.start_row();
274 1080 const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w);
275
276 3240 const size_t lhs_packed_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, method.k0, 1 /* sr */);
277 540 Buffer lhs_data(lhs_packed_size);
278
279
2/4
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
1080 uintptr_t lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride);
280
3/6
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
1620 uintptr_t lhs_packed_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k);
281
282 KAI_UNUSED(lhs_offset);
283
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
1080 method.fn_pack_lhs(
284
4/8
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 540 times.
✗ Branch 7 not taken.
540 rect.height(), info.k, method.m0, method.k0, 1 /* sr */, 0 /* m_idx_start */, data.lhs.data() + lhs_offset,
285
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 lhs_stride, lhs_data.data() + lhs_packed_offset);
286
287
3/6
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
1620 const auto rhs_stride = method.rhs_format.default_row_stride(info.n);
288
289
4/8
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 540 times.
✗ Branch 7 not taken.
2160 const size_t rhs_packed_size = method.fn_get_packed_rhs_size(info.n, info.k);
290
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 Buffer rhs_data(rhs_packed_size);
291
292
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 const auto packed_rhs_start_row = rect.start_col();
293 540 const auto packed_rhs_start_col = 0;
294
295
3/6
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
1080 uintptr_t rhs_offset = method.fn_get_rhs_offset(rect.start_col());
296
3/6
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
1620 uintptr_t rhs_packed_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k);
297 540 const auto ref_rhs_packed_offset =
298
2/4
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
1080 method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k);
299
300
4/16
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 540 times.
540 ASSERT_EQ(rhs_packed_offset, ref_rhs_packed_offset);
301
302
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 uintptr_t bias_offset = sizeof(uint16_t) * rect.start_col();
303
304
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
1080 method.fn_pack_rhs(
305 1, // num_groups
306 2160 rhs_w, info.k, method.n0, method.k0,
307 1, // sr
308
4/6
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 270 times.
✓ Branch 3 taken 270 times.
✓ Branch 4 taken 270 times.
✗ Branch 5 not taken.
540 rhs_stride, data.rhs.data() + rhs_offset, has_bias ? data.bias.data() + bias_offset : nullptr,
309 NULL, // Scale
310
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 rhs_data.data() + rhs_packed_offset, 0, NULL);
311
312
2/2
✓ Branch 0 taken 270 times.
✓ Branch 1 taken 270 times.
540 if (has_bias) {
313
3/6
✓ Branch 0 taken 270 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 270 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 270 times.
✗ Branch 5 not taken.
540 const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_col(), bias_w);
314
4/16
✓ Branch 0 taken 270 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 270 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 270 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 270 times.
270 ASSERT_EQ(ref_bias_offset, bias_offset);
315 270 }
316
317
2/4
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
1080 const auto dst_stride = method.dst_format.default_row_stride(dst_w);
318
4/8
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 540 times.
✗ Branch 7 not taken.
1080 const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
319
4/8
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 540 times.
✗ Branch 7 not taken.
1080 const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w);
320
4/16
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 540 times.
540 ASSERT_EQ(dst_offset, ref_dst_offset);
321
322
4/8
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 540 times.
✗ Branch 7 not taken.
2160 const auto dst_size = method.fn_get_dst_size(info.m, info.n);
323
4/8
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 540 times.
✗ Branch 7 not taken.
2160 const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n);
324
4/16
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 540 times.
540 ASSERT_EQ(dst_size, ref_dst_size);
325
326
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 Buffer dst(dst_size);
327
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 method.main_kernel(
328
4/8
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 540 times.
✗ Branch 7 not taken.
540 rect.height(), rect.width(), info.k, lhs_data.data() + lhs_packed_offset, rhs_data.data() + rhs_packed_offset,
329
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 NULL, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, data.clamp_range.min, data.clamp_range.max);
330
331
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
540 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
332
5/10
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 540 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 540 times.
✗ Branch 9 not taken.
540 const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler);
333
4/16
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 540 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 540 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 540 times.
540 ASSERT_TRUE(success);
334 540 }
335
336
30/104
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✓ Branch 30 taken 2 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✓ Branch 32 taken 270 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✓ Branch 34 taken 540 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
816 INSTANTIATE_TEST_SUITE_P(
337 MatMul, MatMulTestBf16OutFp16,
338 testing::Combine(
339 testing::ValuesIn(get_matmul_methods()),
340 testing::Values(
341 MatMulShape{3, 7, 3}, // Smaller than block size
342 MatMulShape{12, 8, 4}, // Same block size
343 MatMulShape{1, 1, 73}, // Long K
344 MatMulShape{73, 1, 5}, // Long M
345 MatMulShape{2, 73, 6}, // Long N
346 MatMulShape{13, 33, 23}, //
347 MatMulShape{73, 57, 69}, //
348 MatMulShape{70, 70, 70}, // Square
349 MatMulShape{59, 67, 73} // Prime numbers
350 ),
351 testing::Values(
352 MatrixPortion(0, 0, 1, 1), // Full matrix.
353 MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner.
354 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
355 MatrixPortion(0.75, 0, 1, 1), // Partial rows
356 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
357 ),
358 testing::Values(BiasMode::PROVIDED), //
359 testing::ValuesIn(std::initializer_list<float>{1.0f, 0.9f, 0.5f}) // Clamping
360 ),
361 testing::PrintToStringParamName());
362 } // namespace kai::test
363