Line |
Branch |
Exec |
Source |
1 |
|
|
// |
2 |
|
|
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> |
3 |
|
|
// |
4 |
|
|
// SPDX-License-Identifier: Apache-2.0 |
5 |
|
|
// |
6 |
|
|
|
7 |
|
|
#ifndef KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP |
8 |
|
|
#define KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP |
9 |
|
|
|
10 |
|
|
#include <cfloat> |
11 |
|
|
#include <cstddef> |
12 |
|
|
#include <cstdint> |
13 |
|
|
#include <test/common/data_type.hpp> |
14 |
|
|
|
15 |
|
|
#include "kai/kai_common.h" |
16 |
|
|
#include "matmul_interface.hpp" |
17 |
|
|
|
18 |
|
|
namespace kai::benchmark { |
19 |
|
|
|
20 |
|
|
using DataType = test::DataType; |
21 |
|
|
|
22 |
|
|
/// Runner for the matrix multiplication micro-kernel. |
23 |
|
|
/// |
24 |
|
|
/// Prepares and executes the run method of the micro-kernel. |
25 |
|
|
/// |
26 |
|
|
/// @tparam MatMulInterface Interface of the matrix multiplication micro-kernel. |
27 |
|
|
template <typename MatMulInterface> |
28 |
|
|
class MatMulRunner { |
29 |
|
|
public: |
30 |
|
|
/// Constructs a MatMulRunner object. |
31 |
|
|
/// |
32 |
|
|
/// @param matmul_interface Abstraction containing the micro-kernel to run. |
33 |
|
|
/// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions |
34 |
|
|
/// internally about the stride of the data. |
35 |
|
166 |
MatMulRunner(const MatMulInterface& matmul_interface, const DataType dst_type) : |
36 |
|
166 |
matmul_interface_(matmul_interface), dst_type_(dst_type) { |
37 |
|
166 |
} |
38 |
|
|
|
39 |
|
|
/// Sets the M, N and K dimensions to describe the operand and result matrices. |
40 |
|
|
/// |
41 |
|
|
/// @param m Rows in a non-transposed LHS and DST matrix. |
42 |
|
|
/// @param n Columns in a non-transposed RHS and DST matrix. |
43 |
|
|
/// @param k Columns in a non-transposed LHS matrix, and rows in a non-transposed RHS matrix. |
44 |
|
83 |
void set_mnk(const size_t m, const size_t n, const size_t k) { |
45 |
|
83 |
m_ = m; |
46 |
|
83 |
n_ = n; |
47 |
|
83 |
k_ = k; |
48 |
|
|
|
49 |
|
83 |
lhs_stride_ = k_ * data_type_size_in_bits(dst_type_) / 8; |
50 |
|
83 |
dst_stride_row_ = n_ * data_type_size_in_bits(dst_type_) / 8; |
51 |
|
83 |
dst_stride_col_ = data_type_size_in_bits(dst_type_) / 8; |
52 |
|
83 |
} |
53 |
|
|
|
54 |
|
|
/// Sets the block size to use. |
55 |
|
|
/// |
56 |
|
|
/// @param bl Block size. Used for micro-kernels with dynamic blockwise quantization. |
57 |
|
83 |
void set_bl(const size_t bl) { |
58 |
|
83 |
bl_ = bl; |
59 |
|
83 |
} |
60 |
|
|
|
61 |
|
|
/// Runs the matrix multiplication micro-kernel. |
62 |
|
|
/// |
63 |
|
|
/// @param lhs Buffer containing LHS matrix data. |
64 |
|
|
/// @param rhs Buffer containing RHS matrix data. |
65 |
|
|
/// @param dst Destination buffer to write to. |
66 |
|
|
void run(const void* lhs, const void* rhs, void* dst); |
67 |
|
|
|
68 |
|
|
private: |
69 |
|
|
MatMulInterface matmul_interface_ = {}; |
70 |
|
|
|
71 |
|
|
DataType dst_type_ = DataType::FP32; |
72 |
|
|
|
73 |
|
83 |
size_t m_ = 1; |
74 |
|
83 |
size_t n_ = 1; |
75 |
|
83 |
size_t k_ = 1; |
76 |
|
83 |
size_t bl_ = 32; |
77 |
|
|
|
78 |
|
83 |
size_t lhs_stride_ = 1; |
79 |
|
83 |
size_t dst_stride_row_ = 1; |
80 |
|
83 |
size_t dst_stride_col_ = 1; |
81 |
|
|
}; |
82 |
|
|
|
83 |
|
|
/// Runs the matrix multiplication micro-kernel. |
84 |
|
|
/// |
85 |
|
|
/// @param lhs Buffer containing LHS matrix data. |
86 |
|
|
/// @param rhs Buffer containing RHS matrix data. |
87 |
|
|
/// @param dst Destination buffer to write to. |
88 |
|
|
template <typename MatMulInterface> |
89 |
|
18 |
void MatMulRunner<MatMulInterface>::run(const void* lhs, const void* rhs, void* dst) { |
90 |
|
36 |
matmul_interface_.run_matmul( |
91 |
|
18 |
m_, n_, k_, // |
92 |
|
18 |
lhs, rhs, dst, // |
93 |
|
18 |
dst_stride_row_, dst_stride_col_, // |
94 |
|
|
-FLT_MAX, FLT_MAX // |
95 |
|
|
); |
96 |
|
18 |
} |
97 |
|
|
|
98 |
|
|
/// Runs the matrix multiplication micro-kernel. Specialized on the strided LHS interface. |
99 |
|
|
/// |
100 |
|
|
/// @param lhs Buffer containing LHS matrix data. |
101 |
|
|
/// @param rhs Buffer containing RHS matrix data. |
102 |
|
|
/// @param dst Destination buffer to write to. |
103 |
|
|
template <> |
104 |
|
11 |
inline void MatMulRunner<MatMulStridedLhsInterface>::run(const void* lhs, const void* rhs, void* dst) { |
105 |
|
22 |
matmul_interface_.run_matmul( |
106 |
|
11 |
m_, n_, k_, // |
107 |
|
11 |
lhs, lhs_stride_, rhs, dst, // |
108 |
|
11 |
dst_stride_row_, dst_stride_col_, // |
109 |
|
|
-FLT_MAX, FLT_MAX // |
110 |
|
|
); |
111 |
|
11 |
} |
112 |
|
|
|
113 |
|
|
/// Runs the matrix multiplication micro-kernel. Specialized on the interface with a floating point destination buffer. |
114 |
|
|
/// |
115 |
|
|
/// @param lhs Buffer containing LHS matrix data. |
116 |
|
|
/// @param rhs Buffer containing RHS matrix data. |
117 |
|
|
/// @param dst Destination buffer to write to. |
118 |
|
|
template <> |
119 |
|
19 |
inline void MatMulRunner<MatMulFloatInterface>::run(const void* lhs, const void* rhs, void* dst) { |
120 |
|
38 |
matmul_interface_.run_matmul( |
121 |
|
19 |
m_, n_, k_, // |
122 |
|
19 |
lhs, rhs, static_cast<float*>(dst), // |
123 |
|
19 |
dst_stride_row_, dst_stride_col_, // |
124 |
|
|
-FLT_MAX, FLT_MAX // |
125 |
|
|
); |
126 |
|
19 |
} |
127 |
|
|
|
128 |
|
|
/// Runs the matrix multiplication micro-kernel. Specialized on the static quantization interface. |
129 |
|
|
/// |
130 |
|
|
/// @param lhs Buffer containing LHS matrix data. |
131 |
|
|
/// @param rhs Buffer containing RHS matrix data. |
132 |
|
|
/// @param dst Destination buffer to write to. |
133 |
|
|
template <> |
134 |
|
3 |
inline void MatMulRunner<MatMulStaticQuantInterface>::run(const void* lhs, const void* rhs, void* dst) { |
135 |
|
3 |
constexpr kai_matmul_requantize32_params params = {INT8_MIN, INT8_MAX, 0}; |
136 |
|
6 |
matmul_interface_.run_matmul( |
137 |
|
3 |
m_, n_, k_, // |
138 |
|
3 |
lhs, rhs, dst, // |
139 |
|
3 |
dst_stride_row_, dst_stride_col_, // |
140 |
|
|
¶ms // |
141 |
|
|
); |
142 |
|
3 |
} |
143 |
|
|
|
144 |
|
|
/// Runs the matrix multiplication micro-kernel. Specialized on the dynamic blockwise quantization interface with |
145 |
|
|
/// generic destination buffer. |
146 |
|
|
/// |
147 |
|
|
/// @param lhs Buffer containing LHS matrix data. |
148 |
|
|
/// @param rhs Buffer containing RHS matrix data. |
149 |
|
|
/// @param dst Destination buffer to write to. |
150 |
|
|
template <> |
151 |
|
8 |
inline void MatMulRunner<MatMulBlockwiseDynamicQuantGenericDstInterface>::run( |
152 |
|
|
const void* lhs, const void* rhs, void* dst) { |
153 |
|
16 |
matmul_interface_.run_matmul( |
154 |
|
8 |
m_, n_, k_, bl_, // |
155 |
|
8 |
lhs, rhs, dst, // |
156 |
|
8 |
dst_stride_row_, dst_stride_col_, // |
157 |
|
|
-FLT_MAX, FLT_MAX // |
158 |
|
|
); |
159 |
|
8 |
} |
160 |
|
|
|
161 |
|
|
/// Runs the matrix multiplication micro-kernel. Specialized on the dynamic blockwise quantization interface. |
162 |
|
|
/// |
163 |
|
|
/// @param lhs Buffer containing LHS matrix data. |
164 |
|
|
/// @param rhs Buffer containing RHS matrix data. |
165 |
|
|
/// @param dst Destination buffer to write to. |
166 |
|
|
template <> |
167 |
|
24 |
inline void MatMulRunner<MatMulBlockwiseDynamicQuantInterface>::run(const void* lhs, const void* rhs, void* dst) { |
168 |
|
48 |
matmul_interface_.run_matmul( |
169 |
|
24 |
m_, n_, k_, bl_, // |
170 |
|
24 |
lhs, rhs, static_cast<float*>(dst), // |
171 |
|
24 |
dst_stride_row_, dst_stride_col_, // |
172 |
|
|
-FLT_MAX, FLT_MAX // |
173 |
|
|
); |
174 |
|
24 |
} |
175 |
|
|
|
176 |
|
|
} // namespace kai::benchmark |
177 |
|
|
|
178 |
|
|
#endif // KLEIDIAI_BENCHMARK_MATMUL_MATMUL_RUNNER_HPP |
179 |
|
|
|