KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 87.5% 28 / 0 / 32
Functions: 100.0% 12 / 0 / 12
Branches: 44.6% 131 / 0 / 294

benchmark/matmul/matmul_benchmark_logic.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP
8 #define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP
9
10 #include <cstddef>
11 #include <cstdint>
12 #include <functional>
13 #include <vector>
14
15 #include "kai/kai_common.h"
16 #include "matmul_interface.hpp"
17 #include "matmul_runner.hpp"
18
19 #ifdef __GNUC__
20 #pragma GCC diagnostic push
21 #pragma GCC diagnostic ignored "-Wswitch-default"
22 #endif // __GNUC__
23
24 #include <benchmark/benchmark.h>
25
26 #ifdef __GNUC__
27 #pragma GCC diagnostic pop
28 #endif // __GNUC__
29
30 namespace kai::benchmark {
31 using Buffer = std::vector<uint8_t>;
32 using CpuRequirement = std::function<bool()>;
33
34 /// High level description of the matrix multiplication operation.
35 enum class MatMulOp : uint8_t {
36 GEMM,
37 GEMV,
38 };
39
40 /// Benchmarks a matrix multiplication micro-kernel.
41 ///
42 /// @tparam MatMulInterface Interface of the matrix multiplication micro-kernel.
43 /// @param state State for the benchmark to use.
44 /// @param matmul_interface Abstraction containing the micro-kernel to run.
45 /// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions
46 /// internally about the stride of the data.
47 /// @param matmul_op Type of matrix multiplication operation.
48 /// @param is_cpu_supported Function that checks the CPU feature requirement to run this benchmark.
49 template <typename MatMulInterface>
50 534 void kai_benchmark_matmul(
51 ::benchmark::State& state, const MatMulInterface matmul_interface, const DataType dst_type,
52 const MatMulOp matmul_op, const CpuRequirement& is_cpu_supported) {
53
12/12
✓ Branch 0 taken 88 times.
✓ Branch 1 taken 20 times.
✓ Branch 2 taken 50 times.
✓ Branch 3 taken 22 times.
✓ Branch 4 taken 138 times.
✓ Branch 5 taken 36 times.
✓ Branch 6 taken 90 times.
✓ Branch 7 taken 24 times.
✓ Branch 8 taken 6 times.
✓ Branch 9 taken 12 times.
✓ Branch 10 taken 40 times.
✓ Branch 11 taken 8 times.
534 if (!is_cpu_supported()) {
54
14/34
✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 30 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 30 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 24 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 24 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 12 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✗ Branch 23 not taken.
122 state.SkipWithMessage("Unsupported CPU feature");
55 122 return;
56 }
57
58 412 const size_t m = state.range(0);
59 412 const size_t n = state.range(1);
60 412 const size_t k = state.range(2);
61 412 const size_t bl = state.range(3);
62
63
6/24
✗ Branch 0 not taken.
✓ Branch 1 taken 88 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 50 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 138 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 90 times.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 6 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 40 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
412 if (m > 1 && matmul_op == MatMulOp::GEMV) {
64 state.SkipWithMessage("GEMV optimized for m=1 only");
65 return;
66 }
67
68 if constexpr (
69 std::is_same_v<MatMulInterface, MatMulBlockwiseDynamicQuantInterface> ||
70 std::is_same_v<MatMulInterface, MatMulBlockwiseDynamicQuantGenericDstInterface>) {
71
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 138 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 40 times.
178 if (k % bl != 0) {
72 state.SkipWithMessage("K must be a multiple of block size");
73 return;
74 }
75 }
76
77 // Create sufficiently large buffers
78 412 size_t lhs_size = m * k * sizeof(uint64_t);
79 412 size_t rhs_size = n * k * sizeof(uint64_t);
80 412 size_t dst_size = m * n * sizeof(uint32_t);
81
82
16/24
✓ Branch 0 taken 52 times.
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 52 times.
✓ Branch 4 taken 28 times.
✓ Branch 5 taken 22 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 28 times.
✓ Branch 8 taken 86 times.
✓ Branch 9 taken 52 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 86 times.
✓ Branch 12 taken 52 times.
✓ Branch 13 taken 38 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 52 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 6 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 24 times.
✓ Branch 21 taken 16 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 24 times.
412 if (test::cpu_has_sme() || test::cpu_has_sme2()) {
83 170 lhs_size *= kai_get_sme_vector_length_u32();
84 170 rhs_size *= kai_get_sme_vector_length_u32();
85 170 dst_size *= kai_get_sme_vector_length_u32();
86 170 }
87
88
5/12
✓ Branch 0 taken 52 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 28 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 86 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 52 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
412 const Buffer lhs(lhs_size);
89
10/22
✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22 times.
✓ Branch 2 taken 28 times.
✗ Branch 3 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 52 times.
✓ Branch 4 taken 86 times.
✗ Branch 5 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 38 times.
✓ Branch 6 taken 52 times.
✗ Branch 7 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
412 const Buffer rhs(rhs_size);
90
10/22
✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22 times.
✓ Branch 2 taken 28 times.
✗ Branch 3 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 52 times.
✓ Branch 4 taken 86 times.
✗ Branch 5 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 38 times.
✓ Branch 6 taken 52 times.
✗ Branch 7 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
412 Buffer dst(dst_size);
91
92
6/12
✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 90 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
412 MatMulRunner matmul_runner(matmul_interface, dst_type);
93
6/12
✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 90 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
412 matmul_runner.set_mnk(m, n, k);
94
6/12
✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 90 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
412 matmul_runner.set_bl(bl);
95
96
32/52
✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 88 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 176 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 88 times.
✓ Branch 7 taken 88 times.
✓ Branch 8 taken 88 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 88 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 50 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 50 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 100 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 50 times.
✓ Branch 19 taken 50 times.
✓ Branch 20 taken 138 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 138 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 276 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 138 times.
✓ Branch 27 taken 138 times.
✓ Branch 28 taken 90 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 90 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 180 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 90 times.
✓ Branch 35 taken 90 times.
✓ Branch 36 taken 6 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 6 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 12 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 6 times.
✓ Branch 43 taken 6 times.
✓ Branch 44 taken 40 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 40 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 80 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 40 times.
✓ Branch 51 taken 40 times.
824 for (auto _ : state) {
97
6/12
✓ Branch 0 taken 88 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 90 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
412 matmul_runner.run(lhs.data(), rhs.data(), dst.data());
98 412 }
99 534 }
100 } // namespace kai::benchmark
101
102 #endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_BENCHMARK_LOGIC_HPP
103