KleidiAI Coverage Report


Directory: ./
File: benchmark/matmul/matmul_runner.hpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 100.0% 58 0 58
Functions: 100.0% 24 0 24
Branches: -% 0 0 0

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP
8 #define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP
9
10 #include <cfloat>
11 #include <cstddef>
12 #include <cstdint>
13 #include <test/common/data_type.hpp>
14
15 #include "kai/kai_common.h"
16 #include "matmul_interface.hpp"
17
18 namespace kai::benchmark {
19
20 using DataType = test::DataType;
21
22 /// Runner for the matrix multiplication micro-kernel.
23 ///
24 /// Prepares and executes the run method of the micro-kernel.
25 ///
26 /// @tparam MatMulInterface Interface of the matrix multiplication micro-kernel.
27 template <typename MatMulInterface>
28 class MatMulRunner {
29 public:
30 /// Constructs a MatMulRunner object.
31 ///
32 /// @param matmul_interface Abstraction containing the micro-kernel to run.
33 /// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions
34 /// internally about the stride of the data.
35 166 MatMulRunner(const MatMulInterface& matmul_interface, const DataType dst_type) :
36 166 matmul_interface_(matmul_interface), dst_type_(dst_type) {
37 166 }
38
39 /// Sets the M, N and K dimensions to describe the operand and result matrices.
40 ///
41 /// @param m Rows in a non-transposed LHS and DST matrix.
42 /// @param n Columns in a non-transposed RHS and DST matrix.
43 /// @param k Columns in a non-transposed LHS matrix, and rows in a non-transposed RHS matrix.
44 83 void set_mnk(const size_t m, const size_t n, const size_t k) {
45 83 m_ = m;
46 83 n_ = n;
47 83 k_ = k;
48
49 83 lhs_stride_ = k_ * data_type_size_in_bits(dst_type_) / 8;
50 83 dst_stride_row_ = n_ * data_type_size_in_bits(dst_type_) / 8;
51 83 dst_stride_col_ = data_type_size_in_bits(dst_type_) / 8;
52 83 }
53
54 /// Sets the block size to use.
55 ///
56 /// @param bl Block size. Used for micro-kernels with dynamic blockwise quantization.
57 83 void set_bl(const size_t bl) {
58 83 bl_ = bl;
59 83 }
60
61 /// Runs the matrix multiplication micro-kernel.
62 ///
63 /// @param lhs Buffer containing LHS matrix data.
64 /// @param rhs Buffer containing RHS matrix data.
65 /// @param dst Destination buffer to write to.
66 void run(const void* lhs, const void* rhs, void* dst);
67
68 private:
69 MatMulInterface matmul_interface_ = {};
70
71 DataType dst_type_ = DataType::FP32;
72
73 83 size_t m_ = 1;
74 83 size_t n_ = 1;
75 83 size_t k_ = 1;
76 83 size_t bl_ = 32;
77
78 83 size_t lhs_stride_ = 1;
79 83 size_t dst_stride_row_ = 1;
80 83 size_t dst_stride_col_ = 1;
81 };
82
83 /// Runs the matrix multiplication micro-kernel.
84 ///
85 /// @param lhs Buffer containing LHS matrix data.
86 /// @param rhs Buffer containing RHS matrix data.
87 /// @param dst Destination buffer to write to.
88 template <typename MatMulInterface>
89 18 void MatMulRunner<MatMulInterface>::run(const void* lhs, const void* rhs, void* dst) {
90 36 matmul_interface_.run_matmul(
91 18 m_, n_, k_, //
92 18 lhs, rhs, dst, //
93 18 dst_stride_row_, dst_stride_col_, //
94 -FLT_MAX, FLT_MAX //
95 );
96 18 }
97
98 /// Runs the matrix multiplication micro-kernel. Specialized on the strided LHS interface.
99 ///
100 /// @param lhs Buffer containing LHS matrix data.
101 /// @param rhs Buffer containing RHS matrix data.
102 /// @param dst Destination buffer to write to.
103 template <>
104 11 inline void MatMulRunner<MatMulStridedLhsInterface>::run(const void* lhs, const void* rhs, void* dst) {
105 22 matmul_interface_.run_matmul(
106 11 m_, n_, k_, //
107 11 lhs, lhs_stride_, rhs, dst, //
108 11 dst_stride_row_, dst_stride_col_, //
109 -FLT_MAX, FLT_MAX //
110 );
111 11 }
112
113 /// Runs the matrix multiplication micro-kernel. Specialized on the interface with a floating point destination buffer.
114 ///
115 /// @param lhs Buffer containing LHS matrix data.
116 /// @param rhs Buffer containing RHS matrix data.
117 /// @param dst Destination buffer to write to.
118 template <>
119 19 inline void MatMulRunner<MatMulFloatInterface>::run(const void* lhs, const void* rhs, void* dst) {
120 38 matmul_interface_.run_matmul(
121 19 m_, n_, k_, //
122 19 lhs, rhs, static_cast<float*>(dst), //
123 19 dst_stride_row_, dst_stride_col_, //
124 -FLT_MAX, FLT_MAX //
125 );
126 19 }
127
128 /// Runs the matrix multiplication micro-kernel. Specialized on the static quantization interface.
129 ///
130 /// @param lhs Buffer containing LHS matrix data.
131 /// @param rhs Buffer containing RHS matrix data.
132 /// @param dst Destination buffer to write to.
133 template <>
134 3 inline void MatMulRunner<MatMulStaticQuantInterface>::run(const void* lhs, const void* rhs, void* dst) {
135 3 constexpr kai_matmul_requantize32_params params = {INT8_MIN, INT8_MAX, 0};
136 6 matmul_interface_.run_matmul(
137 3 m_, n_, k_, //
138 3 lhs, rhs, dst, //
139 3 dst_stride_row_, dst_stride_col_, //
140 &params //
141 );
142 3 }
143
144 /// Runs the matrix multiplication micro-kernel. Specialized on the dynamic blockwise quantization interface with
145 /// generic destination buffer.
146 ///
147 /// @param lhs Buffer containing LHS matrix data.
148 /// @param rhs Buffer containing RHS matrix data.
149 /// @param dst Destination buffer to write to.
150 template <>
151 8 inline void MatMulRunner<MatMulBlockwiseDynamicQuantGenericDstInterface>::run(
152 const void* lhs, const void* rhs, void* dst) {
153 16 matmul_interface_.run_matmul(
154 8 m_, n_, k_, bl_, //
155 8 lhs, rhs, dst, //
156 8 dst_stride_row_, dst_stride_col_, //
157 -FLT_MAX, FLT_MAX //
158 );
159 8 }
160
161 /// Runs the matrix multiplication micro-kernel. Specialized on the dynamic blockwise quantization interface.
162 ///
163 /// @param lhs Buffer containing LHS matrix data.
164 /// @param rhs Buffer containing RHS matrix data.
165 /// @param dst Destination buffer to write to.
166 template <>
167 24 inline void MatMulRunner<MatMulBlockwiseDynamicQuantInterface>::run(const void* lhs, const void* rhs, void* dst) {
168 48 matmul_interface_.run_matmul(
169 24 m_, n_, k_, bl_, //
170 24 lhs, rhs, static_cast<float*>(dst), //
171 24 dst_stride_row_, dst_stride_col_, //
172 -FLT_MAX, FLT_MAX //
173 );
174 24 }
175
176 } // namespace kai::benchmark
177
178 #endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP
179