KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.9% 137 / 0 / 140
Functions: 90.5% 19 / 0 / 21
Branches: 55.1% 76 / 0 / 138

test/nextgen/tests/matmul_test_next.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <random>
12 #include <string>
13 #include <unordered_map>
14 #include <utility>
15
16 #include "test/common/assert.hpp"
17 #include "test/common/matrix_portion.hpp"
18 #include "test/common/span.hpp"
19 #include "test/nextgen/common/random.hpp"
20 #include "test/nextgen/common/test_registry.hpp"
21 #include "test/nextgen/operators/matmul/matmul_bias_mode.hpp"
22 #include "test/nextgen/operators/matmul/matmul_operator.hpp"
23 #include "test/nextgen/operators/matmul/matmul_tb.hpp"
24
25 namespace kai::test {
26
27 namespace {
28
29 struct MatMulFixtureParams {
30 size_t iteration_no;
31
32 size_t shape_m;
33 size_t shape_n;
34 size_t shape_k;
35 MatMulBiasMode bias_mode;
36 float clamp_ratio;
37
38 const MatMulOperator* op;
39
40 2400 [[nodiscard]] std::string name() const {
41
7/16
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2400 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2400 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2400 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2400 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
4800 return std::string(op->name) + ",m=" + std::to_string(shape_m) + ",n=" + std::to_string(shape_n) +
42
6/12
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2400 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2400 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2400 times.
✗ Branch 11 not taken.
2400 ",k=" + std::to_string(shape_k) + ",bias=" + matmul_bias_mode_name(bias_mode) +
43
5/10
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2400 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2400 times.
✗ Branch 9 not taken.
2400 ",clamp_ratio=" + std::to_string(clamp_ratio) + ",iteration=" + std::to_string(iteration_no);
44 }
45 };
46
47 struct MatMulTestParams {
48 MatrixPortion portion;
49
50 600 [[nodiscard]] std::string name() const {
51
6/12
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 600 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 600 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 600 times.
✗ Branch 11 not taken.
1200 return "start_m=" + std::to_string(portion.start_row()) + ",size_m=" + std::to_string(portion.height()) +
52
7/14
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 600 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 600 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 600 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 600 times.
✗ Branch 13 not taken.
600 ",start_n=" + std::to_string(portion.start_col()) + ",size_n=" + std::to_string(portion.width());
53 }
54 };
55
56 class MatMulFixture : public testing::Test {
57 public:
58 1800 explicit MatMulFixture(const MatMulFixtureParams& params) : m_fixture_params(params) {
59 1800 }
60
61 protected:
62 1800 [[nodiscard]] const MatMulFixtureParams& fixture_params() const {
63 1800 return m_fixture_params;
64 }
65
66 1800 [[nodiscard]] MatMulTb& test_bench() {
67 1800 const std::string name = m_fixture_params.name();
68
69 // If the test has already been created, uses it.
70
1/2
✓ Branch 0 taken 1800 times.
✗ Branch 1 not taken.
1800 const auto it = test_benches.find(name);
71
72
3/4
✓ Branch 0 taken 1800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1600 times.
✓ Branch 3 taken 200 times.
1800 if (it != test_benches.end()) {
73
1/2
✓ Branch 0 taken 1600 times.
✗ Branch 1 not taken.
1600 return it->second;
74 }
75
76 // Creates a new test if it hasn't been created.
77
2/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
400 MatMulTb test(
78 200 m_fixture_params.shape_m, m_fixture_params.shape_n, m_fixture_params.shape_k, m_fixture_params.bias_mode,
79 200 m_fixture_params.clamp_ratio, m_fixture_params.op);
80
81
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 Rng rng(m_fixture_params.iteration_no); // REVISIT: Derive the seed from the global seed.
82
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 test.generate_test_data(rng);
83
84
2/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
200 return test_benches[name] = std::move(test);
85 1800 }
86
87 private:
88 static std::unordered_map<std::string, MatMulTb> test_benches;
89
90 MatMulFixtureParams m_fixture_params;
91 };
92
93 3 std::unordered_map<std::string, MatMulTb> MatMulFixture::test_benches;
94
95 class MatMulPackLhsTest : public MatMulFixture {
96 public:
97 1200 explicit MatMulPackLhsTest(const MatMulFixtureParams& fixture_params, const MatMulTestParams& test_params) :
98 1200 MatMulFixture(fixture_params), m_test_params(test_params) {
99 1200 }
100
101 600 void TestBody() override {
102 600 MatMulTb& test = test_bench();
103 600 const MatMulFixtureParams& params = fixture_params();
104 600 const MatrixPortion& portion = m_test_params.portion;
105
106 1800 const auto [step_m, step_k] = test.lhs_packing_steps();
107 1800 const Rect rect = portion.compute_portion(params.shape_m, params.shape_k, step_m, step_k);
108
109 600 const size_t start_m = rect.start_row();
110 600 const size_t start_k = rect.start_col();
111 600 const size_t size_m = rect.height();
112 600 const size_t size_k = rect.width();
113
114 600 test.test_lhs_packing(start_m, start_k, size_m, size_k);
115 600 }
116
117 private:
118 MatMulTestParams m_test_params;
119 };
120
121 class MatMulPackRhsTest : public MatMulFixture {
122 public:
123 1200 explicit MatMulPackRhsTest(const MatMulFixtureParams& fixture_params, const MatMulTestParams& test_params) :
124 1200 MatMulFixture(fixture_params), m_test_params(test_params) {
125 1200 }
126
127 600 void TestBody() override {
128 600 MatMulTb& test = test_bench();
129 600 const MatMulFixtureParams& params = fixture_params();
130 600 const MatrixPortion& portion = m_test_params.portion;
131
132 1800 const auto [step_n, step_k] = test.rhs_packing_steps();
133 1800 const Rect rect = portion.compute_portion(params.shape_n, params.shape_k, step_n, step_k);
134
135 600 const size_t start_n = rect.start_row();
136 600 const size_t start_k = rect.start_col();
137 600 const size_t size_n = rect.height();
138 600 const size_t size_k = rect.width();
139
140 600 test.test_rhs_packing(start_n, start_k, size_n, size_k);
141 600 }
142
143 private:
144 MatMulTestParams m_test_params;
145 };
146
147 class MatMulMatMulTest : public MatMulFixture {
148 public:
149 1200 explicit MatMulMatMulTest(const MatMulFixtureParams& fixture_params, const MatMulTestParams& test_params) :
150 1200 MatMulFixture(fixture_params), m_test_params(test_params) {
151 1200 }
152
153 600 void TestBody() override {
154 600 MatMulTb& test = test_bench();
155 600 const MatMulFixtureParams& params = fixture_params();
156 600 const MatrixPortion& portion = m_test_params.portion;
157
158 1800 const auto [step_m, step_n] = test.matmul_steps();
159 1800 const Rect rect = portion.compute_portion(params.shape_m, params.shape_n, step_m, step_n);
160
161 600 const size_t start_m = rect.start_row();
162 600 const size_t start_n = rect.start_col();
163 600 const size_t size_m = rect.height();
164 600 const size_t size_n = rect.width();
165
166 600 test.test_matmul(start_m, start_n, size_m, size_n);
167 600 }
168
169 private:
170 MatMulTestParams m_test_params;
171 };
172
173
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
6 const auto matmul_tests_setup = TestRegistry::register_setup([]() {
174 3 const size_t num_shapes_per_op = 100;
175 3 const std::array output_portions = {
176 3 MatrixPortion(0, 0, 1, 1), // Full matrix.
177 MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner.
178 MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner.
179 };
180
181 3 const Span<const MatMulOperator> available_operators = get_available_matmul_operators();
182 3 Rng rng(0); // REVISIT: Use the global seed to initialize this RNG.
183 3 std::uniform_int_distribution<size_t> shape_dist(1, 150);
184 3 std::uniform_real_distribution<float> probability_dist(0.0F, 1.0F);
185 3 std::uniform_real_distribution<float> dist_70_to_100(0.7F, 1.0F);
186 3 std::uniform_real_distribution<float> dist_0_to_70(0.0F, 0.7F);
187
188
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
9 for (const MatMulOperator& op : available_operators) {
189
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 4 times.
6 if (!op.is_cpu_supported()) {
190 4 continue;
191 }
192
193 2 const bool test_pack_lhs = op.pack_lhs != nullptr;
194 2 const bool test_pack_rhs = op.pack_rhs != nullptr;
195 2 const bool test_matmul = op.matmul != nullptr;
196
197 2 const char* test_suite_name = "MatMulNext";
198
199
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 200 times.
202 for (size_t shape_no = 0; shape_no < num_shapes_per_op; ++shape_no) {
200 200 size_t shape_m = 0;
201 200 size_t shape_n = 0;
202 200 size_t shape_k = 0;
203 200 MatMulBiasMode bias_mode = MatMulBiasMode::NO_BIAS;
204 200 float clamp_ratio = 0;
205
206 200 while (true) {
207 200 shape_m = shape_dist(rng);
208 200 shape_n = shape_dist(rng);
209 200 shape_k = shape_dist(rng);
210
211
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (!op.is_shape_suitable(shape_m, shape_n, shape_k)) {
212 continue;
213 }
214
215 // Bias mode:
216 // * 70% of tests have bias.
217 // * 30% of tests have no bias.
218 200 const float bias_prob = probability_dist(rng);
219 200 const bool with_bias = bias_prob < 0.7F;
220
221
2/2
✓ Branch 0 taken 145 times.
✓ Branch 1 taken 55 times.
200 if (with_bias) {
222 145 bias_mode = MatMulBiasMode::PER_N;
223 145 } else {
224 55 bias_mode = MatMulBiasMode::NO_BIAS;
225 }
226
227 // Clamping range:
228 // * 20% of tests have no clamping.
229 // * 40% of tests have clamping range between 70% to 100% the output range.
230 // * 40% of tests have clamping range between 0% to 70% the output range.
231 200 const float clamp_prob = probability_dist(rng);
232
233 200 const bool no_clamp = clamp_prob < 0.2F;
234
2/2
✓ Branch 0 taken 44 times.
✓ Branch 1 taken 156 times.
200 const bool clamp_70_to_100 = clamp_prob >= 0.2F && clamp_prob < 0.6F;
235 200 const bool clamp_0_to_70 = clamp_prob >= 0.6F;
236
237
2/2
✓ Branch 0 taken 44 times.
✓ Branch 1 taken 156 times.
200 if (no_clamp) {
238 44 clamp_ratio = 1.0F;
239
2/2
✓ Branch 0 taken 75 times.
✓ Branch 1 taken 81 times.
200 } else if (clamp_70_to_100) {
240 75 clamp_ratio = dist_70_to_100(rng);
241 75 } else {
242
1/4
✓ Branch 0 taken 81 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
81 KAI_TEST_ASSERT(clamp_0_to_70);
243 81 clamp_ratio = dist_0_to_70(rng);
244 }
245
246 break;
247 200 }
248
249 1600 const MatMulFixtureParams fixture_params{
250 1400 shape_no, shape_m, shape_n, shape_k, bias_mode, clamp_ratio, &op,
251 };
252
253
2/2
✓ Branch 0 taken 600 times.
✓ Branch 1 taken 200 times.
800 for (const MatrixPortion& portion : output_portions) {
254 600 const MatMulTestParams test_params{
255 600 portion,
256 };
257
258
3/6
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 600 times.
600 const std::string params_name = fixture_params.name() + "," + test_params.name();
259
260
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
600 if (test_pack_lhs) {
261
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
600 const std::string test_name = "PackLhs/" + params_name;
262
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
1200 KAI_REGISTER_TEST(
263 MatMulFixture, MatMulPackLhsTest, test_suite_name, test_name.c_str(), fixture_params,
264 test_params);
265 600 }
266
267
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
600 if (test_pack_rhs) {
268
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
600 const std::string test_name = "PackRhs/" + params_name;
269
2/4
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
1200 KAI_REGISTER_TEST(
270 MatMulFixture, MatMulPackRhsTest, test_suite_name, test_name.c_str(), fixture_params,
271 test_params);
272 600 }
273
274
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
600 if (test_matmul) {
275
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
600 const std::string test_name = "MatMul/" + params_name;
276
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
1200 KAI_REGISTER_TEST(
277 MatMulFixture, MatMulMatMulTest, test_suite_name, test_name.c_str(), fixture_params,
278 test_params);
279 600 }
280 600 }
281 200 }
282 6 }
283 3 });
284
285 } // namespace
286
287 } // namespace kai::test
288