Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP | ||
8 | #define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP | ||
9 | |||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <functional> | ||
13 | #include <vector> | ||
14 | |||
15 | #include "kai/kai_common.h" | ||
16 | #include "matmul_interface.hpp" | ||
17 | #include "matmul_runner.hpp" | ||
18 | |||
19 | #ifdef __GNUC__ | ||
20 | #pragma GCC diagnostic push | ||
21 | #pragma GCC diagnostic ignored "-Wswitch-default" | ||
22 | #endif // __GNUC__ | ||
23 | |||
24 | #include <benchmark/benchmark.h> | ||
25 | |||
26 | #ifdef __GNUC__ | ||
27 | #pragma GCC diagnostic pop | ||
28 | #endif // __GNUC__ | ||
29 | |||
30 | namespace kai::benchmark { | ||
31 | using Buffer = std::vector<uint8_t>; | ||
32 | using CpuRequirement = std::function<bool()>; | ||
33 | |||
34 | /// High level description of the matrix multiplication operation. | ||
35 | enum class MatMulOp : uint8_t { | ||
36 | GEMM, | ||
37 | GEMV, | ||
38 | }; | ||
39 | |||
40 | /// Benchmarks a matrix multiplication micro-kernel. | ||
41 | /// | ||
42 | /// @tparam MatMulInterface Interface of the matrix multiplication micro-kernel. | ||
43 | /// @param state State for the benchmark to use. | ||
44 | /// @param matmul_interface Abstraction containing the micro-kernel to run. | ||
45 | /// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions | ||
46 | /// internally about the stride of the data. | ||
47 | /// @param matmul_op Type of matrix multiplication operation. | ||
48 | /// @param is_cpu_supported Function that checks the CPU feature requirement to run this benchmark. | ||
49 | template <typename MatMulInterface> | ||
50 | 83 | void kai_benchmark_matmul( | |
51 | ::benchmark::State& state, const MatMulInterface matmul_interface, const DataType dst_type, | ||
52 | const MatMulOp matmul_op, const CpuRequirement& is_cpu_supported) { | ||
53 |
6/12✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
|
83 | if (!is_cpu_supported()) { |
54 | ✗ | state.SkipWithMessage("Unsupported CPU feature"); | |
55 | ✗ | } | |
56 | |||
57 | 83 | const size_t m = state.range(0); | |
58 | 83 | const size_t n = state.range(1); | |
59 | 83 | const size_t k = state.range(2); | |
60 | 83 | const size_t bl = state.range(3); | |
61 | |||
62 |
6/24✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 11 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 24 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 19 times.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 8 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
|
83 | if (m > 1 && matmul_op == MatMulOp::GEMV) { |
63 | ✗ | state.SkipWithMessage("GEMV optimized for m=1 only"); | |
64 | ✗ | } | |
65 | |||
66 | if constexpr ( | ||
67 | std::is_same_v<MatMulInterface, MatMulBlockwiseDynamicQuantInterface> || | ||
68 | std::is_same_v<MatMulInterface, MatMulBlockwiseDynamicQuantGenericDstInterface>) { | ||
69 |
2/4✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
32 | if (k % bl != 0) { |
70 | ✗ | state.SkipWithMessage("K must be a multiple of block size"); | |
71 | ✗ | } | |
72 | } | ||
73 | |||
74 | // Create sufficiently large buffers | ||
75 | 83 | size_t lhs_size = m * k * sizeof(uint64_t); | |
76 | 83 | size_t rhs_size = n * k * sizeof(uint64_t); | |
77 | 83 | size_t dst_size = m * n * sizeof(uint32_t); | |
78 | |||
79 |
6/24✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 11 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 24 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 19 times.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 8 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
|
83 | if (test::cpu_has_sme() || test::cpu_has_sme2()) { |
80 | 83 | lhs_size *= kai_get_sme_vector_length_u32(); | |
81 | 83 | rhs_size *= kai_get_sme_vector_length_u32(); | |
82 | 83 | dst_size *= kai_get_sme_vector_length_u32(); | |
83 | 83 | } | |
84 | |||
85 | 83 | const Buffer lhs(lhs_size); | |
86 |
6/12✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
|
83 | const Buffer rhs(rhs_size); |
87 |
6/12✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
|
83 | Buffer dst(dst_size); |
88 | |||
89 |
6/12✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
|
83 | MatMulRunner matmul_runner(matmul_interface, dst_type); |
90 |
6/12✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
|
83 | matmul_runner.set_mnk(m, n, k); |
91 |
6/12✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
|
83 | matmul_runner.set_bl(bl); |
92 | |||
93 |
32/52✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
✓ Branch 7 taken 18 times.
✓ Branch 8 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 11 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 22 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 11 times.
✓ Branch 19 taken 11 times.
✓ Branch 20 taken 24 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 24 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 48 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 24 times.
✓ Branch 27 taken 24 times.
✓ Branch 28 taken 19 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 19 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 38 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 19 times.
✓ Branch 35 taken 19 times.
✓ Branch 36 taken 3 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 3 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 6 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 3 times.
✓ Branch 43 taken 3 times.
✓ Branch 44 taken 8 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 8 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 16 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 8 times.
✓ Branch 51 taken 8 times.
|
166 | for (auto _ : state) { |
94 |
6/12✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
|
83 | matmul_runner.run(lhs.data(), rhs.data(), dst.data()); |
95 | 83 | } | |
96 | 83 | } | |
97 | } // namespace kai::benchmark | ||
98 | |||
99 | #endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP | ||
100 |