benchmark/matmul/matmul_benchmark_logic.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP | ||
| 8 | #define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP | ||
| 9 | |||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <functional> | ||
| 13 | #include <vector> | ||
| 14 | |||
| 15 | #include "kai/kai_common.h" | ||
| 16 | #include "matmul_interface.hpp" | ||
| 17 | #include "matmul_runner.hpp" | ||
| 18 | |||
| 19 | #ifdef __GNUC__ | ||
| 20 | #pragma GCC diagnostic push | ||
| 21 | #pragma GCC diagnostic ignored "-Wswitch-default" | ||
| 22 | #endif // __GNUC__ | ||
| 23 | |||
| 24 | #include <benchmark/benchmark.h> | ||
| 25 | |||
| 26 | #ifdef __GNUC__ | ||
| 27 | #pragma GCC diagnostic pop | ||
| 28 | #endif // __GNUC__ | ||
| 29 | |||
| 30 | namespace kai::benchmark { | ||
| 31 | using Buffer = std::vector<uint8_t>; | ||
| 32 | using CpuRequirement = std::function<bool()>; | ||
| 33 | |||
| 34 | /// High level description of the matrix multiplication operation. | ||
| 35 | enum class MatMulOp : uint8_t { | ||
| 36 | GEMM, | ||
| 37 | GEMV, | ||
| 38 | }; | ||
| 39 | |||
| 40 | /// Benchmarks a matrix multiplication micro-kernel. | ||
| 41 | /// | ||
| 42 | /// @tparam MatMulInterface Interface of the matrix multiplication micro-kernel. | ||
| 43 | /// @param state State for the benchmark to use. | ||
| 44 | /// @param matmul_interface Abstraction containing the micro-kernel to run. | ||
| 45 | /// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions | ||
| 46 | /// internally about the stride of the data. | ||
| 47 | /// @param matmul_op Type of matrix multiplication operation. | ||
| 48 | /// @param is_cpu_supported Function that checks the CPU feature requirement to run this benchmark. | ||
| 49 | template <typename MatMulInterface> | ||
| 50 | 534 | void kai_benchmark_matmul( | |
| 51 | ::benchmark::State& state, const MatMulInterface matmul_interface, const DataType dst_type, | ||
| 52 | const MatMulOp matmul_op, const CpuRequirement& is_cpu_supported) { | ||
| 53 |
12/12✓ Branch 0 taken 88 times.
✓ Branch 1 taken 20 times.
✓ Branch 2 taken 50 times.
✓ Branch 3 taken 22 times.
✓ Branch 4 taken 138 times.
✓ Branch 5 taken 36 times.
✓ Branch 6 taken 90 times.
✓ Branch 7 taken 24 times.
✓ Branch 8 taken 6 times.
✓ Branch 9 taken 12 times.
✓ Branch 10 taken 40 times.
✓ Branch 11 taken 8 times.
|
534 | if (!is_cpu_supported()) { |
| 54 |
14/34✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 30 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 30 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 24 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 24 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 12 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✗ Branch 23 not taken.
|
122 | state.SkipWithMessage("Unsupported CPU feature"); |
| 55 | 122 | return; | |
| 56 | } | ||
| 57 | |||
| 58 | 412 | const size_t m = state.range(0); | |
| 59 | 412 | const size_t n = state.range(1); | |
| 60 | 412 | const size_t k = state.range(2); | |
| 61 | 412 | const size_t bl = state.range(3); | |
| 62 | |||
| 63 |
6/24✗ Branch 0 not taken.
✓ Branch 1 taken 88 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 50 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 138 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 90 times.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 6 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 40 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
|
412 | if (m > 1 && matmul_op == MatMulOp::GEMV) { |
| 64 | ✗ | state.SkipWithMessage("GEMV optimized for m=1 only"); | |
| 65 | ✗ | return; | |
| 66 | } | ||
| 67 | |||
| 68 | if constexpr ( | ||
| 69 | std::is_same_v<MatMulInterface, MatMulBlockwiseDynamicQuantInterface> || | ||
| 70 | std::is_same_v<MatMulInterface, MatMulBlockwiseDynamicQuantGenericDstInterface>) { | ||
| 71 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 138 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 40 times.
|
178 | if (k % bl != 0) { |
| 72 | ✗ | state.SkipWithMessage("K must be a multiple of block size"); | |
| 73 | ✗ | return; | |
| 74 | } | ||
| 75 | } | ||
| 76 | |||
| 77 | // Create sufficiently large buffers | ||
| 78 | 412 | size_t lhs_size = m * k * sizeof(uint64_t); | |
| 79 | 412 | size_t rhs_size = n * k * sizeof(uint64_t); | |
| 80 | 412 | size_t dst_size = m * n * sizeof(uint32_t); | |
| 81 | |||
| 82 |
16/24✓ Branch 0 taken 52 times.
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 52 times.
✓ Branch 4 taken 28 times.
✓ Branch 5 taken 22 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 28 times.
✓ Branch 8 taken 86 times.
✓ Branch 9 taken 52 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 86 times.
✓ Branch 12 taken 52 times.
✓ Branch 13 taken 38 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 52 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 6 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 24 times.
✓ Branch 21 taken 16 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 24 times.
|
412 | if (test::cpu_has_sme() || test::cpu_has_sme2()) { |
| 83 | 170 | lhs_size *= kai_get_sme_vector_length_u32(); | |
| 84 | 170 | rhs_size *= kai_get_sme_vector_length_u32(); | |
| 85 | 170 | dst_size *= kai_get_sme_vector_length_u32(); | |
| 86 | 170 | } | |
| 87 | |||
| 88 |
5/12✓ Branch 0 taken 52 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 28 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 86 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 52 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
|
412 | const Buffer lhs(lhs_size); |
| 89 |
10/22✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22 times.
✓ Branch 2 taken 28 times.
✗ Branch 3 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 52 times.
✓ Branch 4 taken 86 times.
✗ Branch 5 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 38 times.
✓ Branch 6 taken 52 times.
✗ Branch 7 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
|
412 | const Buffer rhs(rhs_size); |
| 90 |
10/22✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22 times.
✓ Branch 2 taken 28 times.
✗ Branch 3 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 52 times.
✓ Branch 4 taken 86 times.
✗ Branch 5 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 38 times.
✓ Branch 6 taken 52 times.
✗ Branch 7 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
|
412 | Buffer dst(dst_size); |
| 91 | |||
| 92 |
6/12✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 90 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
|
412 | MatMulRunner matmul_runner(matmul_interface, dst_type); |
| 93 |
6/12✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 90 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
|
412 | matmul_runner.set_mnk(m, n, k); |
| 94 |
6/12✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 90 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
|
412 | matmul_runner.set_bl(bl); |
| 95 | |||
| 96 |
32/52✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 88 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 176 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 88 times.
✓ Branch 7 taken 88 times.
✓ Branch 8 taken 88 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 88 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 50 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 50 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 100 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 50 times.
✓ Branch 19 taken 50 times.
✓ Branch 20 taken 138 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 138 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 276 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 138 times.
✓ Branch 27 taken 138 times.
✓ Branch 28 taken 90 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 90 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 180 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 90 times.
✓ Branch 35 taken 90 times.
✓ Branch 36 taken 6 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 6 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 12 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 6 times.
✓ Branch 43 taken 6 times.
✓ Branch 44 taken 40 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 40 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 80 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 40 times.
✓ Branch 51 taken 40 times.
|
824 | for (auto _ : state) { |
| 97 |
6/12✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 90 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
|
412 | matmul_runner.run(lhs.data(), rhs.data(), dst.data()); |
| 98 | 412 | } | |
| 99 | 534 | } | |
| 100 | } // namespace kai::benchmark | ||
| 101 | |||
| 102 | #endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP | ||
| 103 |