KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 56 / 3 / 59
Functions: 100.0% 13 / 0 / 13
Branches: 62.2% 61 / 0 / 98

test/common/matmul_test_common.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #pragma once
7
8 #include <cstddef>
9 #include <cstdint>
10 #include <functional>
11 #include <iosfwd>
12 #include <string>
13 #include <string_view>
14 #include <tuple>
15
16 #include "kai/kai_common.h"
17 #include "test/common/buffer.hpp"
18 #include "test/common/data_format.hpp"
19 #include "test/common/float16.hpp"
20 #include "test/common/matrix_portion.hpp"
21 #include "test/common/seed.hpp"
22
23 namespace kai::test {
24 /// Matrix multiplication shape.
25 struct MatMulShape {
26 size_t m; ///< LHS height.
27 size_t n; ///< RHS width.
28 size_t k; ///< LHS width and RHS height.
29
30 struct Hash {
31 59329 size_t operator()(const MatMulShape& shape) const {
32 59329 return //
33 118658 (std::hash<size_t>{}(shape.m) << 0) ^ //
34 118658 (std::hash<size_t>{}(shape.n) << 1) ^ //
35 59329 (std::hash<size_t>{}(shape.k) << 2); //
36 }
37 };
38
39 private:
40 50910 friend bool operator==(const MatMulShape& lhs, const MatMulShape& rhs) {
41 50910 return //
42
2/2
✓ Branch 0 taken 44994 times.
✓ Branch 1 taken 5916 times.
50910 lhs.m == rhs.m && //
43
2/2
✓ Branch 0 taken 1308 times.
✓ Branch 1 taken 43686 times.
44994 lhs.n == rhs.n && //
44 43686 lhs.k == rhs.k;
45 }
46 friend std::ostream& operator<<(std::ostream& os, const MatMulShape& shape);
47 };
48
49 /// Value range
50 template <typename T>
51 struct Range {
52 T min;
53 T max;
54
55 [[nodiscard]] T range() const {
56 return max - min;
57 }
58 };
59
60 // NOLINTBEGIN(misc-non-private-member-variables-in-classes)
61
62 /// Matrix multiplication method.
63
34/68
✓ Branch 0 taken 115458 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 115458 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 115458 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 115458 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 115458 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 115458 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 115458 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 115458 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 115458 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 115458 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 115458 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 115458 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 115458 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 115458 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 115458 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 115458 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 115458 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 115458 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 115458 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 115458 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 115458 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 115458 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 115458 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 115458 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 115458 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 115458 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 115458 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 115458 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 115458 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 115458 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 115458 times.
✗ Branch 61 not taken.
✓ Branch 62 taken 115458 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 115458 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 115458 times.
✗ Branch 67 not taken.
115458 struct MatMulMethod {
64 std::string_view name{}; ///< Name of matmul method.
65
66 size_t m0{0}; ///< Block size in M dimension.
67 size_t n0{0}; ///< Block size in N dimension.
68 size_t k0{0}; ///< Block size in K dimension.
69
70 DataFormat dst_format{}; ///< Data format of the destination matrix.
71 DataFormat lhs_format{}; ///< Data format of the LHS matrix.
72 DataFormat packed_lhs_format{}; ///< Data format of the packed LHS matrix.
73 DataFormat rhs_format{}; ///< Data format of the RHS matrix.
74 DataFormat packed_rhs_format{}; ///< Data format of the packed RHS matrix.
75 DataFormat bias_format{}; ///< Data format of the bias vector.
76 bool nb_support{}; ///< Does the kernel support null_bias.
77
78 /// Generate LHS matrix.
79 ///
80 /// @param[in] m Number of rows in the LHS matrix.
81 /// @param[in] k Number of columns in the LHS matrix.
82 /// @param[in] seed_feed Seed feed for random number generation.
83 ///
84 /// @return LHS matrix data buffer.
85 std::function<Buffer(size_t, size_t, SeedFeed&)> fn_generate_lhs{nullptr};
86
87 /// Generate RHS matrix.
88 ///
89 /// @param[in] k Number of rows in the RHS matrix.
90 /// @param[in] n Number of columns in the RHS matrix.
91 /// @param[in] seed_feed Seed feed for random number generation.
92 ///
93 /// @return RHS matrix data buffer.
94 std::function<Buffer(size_t, size_t, SeedFeed&)> fn_generate_rhs{nullptr};
95
96 /// Generate bias.
97 ///
98 /// @param[in] n Number of rows in the bias.
99 /// @param[in] k Number of columns in the bias.
100 /// @param[in] seed_feed Seed feed for random number generation.
101 /// @param[in] null_bias_mode Whether to generate null bias (true) or real bias (false).
102 ///
103 /// @return Bias data buffer.
104 std::function<Buffer(size_t, size_t, SeedFeed&, bool)> fn_generate_bias{nullptr};
105
106 /// Check if CPU supports required features.
107 ///
108 /// @return Supported (true) or not supported (false).
109 std::function<bool(void)> fn_is_supported{nullptr};
110
111 /// Gets mr value.
112 ///
113 /// This is the packing parameter which must be used to pack the LHS matrix (if necessary).
114 ///
115 /// @return The mr value.
116 std::function<size_t(void)> fn_get_mr{nullptr};
117
118 /// Gets nr value.
119 ///
120 /// This is the packing parameter which must be used to pack the RHS matrix (if necessary).
121 ///
122 /// @return The nr value.
123 std::function<size_t(void)> fn_get_nr{nullptr};
124
125 /// Gets kr value.
126 ///
127 /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary).
128 ///
129 /// @return The kr value.
130 std::function<size_t(void)> fn_get_kr{nullptr};
131
132 /// Gets sr value.
133 ///
134 /// This is the packing parameter which must be used to pack the RHS matrix.
135 ///
136 /// @return The sr value.
137 std::function<size_t(void)> fn_get_sr{nullptr};
138
139 /// Gets m step value for main kernel.
140 ///
141 /// The starting row index must be divisible by `m_step`.
142 ///
143 /// @return The m step value.
144 std::function<size_t(void)> fn_get_main_m_step{nullptr};
145
146 /// Gets n step value for RHS packing micro-kernel.
147 ///
148 /// The starting row index must be divisible by `n_step`.
149 ///
150 /// @return The n step value.
151 std::function<size_t(void)> fn_get_pack_rhs_n_step{nullptr};
152
153 /// Gets n step value for main kernel.
154 ///
155 /// The starting column index must be divisible by `n_step`.
156 ///
157 /// @return The n step value.
158 std::function<size_t(void)> fn_get_main_n_step{nullptr};
159
160 /// Gets the offset in bytes of the LHS matrix.
161 ///
162 /// @param[in] m_idx Coordinate of the matrix in M dimension.
163 /// @param[in] stride Row stride in bytes.
164 ///
165 /// @return The offset in bytes.
166 std::function<size_t(size_t m_idx, size_t stride)> fn_get_lhs_offset{nullptr};
167
168 /// Gets the size in bytes of the packed LHS matrix.
169 ///
170 /// @param[in] m Number of rows in the unpacked LHS matrix.
171 /// @param[in] k Number of columns in the unpacked LHS matrix.
172 /// @param[in] mr Number of rows to be interleaved.
173 /// @param[in] kr Unused. Must be 1.
174 /// @param[in] sr Unused. Must be 1.
175 ///
176 /// @return The size in bytes.
177 std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)> fn_get_packed_lhs_size{nullptr};
178
179 /// Gets the offset in bytes of the packed LHS matrix.
180 ///
181 /// @param[in] m_idx Coordinate of the matrix in M dimension.
182 /// @param[in] k Size of the matrix in K dimension.
183 ///
184 /// @return The offset in bytes.
185 std::function<size_t(size_t m_idx, size_t k)> fn_get_packed_lhs_offset{nullptr};
186
187 /// Preprocesses the LHS matrix.
188 ///
189 /// @param[in] m Number of rows of the unpacked LHS matrix.
190 /// @param[in] k Common dimension between the LHS and RHS matrix.
191 /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL.
192 /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}.
193 /// @param[in] sr Number of kr splits. It must be 1.
194 /// @param[in] m_idx_start Unused. Must be 0.
195 /// @param[in] lhs LHS matrix data buffer.
196 /// @param[in] lhs_stride Row stride in bytes of the LHS matrix.
197 /// @param[out] lhs_packed Packed RHS matrix.
198 std::function<void(
199 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
200 void* lhs_packed)>
201 fn_pack_lhs{nullptr};
202
203 /// Gets a value indicating whether LHS packing is needed.
204 4368 [[nodiscard]] bool is_pack_lhs_needed() const {
205 4368 return fn_pack_lhs != nullptr;
206 }
207
208 /// Gets the offset in bytes of the RHS matrix.
209 ///
210 /// @param[in] n_idx Coordinate of the matrix in N dimension.
211 ///
212 /// @return The offset in bytes.
213 std::function<size_t(size_t n_idx)> fn_get_rhs_offset{nullptr};
214
215 /// Gets the size in bytes of the packed RHS matrix.
216 ///
217 /// @param[in] n Size of the matrix in N dimension.
218 /// @param[in] k Size of the matrix in K dimension.
219 ///
220 /// @return The size in bytes.
221 std::function<size_t(size_t n, size_t k)> fn_get_packed_rhs_size{nullptr};
222
223 /// Gets the size in bytes of the packed RHS matrix.
224 ///
225 /// @param[in] n Size of the matrix in N dimension.
226 /// @param[in] k Size of the matrix in K dimension.
227 /// @param[in] nr Block size in N dimension.
228 /// @param[in] kr Block size in K dimension.
229 ///
230 /// @return The size in bytes.
231 std::function<size_t(size_t n, size_t k, size_t nr, size_t kr)> fn_get_packed_rhs_size_generic_block_size = nullptr;
232
233 /// Gets the offset in bytes of the packed RHS matrix in the RHS packing micro-kernel
234 ///
235 /// @param[in] n_idx Coordinate of the matrix in N dimension.
236 /// @param[in] k Size of the matrix in K dimension.
237 ///
238 /// @return The offset in bytes.
239 std::function<size_t(size_t n_idx, size_t k)> fn_get_pack_rhs_packed_rhs_offset{nullptr};
240
241 /// Gets the offset in bytes of the packed RHS matrix in the main kernel.
242 ///
243 /// @param[in] n_idx Coordinate of the matrix in N dimension.
244 /// @param[in] k Size of the matrix in K dimension.
245 ///
246 /// @return The offset in bytes.
247 std::function<size_t(size_t n_idx, size_t k)> fn_get_main_packed_rhs_offset{nullptr};
248
249 std::function<void(
250 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
251 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
252 fn_pack_rhs{nullptr};
253
254 /// Gets n step value.
255 ///
256 /// The starting row index must be divisible by `n_step`.
257 ///
258 /// @return The n step value.
259 std::function<size_t()> fn_pack_rhs_nxk_get_n_step{nullptr};
260
261 /// Gets the offset in bytes to the data element in the RHS matrix buffer.
262 ///
263 /// @param[in] n_idx Column index.
264 /// @param[in] rhs_offset Row stride in bytes of the RHS matrix.
265 ///
266 /// @return The offset in bytes to the data element.
267 std::function<size_t(size_t n_idx, size_t rhs_stride)> fn_pack_rhs_nxk_get_rhs_offset{nullptr};
268
269 /// Gets the offset in bytes to the data element in the bias buffer.
270 ///
271 /// @param[in] n_idx Column index.
272 ///
273 /// @return The offset in bytes to the data element.
274 std::function<size_t(size_t n_idx)> fn_pack_rhs_nxk_get_bias_offset{nullptr};
275
276 /// Gets the offset in bytes to the data element in the packed RHS buffer.
277 ///
278 /// @param[in] n_idx Row index.
279 /// @param[in] k Number of columns.
280 ///
281 /// @return The offset in bytes to the data element.
282 std::function<size_t(size_t n_idx, size_t k)> fn_pack_rhs_nxk_get_packed_rhs_offset{nullptr};
283
284 /// Gets the size in bytes of the packed RHS buffer.
285 ///
286 /// @param[in] n Number of rows.
287 /// @param[in] k Number of columns.
288 ///
289 /// @return The size in bytes of the packed RHS buffer.
290 std::function<size_t(size_t n, size_t k)> fn_pack_rhs_nxk_get_packed_rhs_size{nullptr};
291
292 /// Runs the RHS packing micro-kernel for matrix multiplication.
293 ///
294 /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset
295 /// calculated using the following functions:
296 ///
297 /// * RHS: @ref kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme.
298 /// * Bias: @ref kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme.
299 /// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme.
300 ///
301 /// @param[in] num_groups Number of groups. It must be 1.
302 /// @param[in] n Number of columns of the output matrix.
303 /// @param[in] k Common dimension between the LHS and RHS matrix.
304 /// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u32().
305 /// @param[in] kr Block size in K dimension. It must be 1.
306 /// @param[in] sr Number of kr splits. It must be 1.
307 /// @param[in] rhs_stride Row stride in bytes of the RHS matrix.
308 /// @param[in] rhs RHS matrix data buffer.
309 /// @param[in] bias Bias matrix data buffer.
310 /// @param[in] scale Scale data buffer. It must be NULL.
311 /// @param[out] rhs_packed Packed RHS matrix.
312 /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0.
313 /// @param[in] params Extra packing parameters. It must be NULL.
314 std::function<void(
315 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
316 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
317 fn_pack_rhs_nxk{nullptr};
318
319 /// Gets the offset in bytes to the data element in the bias buffer.
320 ///
321 /// @param[in] n_idx Column index.
322 ///
323 /// @return The offset in bytes to the data element.
324 std::function<size_t(size_t n_idx)> fn_get_bias_offset{nullptr};
325
326 /// Gets the offset in bytes to the data element in the destination matrix buffer.
327 ///
328 /// @param[in] m_idx Row index.
329 /// @param[in] n_idx Column index.
330 /// @param[in] stride Row stride in bytes.
331 ///
332 /// @return The offset in bytes to the data element.
333 std::function<size_t(size_t m_idx, size_t n_idx, size_t stride)> fn_get_dst_offset{nullptr};
334
335 /// Gets the size in bytes of the destination matrix buffer.
336 ///
337 /// @param[in] m Number of rows.
338 /// @param[in] n Number of columns.
339 ///
340 /// @return The size in bytes of the destination matrix buffer.
341 std::function<size_t(size_t m, size_t n)> fn_get_dst_size{nullptr};
342
343 /// Performs F16 or F32 matrix multiplication with RHS packing
344 /// followed by clamp operation.
345 ///
346 /// @param[in] m Size of the matrix in M dimension.
347 /// @param[in] n Size of the matrix in N dimension.
348 /// @param[in] k Size of the matrix in K dimension.
349 /// @param[in] lhs LHS data buffer.
350 /// @param[in] packed_rhs Packed RHS data buffer.
351 /// @param[out] dst Output data buffer.
352 /// @param[in] lhs_stride LHS row stride.
353 /// @param[in] dst_stride_row Output row stride.
354 /// @param[in] dst_stride_col Output column stride.
355 /// @param[in] clamp_min Lower bound of the output data.
356 /// @param[in] clamp_max Upper bound of the output data.
357 std::function<void(
358 size_t m, size_t n, size_t k, //
359 const void* lhs, size_t lhs_stride, //
360 const void* packed_rhs, //
361 void* dst, size_t dst_stride_row, size_t dst_stride_col, //
362 float clamp_min, float clamp_max)>
363 fn_matmul_f16_f16_f16p = nullptr;
364
365 std::function<void(
366 size_t m, size_t n, size_t k, //
367 const void* lhs, size_t lhs_stride, //
368 const void* packed_rhs, //
369 void* dst, size_t dst_stride_row, size_t dst_stride_col, //
370 float clamp_min, float clamp_max)>
371 fn_matmul_f32_f32_f32p = nullptr;
372
373 /// Performs BF16 matrix multiplication with LHS and RHS packing
374 /// followed by clamp operation.
375 ///
376 /// @param[in] m Size of the matrix in M dimension.
377 /// @param[in] n Size of the matrix in N dimension.
378 /// @param[in] k Size of the matrix in K dimension.
379 /// @param[in] packed_lhs Packed LHS data buffer.
380 /// @param[in] packed_rhs Packed RHS data buffer.
381 /// @param[out] dst Output data buffer.
382 /// @param[in] dst_stride_row Output row stride.
383 /// @param[in] dst_stride_col Output column stride.
384 /// @param[in] clamp_min Lower bound of the output data.
385 /// @param[in] clamp_max Upper bound of the output data.
386 std::function<void(
387 size_t m, size_t n, size_t k, //
388 const void* packed_lhs, //
389 const void* packed_rhs, //
390 void* dst, size_t dst_stride_row, size_t dst_stride_col, //
391 float clamp_min, float clamp_max)>
392 fn_matmul_f32_bf16p_bf16p = nullptr;
393
394 std::function<void(
395 size_t m, size_t n, size_t k, //
396 const void* packed_lhs, //
397 const void* packed_rhs, //
398 void* dst, size_t dst_stride_row, size_t dst_stride_col, //
399 float clamp_min, float clamp_max)>
400 fn_matmul_f16_bf16p_bf16p = nullptr;
401
402 /// Performs F16 or F32 matrix multiplication with LHS & RHS packing
403 /// followed by clamp operation.
404 ///
405 /// @param[in] m Number of output rows to be computed.
406 /// @param[in] n Number of output columns to be computed.
407 /// @param[in] k Common dimension of the LHS and RHS operands.
408 /// @param[in] packed_lhs Packed LHS matrix buffer.
409 /// @param[in] packed_rhs Packed RHS matrix buffer.
410 /// @param[out] dst Output matrix buffer.
411 /// @param[in] dst_stride_row Row stride in bytes of the output matrix.
412 /// @param[in] dst_stride_col Column stride in bytes of the output matrix.
413 /// @param[in] clamp_min Minimum value to clamp the final result.
414 /// @param[in] clamp_max Maximum value to clamp the final result.
415 std::function<void(
416 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
417 size_t dst_stride_col, float clamp_min, float clamp_max)>
418 fn_matmul_f16_f16p_f16p = nullptr;
419
420 std::function<void(
421 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
422 size_t dst_stride_col, float clamp_min, float clamp_max)>
423 fn_matmul_f32_f32p_f32p = nullptr;
424
425 /// Gets a value indicating whether pre-processing the RHS matrix is needed.
426 4368 [[nodiscard]] bool is_pack_rhs_needed() const {
427 4368 return fn_pack_rhs != nullptr;
428 }
429
430 /// Gets a value indicating whether pre-processing the transposed RHS matrix is needed.
431 2184 [[nodiscard]] bool is_pack_rhs_nxk_needed() const {
432 2184 return fn_pack_rhs_nxk != nullptr;
433 }
434
435 /// Preprocesses the RHS matrix.
436 ///
437 /// @param[in] n Size of the matrix in N dimension.
438 /// @param[in] k Size of the matrix in K dimension.
439 /// @param[in] rhs RHS data buffer.
440 /// @param[in] rhs_row_stride RHS row stride.
441 /// @param[in] bias Bias data buffer.
442 /// @param[in] scale Quantization scales data buffer.
443 /// @param[out] packed_rhs Packed RHS data buffer.
444 2706 void pack_rhs(
445 size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale,
446 void* packed_rhs) const {
447 KAI_UNUSED(scale);
448
449
1/2
✓ Branch 0 taken 2706 times.
✗ Branch 1 not taken.
2706 if (fn_pack_rhs != nullptr) {
450 5412 fn_pack_rhs(
451 2706 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0,
452 nullptr);
453 2706 } else {
454 KAI_ERROR("RHS pre-processing is not supported!");
455 }
456 2706 }
457
458 /// Preprocesses the transposed RHS matrix.
459 ///
460 /// @param[in] n Size of the matrix in N dimension.
461 /// @param[in] k Size of the matrix in K dimension.
462 /// @param[in] rhs RHS data buffer.
463 /// @param[in] rhs_row_stride RHS row stride.
464 /// @param[in] bias Bias data buffer.
465 /// @param[in] scale Quantization scales data buffer.
466 /// @param[out] packed_rhs Packed RHS data buffer.
467 216 void pack_rhs_nxk(
468 size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale,
469 void* packed_rhs) const {
470 KAI_UNUSED(scale);
471
472
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 if (fn_pack_rhs_nxk != nullptr) {
473 432 fn_pack_rhs_nxk(
474 216 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0,
475 nullptr);
476 216 } else {
477 KAI_ERROR("RHS pre-processing is not supported!");
478 }
479 216 }
480
481 4650 [[nodiscard]] bool has_main_kernel() const {
482
2/2
✓ Branch 0 taken 3798 times.
✓ Branch 1 taken 852 times.
8448 return fn_matmul_f16_f16_f16p != nullptr || //
483
2/2
✓ Branch 0 taken 3690 times.
✓ Branch 1 taken 108 times.
3798 fn_matmul_f16_f16p_f16p != nullptr || //
484
2/2
✓ Branch 0 taken 3582 times.
✓ Branch 1 taken 108 times.
3690 fn_matmul_f32_f32p_f32p != nullptr || //
485
2/2
✓ Branch 0 taken 2466 times.
✓ Branch 1 taken 1116 times.
3582 fn_matmul_f32_f32_f32p != nullptr || //
486
2/2
✓ Branch 0 taken 1926 times.
✓ Branch 1 taken 540 times.
2466 fn_matmul_f32_bf16p_bf16p != nullptr || //
487 540 fn_matmul_f16_bf16p_bf16p != nullptr;
488 }
489
490 4650 void main_kernel(
491 size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride,
492 size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const {
493 KAI_UNUSED(bias);
494 KAI_UNUSED(rhs_stride);
495
496
2/2
✓ Branch 0 taken 852 times.
✓ Branch 1 taken 3798 times.
4650 if (fn_matmul_f16_f16_f16p) {
497 1704 fn_matmul_f16_f16_f16p(
498 852 m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(uint16_t), clamp_min, clamp_max);
499
2/2
✓ Branch 0 taken 1116 times.
✓ Branch 1 taken 2682 times.
4650 } else if (fn_matmul_f32_f32_f32p) {
500 1116 fn_matmul_f32_f32_f32p(m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max);
501
2/2
✓ Branch 0 taken 108 times.
✓ Branch 1 taken 2574 times.
3798 } else if (fn_matmul_f16_f16p_f16p) {
502 108 fn_matmul_f16_f16p_f16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(Float16), clamp_min, clamp_max);
503
2/2
✓ Branch 0 taken 108 times.
✓ Branch 1 taken 2466 times.
2682 } else if (fn_matmul_f32_f32p_f32p) {
504 108 fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max);
505
2/2
✓ Branch 0 taken 1926 times.
✓ Branch 1 taken 540 times.
2574 } else if (fn_matmul_f32_bf16p_bf16p) {
506 3852 fn_matmul_f32_bf16p_bf16p(
507 1926 m, n, k, reinterpret_cast<const uint16_t*>(lhs), rhs, reinterpret_cast<float*>(dst), dst_stride,
508 1926 sizeof(float), clamp_min, clamp_max);
509
1/2
✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
2466 } else if (fn_matmul_f16_bf16p_bf16p) {
510 540 fn_matmul_f16_bf16p_bf16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(uint16_t), clamp_min, clamp_max);
511 540 } else {
512 KAI_ERROR("Main kernel is not available!");
513 }
514 4650 }
515 };
516
517 // NOLINTEND(misc-non-private-member-variables-in-classes)
518
519 /// Describes bias handling
520 enum class BiasMode {
521 INTERNAL, // Zero bias internally generated in kernel
522 PROVIDED, // Bias provided by kernel caller
523 };
524
525 /// Matrix multiplication test information.
526 using MatMulClampTestParams = std::tuple<MatMulMethod, MatMulShape, MatrixPortion, BiasMode, float>;
527 using MatMulClampTestPortionedParams = std::tuple<size_t, MatMulShape, MatrixPortion, float>;
528 using MatMulClampTestPortionedParamsWithBias = std::tuple<size_t, MatMulShape, MatrixPortion, float, bool>;
529 using MatMulClampTestPortionedParamsWithBias_WithBL =
530 std::tuple<size_t, MatMulShape, size_t, MatrixPortion, float, bool>;
531
532 /// Prints the test information.
533 void PrintTo(const MatMulClampTestParams& param, std::ostream* os);
534 void PrintTo(const MatMulShape& shape, std::ostream* os);
535 void PrintTo(const MatrixPortion& portion, std::ostream* os);
536 void PrintTo(const BiasMode& bias_mode, std::ostream* os);
537
538 /// Generate test information
539 std::string test_description(
540 const std::string_view& name, const MatMulShape& shape, const MatrixPortion& portion, bool bias,
541 float clamp_keep_ratio);
542
543 } // namespace kai::test
544
545 template <>
546 struct std::hash<kai::test::MatMulShape> {
547 27295 size_t operator()(const kai::test::MatMulShape& ms) const {
548 27295 return kai::test::MatMulShape::Hash{}(ms);
549 }
550 };
551