KleidiAI Coverage Report


Directory: ./
File: test/common/matmul_test_common.hpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 100.0% 56 3 59
Functions: 100.0% 12 0 12
Branches: 63.0% 58 0 92

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #pragma once
7
8 #include <cstddef>
9 #include <cstdint>
10 #include <functional>
11 #include <iosfwd>
12 #include <string>
13 #include <string_view>
14 #include <tuple>
15
16 #include "kai/kai_common.h"
17 #include "test/common/data_format.hpp"
18 #include "test/common/float16.hpp"
19 #include "test/common/matrix_portion.hpp"
20
21 namespace kai::test {
22 /// Matrix multiplication shape.
23 struct MatMulShape {
24 size_t m; ///< LHS height.
25 size_t n; ///< RHS width.
26 size_t k; ///< LHS width and RHS height.
27
28 struct Hash {
29 81997 size_t operator()(const MatMulShape& shape) const {
30 81997 return //
31 163994 (std::hash<size_t>{}(shape.m) << 0) ^ //
32 163994 (std::hash<size_t>{}(shape.n) << 1) ^ //
33 81997 (std::hash<size_t>{}(shape.k) << 2); //
34 }
35 };
36
37 private:
38 79726 friend bool operator==(const MatMulShape& lhs, const MatMulShape& rhs) {
39 79726 return //
40
2/2
✓ Branch 0 taken 74236 times.
✓ Branch 1 taken 5490 times.
79726 lhs.m == rhs.m && //
41
2/2
✓ Branch 0 taken 2224 times.
✓ Branch 1 taken 72012 times.
74236 lhs.n == rhs.n && //
42 72012 lhs.k == rhs.k;
43 }
44 friend std::ostream& operator<<(std::ostream& os, const MatMulShape& shape);
45 };
46
47 /// Value range
48 template <typename T>
49 struct Range {
50 T min;
51 T max;
52
53 [[nodiscard]] T range() const {
54 return max - min;
55 }
56 };
57
58 // NOLINTBEGIN(misc-non-private-member-variables-in-classes)
59
60 /// Matrix multiplication method.
61
31/62
✓ Branch 0 taken 11118 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11118 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11118 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 11118 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 11118 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11118 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 11118 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 11118 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11118 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 11118 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 11118 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 11118 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 11118 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 11118 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 11118 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 11118 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 11118 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 11118 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 11118 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 11118 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 11118 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 11118 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 11118 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 11118 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 11118 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 11118 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 11118 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 11118 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 11118 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 11118 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 11118 times.
✗ Branch 61 not taken.
11118 struct MatMulMethod {
62 std::string_view name{}; ///< Name of matmul method.
63
64 size_t m0{0}; ///< Block size in M dimension.
65 size_t n0{0}; ///< Block size in N dimension.
66 size_t k0{0}; ///< Block size in K dimension.
67
68 DataFormat dst_format{}; ///< Data format of the destination matrix.
69 DataFormat lhs_format{}; ///< Data format of the LHS matrix.
70 DataFormat packed_lhs_format{}; ///< Data format of the packed LHS matrix.
71 DataFormat rhs_format{}; ///< Data format of the RHS matrix.
72 DataFormat packed_rhs_format{}; ///< Data format of the packed RHS matrix.
73 DataFormat bias_format{}; ///< Data format of the bias vector.
74
75 /// Check if CPU supports required features.
76 ///
77 /// @return Supported (true) or not supported (false).
78 std::function<bool(void)> fn_is_supported{nullptr};
79
80 /// Gets mr value.
81 ///
82 /// This is the packing parameter which must be used to pack the LHS matrix (if necessary).
83 ///
84 /// @return The mr value.
85 std::function<size_t(void)> fn_get_mr{nullptr};
86
87 /// Gets nr value.
88 ///
89 /// This is the packing parameter which must be used to pack the RHS matrix (if necessary).
90 ///
91 /// @return The nr value.
92 std::function<size_t(void)> fn_get_nr{nullptr};
93
94 /// Gets kr value.
95 ///
96 /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary).
97 ///
98 /// @return The kr value.
99 std::function<size_t(void)> fn_get_kr{nullptr};
100
101 /// Gets sr value.
102 ///
103 /// This is the packing parameter which must be used to pack the RHS matrix.
104 ///
105 /// @return The sr value.
106 std::function<size_t(void)> fn_get_sr{nullptr};
107
108 /// Gets m step value for main kernel.
109 ///
110 /// The starting row index must be divisible by `m_step`.
111 ///
112 /// @return The m step value.
113 std::function<size_t(void)> fn_get_main_m_step{nullptr};
114
115 /// Gets n step value for RHS packing micro-kernel.
116 ///
117 /// The starting row index must be divisible by `n_step`.
118 ///
119 /// @return The n step value.
120 std::function<size_t(void)> fn_get_pack_rhs_n_step{nullptr};
121
122 /// Gets n step value for main kernel.
123 ///
124 /// The starting column index must be divisible by `n_step`.
125 ///
126 /// @return The n step value.
127 std::function<size_t(void)> fn_get_main_n_step{nullptr};
128
129 /// Gets the offset in bytes of the LHS matrix.
130 ///
131 /// @param[in] m_idx Coordinate of the matrix in M dimension.
132 /// @param[in] stride Row stride in bytes.
133 ///
134 /// @return The offset in bytes.
135 std::function<size_t(size_t m_idx, size_t stride)> fn_get_lhs_offset{nullptr};
136
137 /// Gets the size in bytes of the packed LHS matrix.
138 ///
139 /// @param[in] m Number of rows in the unpacked LHS matrix.
140 /// @param[in] k Number of columns in the unpacked LHS matrix.
141 /// @param[in] mr Number of rows to be interleaved.
142 /// @param[in] kr Unused. Must be 1.
143 /// @param[in] sr Unused. Must be 1.
144 ///
145 /// @return The size in bytes.
146 std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)> fn_get_packed_lhs_size{nullptr};
147
148 /// Gets the offset in bytes of the packed LHS matrix.
149 ///
150 /// @param[in] m_idx Coordinate of the matrix in M dimension.
151 /// @param[in] k Size of the matrix in K dimension.
152 ///
153 /// @return The offset in bytes.
154 std::function<size_t(size_t m_idx, size_t k)> fn_get_packed_lhs_offset{nullptr};
155
156 /// Preprocesses the LHS matrix.
157 ///
158 /// @param[in] m Number of rows of the unpacked LHS matrix.
159 /// @param[in] k Common dimension between the LHS and RHS matrix.
160 /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL.
161 /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}.
162 /// @param[in] sr Number of kr splits. It must be 1.
163 /// @param[in] m_idx_start Unused. Must be 0.
164 /// @param[in] lhs LHS matrix data buffer.
165 /// @param[in] lhs_stride Row stride in bytes of the LHS matrix.
166 /// @param[out] lhs_packed Packed RHS matrix.
167 std::function<void(
168 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
169 void* lhs_packed)>
170 fn_pack_lhs{nullptr};
171
172 /// Gets a value indicating whether LHS packing is needed.
173 830 [[nodiscard]] bool is_pack_lhs_needed() const {
174 830 return fn_pack_lhs != nullptr;
175 }
176
177 /// Gets the offset in bytes of the RHS matrix.
178 ///
179 /// @param[in] n_idx Coordinate of the matrix in N dimension.
180 ///
181 /// @return The offset in bytes.
182 std::function<size_t(size_t n_idx)> fn_get_rhs_offset{nullptr};
183
184 /// Gets the size in bytes of the packed RHS matrix.
185 ///
186 /// @param[in] n Size of the matrix in N dimension.
187 /// @param[in] k Size of the matrix in K dimension.
188 ///
189 /// @return The size in bytes.
190 std::function<size_t(size_t n, size_t k)> fn_get_packed_rhs_size{nullptr};
191
192 /// Gets the size in bytes of the packed RHS matrix.
193 ///
194 /// @param[in] n Size of the matrix in N dimension.
195 /// @param[in] k Size of the matrix in K dimension.
196 /// @param[in] nr Block size in N dimension.
197 /// @param[in] kr Block size in K dimension.
198 ///
199 /// @return The size in bytes.
200 std::function<size_t(size_t n, size_t k, size_t nr, size_t kr)> fn_get_packed_rhs_size_generic_block_size = nullptr;
201
202 /// Gets the offset in bytes of the packed RHS matrix in the RHS packing micro-kernel
203 ///
204 /// @param[in] n_idx Coordinate of the matrix in N dimension.
205 /// @param[in] k Size of the matrix in K dimension.
206 ///
207 /// @return The offset in bytes.
208 std::function<size_t(size_t n_idx, size_t k)> fn_get_pack_rhs_packed_rhs_offset{nullptr};
209
210 /// Gets the offset in bytes of the packed RHS matrix in the main kernel.
211 ///
212 /// @param[in] n_idx Coordinate of the matrix in N dimension.
213 /// @param[in] k Size of the matrix in K dimension.
214 ///
215 /// @return The offset in bytes.
216 std::function<size_t(size_t n_idx, size_t k)> fn_get_main_packed_rhs_offset{nullptr};
217
218 std::function<void(
219 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
220 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
221 fn_pack_rhs{nullptr};
222
223 /// Gets n step value.
224 ///
225 /// The starting row index must be divisible by `n_step`.
226 ///
227 /// @return The n step value.
228 std::function<size_t()> fn_pack_rhs_nxk_get_n_step{nullptr};
229
230 /// Gets the offset in bytes to the data element in the RHS matrix buffer.
231 ///
232 /// @param[in] n_idx Column index.
233 /// @param[in] rhs_offset Row stride in bytes of the RHS matrix.
234 ///
235 /// @return The offset in bytes to the data element.
236 std::function<size_t(size_t n_idx, size_t rhs_stride)> fn_pack_rhs_nxk_get_rhs_offset{nullptr};
237
238 /// Gets the offset in bytes to the data element in the bias buffer.
239 ///
240 /// @param[in] n_idx Column index.
241 ///
242 /// @return The offset in bytes to the data element.
243 std::function<size_t(size_t n_idx)> fn_pack_rhs_nxk_get_bias_offset{nullptr};
244
245 /// Gets the offset in bytes to the data element in the packed RHS buffer.
246 ///
247 /// @param[in] n_idx Row index.
248 /// @param[in] k Number of columns.
249 ///
250 /// @return The offset in bytes to the data element.
251 std::function<size_t(size_t n_idx, size_t k)> fn_pack_rhs_nxk_get_packed_rhs_offset{nullptr};
252
253 /// Gets the size in bytes of the packed RHS buffer.
254 ///
255 /// @param[in] n Number of rows.
256 /// @param[in] k Number of columns.
257 ///
258 /// @return The size in bytes of the packed RHS buffer.
259 std::function<size_t(size_t n, size_t k)> fn_pack_rhs_nxk_get_packed_rhs_size{nullptr};
260
261 /// Runs the RHS packing micro-kernel for matrix multiplication.
262 ///
263 /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset
264 /// calculated using the following functions:
265 ///
266 /// * RHS: @ref kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme.
267 /// * Bias: @ref kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme.
268 /// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme.
269 ///
270 /// @param[in] num_groups Number of groups. It must be 1.
271 /// @param[in] n Number of columns of the output matrix.
272 /// @param[in] k Common dimension between the LHS and RHS matrix.
273 /// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u32().
274 /// @param[in] kr Block size in K dimension. It must be 1.
275 /// @param[in] sr Number of kr splits. It must be 1.
276 /// @param[in] rhs_stride Row stride in bytes of the RHS matrix.
277 /// @param[in] rhs RHS matrix data buffer.
278 /// @param[in] bias Bias matrix data buffer.
279 /// @param[in] scale Scale data buffer. It must be NULL.
280 /// @param[out] rhs_packed Packed RHS matrix.
281 /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0.
282 /// @param[in] params Extra packing parameters. It must be NULL.
283 std::function<void(
284 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
285 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
286 fn_pack_rhs_nxk{nullptr};
287
288 /// Gets the offset in bytes to the data element in the bias buffer.
289 ///
290 /// @param[in] n_idx Column index.
291 ///
292 /// @return The offset in bytes to the data element.
293 std::function<size_t(size_t n_idx)> fn_get_bias_offset{nullptr};
294
295 /// Gets the offset in bytes to the data element in the destination matrix buffer.
296 ///
297 /// @param[in] m_idx Row index.
298 /// @param[in] n_idx Column index.
299 /// @param[in] stride Row stride in bytes.
300 ///
301 /// @return The offset in bytes to the data element.
302 std::function<size_t(size_t m_idx, size_t n_idx, size_t stride)> fn_get_dst_offset{nullptr};
303
304 /// Gets the size in bytes of the destination matrix buffer.
305 ///
306 /// @param[in] m Number of rows.
307 /// @param[in] n Number of columns.
308 ///
309 /// @return The size in bytes of the destination matrix buffer.
310 std::function<size_t(size_t m, size_t n)> fn_get_dst_size{nullptr};
311
312 /// Performs F16 or F32 matrix multiplication with RHS packing
313 /// followed by clamp operation.
314 ///
315 /// @param[in] m Size of the matrix in M dimension.
316 /// @param[in] n Size of the matrix in N dimension.
317 /// @param[in] k Size of the matrix in K dimension.
318 /// @param[in] lhs LHS data buffer.
319 /// @param[in] packed_rhs Packed RHS data buffer.
320 /// @param[out] dst Output data buffer.
321 /// @param[in] lhs_stride LHS row stride.
322 /// @param[in] dst_stride_row Output row stride.
323 /// @param[in] dst_stride_col Output column stride.
324 /// @param[in] clamp_min Lower bound of the output data.
325 /// @param[in] clamp_max Upper bound of the output data.
326 std::function<void(
327 size_t m, size_t n, size_t k, //
328 const void* lhs, size_t lhs_stride, //
329 const void* packed_rhs, //
330 void* dst, size_t dst_stride_row, size_t dst_stride_col, //
331 float clamp_min, float clamp_max)>
332 fn_matmul_f16_f16_f16p = nullptr;
333
334 std::function<void(
335 size_t m, size_t n, size_t k, //
336 const void* lhs, size_t lhs_stride, //
337 const void* packed_rhs, //
338 void* dst, size_t dst_stride_row, size_t dst_stride_col, //
339 float clamp_min, float clamp_max)>
340 fn_matmul_f32_f32_f32p = nullptr;
341
342 /// Performs BF16 matrix multiplication with LHS and RHS packing
343 /// followed by clamp operation.
344 ///
345 /// @param[in] m Size of the matrix in M dimension.
346 /// @param[in] n Size of the matrix in N dimension.
347 /// @param[in] k Size of the matrix in K dimension.
348 /// @param[in] packed_lhs Packed LHS data buffer.
349 /// @param[in] packed_rhs Packed RHS data buffer.
350 /// @param[out] dst Output data buffer.
351 /// @param[in] dst_stride_row Output row stride.
352 /// @param[in] dst_stride_col Output column stride.
353 /// @param[in] clamp_min Lower bound of the output data.
354 /// @param[in] clamp_max Upper bound of the output data.
355 std::function<void(
356 size_t m, size_t n, size_t k, //
357 const void* packed_lhs, //
358 const void* packed_rhs, //
359 void* dst, size_t dst_stride_row, size_t dst_stride_col, //
360 float clamp_min, float clamp_max)>
361 fn_matmul_f32_bf16p_bf16p = nullptr;
362
363 std::function<void(
364 size_t m, size_t n, size_t k, //
365 const void* packed_lhs, //
366 const void* packed_rhs, //
367 void* dst, size_t dst_stride_row, size_t dst_stride_col, //
368 float clamp_min, float clamp_max)>
369 fn_matmul_f16_bf16p_bf16p = nullptr;
370
371 /// Performs F16 or F32 matrix multiplication with LHS & RHS packing
372 /// followed by clamp operation.
373 ///
374 /// @param[in] m Number of output rows to be computed.
375 /// @param[in] n Number of output columns to be computed.
376 /// @param[in] k Common dimension of the LHS and RHS operands.
377 /// @param[in] packed_lhs Packed LHS matrix buffer.
378 /// @param[in] packed_rhs Packed RHS matrix buffer.
379 /// @param[out] dst Output matrix buffer.
380 /// @param[in] dst_stride_row Row stride in bytes of the output matrix.
381 /// @param[in] dst_stride_col Column stride in bytes of the output matrix.
382 /// @param[in] clamp_min Minimum value to clamp the final result.
383 /// @param[in] clamp_max Maximum value to clamp the final result.
384 std::function<void(
385 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
386 size_t dst_stride_col, float clamp_min, float clamp_max)>
387 fn_matmul_f16_f16p_f16p = nullptr;
388
389 std::function<void(
390 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
391 size_t dst_stride_col, float clamp_min, float clamp_max)>
392 fn_matmul_f32_f32p_f32p = nullptr;
393
394 /// Gets a value indicating whether pre-processing the RHS matrix is needed.
395 830 [[nodiscard]] bool is_pack_rhs_needed() const {
396 830 return fn_pack_rhs != nullptr;
397 }
398
399 /// Gets a value indicating whether pre-processing the transposed RHS matrix is needed.
400 440 [[nodiscard]] bool is_pack_rhs_nxk_needed() const {
401 440 return fn_pack_rhs_nxk != nullptr;
402 }
403
404 /// Preprocesses the RHS matrix.
405 ///
406 /// @param[in] n Size of the matrix in N dimension.
407 /// @param[in] k Size of the matrix in K dimension.
408 /// @param[in] rhs RHS data buffer.
409 /// @param[in] rhs_row_stride RHS row stride.
410 /// @param[in] bias Bias data buffer.
411 /// @param[in] scale Quantization scales data buffer.
412 /// @param[out] packed_rhs Packed RHS data buffer.
413 492 void pack_rhs(
414 size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale,
415 void* packed_rhs) const {
416 KAI_UNUSED(scale);
417
418
1/2
✓ Branch 0 taken 492 times.
✗ Branch 1 not taken.
492 if (fn_pack_rhs != nullptr) {
419 984 fn_pack_rhs(
420 492 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0,
421 nullptr);
422 492 } else {
423 KAI_ERROR("RHS pre-processing is not supported!");
424 }
425 492 }
426
427 /// Preprocesses the transposed RHS matrix.
428 ///
429 /// @param[in] n Size of the matrix in N dimension.
430 /// @param[in] k Size of the matrix in K dimension.
431 /// @param[in] rhs RHS data buffer.
432 /// @param[in] rhs_row_stride RHS row stride.
433 /// @param[in] bias Bias data buffer.
434 /// @param[in] scale Quantization scales data buffer.
435 /// @param[out] packed_rhs Packed RHS data buffer.
436 68 void pack_rhs_nxk(
437 size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale,
438 void* packed_rhs) const {
439 KAI_UNUSED(scale);
440
441
1/2
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
68 if (fn_pack_rhs_nxk != nullptr) {
442 136 fn_pack_rhs_nxk(
443 68 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0,
444 nullptr);
445 68 } else {
446 KAI_ERROR("RHS pre-processing is not supported!");
447 }
448 68 }
449
450 876 [[nodiscard]] bool has_main_kernel() const {
451
2/2
✓ Branch 0 taken 718 times.
✓ Branch 1 taken 158 times.
1594 return fn_matmul_f16_f16_f16p != nullptr || //
452
2/2
✓ Branch 0 taken 682 times.
✓ Branch 1 taken 36 times.
718 fn_matmul_f16_f16p_f16p != nullptr || //
453
2/2
✓ Branch 0 taken 646 times.
✓ Branch 1 taken 36 times.
682 fn_matmul_f32_f32p_f32p != nullptr || //
454
2/2
✓ Branch 0 taken 436 times.
✓ Branch 1 taken 210 times.
646 fn_matmul_f32_f32_f32p != nullptr || //
455
2/2
✓ Branch 0 taken 346 times.
✓ Branch 1 taken 90 times.
436 fn_matmul_f32_bf16p_bf16p != nullptr || //
456 90 fn_matmul_f16_bf16p_bf16p != nullptr;
457 }
458
459 796 void main_kernel(
460 size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride,
461 size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const {
462 KAI_UNUSED(bias);
463 KAI_UNUSED(rhs_stride);
464
465
2/2
✓ Branch 0 taken 140 times.
✓ Branch 1 taken 656 times.
796 if (fn_matmul_f16_f16_f16p) {
466 280 fn_matmul_f16_f16_f16p(
467 140 m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(uint16_t), clamp_min, clamp_max);
468
2/2
✓ Branch 0 taken 186 times.
✓ Branch 1 taken 470 times.
796 } else if (fn_matmul_f32_f32_f32p) {
469 186 fn_matmul_f32_f32_f32p(m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max);
470
2/2
✓ Branch 0 taken 32 times.
✓ Branch 1 taken 438 times.
656 } else if (fn_matmul_f16_f16p_f16p) {
471 32 fn_matmul_f16_f16p_f16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(Float16), clamp_min, clamp_max);
472
2/2
✓ Branch 0 taken 32 times.
✓ Branch 1 taken 406 times.
470 } else if (fn_matmul_f32_f32p_f32p) {
473 32 fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max);
474
2/2
✓ Branch 0 taken 322 times.
✓ Branch 1 taken 84 times.
438 } else if (fn_matmul_f32_bf16p_bf16p) {
475 644 fn_matmul_f32_bf16p_bf16p(
476 322 m, n, k, reinterpret_cast<const uint16_t*>(lhs), rhs, reinterpret_cast<float*>(dst), dst_stride,
477 322 sizeof(float), clamp_min, clamp_max);
478
1/2
✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
406 } else if (fn_matmul_f16_bf16p_bf16p) {
479 84 fn_matmul_f16_bf16p_bf16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(uint16_t), clamp_min, clamp_max);
480 84 } else {
481 KAI_ERROR("Main kernel is not available!");
482 }
483 796 }
484 };
485
486 // NOLINTEND(misc-non-private-member-variables-in-classes)
487
488 /// Matrix multiplication test information.
489 using MatMulTestParams = std::tuple<MatMulMethod, MatMulShape, MatrixPortion>;
490 using MatMulTestPortionedParams = std::tuple<size_t, MatMulShape, MatrixPortion>;
491 using MatMulTestPortionedParamsWithBias = std::tuple<size_t, MatMulShape, MatrixPortion, bool>;
492 using MatMulTestPortionedParamsWithBias_WithBL = std::tuple<size_t, MatMulShape, size_t, MatrixPortion, bool>;
493
494 /// Prints the test information.
495 void PrintTo(const MatMulTestParams& param, std::ostream* os);
496 void PrintTo(const MatMulShape& shape, std::ostream* os);
497 void PrintTo(const MatrixPortion& portion, std::ostream* os);
498
499 /// Generate test information.
500 std::string test_description(
501 const std::string_view& name, const MatMulShape& shape, const MatrixPortion& portion, bool bias);
502
503 } // namespace kai::test
504
505 template <>
506 struct std::hash<kai::test::MatMulShape> {
507 1243 size_t operator()(const kai::test::MatMulShape& ms) const {
508 1243 return kai::test::MatMulShape::Hash{}(ms);
509 }
510 };
511