Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #pragma once | ||
7 | |||
8 | #include <cstddef> | ||
9 | #include <cstdint> | ||
10 | #include <functional> | ||
11 | #include <iosfwd> | ||
12 | #include <string> | ||
13 | #include <string_view> | ||
14 | #include <tuple> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | #include "test/common/data_format.hpp" | ||
18 | #include "test/common/float16.hpp" | ||
19 | #include "test/common/matrix_portion.hpp" | ||
20 | |||
21 | namespace kai::test { | ||
22 | /// Matrix multiplication shape. | ||
23 | struct MatMulShape { | ||
24 | size_t m; ///< LHS height. | ||
25 | size_t n; ///< RHS width. | ||
26 | size_t k; ///< LHS width and RHS height. | ||
27 | |||
28 | struct Hash { | ||
29 | 81997 | size_t operator()(const MatMulShape& shape) const { | |
30 | 81997 | return // | |
31 | 163994 | (std::hash<size_t>{}(shape.m) << 0) ^ // | |
32 | 163994 | (std::hash<size_t>{}(shape.n) << 1) ^ // | |
33 | 81997 | (std::hash<size_t>{}(shape.k) << 2); // | |
34 | } | ||
35 | }; | ||
36 | |||
37 | private: | ||
38 | 79726 | friend bool operator==(const MatMulShape& lhs, const MatMulShape& rhs) { | |
39 | 79726 | return // | |
40 |
2/2✓ Branch 0 taken 74236 times.
✓ Branch 1 taken 5490 times.
|
79726 | lhs.m == rhs.m && // |
41 |
2/2✓ Branch 0 taken 2224 times.
✓ Branch 1 taken 72012 times.
|
74236 | lhs.n == rhs.n && // |
42 | 72012 | lhs.k == rhs.k; | |
43 | } | ||
44 | friend std::ostream& operator<<(std::ostream& os, const MatMulShape& shape); | ||
45 | }; | ||
46 | |||
47 | /// Value range | ||
48 | template <typename T> | ||
49 | struct Range { | ||
50 | T min; | ||
51 | T max; | ||
52 | |||
53 | [[nodiscard]] T range() const { | ||
54 | return max - min; | ||
55 | } | ||
56 | }; | ||
57 | |||
58 | // NOLINTBEGIN(misc-non-private-member-variables-in-classes) | ||
59 | |||
60 | /// Matrix multiplication method. | ||
61 |
31/62✓ Branch 0 taken 11118 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11118 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11118 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 11118 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 11118 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11118 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 11118 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 11118 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11118 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 11118 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 11118 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 11118 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 11118 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 11118 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 11118 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 11118 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 11118 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 11118 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 11118 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 11118 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 11118 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 11118 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 11118 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 11118 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 11118 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 11118 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 11118 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 11118 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 11118 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 11118 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 11118 times.
✗ Branch 61 not taken.
|
11118 | struct MatMulMethod { |
62 | std::string_view name{}; ///< Name of matmul method. | ||
63 | |||
64 | size_t m0{0}; ///< Block size in M dimension. | ||
65 | size_t n0{0}; ///< Block size in N dimension. | ||
66 | size_t k0{0}; ///< Block size in K dimension. | ||
67 | |||
68 | DataFormat dst_format{}; ///< Data format of the destination matrix. | ||
69 | DataFormat lhs_format{}; ///< Data format of the LHS matrix. | ||
70 | DataFormat packed_lhs_format{}; ///< Data format of the packed LHS matrix. | ||
71 | DataFormat rhs_format{}; ///< Data format of the RHS matrix. | ||
72 | DataFormat packed_rhs_format{}; ///< Data format of the packed RHS matrix. | ||
73 | DataFormat bias_format{}; ///< Data format of the bias vector. | ||
74 | |||
75 | /// Check if CPU supports required features. | ||
76 | /// | ||
77 | /// @return Supported (true) or not supported (false). | ||
78 | std::function<bool(void)> fn_is_supported{nullptr}; | ||
79 | |||
80 | /// Gets mr value. | ||
81 | /// | ||
82 | /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). | ||
83 | /// | ||
84 | /// @return The mr value. | ||
85 | std::function<size_t(void)> fn_get_mr{nullptr}; | ||
86 | |||
87 | /// Gets nr value. | ||
88 | /// | ||
89 | /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). | ||
90 | /// | ||
91 | /// @return The nr value. | ||
92 | std::function<size_t(void)> fn_get_nr{nullptr}; | ||
93 | |||
94 | /// Gets kr value. | ||
95 | /// | ||
96 | /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). | ||
97 | /// | ||
98 | /// @return The kr value. | ||
99 | std::function<size_t(void)> fn_get_kr{nullptr}; | ||
100 | |||
101 | /// Gets sr value. | ||
102 | /// | ||
103 | /// This is the packing parameter which must be used to pack the RHS matrix. | ||
104 | /// | ||
105 | /// @return The sr value. | ||
106 | std::function<size_t(void)> fn_get_sr{nullptr}; | ||
107 | |||
108 | /// Gets m step value for main kernel. | ||
109 | /// | ||
110 | /// The starting row index must be divisible by `m_step`. | ||
111 | /// | ||
112 | /// @return The m step value. | ||
113 | std::function<size_t(void)> fn_get_main_m_step{nullptr}; | ||
114 | |||
115 | /// Gets n step value for RHS packing micro-kernel. | ||
116 | /// | ||
117 | /// The starting row index must be divisible by `n_step`. | ||
118 | /// | ||
119 | /// @return The n step value. | ||
120 | std::function<size_t(void)> fn_get_pack_rhs_n_step{nullptr}; | ||
121 | |||
122 | /// Gets n step value for main kernel. | ||
123 | /// | ||
124 | /// The starting column index must be divisible by `n_step`. | ||
125 | /// | ||
126 | /// @return The n step value. | ||
127 | std::function<size_t(void)> fn_get_main_n_step{nullptr}; | ||
128 | |||
129 | /// Gets the offset in bytes of the LHS matrix. | ||
130 | /// | ||
131 | /// @param[in] m_idx Coordinate of the matrix in M dimension. | ||
132 | /// @param[in] stride Row stride in bytes. | ||
133 | /// | ||
134 | /// @return The offset in bytes. | ||
135 | std::function<size_t(size_t m_idx, size_t stride)> fn_get_lhs_offset{nullptr}; | ||
136 | |||
137 | /// Gets the size in bytes of the packed LHS matrix. | ||
138 | /// | ||
139 | /// @param[in] m Number of rows in the unpacked LHS matrix. | ||
140 | /// @param[in] k Number of columns in the unpacked LHS matrix. | ||
141 | /// @param[in] mr Number of rows to be interleaved. | ||
142 | /// @param[in] kr Unused. Must be 1. | ||
143 | /// @param[in] sr Unused. Must be 1. | ||
144 | /// | ||
145 | /// @return The size in bytes. | ||
146 | std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)> fn_get_packed_lhs_size{nullptr}; | ||
147 | |||
148 | /// Gets the offset in bytes of the packed LHS matrix. | ||
149 | /// | ||
150 | /// @param[in] m_idx Coordinate of the matrix in M dimension. | ||
151 | /// @param[in] k Size of the matrix in K dimension. | ||
152 | /// | ||
153 | /// @return The offset in bytes. | ||
154 | std::function<size_t(size_t m_idx, size_t k)> fn_get_packed_lhs_offset{nullptr}; | ||
155 | |||
156 | /// Preprocesses the LHS matrix. | ||
157 | /// | ||
158 | /// @param[in] m Number of rows of the unpacked LHS matrix. | ||
159 | /// @param[in] k Common dimension between the LHS and RHS matrix. | ||
160 | /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. | ||
161 | /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. | ||
162 | /// @param[in] sr Number of kr splits. It must be 1. | ||
163 | /// @param[in] m_idx_start Unused. Must be 0. | ||
164 | /// @param[in] lhs LHS matrix data buffer. | ||
165 | /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. | ||
166 | /// @param[out] lhs_packed Packed RHS matrix. | ||
167 | std::function<void( | ||
168 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
169 | void* lhs_packed)> | ||
170 | fn_pack_lhs{nullptr}; | ||
171 | |||
172 | /// Gets a value indicating whether LHS packing is needed. | ||
173 | 830 | [[nodiscard]] bool is_pack_lhs_needed() const { | |
174 | 830 | return fn_pack_lhs != nullptr; | |
175 | } | ||
176 | |||
177 | /// Gets the offset in bytes of the RHS matrix. | ||
178 | /// | ||
179 | /// @param[in] n_idx Coordinate of the matrix in N dimension. | ||
180 | /// | ||
181 | /// @return The offset in bytes. | ||
182 | std::function<size_t(size_t n_idx)> fn_get_rhs_offset{nullptr}; | ||
183 | |||
184 | /// Gets the size in bytes of the packed RHS matrix. | ||
185 | /// | ||
186 | /// @param[in] n Size of the matrix in N dimension. | ||
187 | /// @param[in] k Size of the matrix in K dimension. | ||
188 | /// | ||
189 | /// @return The size in bytes. | ||
190 | std::function<size_t(size_t n, size_t k)> fn_get_packed_rhs_size{nullptr}; | ||
191 | |||
192 | /// Gets the size in bytes of the packed RHS matrix. | ||
193 | /// | ||
194 | /// @param[in] n Size of the matrix in N dimension. | ||
195 | /// @param[in] k Size of the matrix in K dimension. | ||
196 | /// @param[in] nr Block size in N dimension. | ||
197 | /// @param[in] kr Block size in K dimension. | ||
198 | /// | ||
199 | /// @return The size in bytes. | ||
200 | std::function<size_t(size_t n, size_t k, size_t nr, size_t kr)> fn_get_packed_rhs_size_generic_block_size = nullptr; | ||
201 | |||
202 | /// Gets the offset in bytes of the packed RHS matrix in the RHS packing micro-kernel | ||
203 | /// | ||
204 | /// @param[in] n_idx Coordinate of the matrix in N dimension. | ||
205 | /// @param[in] k Size of the matrix in K dimension. | ||
206 | /// | ||
207 | /// @return The offset in bytes. | ||
208 | std::function<size_t(size_t n_idx, size_t k)> fn_get_pack_rhs_packed_rhs_offset{nullptr}; | ||
209 | |||
210 | /// Gets the offset in bytes of the packed RHS matrix in the main kernel. | ||
211 | /// | ||
212 | /// @param[in] n_idx Coordinate of the matrix in N dimension. | ||
213 | /// @param[in] k Size of the matrix in K dimension. | ||
214 | /// | ||
215 | /// @return The offset in bytes. | ||
216 | std::function<size_t(size_t n_idx, size_t k)> fn_get_main_packed_rhs_offset{nullptr}; | ||
217 | |||
218 | std::function<void( | ||
219 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||
220 | const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)> | ||
221 | fn_pack_rhs{nullptr}; | ||
222 | |||
223 | /// Gets n step value. | ||
224 | /// | ||
225 | /// The starting row index must be divisible by `n_step`. | ||
226 | /// | ||
227 | /// @return The n step value. | ||
228 | std::function<size_t()> fn_pack_rhs_nxk_get_n_step{nullptr}; | ||
229 | |||
230 | /// Gets the offset in bytes to the data element in the RHS matrix buffer. | ||
231 | /// | ||
232 | /// @param[in] n_idx Column index. | ||
233 | /// @param[in] rhs_offset Row stride in bytes of the RHS matrix. | ||
234 | /// | ||
235 | /// @return The offset in bytes to the data element. | ||
236 | std::function<size_t(size_t n_idx, size_t rhs_stride)> fn_pack_rhs_nxk_get_rhs_offset{nullptr}; | ||
237 | |||
238 | /// Gets the offset in bytes to the data element in the bias buffer. | ||
239 | /// | ||
240 | /// @param[in] n_idx Column index. | ||
241 | /// | ||
242 | /// @return The offset in bytes to the data element. | ||
243 | std::function<size_t(size_t n_idx)> fn_pack_rhs_nxk_get_bias_offset{nullptr}; | ||
244 | |||
245 | /// Gets the offset in bytes to the data element in the packed RHS buffer. | ||
246 | /// | ||
247 | /// @param[in] n_idx Row index. | ||
248 | /// @param[in] k Number of columns. | ||
249 | /// | ||
250 | /// @return The offset in bytes to the data element. | ||
251 | std::function<size_t(size_t n_idx, size_t k)> fn_pack_rhs_nxk_get_packed_rhs_offset{nullptr}; | ||
252 | |||
253 | /// Gets the size in bytes of the packed RHS buffer. | ||
254 | /// | ||
255 | /// @param[in] n Number of rows. | ||
256 | /// @param[in] k Number of columns. | ||
257 | /// | ||
258 | /// @return The size in bytes of the packed RHS buffer. | ||
259 | std::function<size_t(size_t n, size_t k)> fn_pack_rhs_nxk_get_packed_rhs_size{nullptr}; | ||
260 | |||
261 | /// Runs the RHS packing micro-kernel for matrix multiplication. | ||
262 | /// | ||
263 | /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset | ||
264 | /// calculated using the following functions: | ||
265 | /// | ||
266 | /// * RHS: @ref kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme. | ||
267 | /// * Bias: @ref kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme. | ||
268 | /// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme. | ||
269 | /// | ||
270 | /// @param[in] num_groups Number of groups. It must be 1. | ||
271 | /// @param[in] n Number of columns of the output matrix. | ||
272 | /// @param[in] k Common dimension between the LHS and RHS matrix. | ||
273 | /// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u32(). | ||
274 | /// @param[in] kr Block size in K dimension. It must be 1. | ||
275 | /// @param[in] sr Number of kr splits. It must be 1. | ||
276 | /// @param[in] rhs_stride Row stride in bytes of the RHS matrix. | ||
277 | /// @param[in] rhs RHS matrix data buffer. | ||
278 | /// @param[in] bias Bias matrix data buffer. | ||
279 | /// @param[in] scale Scale data buffer. It must be NULL. | ||
280 | /// @param[out] rhs_packed Packed RHS matrix. | ||
281 | /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. | ||
282 | /// @param[in] params Extra packing parameters. It must be NULL. | ||
283 | std::function<void( | ||
284 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||
285 | const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)> | ||
286 | fn_pack_rhs_nxk{nullptr}; | ||
287 | |||
288 | /// Gets the offset in bytes to the data element in the bias buffer. | ||
289 | /// | ||
290 | /// @param[in] n_idx Column index. | ||
291 | /// | ||
292 | /// @return The offset in bytes to the data element. | ||
293 | std::function<size_t(size_t n_idx)> fn_get_bias_offset{nullptr}; | ||
294 | |||
295 | /// Gets the offset in bytes to the data element in the destination matrix buffer. | ||
296 | /// | ||
297 | /// @param[in] m_idx Row index. | ||
298 | /// @param[in] n_idx Column index. | ||
299 | /// @param[in] stride Row stride in bytes. | ||
300 | /// | ||
301 | /// @return The offset in bytes to the data element. | ||
302 | std::function<size_t(size_t m_idx, size_t n_idx, size_t stride)> fn_get_dst_offset{nullptr}; | ||
303 | |||
304 | /// Gets the size in bytes of the destination matrix buffer. | ||
305 | /// | ||
306 | /// @param[in] m Number of rows. | ||
307 | /// @param[in] n Number of columns. | ||
308 | /// | ||
309 | /// @return The size in bytes of the destination matrix buffer. | ||
310 | std::function<size_t(size_t m, size_t n)> fn_get_dst_size{nullptr}; | ||
311 | |||
312 | /// Performs F16 or F32 matrix multiplication with RHS packing | ||
313 | /// followed by clamp operation. | ||
314 | /// | ||
315 | /// @param[in] m Size of the matrix in M dimension. | ||
316 | /// @param[in] n Size of the matrix in N dimension. | ||
317 | /// @param[in] k Size of the matrix in K dimension. | ||
318 | /// @param[in] lhs LHS data buffer. | ||
319 | /// @param[in] packed_rhs Packed RHS data buffer. | ||
320 | /// @param[out] dst Output data buffer. | ||
321 | /// @param[in] lhs_stride LHS row stride. | ||
322 | /// @param[in] dst_stride_row Output row stride. | ||
323 | /// @param[in] dst_stride_col Output column stride. | ||
324 | /// @param[in] clamp_min Lower bound of the output data. | ||
325 | /// @param[in] clamp_max Upper bound of the output data. | ||
326 | std::function<void( | ||
327 | size_t m, size_t n, size_t k, // | ||
328 | const void* lhs, size_t lhs_stride, // | ||
329 | const void* packed_rhs, // | ||
330 | void* dst, size_t dst_stride_row, size_t dst_stride_col, // | ||
331 | float clamp_min, float clamp_max)> | ||
332 | fn_matmul_f16_f16_f16p = nullptr; | ||
333 | |||
334 | std::function<void( | ||
335 | size_t m, size_t n, size_t k, // | ||
336 | const void* lhs, size_t lhs_stride, // | ||
337 | const void* packed_rhs, // | ||
338 | void* dst, size_t dst_stride_row, size_t dst_stride_col, // | ||
339 | float clamp_min, float clamp_max)> | ||
340 | fn_matmul_f32_f32_f32p = nullptr; | ||
341 | |||
342 | /// Performs BF16 matrix multiplication with LHS and RHS packing | ||
343 | /// followed by clamp operation. | ||
344 | /// | ||
345 | /// @param[in] m Size of the matrix in M dimension. | ||
346 | /// @param[in] n Size of the matrix in N dimension. | ||
347 | /// @param[in] k Size of the matrix in K dimension. | ||
348 | /// @param[in] packed_lhs Packed LHS data buffer. | ||
349 | /// @param[in] packed_rhs Packed RHS data buffer. | ||
350 | /// @param[out] dst Output data buffer. | ||
351 | /// @param[in] dst_stride_row Output row stride. | ||
352 | /// @param[in] dst_stride_col Output column stride. | ||
353 | /// @param[in] clamp_min Lower bound of the output data. | ||
354 | /// @param[in] clamp_max Upper bound of the output data. | ||
355 | std::function<void( | ||
356 | size_t m, size_t n, size_t k, // | ||
357 | const void* packed_lhs, // | ||
358 | const void* packed_rhs, // | ||
359 | void* dst, size_t dst_stride_row, size_t dst_stride_col, // | ||
360 | float clamp_min, float clamp_max)> | ||
361 | fn_matmul_f32_bf16p_bf16p = nullptr; | ||
362 | |||
363 | std::function<void( | ||
364 | size_t m, size_t n, size_t k, // | ||
365 | const void* packed_lhs, // | ||
366 | const void* packed_rhs, // | ||
367 | void* dst, size_t dst_stride_row, size_t dst_stride_col, // | ||
368 | float clamp_min, float clamp_max)> | ||
369 | fn_matmul_f16_bf16p_bf16p = nullptr; | ||
370 | |||
371 | /// Performs F16 or F32 matrix multiplication with LHS & RHS packing | ||
372 | /// followed by clamp operation. | ||
373 | /// | ||
374 | /// @param[in] m Number of output rows to be computed. | ||
375 | /// @param[in] n Number of output columns to be computed. | ||
376 | /// @param[in] k Common dimension of the LHS and RHS operands. | ||
377 | /// @param[in] packed_lhs Packed LHS matrix buffer. | ||
378 | /// @param[in] packed_rhs Packed RHS matrix buffer. | ||
379 | /// @param[out] dst Output matrix buffer. | ||
380 | /// @param[in] dst_stride_row Row stride in bytes of the output matrix. | ||
381 | /// @param[in] dst_stride_col Column stride in bytes of the output matrix. | ||
382 | /// @param[in] clamp_min Minimum value to clamp the final result. | ||
383 | /// @param[in] clamp_max Maximum value to clamp the final result. | ||
384 | std::function<void( | ||
385 | size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, | ||
386 | size_t dst_stride_col, float clamp_min, float clamp_max)> | ||
387 | fn_matmul_f16_f16p_f16p = nullptr; | ||
388 | |||
389 | std::function<void( | ||
390 | size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, | ||
391 | size_t dst_stride_col, float clamp_min, float clamp_max)> | ||
392 | fn_matmul_f32_f32p_f32p = nullptr; | ||
393 | |||
394 | /// Gets a value indicating whether pre-processing the RHS matrix is needed. | ||
395 | 830 | [[nodiscard]] bool is_pack_rhs_needed() const { | |
396 | 830 | return fn_pack_rhs != nullptr; | |
397 | } | ||
398 | |||
399 | /// Gets a value indicating whether pre-processing the transposed RHS matrix is needed. | ||
400 | 440 | [[nodiscard]] bool is_pack_rhs_nxk_needed() const { | |
401 | 440 | return fn_pack_rhs_nxk != nullptr; | |
402 | } | ||
403 | |||
404 | /// Preprocesses the RHS matrix. | ||
405 | /// | ||
406 | /// @param[in] n Size of the matrix in N dimension. | ||
407 | /// @param[in] k Size of the matrix in K dimension. | ||
408 | /// @param[in] rhs RHS data buffer. | ||
409 | /// @param[in] rhs_row_stride RHS row stride. | ||
410 | /// @param[in] bias Bias data buffer. | ||
411 | /// @param[in] scale Quantization scales data buffer. | ||
412 | /// @param[out] packed_rhs Packed RHS data buffer. | ||
413 | 492 | void pack_rhs( | |
414 | size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, | ||
415 | void* packed_rhs) const { | ||
416 | KAI_UNUSED(scale); | ||
417 | |||
418 |
1/2✓ Branch 0 taken 492 times.
✗ Branch 1 not taken.
|
492 | if (fn_pack_rhs != nullptr) { |
419 | 984 | fn_pack_rhs( | |
420 | 492 | 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, | |
421 | nullptr); | ||
422 | 492 | } else { | |
423 | − | KAI_ERROR("RHS pre-processing is not supported!"); | |
424 | } | ||
425 | 492 | } | |
426 | |||
427 | /// Preprocesses the transposed RHS matrix. | ||
428 | /// | ||
429 | /// @param[in] n Size of the matrix in N dimension. | ||
430 | /// @param[in] k Size of the matrix in K dimension. | ||
431 | /// @param[in] rhs RHS data buffer. | ||
432 | /// @param[in] rhs_row_stride RHS row stride. | ||
433 | /// @param[in] bias Bias data buffer. | ||
434 | /// @param[in] scale Quantization scales data buffer. | ||
435 | /// @param[out] packed_rhs Packed RHS data buffer. | ||
436 | 68 | void pack_rhs_nxk( | |
437 | size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, | ||
438 | void* packed_rhs) const { | ||
439 | KAI_UNUSED(scale); | ||
440 | |||
441 |
1/2✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
|
68 | if (fn_pack_rhs_nxk != nullptr) { |
442 | 136 | fn_pack_rhs_nxk( | |
443 | 68 | 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, | |
444 | nullptr); | ||
445 | 68 | } else { | |
446 | − | KAI_ERROR("RHS pre-processing is not supported!"); | |
447 | } | ||
448 | 68 | } | |
449 | |||
450 | 876 | [[nodiscard]] bool has_main_kernel() const { | |
451 |
2/2✓ Branch 0 taken 718 times.
✓ Branch 1 taken 158 times.
|
1594 | return fn_matmul_f16_f16_f16p != nullptr || // |
452 |
2/2✓ Branch 0 taken 682 times.
✓ Branch 1 taken 36 times.
|
718 | fn_matmul_f16_f16p_f16p != nullptr || // |
453 |
2/2✓ Branch 0 taken 646 times.
✓ Branch 1 taken 36 times.
|
682 | fn_matmul_f32_f32p_f32p != nullptr || // |
454 |
2/2✓ Branch 0 taken 436 times.
✓ Branch 1 taken 210 times.
|
646 | fn_matmul_f32_f32_f32p != nullptr || // |
455 |
2/2✓ Branch 0 taken 346 times.
✓ Branch 1 taken 90 times.
|
436 | fn_matmul_f32_bf16p_bf16p != nullptr || // |
456 | 90 | fn_matmul_f16_bf16p_bf16p != nullptr; | |
457 | } | ||
458 | |||
459 | 796 | void main_kernel( | |
460 | size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, | ||
461 | size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { | ||
462 | KAI_UNUSED(bias); | ||
463 | KAI_UNUSED(rhs_stride); | ||
464 | |||
465 |
2/2✓ Branch 0 taken 140 times.
✓ Branch 1 taken 656 times.
|
796 | if (fn_matmul_f16_f16_f16p) { |
466 | 280 | fn_matmul_f16_f16_f16p( | |
467 | 140 | m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(uint16_t), clamp_min, clamp_max); | |
468 |
2/2✓ Branch 0 taken 186 times.
✓ Branch 1 taken 470 times.
|
796 | } else if (fn_matmul_f32_f32_f32p) { |
469 | 186 | fn_matmul_f32_f32_f32p(m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); | |
470 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 438 times.
|
656 | } else if (fn_matmul_f16_f16p_f16p) { |
471 | 32 | fn_matmul_f16_f16p_f16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(Float16), clamp_min, clamp_max); | |
472 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 406 times.
|
470 | } else if (fn_matmul_f32_f32p_f32p) { |
473 | 32 | fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); | |
474 |
2/2✓ Branch 0 taken 322 times.
✓ Branch 1 taken 84 times.
|
438 | } else if (fn_matmul_f32_bf16p_bf16p) { |
475 | 644 | fn_matmul_f32_bf16p_bf16p( | |
476 | 322 | m, n, k, reinterpret_cast<const uint16_t*>(lhs), rhs, reinterpret_cast<float*>(dst), dst_stride, | |
477 | 322 | sizeof(float), clamp_min, clamp_max); | |
478 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
406 | } else if (fn_matmul_f16_bf16p_bf16p) { |
479 | 84 | fn_matmul_f16_bf16p_bf16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(uint16_t), clamp_min, clamp_max); | |
480 | 84 | } else { | |
481 | − | KAI_ERROR("Main kernel is not available!"); | |
482 | } | ||
483 | 796 | } | |
484 | }; | ||
485 | |||
486 | // NOLINTEND(misc-non-private-member-variables-in-classes) | ||
487 | |||
488 | /// Matrix multiplication test information. | ||
489 | using MatMulTestParams = std::tuple<MatMulMethod, MatMulShape, MatrixPortion>; | ||
490 | using MatMulTestPortionedParams = std::tuple<size_t, MatMulShape, MatrixPortion>; | ||
491 | using MatMulTestPortionedParamsWithBias = std::tuple<size_t, MatMulShape, MatrixPortion, bool>; | ||
492 | using MatMulTestPortionedParamsWithBias_WithBL = std::tuple<size_t, MatMulShape, size_t, MatrixPortion, bool>; | ||
493 | |||
494 | /// Prints the test information. | ||
495 | void PrintTo(const MatMulTestParams& param, std::ostream* os); | ||
496 | void PrintTo(const MatMulShape& shape, std::ostream* os); | ||
497 | void PrintTo(const MatrixPortion& portion, std::ostream* os); | ||
498 | |||
499 | /// Generate test information. | ||
500 | std::string test_description( | ||
501 | const std::string_view& name, const MatMulShape& shape, const MatrixPortion& portion, bool bias); | ||
502 | |||
503 | } // namespace kai::test | ||
504 | |||
505 | template <> | ||
506 | struct std::hash<kai::test::MatMulShape> { | ||
507 | 1243 | size_t operator()(const kai::test::MatMulShape& ms) const { | |
508 | 1243 | return kai::test::MatMulShape::Hash{}(ms); | |
509 | } | ||
510 | }; | ||
511 |