KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 58 / 0 / 58
Functions: 100.0% 24 / 0 / 24
Branches: -% 0 / 0 / 0

benchmark/matmul/matmul_runner.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP
8 #define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP
9
10 #include <cfloat>
11 #include <cstddef>
12 #include <cstdint>
13 #include <test/common/data_type.hpp>
14
15 #include "kai/kai_common.h"
16 #include "matmul_interface.hpp"
17
18 namespace kai::benchmark {
19
20 using DataType = test::DataType;
21
22 /// Runner for the matrix multiplication micro-kernel.
23 ///
24 /// Prepares and executes the run method of the micro-kernel.
25 ///
26 /// @tparam MatMulInterface Interface of the matrix multiplication micro-kernel.
27 template <typename MatMulInterface>
28 class MatMulRunner {
29 public:
30 /// Constructs a MatMulRunner object.
31 ///
32 /// @param matmul_interface Abstraction containing the micro-kernel to run.
33 /// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions
34 /// internally about the stride of the data.
35 582 MatMulRunner(const MatMulInterface& matmul_interface, const DataType dst_type) :
36 582 matmul_interface_(matmul_interface), dst_type_(dst_type) {
37 582 }
38
39 /// Sets the M, N and K dimensions to describe the operand and result matrices.
40 ///
41 /// @param m Rows in a non-transposed LHS and DST matrix.
42 /// @param n Columns in a non-transposed RHS and DST matrix.
43 /// @param k Columns in a non-transposed LHS matrix, and rows in a non-transposed RHS matrix.
44 412 void set_mnk(const size_t m, const size_t n, const size_t k) {
45 412 m_ = m;
46 412 n_ = n;
47 412 k_ = k;
48
49 412 lhs_stride_ = k_ * data_type_size_in_bits(dst_type_) / 8;
50 412 dst_stride_row_ = n_ * data_type_size_in_bits(dst_type_) / 8;
51 412 dst_stride_col_ = data_type_size_in_bits(dst_type_) / 8;
52 412 }
53
54 /// Sets the block size to use.
55 ///
56 /// @param bl Block size. Used for micro-kernels with dynamic blockwise quantization.
57 412 void set_bl(const size_t bl) {
58 412 bl_ = bl;
59 412 }
60
61 /// Runs the matrix multiplication micro-kernel.
62 ///
63 /// @param lhs Buffer containing LHS matrix data.
64 /// @param rhs Buffer containing RHS matrix data.
65 /// @param dst Destination buffer to write to.
66 void run(const void* lhs, const void* rhs, void* dst);
67
68 private:
69 MatMulInterface matmul_interface_ = {};
70
71 DataType dst_type_ = DataType::FP32;
72
73 412 size_t m_ = 1;
74 412 size_t n_ = 1;
75 412 size_t k_ = 1;
76 412 size_t bl_ = 32;
77
78 412 size_t lhs_stride_ = 1;
79 412 size_t dst_stride_row_ = 1;
80 412 size_t dst_stride_col_ = 1;
81 };
82
83 /// Runs the matrix multiplication micro-kernel.
84 ///
85 /// @param lhs Buffer containing LHS matrix data.
86 /// @param rhs Buffer containing RHS matrix data.
87 /// @param dst Destination buffer to write to.
88 template <typename MatMulInterface>
89 88 void MatMulRunner<MatMulInterface>::run(const void* lhs, const void* rhs, void* dst) {
90 176 matmul_interface_.run_matmul(
91 88 m_, n_, k_, //
92 88 lhs, rhs, dst, //
93 88 dst_stride_row_, dst_stride_col_, //
94 -FLT_MAX, FLT_MAX //
95 );
96 88 }
97
98 /// Runs the matrix multiplication micro-kernel. Specialized on the strided LHS interface.
99 ///
100 /// @param lhs Buffer containing LHS matrix data.
101 /// @param rhs Buffer containing RHS matrix data.
102 /// @param dst Destination buffer to write to.
103 template <>
104 50 inline void MatMulRunner<MatMulStridedLhsInterface>::run(const void* lhs, const void* rhs, void* dst) {
105 100 matmul_interface_.run_matmul(
106 50 m_, n_, k_, //
107 50 lhs, lhs_stride_, rhs, dst, //
108 50 dst_stride_row_, dst_stride_col_, //
109 -FLT_MAX, FLT_MAX //
110 );
111 50 }
112
113 /// Runs the matrix multiplication micro-kernel. Specialized on the interface with a floating point destination buffer.
114 ///
115 /// @param lhs Buffer containing LHS matrix data.
116 /// @param rhs Buffer containing RHS matrix data.
117 /// @param dst Destination buffer to write to.
118 template <>
119 90 inline void MatMulRunner<MatMulFloatInterface>::run(const void* lhs, const void* rhs, void* dst) {
120 180 matmul_interface_.run_matmul(
121 90 m_, n_, k_, //
122 90 lhs, rhs, static_cast<float*>(dst), //
123 90 dst_stride_row_, dst_stride_col_, //
124 -FLT_MAX, FLT_MAX //
125 );
126 90 }
127
128 /// Runs the matrix multiplication micro-kernel. Specialized on the static quantization interface.
129 ///
130 /// @param lhs Buffer containing LHS matrix data.
131 /// @param rhs Buffer containing RHS matrix data.
132 /// @param dst Destination buffer to write to.
133 template <>
134 6 inline void MatMulRunner<MatMulStaticQuantInterface>::run(const void* lhs, const void* rhs, void* dst) {
135 6 constexpr kai_matmul_requantize32_params params = {INT8_MIN, INT8_MAX, 0};
136 12 matmul_interface_.run_matmul(
137 6 m_, n_, k_, //
138 6 lhs, rhs, dst, //
139 6 dst_stride_row_, dst_stride_col_, //
140 &params //
141 );
142 6 }
143
144 /// Runs the matrix multiplication micro-kernel. Specialized on the dynamic blockwise quantization interface with
145 /// generic destination buffer.
146 ///
147 /// @param lhs Buffer containing LHS matrix data.
148 /// @param rhs Buffer containing RHS matrix data.
149 /// @param dst Destination buffer to write to.
150 template <>
151 40 inline void MatMulRunner<MatMulBlockwiseDynamicQuantGenericDstInterface>::run(
152 const void* lhs, const void* rhs, void* dst) {
153 80 matmul_interface_.run_matmul(
154 40 m_, n_, k_, bl_, //
155 40 lhs, rhs, dst, //
156 40 dst_stride_row_, dst_stride_col_, //
157 -FLT_MAX, FLT_MAX //
158 );
159 40 }
160
161 /// Runs the matrix multiplication micro-kernel. Specialized on the dynamic blockwise quantization interface.
162 ///
163 /// @param lhs Buffer containing LHS matrix data.
164 /// @param rhs Buffer containing RHS matrix data.
165 /// @param dst Destination buffer to write to.
166 template <>
167 138 inline void MatMulRunner<MatMulBlockwiseDynamicQuantInterface>::run(const void* lhs, const void* rhs, void* dst) {
168 276 matmul_interface_.run_matmul(
169 138 m_, n_, k_, bl_, //
170 138 lhs, rhs, static_cast<float*>(dst), //
171 138 dst_stride_row_, dst_stride_col_, //
172 -FLT_MAX, FLT_MAX //
173 );
174 138 }
175
176 } // namespace kai::benchmark
177
178 #endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP
179