test/nextgen/tests/matmul_test_next.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <random> | ||
| 12 | #include <string> | ||
| 13 | #include <unordered_map> | ||
| 14 | #include <utility> | ||
| 15 | |||
| 16 | #include "test/common/assert.hpp" | ||
| 17 | #include "test/common/matrix_portion.hpp" | ||
| 18 | #include "test/common/span.hpp" | ||
| 19 | #include "test/nextgen/common/random.hpp" | ||
| 20 | #include "test/nextgen/common/test_registry.hpp" | ||
| 21 | #include "test/nextgen/operators/matmul/matmul_bias_mode.hpp" | ||
| 22 | #include "test/nextgen/operators/matmul/matmul_operator.hpp" | ||
| 23 | #include "test/nextgen/operators/matmul/matmul_tb.hpp" | ||
| 24 | |||
| 25 | namespace kai::test { | ||
| 26 | |||
| 27 | namespace { | ||
| 28 | |||
| 29 | struct MatMulFixtureParams { | ||
| 30 | size_t iteration_no; | ||
| 31 | |||
| 32 | size_t shape_m; | ||
| 33 | size_t shape_n; | ||
| 34 | size_t shape_k; | ||
| 35 | MatMulBiasMode bias_mode; | ||
| 36 | float clamp_ratio; | ||
| 37 | |||
| 38 | const MatMulOperator* op; | ||
| 39 | |||
| 40 | 2400 | [[nodiscard]] std::string name() const { | |
| 41 |
7/16✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2400 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2400 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2400 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2400 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
|
4800 | return std::string(op->name) + ",m=" + std::to_string(shape_m) + ",n=" + std::to_string(shape_n) + |
| 42 |
6/12✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2400 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2400 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2400 times.
✗ Branch 11 not taken.
|
2400 | ",k=" + std::to_string(shape_k) + ",bias=" + matmul_bias_mode_name(bias_mode) + |
| 43 |
5/10✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2400 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2400 times.
✗ Branch 9 not taken.
|
2400 | ",clamp_ratio=" + std::to_string(clamp_ratio) + ",iteration=" + std::to_string(iteration_no); |
| 44 | ✗ | } | |
| 45 | }; | ||
| 46 | |||
| 47 | struct MatMulTestParams { | ||
| 48 | MatrixPortion portion; | ||
| 49 | |||
| 50 | 600 | [[nodiscard]] std::string name() const { | |
| 51 |
6/12✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 600 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 600 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 600 times.
✗ Branch 11 not taken.
|
1200 | return "start_m=" + std::to_string(portion.start_row()) + ",size_m=" + std::to_string(portion.height()) + |
| 52 |
7/14✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 600 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 600 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 600 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 600 times.
✗ Branch 13 not taken.
|
600 | ",start_n=" + std::to_string(portion.start_col()) + ",size_n=" + std::to_string(portion.width()); |
| 53 | ✗ | } | |
| 54 | }; | ||
| 55 | |||
| 56 | class MatMulFixture : public testing::Test { | ||
| 57 | public: | ||
| 58 | 1800 | explicit MatMulFixture(const MatMulFixtureParams& params) : m_fixture_params(params) { | |
| 59 | 1800 | } | |
| 60 | |||
| 61 | protected: | ||
| 62 | 1800 | [[nodiscard]] const MatMulFixtureParams& fixture_params() const { | |
| 63 | 1800 | return m_fixture_params; | |
| 64 | } | ||
| 65 | |||
| 66 | 1800 | [[nodiscard]] MatMulTb& test_bench() { | |
| 67 | 1800 | const std::string name = m_fixture_params.name(); | |
| 68 | |||
| 69 | // If the test has already been created, uses it. | ||
| 70 |
1/2✓ Branch 0 taken 1800 times.
✗ Branch 1 not taken.
|
1800 | const auto it = test_benches.find(name); |
| 71 | |||
| 72 |
3/4✓ Branch 0 taken 1800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1600 times.
✓ Branch 3 taken 200 times.
|
1800 | if (it != test_benches.end()) { |
| 73 |
1/2✓ Branch 0 taken 1600 times.
✗ Branch 1 not taken.
|
1600 | return it->second; |
| 74 | } | ||
| 75 | |||
| 76 | // Creates a new test if it hasn't been created. | ||
| 77 |
2/4✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
|
400 | MatMulTb test( |
| 78 | 200 | m_fixture_params.shape_m, m_fixture_params.shape_n, m_fixture_params.shape_k, m_fixture_params.bias_mode, | |
| 79 | 200 | m_fixture_params.clamp_ratio, m_fixture_params.op); | |
| 80 | |||
| 81 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | Rng rng(m_fixture_params.iteration_no); // REVISIT: Derive the seed from the global seed. |
| 82 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | test.generate_test_data(rng); |
| 83 | |||
| 84 |
2/4✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
|
200 | return test_benches[name] = std::move(test); |
| 85 | 1800 | } | |
| 86 | |||
| 87 | private: | ||
| 88 | static std::unordered_map<std::string, MatMulTb> test_benches; | ||
| 89 | |||
| 90 | MatMulFixtureParams m_fixture_params; | ||
| 91 | }; | ||
| 92 | |||
| 93 | 3 | std::unordered_map<std::string, MatMulTb> MatMulFixture::test_benches; | |
| 94 | |||
| 95 | class MatMulPackLhsTest : public MatMulFixture { | ||
| 96 | public: | ||
| 97 | 1200 | explicit MatMulPackLhsTest(const MatMulFixtureParams& fixture_params, const MatMulTestParams& test_params) : | |
| 98 | 1200 | MatMulFixture(fixture_params), m_test_params(test_params) { | |
| 99 | 1200 | } | |
| 100 | |||
| 101 | 600 | void TestBody() override { | |
| 102 | 600 | MatMulTb& test = test_bench(); | |
| 103 | 600 | const MatMulFixtureParams& params = fixture_params(); | |
| 104 | 600 | const MatrixPortion& portion = m_test_params.portion; | |
| 105 | |||
| 106 | 1800 | const auto [step_m, step_k] = test.lhs_packing_steps(); | |
| 107 | 1800 | const Rect rect = portion.compute_portion(params.shape_m, params.shape_k, step_m, step_k); | |
| 108 | |||
| 109 | 600 | const size_t start_m = rect.start_row(); | |
| 110 | 600 | const size_t start_k = rect.start_col(); | |
| 111 | 600 | const size_t size_m = rect.height(); | |
| 112 | 600 | const size_t size_k = rect.width(); | |
| 113 | |||
| 114 | 600 | test.test_lhs_packing(start_m, start_k, size_m, size_k); | |
| 115 | 600 | } | |
| 116 | |||
| 117 | private: | ||
| 118 | MatMulTestParams m_test_params; | ||
| 119 | }; | ||
| 120 | |||
| 121 | class MatMulPackRhsTest : public MatMulFixture { | ||
| 122 | public: | ||
| 123 | 1200 | explicit MatMulPackRhsTest(const MatMulFixtureParams& fixture_params, const MatMulTestParams& test_params) : | |
| 124 | 1200 | MatMulFixture(fixture_params), m_test_params(test_params) { | |
| 125 | 1200 | } | |
| 126 | |||
| 127 | 600 | void TestBody() override { | |
| 128 | 600 | MatMulTb& test = test_bench(); | |
| 129 | 600 | const MatMulFixtureParams& params = fixture_params(); | |
| 130 | 600 | const MatrixPortion& portion = m_test_params.portion; | |
| 131 | |||
| 132 | 1800 | const auto [step_n, step_k] = test.rhs_packing_steps(); | |
| 133 | 1800 | const Rect rect = portion.compute_portion(params.shape_n, params.shape_k, step_n, step_k); | |
| 134 | |||
| 135 | 600 | const size_t start_n = rect.start_row(); | |
| 136 | 600 | const size_t start_k = rect.start_col(); | |
| 137 | 600 | const size_t size_n = rect.height(); | |
| 138 | 600 | const size_t size_k = rect.width(); | |
| 139 | |||
| 140 | 600 | test.test_rhs_packing(start_n, start_k, size_n, size_k); | |
| 141 | 600 | } | |
| 142 | |||
| 143 | private: | ||
| 144 | MatMulTestParams m_test_params; | ||
| 145 | }; | ||
| 146 | |||
| 147 | class MatMulMatMulTest : public MatMulFixture { | ||
| 148 | public: | ||
| 149 | 1200 | explicit MatMulMatMulTest(const MatMulFixtureParams& fixture_params, const MatMulTestParams& test_params) : | |
| 150 | 1200 | MatMulFixture(fixture_params), m_test_params(test_params) { | |
| 151 | 1200 | } | |
| 152 | |||
| 153 | 600 | void TestBody() override { | |
| 154 | 600 | MatMulTb& test = test_bench(); | |
| 155 | 600 | const MatMulFixtureParams& params = fixture_params(); | |
| 156 | 600 | const MatrixPortion& portion = m_test_params.portion; | |
| 157 | |||
| 158 | 1800 | const auto [step_m, step_n] = test.matmul_steps(); | |
| 159 | 1800 | const Rect rect = portion.compute_portion(params.shape_m, params.shape_n, step_m, step_n); | |
| 160 | |||
| 161 | 600 | const size_t start_m = rect.start_row(); | |
| 162 | 600 | const size_t start_n = rect.start_col(); | |
| 163 | 600 | const size_t size_m = rect.height(); | |
| 164 | 600 | const size_t size_n = rect.width(); | |
| 165 | |||
| 166 | 600 | test.test_matmul(start_m, start_n, size_m, size_n); | |
| 167 | 600 | } | |
| 168 | |||
| 169 | private: | ||
| 170 | MatMulTestParams m_test_params; | ||
| 171 | }; | ||
| 172 | |||
| 173 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
6 | const auto matmul_tests_setup = TestRegistry::register_setup([]() { |
| 174 | 3 | const size_t num_shapes_per_op = 100; | |
| 175 | 3 | const std::array output_portions = { | |
| 176 | 3 | MatrixPortion(0, 0, 1, 1), // Full matrix. | |
| 177 | MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. | ||
| 178 | MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner. | ||
| 179 | }; | ||
| 180 | |||
| 181 | 3 | const Span<const MatMulOperator> available_operators = get_available_matmul_operators(); | |
| 182 | 3 | Rng rng(0); // REVISIT: Use the global seed to initialize this RNG. | |
| 183 | 3 | std::uniform_int_distribution<size_t> shape_dist(1, 150); | |
| 184 | 3 | std::uniform_real_distribution<float> probability_dist(0.0F, 1.0F); | |
| 185 | 3 | std::uniform_real_distribution<float> dist_70_to_100(0.7F, 1.0F); | |
| 186 | 3 | std::uniform_real_distribution<float> dist_0_to_70(0.0F, 0.7F); | |
| 187 | |||
| 188 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
|
9 | for (const MatMulOperator& op : available_operators) { |
| 189 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 4 times.
|
6 | if (!op.is_cpu_supported()) { |
| 190 | 4 | continue; | |
| 191 | } | ||
| 192 | |||
| 193 | 2 | const bool test_pack_lhs = op.pack_lhs != nullptr; | |
| 194 | 2 | const bool test_pack_rhs = op.pack_rhs != nullptr; | |
| 195 | 2 | const bool test_matmul = op.matmul != nullptr; | |
| 196 | |||
| 197 | 2 | const char* test_suite_name = "MatMulNext"; | |
| 198 | |||
| 199 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 200 times.
|
202 | for (size_t shape_no = 0; shape_no < num_shapes_per_op; ++shape_no) { |
| 200 | 200 | size_t shape_m = 0; | |
| 201 | 200 | size_t shape_n = 0; | |
| 202 | 200 | size_t shape_k = 0; | |
| 203 | 200 | MatMulBiasMode bias_mode = MatMulBiasMode::NO_BIAS; | |
| 204 | 200 | float clamp_ratio = 0; | |
| 205 | |||
| 206 | 200 | while (true) { | |
| 207 | 200 | shape_m = shape_dist(rng); | |
| 208 | 200 | shape_n = shape_dist(rng); | |
| 209 | 200 | shape_k = shape_dist(rng); | |
| 210 | |||
| 211 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (!op.is_shape_suitable(shape_m, shape_n, shape_k)) { |
| 212 | ✗ | continue; | |
| 213 | } | ||
| 214 | |||
| 215 | // Bias mode: | ||
| 216 | // * 70% of tests have bias. | ||
| 217 | // * 30% of tests have no bias. | ||
| 218 | 200 | const float bias_prob = probability_dist(rng); | |
| 219 | 200 | const bool with_bias = bias_prob < 0.7F; | |
| 220 | |||
| 221 |
2/2✓ Branch 0 taken 145 times.
✓ Branch 1 taken 55 times.
|
200 | if (with_bias) { |
| 222 | 145 | bias_mode = MatMulBiasMode::PER_N; | |
| 223 | 145 | } else { | |
| 224 | 55 | bias_mode = MatMulBiasMode::NO_BIAS; | |
| 225 | } | ||
| 226 | |||
| 227 | // Clamping range: | ||
| 228 | // * 20% of tests have no clamping. | ||
| 229 | // * 40% of tests have clamping range between 70% to 100% the output range. | ||
| 230 | // * 40% of tests have clamping range between 0% to 70% the output range. | ||
| 231 | 200 | const float clamp_prob = probability_dist(rng); | |
| 232 | |||
| 233 | 200 | const bool no_clamp = clamp_prob < 0.2F; | |
| 234 |
2/2✓ Branch 0 taken 44 times.
✓ Branch 1 taken 156 times.
|
200 | const bool clamp_70_to_100 = clamp_prob >= 0.2F && clamp_prob < 0.6F; |
| 235 | 200 | const bool clamp_0_to_70 = clamp_prob >= 0.6F; | |
| 236 | |||
| 237 |
2/2✓ Branch 0 taken 44 times.
✓ Branch 1 taken 156 times.
|
200 | if (no_clamp) { |
| 238 | 44 | clamp_ratio = 1.0F; | |
| 239 |
2/2✓ Branch 0 taken 75 times.
✓ Branch 1 taken 81 times.
|
200 | } else if (clamp_70_to_100) { |
| 240 | 75 | clamp_ratio = dist_70_to_100(rng); | |
| 241 | 75 | } else { | |
| 242 |
1/4✓ Branch 0 taken 81 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
81 | KAI_TEST_ASSERT(clamp_0_to_70); |
| 243 | 81 | clamp_ratio = dist_0_to_70(rng); | |
| 244 | } | ||
| 245 | |||
| 246 | break; | ||
| 247 | 200 | } | |
| 248 | |||
| 249 | 1600 | const MatMulFixtureParams fixture_params{ | |
| 250 | 1400 | shape_no, shape_m, shape_n, shape_k, bias_mode, clamp_ratio, &op, | |
| 251 | }; | ||
| 252 | |||
| 253 |
2/2✓ Branch 0 taken 600 times.
✓ Branch 1 taken 200 times.
|
800 | for (const MatrixPortion& portion : output_portions) { |
| 254 | 600 | const MatMulTestParams test_params{ | |
| 255 | 600 | portion, | |
| 256 | }; | ||
| 257 | |||
| 258 |
3/6✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 600 times.
|
600 | const std::string params_name = fixture_params.name() + "," + test_params.name(); |
| 259 | |||
| 260 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
|
600 | if (test_pack_lhs) { |
| 261 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
|
600 | const std::string test_name = "PackLhs/" + params_name; |
| 262 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
|
1200 | KAI_REGISTER_TEST( |
| 263 | MatMulFixture, MatMulPackLhsTest, test_suite_name, test_name.c_str(), fixture_params, | ||
| 264 | test_params); | ||
| 265 | 600 | } | |
| 266 | |||
| 267 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
|
600 | if (test_pack_rhs) { |
| 268 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
|
600 | const std::string test_name = "PackRhs/" + params_name; |
| 269 |
2/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
|
1200 | KAI_REGISTER_TEST( |
| 270 | MatMulFixture, MatMulPackRhsTest, test_suite_name, test_name.c_str(), fixture_params, | ||
| 271 | test_params); | ||
| 272 | 600 | } | |
| 273 | |||
| 274 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
|
600 | if (test_matmul) { |
| 275 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
|
600 | const std::string test_name = "MatMul/" + params_name; |
| 276 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
|
1200 | KAI_REGISTER_TEST( |
| 277 | MatMulFixture, MatMulMatMulTest, test_suite_name, test_name.c_str(), fixture_params, | ||
| 278 | test_params); | ||
| 279 | 600 | } | |
| 280 | 600 | } | |
| 281 | 200 | } | |
| 282 | 6 | } | |
| 283 | 3 | }); | |
| 284 | |||
| 285 | } // namespace | ||
| 286 | |||
| 287 | } // namespace kai::test | ||
| 288 |