KleidiAI Coverage Report


Directory: ./
File: benchmark/matmul/matmul_benchmark_logic.hpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 81.2% 26 0 32
Functions: 100.0% 6 0 6
Branches: 40.7% 88 0 216

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP
8 #define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP
9
10 #include <cstddef>
11 #include <cstdint>
12 #include <functional>
13 #include <vector>
14
15 #include "kai/kai_common.h"
16 #include "matmul_interface.hpp"
17 #include "matmul_runner.hpp"
18
19 #ifdef __GNUC__
20 #pragma GCC diagnostic push
21 #pragma GCC diagnostic ignored "-Wswitch-default"
22 #endif // __GNUC__
23
24 #include <benchmark/benchmark.h>
25
26 #ifdef __GNUC__
27 #pragma GCC diagnostic pop
28 #endif // __GNUC__
29
30 namespace kai::benchmark {
31 using Buffer = std::vector<uint8_t>;
32 using CpuRequirement = std::function<bool()>;
33
34 /// High level description of the matrix multiplication operation.
35 enum class MatMulOp : uint8_t {
36 GEMM,
37 GEMV,
38 };
39
40 /// Benchmarks a matrix multiplication micro-kernel.
41 ///
42 /// @tparam MatMulInterface Interface of the matrix multiplication micro-kernel.
43 /// @param state State for the benchmark to use.
44 /// @param matmul_interface Abstraction containing the micro-kernel to run.
45 /// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions
46 /// internally about the stride of the data.
47 /// @param matmul_op Type of matrix multiplication operation.
48 /// @param is_cpu_supported Function that checks the CPU feature requirement to run this benchmark.
49 template <typename MatMulInterface>
50 83 void kai_benchmark_matmul(
51 ::benchmark::State& state, const MatMulInterface matmul_interface, const DataType dst_type,
52 const MatMulOp matmul_op, const CpuRequirement& is_cpu_supported) {
53
6/12
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
83 if (!is_cpu_supported()) {
54 state.SkipWithMessage("Unsupported CPU feature");
55 }
56
57 83 const size_t m = state.range(0);
58 83 const size_t n = state.range(1);
59 83 const size_t k = state.range(2);
60 83 const size_t bl = state.range(3);
61
62
6/24
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 11 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 24 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 19 times.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 8 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
83 if (m > 1 && matmul_op == MatMulOp::GEMV) {
63 state.SkipWithMessage("GEMV optimized for m=1 only");
64 }
65
66 if constexpr (
67 std::is_same_v<MatMulInterface, MatMulBlockwiseDynamicQuantInterface> ||
68 std::is_same_v<MatMulInterface, MatMulBlockwiseDynamicQuantGenericDstInterface>) {
69
2/4
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
32 if (k % bl != 0) {
70 state.SkipWithMessage("K must be a multiple of block size");
71 }
72 }
73
74 // Create sufficiently large buffers
75 83 size_t lhs_size = m * k * sizeof(uint64_t);
76 83 size_t rhs_size = n * k * sizeof(uint64_t);
77 83 size_t dst_size = m * n * sizeof(uint32_t);
78
79
6/24
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 11 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 24 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 19 times.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 8 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
83 if (test::cpu_has_sme() || test::cpu_has_sme2()) {
80 83 lhs_size *= kai_get_sme_vector_length_u32();
81 83 rhs_size *= kai_get_sme_vector_length_u32();
82 83 dst_size *= kai_get_sme_vector_length_u32();
83 83 }
84
85 83 const Buffer lhs(lhs_size);
86
6/12
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
83 const Buffer rhs(rhs_size);
87
6/12
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
83 Buffer dst(dst_size);
88
89
6/12
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
83 MatMulRunner matmul_runner(matmul_interface, dst_type);
90
6/12
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
83 matmul_runner.set_mnk(m, n, k);
91
6/12
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
83 matmul_runner.set_bl(bl);
92
93
32/52
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
✓ Branch 7 taken 18 times.
✓ Branch 8 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 11 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 22 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 11 times.
✓ Branch 19 taken 11 times.
✓ Branch 20 taken 24 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 24 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 48 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 24 times.
✓ Branch 27 taken 24 times.
✓ Branch 28 taken 19 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 19 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 38 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 19 times.
✓ Branch 35 taken 19 times.
✓ Branch 36 taken 3 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 3 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 6 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 3 times.
✓ Branch 43 taken 3 times.
✓ Branch 44 taken 8 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 8 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 16 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 8 times.
✓ Branch 51 taken 8 times.
166 for (auto _ : state) {
94
6/12
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
83 matmul_runner.run(lhs.data(), rhs.data(), dst.data());
95 83 }
96 83 }
97 } // namespace kai::benchmark
98
99 #endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP
100