KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 94.9% 259 4 277
Functions: 100.0% 28 0 28
Branches: 36.1% 313 24 892

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/matmul.hpp"
8
9 #include <gtest/gtest.h>
10
11 #include <array>
12 #include <cstddef>
13 #include <cstdint>
14 #include <functional>
15 #include <map>
16 #include <string_view>
17 #include <tuple>
18 #include <utility>
19
20 #include "kai/kai_common.h"
21 #include "test/common/buffer.hpp"
22 #include "test/common/compare.hpp"
23 #include "test/common/cpu_info.hpp"
24 #include "test/common/data_format.hpp"
25 #include "test/common/data_type.hpp"
26 #include "test/common/float16.hpp"
27 #include "test/common/matmul_test_common.hpp"
28 #include "test/common/matrix_portion.hpp"
29 #include "test/common/sme.hpp"
30 #include "test/reference/clamp.hpp"
31 #include "test/reference/fill.hpp"
32 #include "test/reference/pack.hpp"
33 #include "test/reference/transpose.hpp"
34
35 // matmul_clamp_f16_f16_f16p
36 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h"
37 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h"
38 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla.h"
39 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla.h"
40 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55.h"
41 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h"
42 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p32x1b_x16_x16_neon.h"
43
44 // matmul_clamp_f16_f16p_f16p
45 #include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
46 #include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h"
47 #include "kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h"
48 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
49 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h"
50
51 // matmul_clamp_f32_f32_f32p
52 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h"
53 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla.h"
54 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55.h"
55 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
56 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h"
57 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h"
58 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h"
59 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h"
60 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x32p16x1b_x32_x32_neon.h"
61
62 // matmul_clamp_f32_f32p_f32p
63 #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
64 #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
65 #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h"
66 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h"
67 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h"
68
69 namespace kai::test {
70
71 4 static const auto& get_matmul_methods() {
72 // List of supported matrix multiplication methods.
73
3/4
✓ Branch 0 taken 1 time.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
4 static std::array<MatMulMethod, 10> matmul_methods{};
74
75 matmul_methods[0].name = "matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla";
76 matmul_methods[0].m0 = 6;
77 matmul_methods[0].n0 = 16;
78 matmul_methods[0].dst_format = DataFormat(DataType::FP16);
79 matmul_methods[0].lhs_format = DataFormat(DataType::FP16);
80 matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN);
81 matmul_methods[0].rhs_format = DataFormat(DataType::FP16);
82 matmul_methods[0].packed_rhs_format = DataFormat(
83 DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1);
84 matmul_methods[0].bias_format = DataFormat(DataType::FP16);
85 matmul_methods[0].fn_is_supported = cpu_has_fp16;
86 matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
87 matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
88 matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
89 matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
90 matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
91 matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
92 matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
93 matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
94 matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
95 matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset =
96 kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
97 matmul_methods[0].fn_get_main_packed_rhs_offset =
98 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
99 matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
100 matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
101 matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
102 matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
103 matmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
104
105 matmul_methods[1].name = "matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa";
106 matmul_methods[1].m0 = 2 * get_sme_vector_length<float>();
107 matmul_methods[1].n0 = 2 * get_sme_vector_length<float>();
108 matmul_methods[1].dst_format = DataFormat(DataType::FP16);
109 matmul_methods[1].lhs_format = DataFormat(DataType::FP16);
110 matmul_methods[1].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length<float>(), 2);
111 matmul_methods[1].rhs_format = DataFormat(DataType::FP16);
112 matmul_methods[1].packed_rhs_format = DataFormat(
113 DataType::FP16, // Output type
114 2 * get_sme_vector_length<float>(), 2, // Block size
115 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
116 DataType::FP16, // Bias format
117 DataType::UNKNOWN, // Scaling type
118 2 * get_sme_vector_length<float>(), 2); // Sub-block
119 matmul_methods[1].bias_format = DataFormat(DataType::FP16);
120 matmul_methods[1].fn_is_supported = cpu_has_sme2;
121 matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
122 matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
123 matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
124 matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
125 matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
126 matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
127 matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
128 matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme;
129 matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme;
130 matmul_methods[1].fn_get_packed_lhs_offset =
131 kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
132 matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme;
133 matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
134 matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
135 matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
136 matmul_methods[1].fn_get_main_packed_rhs_offset =
137 kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
138 matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
139 matmul_methods[1].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
140 matmul_methods[1].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
141 matmul_methods[1].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
142 matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_offset =
143 kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
144 matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
145 matmul_methods[1].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
146 matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
147 matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
148 matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
149 matmul_methods[1].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
150
151 matmul_methods[2].name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla";
152 matmul_methods[2].m0 = 6;
153 matmul_methods[2].n0 = 8;
154 matmul_methods[2].dst_format = DataFormat(DataType::FP32);
155 matmul_methods[2].lhs_format = DataFormat(DataType::FP32);
156 matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN);
157 matmul_methods[2].rhs_format = DataFormat(DataType::FP32);
158 matmul_methods[2].packed_rhs_format =
159 DataFormat(DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1);
160 matmul_methods[2].bias_format = DataFormat(DataType::FP32);
161 matmul_methods[2].fn_is_supported = cpu_has_advsimd;
162 matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
163 matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
164 matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
165 matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
166 matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
167 matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
168 matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
169 matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
170 matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
171 matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset =
172 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
173 matmul_methods[2].fn_get_main_packed_rhs_offset =
174 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
175 matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
176 matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
177 matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
178 matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
179 matmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
180
181 matmul_methods[3].name = "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa";
182 matmul_methods[3].m0 = 2 * get_sme_vector_length<float>();
183 matmul_methods[3].n0 = 2 * get_sme_vector_length<float>();
184 matmul_methods[3].dst_format = DataFormat(DataType::FP32);
185 matmul_methods[3].lhs_format = DataFormat(DataType::FP32);
186 matmul_methods[3].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length<float>(), 1);
187 matmul_methods[3].rhs_format = DataFormat(DataType::FP32);
188 matmul_methods[3].packed_rhs_format = DataFormat(
189 DataType::FP32, 2 * get_sme_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32,
190 DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 1);
191 matmul_methods[3].bias_format = DataFormat(DataType::FP32);
192 matmul_methods[3].fn_is_supported = cpu_has_sme2;
193 matmul_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
194 matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
195 matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
196 matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
197 matmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
198 matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
199 matmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
200 matmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme;
201 matmul_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme;
202 matmul_methods[3].fn_get_packed_lhs_offset =
203 kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
204 matmul_methods[3].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme;
205 matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
206 matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
207 matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset =
208 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
209 matmul_methods[3].fn_get_main_packed_rhs_offset =
210 kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
211 matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
212 matmul_methods[3].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
213 matmul_methods[3].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
214 matmul_methods[3].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
215 matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_offset =
216 kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
217 matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_size =
218 kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
219 matmul_methods[3].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
220 matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
221 matmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
222 matmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
223 matmul_methods[3].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
224
225 matmul_methods[4].name = "matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa";
226 matmul_methods[4].m0 = 2 * get_sme_vector_length<float>();
227 matmul_methods[4].n0 = 2 * get_sme_vector_length<float>();
228 matmul_methods[4].dst_format = DataFormat(DataType::FP32);
229 matmul_methods[4].lhs_format = DataFormat(DataType::FP32);
230 matmul_methods[4].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length<float>(), 1);
231 matmul_methods[4].rhs_format = DataFormat(DataType::FP32);
232 matmul_methods[4].packed_rhs_format = DataFormat(
233 DataType::FP32, 2 * get_sme_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32,
234 DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 1);
235 matmul_methods[4].bias_format = DataFormat(DataType::FP32);
236 matmul_methods[4].fn_is_supported = cpu_has_sme;
237 matmul_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
238 matmul_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
239 matmul_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
240 matmul_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
241 matmul_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
242 matmul_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
243 matmul_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
244 matmul_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme;
245 matmul_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme;
246 matmul_methods[4].fn_get_packed_lhs_offset =
247 kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
248 matmul_methods[4].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme;
249 matmul_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
250 matmul_methods[4].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
251 matmul_methods[4].fn_get_pack_rhs_packed_rhs_offset =
252 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
253 matmul_methods[4].fn_get_main_packed_rhs_offset =
254 kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
255 matmul_methods[4].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
256 matmul_methods[4].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
257 matmul_methods[4].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
258 matmul_methods[4].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
259 matmul_methods[4].fn_pack_rhs_nxk_get_packed_rhs_offset =
260 kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
261 matmul_methods[4].fn_pack_rhs_nxk_get_packed_rhs_size =
262 kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
263 matmul_methods[4].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
264 matmul_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
265 matmul_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
266 matmul_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
267 matmul_methods[4].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
268
269 matmul_methods[5].name = "matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa";
270 matmul_methods[5].m0 = 2 * get_sme_vector_length<float>();
271 matmul_methods[5].n0 = 2 * get_sme_vector_length<float>();
272 matmul_methods[5].dst_format = DataFormat(DataType::FP16);
273 matmul_methods[5].lhs_format = DataFormat(DataType::FP16);
274 matmul_methods[5].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length<float>(), 2);
275 matmul_methods[5].rhs_format = DataFormat(DataType::FP16);
276 matmul_methods[5].packed_rhs_format = DataFormat(
277 DataType::FP16, // Output type
278 2 * get_sme_vector_length<float>(), 2, // Block size
279 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
280 DataType::FP16, // Bias format
281 DataType::UNKNOWN, // Scaling type
282 2 * get_sme_vector_length<float>(), 2); // Sub-block
283 matmul_methods[5].bias_format = DataFormat(DataType::FP16);
284 matmul_methods[5].fn_is_supported = cpu_has_sme;
285 matmul_methods[5].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
286 matmul_methods[5].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
287 matmul_methods[5].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
288 matmul_methods[5].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
289 matmul_methods[5].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
290 matmul_methods[5].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
291 matmul_methods[5].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
292 matmul_methods[5].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme;
293 matmul_methods[5].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme;
294 matmul_methods[5].fn_get_packed_lhs_offset =
295 kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
296 matmul_methods[5].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme;
297 matmul_methods[5].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
298 matmul_methods[5].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
299 matmul_methods[5].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
300 matmul_methods[5].fn_get_main_packed_rhs_offset =
301 kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
302 matmul_methods[5].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
303 matmul_methods[5].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
304 matmul_methods[5].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
305 matmul_methods[5].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
306 matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_offset =
307 kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
308 matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
309 matmul_methods[5].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
310 matmul_methods[5].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
311 matmul_methods[5].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
312 matmul_methods[5].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
313 matmul_methods[5].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
314
315 matmul_methods[6].name = "matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla";
316 matmul_methods[6].m0 = 6;
317 matmul_methods[6].n0 = 32;
318 matmul_methods[6].dst_format = DataFormat(DataType::FP16);
319 matmul_methods[6].lhs_format = DataFormat(DataType::FP16);
320 matmul_methods[6].packed_lhs_format = DataFormat(DataType::UNKNOWN);
321 matmul_methods[6].rhs_format = DataFormat(DataType::FP16);
322 matmul_methods[6].packed_rhs_format = DataFormat(
323 DataType::FP16, 32, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 32, 1);
324 matmul_methods[6].bias_format = DataFormat(DataType::FP16);
325 matmul_methods[6].fn_is_supported = cpu_has_fp16;
326 matmul_methods[6].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
327 matmul_methods[6].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
328 matmul_methods[6].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
329 matmul_methods[6].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
330 matmul_methods[6].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
331 matmul_methods[6].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
332 matmul_methods[6].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
333 matmul_methods[6].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
334 matmul_methods[6].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
335 matmul_methods[6].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
336 matmul_methods[6].fn_get_main_packed_rhs_offset =
337 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
338 matmul_methods[6].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
339 matmul_methods[6].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
340 matmul_methods[6].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
341 matmul_methods[6].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
342 matmul_methods[6].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
343
344 matmul_methods[7].name = "matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55";
345 matmul_methods[7].m0 = 6;
346 matmul_methods[7].n0 = 32;
347 matmul_methods[7].dst_format = DataFormat(DataType::FP16);
348 matmul_methods[7].lhs_format = DataFormat(DataType::FP16);
349 matmul_methods[7].packed_lhs_format = DataFormat(DataType::UNKNOWN);
350 matmul_methods[7].rhs_format = DataFormat(DataType::FP16);
351 matmul_methods[7].packed_rhs_format = DataFormat(
352 DataType::FP16, 32, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 32, 1);
353 matmul_methods[7].bias_format = DataFormat(DataType::FP16);
354 matmul_methods[7].fn_is_supported = cpu_has_fp16;
355 matmul_methods[7].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
356 matmul_methods[7].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
357 matmul_methods[7].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
358 matmul_methods[7].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
359 matmul_methods[7].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
360 matmul_methods[7].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
361 matmul_methods[7].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
362 matmul_methods[7].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
363 matmul_methods[7].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
364 matmul_methods[7].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
365 matmul_methods[7].fn_get_main_packed_rhs_offset =
366 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
367 matmul_methods[7].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
368 matmul_methods[7].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
369 matmul_methods[7].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
370 matmul_methods[7].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
371 matmul_methods[7].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
372
373 matmul_methods[8].name = "matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla";
374 matmul_methods[8].m0 = 6;
375 matmul_methods[8].n0 = 16;
376 matmul_methods[8].dst_format = DataFormat(DataType::FP32);
377 matmul_methods[8].lhs_format = DataFormat(DataType::FP32);
378 matmul_methods[8].packed_lhs_format = DataFormat(DataType::UNKNOWN);
379 matmul_methods[8].rhs_format = DataFormat(DataType::FP32);
380 matmul_methods[8].packed_rhs_format = DataFormat(
381 DataType::FP32, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 16, 1);
382 matmul_methods[8].bias_format = DataFormat(DataType::FP32);
383 matmul_methods[8].fn_is_supported = cpu_has_advsimd;
384 matmul_methods[8].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
385 matmul_methods[8].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
386 matmul_methods[8].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
387 matmul_methods[8].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
388 matmul_methods[8].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
389 matmul_methods[8].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
390 matmul_methods[8].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
391 matmul_methods[8].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
392 matmul_methods[8].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
393 matmul_methods[8].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
394 matmul_methods[8].fn_get_main_packed_rhs_offset =
395 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
396 matmul_methods[8].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
397 matmul_methods[8].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
398 matmul_methods[8].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
399 matmul_methods[8].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
400 matmul_methods[8].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
401
402 matmul_methods[9].name = "matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55";
403 matmul_methods[9].m0 = 6;
404 matmul_methods[9].n0 = 16;
405 matmul_methods[9].dst_format = DataFormat(DataType::FP32);
406 matmul_methods[9].lhs_format = DataFormat(DataType::FP32);
407 matmul_methods[9].packed_lhs_format = DataFormat(DataType::UNKNOWN);
408 matmul_methods[9].rhs_format = DataFormat(DataType::FP32);
409 matmul_methods[9].packed_rhs_format = DataFormat(
410 DataType::FP32, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 16, 1);
411 matmul_methods[9].bias_format = DataFormat(DataType::FP32);
412 matmul_methods[9].fn_is_supported = cpu_has_advsimd;
413 matmul_methods[9].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
414 matmul_methods[9].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
415 matmul_methods[9].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
416 matmul_methods[9].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
417 matmul_methods[9].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
418 matmul_methods[9].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
419 matmul_methods[9].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
420 matmul_methods[9].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
421 matmul_methods[9].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
422 matmul_methods[9].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
423 matmul_methods[9].fn_get_main_packed_rhs_offset =
424 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
425 matmul_methods[9].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
426 matmul_methods[9].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
427 matmul_methods[9].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
428 matmul_methods[9].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
429 matmul_methods[9].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
430
431 return matmul_methods;
432 }
433
434 4 static const auto& get_vecmul_methods() {
435 // List of supported vector by matrix multiplication methods
436
3/4
✓ Branch 0 taken 1 time.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
4 static std::array<MatMulMethod, 5> vecmul_methods{};
437
438 vecmul_methods[0].name = "matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot";
439 vecmul_methods[0].m0 = 1;
440 vecmul_methods[0].n0 = 16 * get_sme_vector_length<float>();
441 vecmul_methods[0].dst_format = DataFormat(DataType::FP16);
442 vecmul_methods[0].lhs_format = DataFormat(DataType::FP16);
443 vecmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN);
444 vecmul_methods[0].rhs_format = DataFormat(DataType::FP16);
445 vecmul_methods[0].packed_rhs_format = DataFormat(
446 DataType::FP16, // Output type
447 2 * get_sme_vector_length<float>(), 2, // Block size
448 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
449 DataType::FP16, // Bias format
450 DataType::UNKNOWN, // Scaling type
451 2 * get_sme_vector_length<float>(), 2); // Sub-block
452 vecmul_methods[0].bias_format = DataFormat(DataType::FP16);
453 vecmul_methods[0].fn_is_supported = cpu_has_sme2;
454 vecmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
455 vecmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
456 vecmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
457 vecmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
458 vecmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
459 vecmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
460 vecmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme;
461 vecmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
462 vecmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
463 vecmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
464 vecmul_methods[0].fn_get_main_packed_rhs_offset =
465 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
466 vecmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
467 vecmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
468 vecmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
469 vecmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
470 vecmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
471
472 vecmul_methods[1].name = "matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla";
473 vecmul_methods[1].m0 = 1;
474 vecmul_methods[1].n0 = 8 * get_sme_vector_length<float>();
475 vecmul_methods[1].dst_format = DataFormat(DataType::FP16);
476 vecmul_methods[1].lhs_format = DataFormat(DataType::FP16);
477 vecmul_methods[1].packed_lhs_format = DataFormat(DataType::UNKNOWN);
478 vecmul_methods[1].rhs_format = DataFormat(DataType::FP16);
479 vecmul_methods[1].packed_rhs_format = DataFormat(
480 DataType::FP16, // Output type
481 2 * get_sme_vector_length<float>(), 2, // Block size
482 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
483 DataType::FP16, // Bias format
484 DataType::UNKNOWN, // Scaling type
485 2 * get_sme_vector_length<float>(), 2); // Sub-block
486 vecmul_methods[1].bias_format = DataFormat(DataType::FP16);
487 vecmul_methods[1].fn_is_supported = cpu_has_sme;
488 vecmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
489 vecmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
490 vecmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
491 vecmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
492 vecmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
493 vecmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
494 vecmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme;
495 vecmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
496 vecmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
497 vecmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
498 vecmul_methods[1].fn_get_main_packed_rhs_offset =
499 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
500 vecmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
501 vecmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
502 vecmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
503 vecmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
504 vecmul_methods[1].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
505
506 vecmul_methods[2].name = "matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla";
507 vecmul_methods[2].m0 = 1;
508 vecmul_methods[2].n0 = 8 * get_sme_vector_length<float>();
509 vecmul_methods[2].dst_format = DataFormat(DataType::FP32);
510 vecmul_methods[2].lhs_format = DataFormat(DataType::FP32);
511 vecmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN);
512 vecmul_methods[2].rhs_format = DataFormat(DataType::FP32);
513 vecmul_methods[2].packed_rhs_format = DataFormat(
514 DataType::FP32, // Output type
515 2 * get_sme_vector_length<float>(), 1, // Block size
516 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
517 DataType::FP32, // Bias format
518 DataType::UNKNOWN, // Scaling type
519 2 * get_sme_vector_length<float>(), 1); // Sub-block
520 vecmul_methods[2].bias_format = DataFormat(DataType::FP32);
521 vecmul_methods[2].fn_is_supported = cpu_has_sme;
522 vecmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
523 vecmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
524 vecmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
525 vecmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
526 vecmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
527 vecmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
528 vecmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
529 vecmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
530 vecmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
531 vecmul_methods[2].fn_get_pack_rhs_packed_rhs_offset =
532 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
533 vecmul_methods[2].fn_get_main_packed_rhs_offset =
534 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
535 vecmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
536 vecmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
537 vecmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
538 vecmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
539 vecmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
540
541 vecmul_methods[3].name = "matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla";
542 vecmul_methods[3].m0 = 1;
543 vecmul_methods[3].n0 = 16 * get_sme_vector_length<float>();
544 vecmul_methods[3].dst_format = DataFormat(DataType::FP32);
545 vecmul_methods[3].lhs_format = DataFormat(DataType::FP32);
546 vecmul_methods[3].packed_lhs_format = DataFormat(DataType::UNKNOWN);
547 vecmul_methods[3].rhs_format = DataFormat(DataType::FP32);
548 vecmul_methods[3].packed_rhs_format = DataFormat(
549 DataType::FP32, // Output type
550 2 * get_sme_vector_length<float>(), 1, // Block size
551 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
552 DataType::FP32, // Bias format
553 DataType::UNKNOWN, // Scaling type
554 2 * get_sme_vector_length<float>(), 1); // Sub-block
555 vecmul_methods[3].bias_format = DataFormat(DataType::FP32);
556 vecmul_methods[3].fn_is_supported = cpu_has_sme2;
557 vecmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
558 vecmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
559 vecmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
560 vecmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
561 vecmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
562 vecmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
563 vecmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
564 vecmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
565 vecmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
566 vecmul_methods[3].fn_get_pack_rhs_packed_rhs_offset =
567 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
568 vecmul_methods[3].fn_get_main_packed_rhs_offset =
569 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
570 vecmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
571 vecmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
572 vecmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
573 vecmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
574 vecmul_methods[3].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
575
576 vecmul_methods[4].name = "matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla";
577 vecmul_methods[4].m0 = 1;
578 vecmul_methods[4].n0 = 16 * get_sme_vector_length<float>();
579 vecmul_methods[4].dst_format = DataFormat(DataType::FP32);
580 vecmul_methods[4].lhs_format = DataFormat(DataType::FP32);
581 vecmul_methods[4].packed_lhs_format = DataFormat(DataType::UNKNOWN);
582 vecmul_methods[4].rhs_format = DataFormat(DataType::FP32);
583 vecmul_methods[4].packed_rhs_format = DataFormat(
584 DataType::FP32, // Output type
585 16 * get_sme_vector_length<float>(), 1, // Block size
586 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
587 DataType::FP32, // Bias format
588 DataType::UNKNOWN, // Scaling type
589 16 * get_sme_vector_length<float>(), 1); // Sub-block
590 vecmul_methods[4].bias_format = DataFormat(DataType::FP32);
591 vecmul_methods[4].fn_is_supported = cpu_has_sme2;
592 vecmul_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
593 vecmul_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
594 vecmul_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
595 vecmul_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
596 vecmul_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
597 vecmul_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
598 vecmul_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
599 vecmul_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
600 vecmul_methods[4].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
601 vecmul_methods[4].fn_get_pack_rhs_packed_rhs_offset =
602 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
603 vecmul_methods[4].fn_get_main_packed_rhs_offset =
604 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
605 vecmul_methods[4].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
606 vecmul_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
607 vecmul_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
608 vecmul_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
609 vecmul_methods[4].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
610
611 return vecmul_methods;
612 }
613
614 /// Matrix multiplication test fixture.
615 class MatMulTest : public testing::TestWithParam<MatMulTestParams> {
616 private:
617 /// Unique ID: m, n, k, method_id.
618 using TestDataId = std::tuple<size_t, size_t, size_t, std::string_view>;
619
620 protected:
621 /// Cached test data that is shared between multiple test case.
622 125 struct TestData {
623 250 Buffer lhs{}; ///< LHS operand.
624 250 Buffer ref_packed_lhs{}; ///< Reference packed LHS.
625 250 Buffer rhs{}; ///< RHS operand.
626 250 Buffer rhs_scales{}; ///< RHS per-row quantization scales.
627 250 Buffer bias{}; ///< Bias.
628 250 Buffer rhs_t{}; ///< Transposed RHS matrix.
629 250 Buffer ref_packed_rhs{}; ///< Reference packed RHS.
630 250 Buffer ref_dst{}; ///< Reference output.
631 125 float clamp_min{}; ///< Minimum output value.
632 125 float clamp_max{}; ///< Maximum output value.
633 };
634
635 /// Gets the test data for the current test case.
636 1024 static const TestData& test_data() {
637 7442 const auto& [method, info, portion] = GetParam();
638 5120 const TestDataId data_id{info.m, info.n, info.k, method.name};
639
640 // If the test data is already available, returns it.
641 1024 const auto data_it = _data.find(data_id);
642
643
2/2
✓ Branch 0 taken 899 times.
✓ Branch 1 taken 125 times.
1024 if (data_it != _data.end()) {
644 899 return data_it->second;
645 }
646
647 // Generates the test data.
648 250 const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN;
649 250 const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN;
650 250 const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN;
651
652 250 const auto lhs_h = info.m;
653 250 const auto lhs_w = info.k;
654 250 auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0);
655 125 Buffer ref_packed_lhs;
656
657
2/2
✓ Branch 0 taken 101 times.
✓ Branch 1 taken 24 times.
125 if (has_lhs_pack) {
658 24 ref_packed_lhs =
659
3/6
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
48 pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w);
660 24 }
661
662 250 const auto rhs_h = info.k;
663 250 const auto rhs_w = info.n;
664
2/4
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
250 auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1);
665
666 KAI_ASSUME(method.rhs_format.is_raw());
667
3/6
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
125 auto rhs_t = transpose(rhs.data(), method.rhs_format.data_type(), rhs_h, rhs_w);
668
669 125 Buffer rhs_scales;
670
3/8
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 125 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
125 if (data_type_is_quantized(method.rhs_format.data_type()) &&
671 method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) {
672 rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2);
673 }
674
675 125 const auto bias_h = 1;
676 250 const auto bias_w = info.n;
677 125 Buffer bias;
678
679
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 125 times.
125 if (has_bias) {
680
2/4
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
250 bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3);
681 125 }
682
683 125 Buffer packed_rhs;
684
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 125 times.
125 if (has_rhs_pack) {
685
1/2
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
250 packed_rhs = matmul_pack_rhs(
686
3/6
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
125 rhs.data(), rhs_scales.data(), bias.data(), method.rhs_format, method.packed_rhs_format, info.n, info.k,
687 true);
688 125 }
689
690 KAI_ASSUME(method.lhs_format.is_raw());
691 KAI_ASSUME(method.rhs_format.is_raw());
692 KAI_ASSUME(method.dst_format.is_raw());
693
1/2
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
250 auto ref_dst = matmul(
694
2/4
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
125 lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), //
695
3/6
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
125 rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), //
696
2/4
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
125 bias.data(), nullptr, nullptr, method.bias_format.data_type(), //
697
1/2
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
125 method.dst_format.data_type(), //
698 375 info.m, info.n, info.k, false, false);
699
700 static constexpr float clamp_ratio = 0.8F;
701 750 const auto [clamp_min, clamp_max] =
702
4/8
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125 times.
✗ Branch 7 not taken.
125 find_clamp_range(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_ratio);
703
7/14
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 125 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 125 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 125 times.
✗ Branch 13 not taken.
250 ref_dst = clamp(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_min, clamp_max);
704
705
9/18
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 125 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 125 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 125 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 125 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 125 times.
✗ Branch 17 not taken.
1125 auto& data = _data[data_id] = {};
706 125 data.lhs = std::move(lhs);
707 125 data.ref_packed_lhs = std::move(ref_packed_lhs);
708 125 data.rhs = std::move(rhs);
709 125 data.rhs_scales = std::move(rhs_scales);
710 125 data.bias = std::move(bias);
711 125 data.rhs_t = std::move(rhs_t);
712 125 data.ref_packed_rhs = std::move(packed_rhs);
713 125 data.ref_dst = std::move(ref_dst);
714 125 data.clamp_min = clamp_min;
715 125 data.clamp_max = clamp_max;
716
717 125 return data;
718 1024 }
719
720 private:
721 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
722 static std::map<TestDataId, TestData> _data;
723 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
724 };
725
726 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
727 1 std::map<MatMulTest::TestDataId, MatMulTest::TestData> MatMulTest::_data;
728 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
729
730 /// Tests the LHS packing micro-kernel.
731
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1322 TEST_P(MatMulTest, PackedLhs) {
732 1460 const auto& [method, info, portion] = GetParam();
733
734
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 if (method.fn_is_supported && !method.fn_is_supported()) {
735 GTEST_SKIP() << "Unsupported CPU feature";
736 }
737
738
2/2
✓ Branch 0 taken 72 times.
✓ Branch 1 taken 368 times.
440 if (!method.is_pack_lhs_needed()) {
739
3/6
✓ Branch 0 taken 368 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 368 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 368 times.
✗ Branch 5 not taken.
368 GTEST_SKIP() << "Test not valid w/o LHS pack";
740 }
741
742 72 const auto& data = test_data();
743 144 const auto lhs_h = info.m;
744 144 const auto lhs_w = info.k;
745
746 144 const auto rect = portion.compute_portion(
747 144 lhs_h, lhs_w, method.packed_lhs_format.scheduler_block_height(lhs_h),
748 72 lhs_w); // LHS packing micro-kernel API doesn't support scheduling over K dimension.
749
750
3/4
✓ Branch 0 taken 68 times.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 68 times.
72 if (rect.height() == 0 || rect.width() == 0) {
751
9/18
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
4 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
752 }
753
754 136 const auto mr = method.fn_get_mr();
755 136 const auto kr = method.fn_get_kr();
756 136 const auto sr = method.fn_get_sr();
757 136 const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(lhs_w);
758
759 272 const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k, mr, kr, sr);
760 136 const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(lhs_h, lhs_w);
761
3/14
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
68 ASSERT_EQ(packed_lhs_size, ref_packed_lhs_size);
762
763 136 const auto lhs_offset = method.fn_get_lhs_offset(rect.start_row(), ref_lhs_row_stride);
764 136 const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), lhs_w);
765
3/14
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
68 ASSERT_EQ(lhs_offset, ref_lhs_offset);
766
767 204 const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect.start_row(), info.k);
768 136 const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w);
769
3/14
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
68 ASSERT_EQ(packed_lhs_offset, ref_packed_lhs_offset);
770
771 68 Buffer packed_lhs(packed_lhs_size, 0);
772
1/2
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
136 method.fn_pack_lhs(
773
3/6
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
68 rect.height(), rect.width(), mr, kr, sr, 0, data.lhs.data() + lhs_offset, ref_lhs_row_stride,
774
1/2
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
68 packed_lhs.data() + packed_lhs_offset);
775
776
1/2
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
68 DefaultMismatchHandler handler(0, 0.0001, 0, 0.001);
777 136 const auto success =
778
3/6
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
68 compare(packed_lhs.data(), data.ref_packed_lhs.data(), method.packed_lhs_format, lhs_h, lhs_w, rect, handler);
779
4/16
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 68 times.
68 ASSERT_TRUE(success);
780 440 }
781
782 /// Tests the RHS packing micro-kernel.
783
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1322 TEST_P(MatMulTest, PackedRhs) {
784 6040 const auto& [method, info, portion] = GetParam();
785
786
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 if (method.fn_is_supported && !method.fn_is_supported()) {
787 GTEST_SKIP() << "Unsupported CPU feature";
788 }
789
790
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 if (!method.is_pack_rhs_needed()) {
791 GTEST_SKIP() << "Test not valid w/o RHS pack";
792 }
793
794 440 const auto& data = test_data();
795 880 const auto rhs_full_width = info.n;
796 880 const auto rhs_full_height = info.k;
797
798 880 const auto block_height = method.packed_rhs_format.scheduler_block_height(rhs_full_width);
799 880 const auto block_width = method.packed_rhs_format.scheduler_block_width(rhs_full_height);
800
801 880 const Rect rect = portion.compute_portion(rhs_full_width, rhs_full_height, block_height, block_width);
802
803
4/4
✓ Branch 0 taken 430 times.
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 30 times.
✓ Branch 3 taken 400 times.
440 if (rect.height() == 0 || rect.width() == 0) {
804
9/18
✓ Branch 0 taken 40 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 40 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 40 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 40 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 40 times.
✗ Branch 17 not taken.
40 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
805 }
806
807 400 const auto rhs_start_row = rect.start_row();
808 400 const auto rhs_start_col = rect.start_col();
809 400 const auto width = rect.width();
810 400 const auto height = rect.height();
811 800 const auto rhs_row_stride = method.rhs_format.default_row_stride(rhs_full_width);
812
813 /** Ensure that all relevant parameters are sane **/
814 800 const auto n_step = method.fn_get_pack_rhs_n_step();
815 400 const auto ref_n_step = block_height;
816
3/14
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
400 ASSERT_EQ(n_step, ref_n_step);
817
818 800 const auto rhs_offset = method.fn_get_rhs_offset(rhs_start_row);
819 800 const auto ref_rhs_offset =
820 400 method.rhs_format.default_offset_in_bytes(rhs_start_col, rhs_start_row, rhs_full_height);
821
3/14
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
400 ASSERT_EQ(rhs_offset, ref_rhs_offset);
822
823 800 const auto packed_rhs_size = method.fn_get_packed_rhs_size(rhs_full_width, rhs_full_height);
824 800 const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(rhs_full_width, rhs_full_height);
825
3/14
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
400 ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size);
826
827 800 const auto packed_rhs_offset = method.fn_get_pack_rhs_packed_rhs_offset(rhs_start_row, rhs_full_height);
828 800 const auto ref_packed_rhs_offset =
829 400 method.packed_rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_full_height);
830
3/14
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
400 ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset);
831
832 800 const auto scale_type = method.packed_rhs_format.scale_data_type();
833 400 const auto ref_rhs_scales_offset = rhs_start_row * data_type_size_in_bits(scale_type) / 8;
834
835 800 const auto bias_offset = method.fn_get_bias_offset(rhs_start_row);
836 800 const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rhs_start_row, rhs_full_height);
837
3/14
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
400 ASSERT_EQ(bias_offset, ref_bias_offset);
838
839 /** Perform RHS packing, and compare with reference result **/
840 400 Buffer packed_rhs(packed_rhs_size, 0);
841
1/2
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
400 method.pack_rhs(
842
2/4
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
400 height, width, data.rhs.data() + rhs_offset, rhs_row_stride, data.bias.data() + bias_offset,
843
2/6
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 400 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
400 data.rhs_scales.data() != nullptr ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr,
844
1/2
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
400 packed_rhs.data() + packed_rhs_offset);
845
846
2/4
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
800 const bool exact = method.packed_rhs_format.pack_format() != DataFormat::PackFormat::QUANTIZE_PER_ROW;
847
1/2
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
400 DefaultMismatchHandler handler(0, exact ? 0 : 0.0001, 0, exact ? 0 : 0.001);
848
1/2
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
800 const auto success = compare(
849
2/4
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
400 packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, rhs_full_width, rhs_full_height, rect,
850 handler);
851
4/16
✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 400 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 400 times.
400 ASSERT_TRUE(success);
852 440 }
853
854 /// Tests the transposed RHS packing micro-kernel.
855
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1322 TEST_P(MatMulTest, PackedTransposedRhs) {
856 2504 const auto& [method, info, portion] = GetParam();
857
858
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 if (method.fn_is_supported && !method.fn_is_supported()) {
859 GTEST_SKIP() << "Unsupported CPU feature";
860 }
861
862
2/2
✓ Branch 0 taken 72 times.
✓ Branch 1 taken 368 times.
440 if (!method.is_pack_rhs_nxk_needed()) {
863
3/6
✓ Branch 0 taken 368 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 368 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 368 times.
✗ Branch 5 not taken.
368 GTEST_SKIP() << "Test not valid w/o pre-processing of transposed RHS matrix";
864 }
865
866 72 const auto& data = test_data();
867 144 const auto n_step = method.fn_pack_rhs_nxk_get_n_step();
868 216 const auto ref_n_step = method.packed_rhs_format.scheduler_block_height(info.n);
869
3/14
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 72 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 72 times.
72 ASSERT_EQ(n_step, ref_n_step);
870
871 144 const auto rect = portion.compute_portion(
872 288 info.n, info.k, method.packed_rhs_format.scheduler_block_height(info.n),
873 144 method.packed_rhs_format.scheduler_block_width(info.k));
874
875
3/4
✓ Branch 0 taken 68 times.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 68 times.
72 if (rect.height() == 0 || rect.width() == 0) {
876
9/18
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
4 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
877 }
878
879 204 const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(info.k);
880
881 136 const auto rhs_offset = method.fn_pack_rhs_nxk_get_rhs_offset(rect.start_row(), ref_rhs_row_stride);
882 204 const auto ref_rhs_offset = method.rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k);
883
3/14
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
68 ASSERT_EQ(rhs_offset, ref_rhs_offset);
884
885 272 const auto packed_rhs_size = method.fn_pack_rhs_nxk_get_packed_rhs_size(info.n, info.k);
886 272 const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(info.n, info.k);
887
3/14
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
68 ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size);
888
889 204 const auto packed_rhs_offset = method.fn_pack_rhs_nxk_get_packed_rhs_offset(rect.start_row(), info.k);
890 136 const auto ref_packed_rhs_offset =
891 136 method.packed_rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k);
892
3/14
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
68 ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset);
893
894 136 const auto ref_rhs_scales_offset =
895 136 rect.start_row() * data_type_size_in_bits(method.packed_rhs_format.scale_data_type()) / 8;
896
897 136 const auto bias_offset = method.fn_get_bias_offset(rect.start_row());
898 204 const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), info.n);
899
3/14
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
68 ASSERT_EQ(bias_offset, ref_bias_offset);
900
901 68 Buffer packed_rhs(packed_rhs_size, 0);
902
903
1/2
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
68 method.pack_rhs_nxk(
904
4/8
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 68 times.
✗ Branch 7 not taken.
68 rect.height(), rect.width(), data.rhs_t.data() + rhs_offset, ref_rhs_row_stride, data.bias.data() + bias_offset,
905
2/6
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 68 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
68 data.rhs_scales.data() != nullptr ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr,
906
1/2
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
68 packed_rhs.data() + packed_rhs_offset);
907
908
2/4
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
136 const auto exact = method.packed_rhs_format.pack_format() != DataFormat::PackFormat::QUANTIZE_PER_ROW;
909
1/2
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
68 DefaultMismatchHandler handler(0, exact ? 0 : 0.0001, 0, exact ? 0 : 0.001);
910 136 const auto success =
911
5/10
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 68 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 68 times.
✗ Branch 9 not taken.
68 compare(packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, info.n, info.k, rect, handler);
912
4/16
✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 68 times.
68 ASSERT_TRUE(success);
913 440 }
914
915 /// Tests the output.
916
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1322 TEST_P(MatMulTest, Output) {
917 12050 const auto& [method, info, portion] = GetParam();
918
919
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 if (method.fn_is_supported && !method.fn_is_supported()) {
920 GTEST_SKIP() << "Unsupported CPU feature";
921 }
922
923
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 if (!method.has_main_kernel()) {
924 GTEST_SKIP() << "No main kernel available";
925 }
926
927 440 const auto& data = test_data();
928 880 const auto m_step = method.fn_get_main_m_step();
929
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
880 ASSERT_EQ(m_step, method.m0);
930
931 880 const auto n_step = method.fn_get_main_n_step();
932
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
880 ASSERT_EQ(n_step, method.n0);
933
934 2200 const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0);
935
936
4/4
✓ Branch 0 taken 430 times.
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 40 times.
✓ Branch 3 taken 390 times.
440 if (rect.height() == 0 || rect.width() == 0) {
937
9/18
✓ Branch 0 taken 50 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 50 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 50 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 50 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 50 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 50 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 50 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 50 times.
✗ Branch 17 not taken.
50 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
938 }
939
940 780 const auto lhs_w = info.k;
941 780 const auto rhs_w = info.n;
942 780 const auto bias_w = info.n;
943 780 const auto dst_w = info.n;
944
945 390 const auto lhs_start_row = rect.start_row();
946 390 const auto lhs_start_col = 0;
947 780 const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w);
948
949 390 const std::byte* lhs_data = nullptr;
950 390 uintptr_t lhs_offset = 0;
951
952
2/2
✓ Branch 0 taken 64 times.
✓ Branch 1 taken 326 times.
390 if (method.is_pack_lhs_needed()) {
953 64 lhs_data = data.ref_packed_lhs.data();
954
955 128 const auto ref_packed_lhs_offset =
956 128 method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k);
957 128 lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k);
958
3/14
✓ Branch 0 taken 64 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 64 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 64 times.
64 ASSERT_EQ(lhs_offset, ref_packed_lhs_offset);
959
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 64 times.
64 } else {
960 326 lhs_data = data.lhs.data();
961
962 326 lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride);
963 652 const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, lhs_w);
964
3/14
✓ Branch 0 taken 326 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 326 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 326 times.
326 ASSERT_EQ(lhs_offset, ref_lhs_offset);
965 326 }
966
967 780 const auto rhs_stride = method.rhs_format.default_row_stride(rhs_w);
968
969 390 const std::byte* rhs_data = nullptr;
970 390 uintptr_t rhs_offset = 0;
971
972
1/2
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
390 if (method.is_pack_rhs_needed()) {
973 390 const auto packed_rhs_start_row = rect.start_col();
974 390 const auto packed_rhs_start_col = 0;
975
976 390 rhs_data = data.ref_packed_rhs.data();
977
978 780 rhs_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k);
979 780 const auto ref_rhs_offset =
980 780 method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k);
981
3/14
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 390 times.
390 ASSERT_EQ(rhs_offset, ref_rhs_offset);
982
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 390 times.
390 } else {
983 const auto rhs_start_row = 0;
984 const auto rhs_start_col = rect.start_col();
985
986 rhs_data = data.rhs.data();
987 rhs_offset = method.rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_w);
988 }
989
990 390 const auto* bias_data = data.bias.data();
991 780 const auto bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), bias_w);
992
993 780 const auto dst_stride = method.dst_format.default_row_stride(dst_w);
994 780 const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
995 780 const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w);
996
3/14
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 390 times.
390 ASSERT_EQ(dst_offset, ref_dst_offset);
997
998 1560 const auto dst_size = method.fn_get_dst_size(info.m, info.n);
999 1560 const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n);
1000
3/14
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 390 times.
390 ASSERT_EQ(dst_size, ref_dst_size);
1001
1002 390 Buffer dst(dst_size, 0);
1003
1004
1/2
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
390 method.main_kernel(
1005
2/4
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
390 rect.height(), rect.width(), info.k, lhs_data + lhs_offset, rhs_data + rhs_offset, bias_data + bias_offset,
1006
1/2
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
390 dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, data.clamp_min, data.clamp_max);
1007
1008
1/2
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
390 DefaultMismatchHandler handler(0, 0.1, 0, 0.05);
1009
5/10
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 390 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 390 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 390 times.
✗ Branch 9 not taken.
390 const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler);
1010
4/16
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 390 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 390 times.
390 ASSERT_TRUE(success);
1011 440 }
1012
1013
13/40
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 720 times.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
725 INSTANTIATE_TEST_SUITE_P(
1014 MatMul, MatMulTest,
1015 testing::Combine(
1016 testing::ValuesIn(get_matmul_methods()),
1017 testing::Values(
1018 MatMulShape{1, 16, 16}, //
1019 MatMulShape{20, 1, 20}, //
1020 MatMulShape{6, 16, 32}, //
1021 MatMulShape{12, 32, 17}, //
1022 MatMulShape{13, 33, 23}, //
1023 MatMulShape{87, 93, 56} //
1024 ),
1025 testing::Values(
1026 MatrixPortion(0, 0, 1, 1), // Full matrix.
1027 MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner.
1028 MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner.
1029 )),
1030 testing::PrintToStringParamName());
1031
1032
14/44
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 4 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1040 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
1045 INSTANTIATE_TEST_SUITE_P(
1033 VecMul, MatMulTest,
1034 testing::Combine(
1035 testing::ValuesIn(get_vecmul_methods()),
1036 testing::Values(
1037 MatMulShape{1, 16, 16}, //
1038 MatMulShape{1, 1, 20}, //
1039 MatMulShape{1, 16, 32}, //
1040 MatMulShape{1, 32, 17}, //
1041 MatMulShape{1, 33, 23}, //
1042 MatMulShape{1, 1500, 20}, //
1043 MatMulShape{1, 93, 56}, //
1044 MatMulShape{1, 1, 1}, //
1045 MatMulShape{1, 16, 1}, //
1046 MatMulShape{1, 32, 64}, //
1047 MatMulShape{1, 7, 74}, //
1048 MatMulShape{1, 800, 64}, //
1049 MatMulShape{1, 512, 130} //
1050 ),
1051 testing::Values(
1052 MatrixPortion(0, 0, 1, 1), // Full row.
1053 MatrixPortion(0, 0, 1, 0.5), // First half
1054 MatrixPortion(0, .4, 1, 0.3), // mid row-section.
1055 MatrixPortion(0, 0.75, 1, .25) // right row section
1056 )),
1057 testing::PrintToStringParamName());
1058
1059 } // namespace kai::test
1060