KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 87.0% 388 / 4 / 450
Functions: 100.0% 82 / 0 / 82
Branches: 36.3% 444 / 16 / 1238

test/tests/matmul_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/matmul.hpp"
8
9 #include <gtest/gtest.h>
10
11 #include <array>
12 #include <cstddef>
13 #include <cstdint>
14 #include <functional>
15 #include <map>
16 #include <string_view>
17 #include <tuple>
18 #include <utility>
19
20 #include "kai/kai_common.h"
21 #include "test/common/abi_checker.hpp"
22 #include "test/common/buffer.hpp"
23 #include "test/common/compare.hpp"
24 #include "test/common/cpu_info.hpp"
25 #include "test/common/data_format.hpp"
26 #include "test/common/data_type.hpp"
27 #include "test/common/matmul_test_common.hpp"
28 #include "test/common/matrix_portion.hpp"
29 #include "test/common/seed.hpp"
30 #include "test/common/sme.hpp"
31 #include "test/common/sve.hpp"
32 #include "test/reference/clamp.hpp"
33 #include "test/reference/fill.hpp"
34 #include "test/reference/generators.hpp"
35 #include "test/reference/pack.hpp"
36 #include "test/reference/transpose.hpp"
37
38 // matmul_clamp_f16_f16_f16p
39 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h"
40 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h"
41 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla.h"
42 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla.h"
43 #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55.h"
44 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h"
45 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p32x1b_x16_x16_neon.h"
46
47 // matmul_clamp_f16_f16p_f16p
48 #include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
49 #include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h"
50 #include "kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h"
51 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
52 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h"
53
54 // matmul_clamp_f32_f32_f32p
55 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h"
56 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla.h"
57 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55.h"
58 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
59 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h"
60 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla.h"
61 #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h"
62 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h"
63 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h"
64 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x32p16x1b_x32_x32_neon.h"
65 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve.h"
66
67 // matmul_clamp_f32_f32p_f32p
68 #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
69 #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
70 #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h"
71 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h"
72 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h"
73
74 namespace kai::test {
75
76 12 static const auto& get_matmul_methods() {
77 // List of supported matrix multiplication methods.
78
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
12 static std::array<MatMulMethod, 7> matmul_methods{};
79
80 matmul_methods[0].name = "matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla";
81 matmul_methods[0].m0 = 6;
82 matmul_methods[0].n0 = 16;
83 matmul_methods[0].dst_format = DataFormat(DataType::FP16);
84 matmul_methods[0].lhs_format = DataFormat(DataType::FP16);
85 matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN);
86 matmul_methods[0].rhs_format = DataFormat(DataType::FP16);
87 matmul_methods[0].packed_rhs_format = DataFormat(
88 DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1);
89 matmul_methods[0].bias_format = DataFormat(DataType::FP16);
90 36 matmul_methods[0].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
91
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return NormalRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
92 };
93 36 matmul_methods[0].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
94
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return NormalRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
95 };
96 36 matmul_methods[0].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
97 KAI_UNUSED(null_bias_mode);
98
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return NormalRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
99 };
100 matmul_methods[0].fn_is_supported = cpu_has_fp16;
101 matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
102 matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
103 matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
104 matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
105 matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
106 matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
107 matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
108 matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
109 matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
110 matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset =
111 kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
112 matmul_methods[0].fn_get_main_packed_rhs_offset =
113 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
114 matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
115 matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon;
116 matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
117 matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
118 matmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla;
119
120 matmul_methods[1].name = "matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa";
121 matmul_methods[1].m0 = 2 * get_sme_vector_length<float>();
122 matmul_methods[1].n0 = 2 * get_sme_vector_length<float>();
123 matmul_methods[1].dst_format = DataFormat(DataType::FP16);
124 matmul_methods[1].lhs_format = DataFormat(DataType::FP16);
125 matmul_methods[1].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length<float>(), 2);
126 matmul_methods[1].rhs_format = DataFormat(DataType::FP16);
127 matmul_methods[1].packed_rhs_format = DataFormat(
128 DataType::FP16, // Output type
129 2 * get_sme_vector_length<float>(), 2, // Block size
130 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
131 DataType::FP16, // Bias format
132 DataType::UNKNOWN, // Scaling type
133 2 * get_sme_vector_length<float>(), 2); // Sub-block
134 matmul_methods[1].bias_format = DataFormat(DataType::FP16);
135 18 matmul_methods[1].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
136
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
137 };
138 18 matmul_methods[1].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
139
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
140 };
141 18 matmul_methods[1].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
142 KAI_UNUSED(null_bias_mode);
143
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
144 };
145 matmul_methods[1].fn_is_supported = cpu_has_sme2;
146 matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
147 matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
148 matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
149 matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
150 matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
151 matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
152 matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
153 matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme;
154 matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme;
155 matmul_methods[1].fn_get_packed_lhs_offset =
156 kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
157 matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme;
158 matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
159 matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
160 matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
161 matmul_methods[1].fn_get_main_packed_rhs_offset =
162 kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
163 matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
164 matmul_methods[1].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
165 matmul_methods[1].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
166 matmul_methods[1].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
167 matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_offset =
168 kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
169 matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
170 matmul_methods[1].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
171 matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
172 matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
173 matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
174 matmul_methods[1].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
175
176 matmul_methods[2].name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla";
177 matmul_methods[2].m0 = 6;
178 matmul_methods[2].n0 = 8;
179 matmul_methods[2].dst_format = DataFormat(DataType::FP32);
180 matmul_methods[2].lhs_format = DataFormat(DataType::FP32);
181 matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN);
182 matmul_methods[2].rhs_format = DataFormat(DataType::FP32);
183 matmul_methods[2].packed_rhs_format =
184 DataFormat(DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1);
185 matmul_methods[2].bias_format = DataFormat(DataType::FP32);
186 36 matmul_methods[2].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
187
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
188 };
189 36 matmul_methods[2].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
190
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
191 };
192 36 matmul_methods[2].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
193 KAI_UNUSED(null_bias_mode);
194
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
195 };
196 matmul_methods[2].fn_is_supported = cpu_has_advsimd;
197 matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
198 matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
199 matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
200 matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
201 matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
202 matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
203 matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
204 matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
205 matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
206 matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset =
207 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
208 matmul_methods[2].fn_get_main_packed_rhs_offset =
209 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
210 matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
211 matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon;
212 matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
213 matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
214 matmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla;
215
216 matmul_methods[3].name = "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa";
217 matmul_methods[3].m0 = 2 * get_sme_vector_length<float>();
218 matmul_methods[3].n0 = 2 * get_sme_vector_length<float>();
219 matmul_methods[3].dst_format = DataFormat(DataType::FP32);
220 matmul_methods[3].lhs_format = DataFormat(DataType::FP32);
221 matmul_methods[3].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length<float>(), 1);
222 matmul_methods[3].rhs_format = DataFormat(DataType::FP32);
223 matmul_methods[3].packed_rhs_format = DataFormat(
224 DataType::FP32, 2 * get_sme_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32,
225 DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 1);
226 matmul_methods[3].bias_format = DataFormat(DataType::FP32);
227 18 matmul_methods[3].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
228
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
229 };
230 18 matmul_methods[3].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
231
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
232 };
233 18 matmul_methods[3].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
234 KAI_UNUSED(null_bias_mode);
235
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
236 };
237 matmul_methods[3].fn_is_supported = cpu_has_sme2;
238 matmul_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
239 matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
240 matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
241 matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
242 matmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
243 matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
244 matmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
245 matmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme;
246 matmul_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme;
247 matmul_methods[3].fn_get_packed_lhs_offset =
248 kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
249 matmul_methods[3].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme;
250 matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
251 matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
252 matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset =
253 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
254 matmul_methods[3].fn_get_main_packed_rhs_offset =
255 kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
256 matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
257 matmul_methods[3].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
258 matmul_methods[3].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
259 matmul_methods[3].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
260 matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_offset =
261 kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
262 matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_size =
263 kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
264 matmul_methods[3].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
265 matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
266 matmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
267 matmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
268 matmul_methods[3].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa;
269
270 matmul_methods[4].name = "matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa";
271 matmul_methods[4].m0 = 2 * get_sme_vector_length<float>();
272 matmul_methods[4].n0 = 2 * get_sme_vector_length<float>();
273 matmul_methods[4].dst_format = DataFormat(DataType::FP32);
274 matmul_methods[4].lhs_format = DataFormat(DataType::FP32);
275 matmul_methods[4].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length<float>(), 1);
276 matmul_methods[4].rhs_format = DataFormat(DataType::FP32);
277 matmul_methods[4].packed_rhs_format = DataFormat(
278 DataType::FP32, 2 * get_sme_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32,
279 DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 1);
280 matmul_methods[4].bias_format = DataFormat(DataType::FP32);
281 18 matmul_methods[4].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
282
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
283 };
284 18 matmul_methods[4].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
285
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
286 };
287 18 matmul_methods[4].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
288 KAI_UNUSED(null_bias_mode);
289
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
290 };
291 matmul_methods[4].fn_is_supported = cpu_has_sme;
292 matmul_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
293 matmul_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
294 matmul_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
295 matmul_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
296 matmul_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
297 matmul_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
298 matmul_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
299 matmul_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme;
300 matmul_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme;
301 matmul_methods[4].fn_get_packed_lhs_offset =
302 kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
303 matmul_methods[4].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme;
304 matmul_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
305 matmul_methods[4].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
306 matmul_methods[4].fn_get_pack_rhs_packed_rhs_offset =
307 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
308 matmul_methods[4].fn_get_main_packed_rhs_offset =
309 kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
310 matmul_methods[4].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
311 matmul_methods[4].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
312 matmul_methods[4].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
313 matmul_methods[4].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
314 matmul_methods[4].fn_pack_rhs_nxk_get_packed_rhs_offset =
315 kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
316 matmul_methods[4].fn_pack_rhs_nxk_get_packed_rhs_size =
317 kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
318 matmul_methods[4].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme;
319 matmul_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
320 matmul_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
321 matmul_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
322 matmul_methods[4].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
323
324 matmul_methods[5].name = "matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa";
325 matmul_methods[5].m0 = 2 * get_sme_vector_length<float>();
326 matmul_methods[5].n0 = 2 * get_sme_vector_length<float>();
327 matmul_methods[5].dst_format = DataFormat(DataType::FP16);
328 matmul_methods[5].lhs_format = DataFormat(DataType::FP16);
329 matmul_methods[5].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length<float>(), 2);
330 matmul_methods[5].rhs_format = DataFormat(DataType::FP16);
331 matmul_methods[5].packed_rhs_format = DataFormat(
332 DataType::FP16, // Output type
333 2 * get_sme_vector_length<float>(), 2, // Block size
334 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
335 DataType::FP16, // Bias format
336 DataType::UNKNOWN, // Scaling type
337 2 * get_sme_vector_length<float>(), 2); // Sub-block
338 matmul_methods[5].bias_format = DataFormat(DataType::FP16);
339 18 matmul_methods[5].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
340
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
341 };
342 18 matmul_methods[5].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
343
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
344 };
345 18 matmul_methods[5].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
346 KAI_UNUSED(null_bias_mode);
347
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
348 };
349 matmul_methods[5].fn_is_supported = cpu_has_sme;
350 matmul_methods[5].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
351 matmul_methods[5].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
352 matmul_methods[5].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
353 matmul_methods[5].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
354 matmul_methods[5].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
355 matmul_methods[5].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
356 matmul_methods[5].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
357 matmul_methods[5].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme;
358 matmul_methods[5].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme;
359 matmul_methods[5].fn_get_packed_lhs_offset =
360 kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
361 matmul_methods[5].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme;
362 matmul_methods[5].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
363 matmul_methods[5].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
364 matmul_methods[5].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
365 matmul_methods[5].fn_get_main_packed_rhs_offset =
366 kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
367 matmul_methods[5].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
368 matmul_methods[5].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
369 matmul_methods[5].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
370 matmul_methods[5].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
371 matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_offset =
372 kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
373 matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
374 matmul_methods[5].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme;
375 matmul_methods[5].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
376 matmul_methods[5].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
377 matmul_methods[5].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
378 matmul_methods[5].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
379
380 matmul_methods[6].name = "matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla";
381 matmul_methods[6].m0 = 1;
382 matmul_methods[6].n0 = 4 * get_sve_vector_length<float>();
383 matmul_methods[6].dst_format = DataFormat(DataType::FP32);
384 matmul_methods[6].lhs_format = DataFormat(DataType::FP32);
385 matmul_methods[6].packed_lhs_format = DataFormat(DataType::UNKNOWN);
386 matmul_methods[6].rhs_format = DataFormat(DataType::FP32);
387 matmul_methods[6].packed_rhs_format = DataFormat(
388 DataType::FP32, 4 * get_sve_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32,
389 DataType::UNKNOWN, 4 * get_sve_vector_length<float>(), 1);
390 matmul_methods[6].bias_format = DataFormat(DataType::FP32);
391 36 matmul_methods[6].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
392
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
393 };
394 36 matmul_methods[6].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
395
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
396 };
397 36 matmul_methods[6].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
398 KAI_UNUSED(null_bias_mode);
399
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
400 };
401 matmul_methods[6].fn_is_supported = cpu_has_sve;
402 matmul_methods[6].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
403 matmul_methods[6].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
404 matmul_methods[6].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
405 matmul_methods[6].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
406 matmul_methods[6].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve;
407 matmul_methods[6].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
408 matmul_methods[6].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
409 matmul_methods[6].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve;
410 matmul_methods[6].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve;
411 matmul_methods[6].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve;
412 matmul_methods[6].fn_get_main_packed_rhs_offset =
413 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
414 matmul_methods[6].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve;
415 matmul_methods[6].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve;
416 matmul_methods[6].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
417 matmul_methods[6].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
418 matmul_methods[6].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla;
419
420 return matmul_methods;
421 }
422
423 12 static const auto& get_vecmul_methods() {
424 // List of supported vector by matrix multiplication methods
425
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
12 static std::array<MatMulMethod, 5> vecmul_methods{};
426
427 vecmul_methods[0].name = "matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot";
428 vecmul_methods[0].m0 = 1;
429 vecmul_methods[0].n0 = 16 * get_sme_vector_length<float>();
430 vecmul_methods[0].dst_format = DataFormat(DataType::FP16);
431 vecmul_methods[0].lhs_format = DataFormat(DataType::FP16);
432 vecmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN);
433 vecmul_methods[0].rhs_format = DataFormat(DataType::FP16);
434 vecmul_methods[0].packed_rhs_format = DataFormat(
435 DataType::FP16, // Output type
436 2 * get_sme_vector_length<float>(), 2, // Block size
437 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
438 DataType::FP16, // Bias format
439 DataType::UNKNOWN, // Scaling type
440 2 * get_sme_vector_length<float>(), 2); // Sub-block
441 vecmul_methods[0].bias_format = DataFormat(DataType::FP16);
442 39 vecmul_methods[0].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
443
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
444 };
445 39 vecmul_methods[0].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
446
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
447 };
448 39 vecmul_methods[0].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
449 KAI_UNUSED(null_bias_mode);
450
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
451 };
452 vecmul_methods[0].fn_is_supported = cpu_has_sme2;
453 vecmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
454 vecmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
455 vecmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
456 vecmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
457 vecmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
458 vecmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
459 vecmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme;
460 vecmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
461 vecmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
462 vecmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
463 vecmul_methods[0].fn_get_main_packed_rhs_offset =
464 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
465 vecmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
466 vecmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
467 vecmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
468 vecmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
469 vecmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot;
470
471 vecmul_methods[1].name = "matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla";
472 vecmul_methods[1].m0 = 1;
473 vecmul_methods[1].n0 = 8 * get_sme_vector_length<float>();
474 vecmul_methods[1].dst_format = DataFormat(DataType::FP16);
475 vecmul_methods[1].lhs_format = DataFormat(DataType::FP16);
476 vecmul_methods[1].packed_lhs_format = DataFormat(DataType::UNKNOWN);
477 vecmul_methods[1].rhs_format = DataFormat(DataType::FP16);
478 vecmul_methods[1].packed_rhs_format = DataFormat(
479 DataType::FP16, // Output type
480 2 * get_sme_vector_length<float>(), 2, // Block size
481 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
482 DataType::FP16, // Bias format
483 DataType::UNKNOWN, // Scaling type
484 2 * get_sme_vector_length<float>(), 2); // Sub-block
485 vecmul_methods[1].bias_format = DataFormat(DataType::FP16);
486 39 vecmul_methods[1].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
487
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
488 };
489 39 vecmul_methods[1].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
490
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
491 };
492 39 vecmul_methods[1].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
493 KAI_UNUSED(null_bias_mode);
494
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
495 };
496 vecmul_methods[1].fn_is_supported = cpu_has_sme;
497 vecmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
498 vecmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
499 vecmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
500 vecmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
501 vecmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
502 vecmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
503 vecmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme;
504 vecmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
505 vecmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
506 vecmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
507 vecmul_methods[1].fn_get_main_packed_rhs_offset =
508 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
509 vecmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
510 vecmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme;
511 vecmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
512 vecmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
513 vecmul_methods[1].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla;
514
515 vecmul_methods[2].name = "matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla";
516 vecmul_methods[2].m0 = 1;
517 vecmul_methods[2].n0 = 8 * get_sme_vector_length<float>();
518 vecmul_methods[2].dst_format = DataFormat(DataType::FP32);
519 vecmul_methods[2].lhs_format = DataFormat(DataType::FP32);
520 vecmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN);
521 vecmul_methods[2].rhs_format = DataFormat(DataType::FP32);
522 vecmul_methods[2].packed_rhs_format = DataFormat(
523 DataType::FP32, // Output type
524 2 * get_sme_vector_length<float>(), 1, // Block size
525 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
526 DataType::FP32, // Bias format
527 DataType::UNKNOWN, // Scaling type
528 2 * get_sme_vector_length<float>(), 1); // Sub-block
529 vecmul_methods[2].bias_format = DataFormat(DataType::FP32);
530 39 vecmul_methods[2].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
531
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
532 };
533 39 vecmul_methods[2].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
534
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
535 };
536 39 vecmul_methods[2].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
537 KAI_UNUSED(null_bias_mode);
538
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
539 };
540 vecmul_methods[2].fn_is_supported = cpu_has_sme;
541 vecmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
542 vecmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
543 vecmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
544 vecmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
545 vecmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
546 vecmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
547 vecmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
548 vecmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
549 vecmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
550 vecmul_methods[2].fn_get_pack_rhs_packed_rhs_offset =
551 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
552 vecmul_methods[2].fn_get_main_packed_rhs_offset =
553 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
554 vecmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
555 vecmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
556 vecmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
557 vecmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
558 vecmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla;
559
560 vecmul_methods[3].name = "matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla";
561 vecmul_methods[3].m0 = 1;
562 vecmul_methods[3].n0 = 16 * get_sme_vector_length<float>();
563 vecmul_methods[3].dst_format = DataFormat(DataType::FP32);
564 vecmul_methods[3].lhs_format = DataFormat(DataType::FP32);
565 vecmul_methods[3].packed_lhs_format = DataFormat(DataType::UNKNOWN);
566 vecmul_methods[3].rhs_format = DataFormat(DataType::FP32);
567 vecmul_methods[3].packed_rhs_format = DataFormat(
568 DataType::FP32, // Output type
569 2 * get_sme_vector_length<float>(), 1, // Block size
570 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
571 DataType::FP32, // Bias format
572 DataType::UNKNOWN, // Scaling type
573 2 * get_sme_vector_length<float>(), 1); // Sub-block
574 vecmul_methods[3].bias_format = DataFormat(DataType::FP32);
575 39 vecmul_methods[3].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
576
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
577 };
578 39 vecmul_methods[3].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
579
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
580 };
581 39 vecmul_methods[3].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
582 KAI_UNUSED(null_bias_mode);
583
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
584 };
585 vecmul_methods[3].fn_is_supported = cpu_has_sme2;
586 vecmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
587 vecmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
588 vecmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
589 vecmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
590 vecmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
591 vecmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
592 vecmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
593 vecmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
594 vecmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
595 vecmul_methods[3].fn_get_pack_rhs_packed_rhs_offset =
596 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
597 vecmul_methods[3].fn_get_main_packed_rhs_offset =
598 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
599 vecmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
600 vecmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme;
601 vecmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
602 vecmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
603 vecmul_methods[3].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla;
604
605 vecmul_methods[4].name = "matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla";
606 vecmul_methods[4].m0 = 1;
607 vecmul_methods[4].n0 = 16 * get_sme_vector_length<float>();
608 vecmul_methods[4].dst_format = DataFormat(DataType::FP32);
609 vecmul_methods[4].lhs_format = DataFormat(DataType::FP32);
610 vecmul_methods[4].packed_lhs_format = DataFormat(DataType::UNKNOWN);
611 vecmul_methods[4].rhs_format = DataFormat(DataType::FP32);
612 vecmul_methods[4].packed_rhs_format = DataFormat(
613 DataType::FP32, // Output type
614 16 * get_sme_vector_length<float>(), 1, // Block size
615 DataFormat::PackFormat::BIAS_PER_ROW, // Data layout
616 DataType::FP32, // Bias format
617 DataType::UNKNOWN, // Scaling type
618 16 * get_sme_vector_length<float>(), 1); // Sub-block
619 vecmul_methods[4].bias_format = DataFormat(DataType::FP32);
620 39 vecmul_methods[4].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
621
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
622 };
623 39 vecmul_methods[4].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
624
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
625 };
626 39 vecmul_methods[4].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
627 KAI_UNUSED(null_bias_mode);
628
1/2
✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
39 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
629 };
630 vecmul_methods[4].fn_is_supported = cpu_has_sme2;
631 vecmul_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
632 vecmul_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
633 vecmul_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
634 vecmul_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
635 vecmul_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
636 vecmul_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
637 vecmul_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
638 vecmul_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
639 vecmul_methods[4].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
640 vecmul_methods[4].fn_get_pack_rhs_packed_rhs_offset =
641 kai_get_rhs_packed_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
642 vecmul_methods[4].fn_get_main_packed_rhs_offset =
643 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
644 vecmul_methods[4].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
645 vecmul_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme;
646 vecmul_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
647 vecmul_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
648 vecmul_methods[4].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla;
649
650 return vecmul_methods;
651 }
652
653 12 static const auto& get_nullbias_matmul_methods() {
654 // List of supported vector by matrix multiplication methods
655
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
12 static std::array<MatMulMethod, 4> nullbias_matmul_methods{};
656
657 nullbias_matmul_methods[0].name = "matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla";
658 nullbias_matmul_methods[0].m0 = 6;
659 nullbias_matmul_methods[0].n0 = 16;
660 nullbias_matmul_methods[0].dst_format = DataFormat(DataType::FP32);
661 nullbias_matmul_methods[0].lhs_format = DataFormat(DataType::FP32);
662 nullbias_matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN);
663 nullbias_matmul_methods[0].rhs_format = DataFormat(DataType::FP32);
664 nullbias_matmul_methods[0].packed_rhs_format = DataFormat(
665 DataType::FP32, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 16, 1);
666 nullbias_matmul_methods[0].bias_format = DataFormat(DataType::FP32);
667 72 nullbias_matmul_methods[0].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
668
1/2
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
72 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
669 };
670 72 nullbias_matmul_methods[0].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
671
1/2
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
72 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
672 };
673 72 nullbias_matmul_methods[0].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
674
2/2
✓ Branch 0 taken 36 times.
✓ Branch 1 taken 36 times.
72 if (null_bias_mode) {
675
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return ConstantGenerator<float>(0.0)(rows, cols);
676 } else {
677
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
678 }
679 72 };
680 nullbias_matmul_methods[0].fn_is_supported = cpu_has_advsimd;
681 nullbias_matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
682 nullbias_matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
683 nullbias_matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
684 nullbias_matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
685 nullbias_matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
686 nullbias_matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
687 nullbias_matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
688 nullbias_matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
689 nullbias_matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
690 nullbias_matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset =
691 kai_get_rhs_packed_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
692 nullbias_matmul_methods[0].fn_get_main_packed_rhs_offset =
693 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
694 nullbias_matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
695 nullbias_matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
696 nullbias_matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
697 nullbias_matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
698 nullbias_matmul_methods[0].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla;
699
700 nullbias_matmul_methods[1].name = "matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55";
701 nullbias_matmul_methods[1].m0 = 6;
702 nullbias_matmul_methods[1].n0 = 16;
703 nullbias_matmul_methods[1].dst_format = DataFormat(DataType::FP32);
704 nullbias_matmul_methods[1].lhs_format = DataFormat(DataType::FP32);
705 nullbias_matmul_methods[1].packed_lhs_format = DataFormat(DataType::UNKNOWN);
706 nullbias_matmul_methods[1].rhs_format = DataFormat(DataType::FP32);
707 nullbias_matmul_methods[1].packed_rhs_format = DataFormat(
708 DataType::FP32, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 16, 1);
709 nullbias_matmul_methods[1].bias_format = DataFormat(DataType::FP32);
710 72 nullbias_matmul_methods[1].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
711
1/2
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
72 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
712 };
713 72 nullbias_matmul_methods[1].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
714
1/2
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
72 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
715 };
716 72 nullbias_matmul_methods[1].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
717
2/2
✓ Branch 0 taken 36 times.
✓ Branch 1 taken 36 times.
72 if (null_bias_mode) {
718
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return ConstantGenerator<float>(0.0)(rows, cols);
719 } else {
720
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols);
721 }
722 72 };
723 nullbias_matmul_methods[1].fn_is_supported = cpu_has_advsimd;
724 nullbias_matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
725 nullbias_matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
726 nullbias_matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
727 nullbias_matmul_methods[1].fn_get_main_m_step =
728 kai_get_m_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
729 nullbias_matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
730 nullbias_matmul_methods[1].fn_get_main_n_step =
731 kai_get_n_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
732 nullbias_matmul_methods[1].fn_get_lhs_offset =
733 kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
734 nullbias_matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
735 nullbias_matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
736 nullbias_matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset =
737 kai_get_rhs_packed_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
738 nullbias_matmul_methods[1].fn_get_main_packed_rhs_offset =
739 kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
740 nullbias_matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
741 nullbias_matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon;
742 nullbias_matmul_methods[1].fn_get_dst_offset =
743 kai_get_dst_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
744 nullbias_matmul_methods[1].fn_get_dst_size =
745 kai_get_dst_size_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
746 nullbias_matmul_methods[1].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55;
747
748 nullbias_matmul_methods[2].name = "matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla";
749 nullbias_matmul_methods[2].m0 = 6;
750 nullbias_matmul_methods[2].n0 = 32;
751 nullbias_matmul_methods[2].dst_format = DataFormat(DataType::FP16);
752 nullbias_matmul_methods[2].lhs_format = DataFormat(DataType::FP16);
753 nullbias_matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN);
754 nullbias_matmul_methods[2].rhs_format = DataFormat(DataType::FP16);
755 nullbias_matmul_methods[2].packed_rhs_format = DataFormat(
756 DataType::FP16, 32, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 32, 1);
757 nullbias_matmul_methods[2].bias_format = DataFormat(DataType::FP16);
758 72 nullbias_matmul_methods[2].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
759
1/2
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
72 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
760 };
761 72 nullbias_matmul_methods[2].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
762
1/2
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
72 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
763 };
764 72 nullbias_matmul_methods[2].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
765
2/2
✓ Branch 0 taken 36 times.
✓ Branch 1 taken 36 times.
72 if (null_bias_mode) {
766
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return ConstantGenerator<Float16>(0.0)(rows, cols);
767 } else {
768
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
769 }
770 72 };
771 nullbias_matmul_methods[2].fn_is_supported = cpu_has_fp16;
772 nullbias_matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
773 nullbias_matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
774 nullbias_matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
775 nullbias_matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
776 nullbias_matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
777 nullbias_matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
778 nullbias_matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
779 nullbias_matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
780 nullbias_matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
781 nullbias_matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset =
782 kai_get_rhs_packed_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
783 nullbias_matmul_methods[2].fn_get_main_packed_rhs_offset =
784 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
785 nullbias_matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
786 nullbias_matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
787 nullbias_matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
788 nullbias_matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
789 nullbias_matmul_methods[2].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla;
790
791 nullbias_matmul_methods[3].name = "matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55";
792 nullbias_matmul_methods[3].m0 = 6;
793 nullbias_matmul_methods[3].n0 = 32;
794 nullbias_matmul_methods[3].dst_format = DataFormat(DataType::FP16);
795 nullbias_matmul_methods[3].lhs_format = DataFormat(DataType::FP16);
796 nullbias_matmul_methods[3].packed_lhs_format = DataFormat(DataType::UNKNOWN);
797 nullbias_matmul_methods[3].rhs_format = DataFormat(DataType::FP16);
798 nullbias_matmul_methods[3].packed_rhs_format = DataFormat(
799 DataType::FP16, 32, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 32, 1);
800 nullbias_matmul_methods[3].bias_format = DataFormat(DataType::FP16);
801 72 nullbias_matmul_methods[3].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) {
802
1/2
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
72 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
803 };
804 72 nullbias_matmul_methods[3].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) {
805
1/2
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
72 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
806 };
807 72 nullbias_matmul_methods[3].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) {
808
2/2
✓ Branch 0 taken 36 times.
✓ Branch 1 taken 36 times.
72 if (null_bias_mode) {
809
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return ConstantGenerator<Float16>(0.0)(rows, cols);
810 } else {
811
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols);
812 }
813 72 };
814 nullbias_matmul_methods[3].fn_is_supported = cpu_has_fp16;
815 nullbias_matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
816 nullbias_matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
817 nullbias_matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
818 nullbias_matmul_methods[3].fn_get_main_m_step =
819 kai_get_m_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
820 nullbias_matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
821 nullbias_matmul_methods[3].fn_get_main_n_step =
822 kai_get_n_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
823 nullbias_matmul_methods[3].fn_get_lhs_offset =
824 kai_get_lhs_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
825 nullbias_matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
826 nullbias_matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
827 nullbias_matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset =
828 kai_get_rhs_packed_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
829 nullbias_matmul_methods[3].fn_get_main_packed_rhs_offset =
830 kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
831 nullbias_matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
832 nullbias_matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon;
833 nullbias_matmul_methods[3].fn_get_dst_offset =
834 kai_get_dst_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
835 nullbias_matmul_methods[3].fn_get_dst_size =
836 kai_get_dst_size_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
837 nullbias_matmul_methods[3].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55;
838
839 return nullbias_matmul_methods;
840 }
841
842 using MatMulClampTestParams = std::tuple<MatMulMethod, MatMulShape, MatrixPortion, BiasMode, float>;
843
844 /// Matrix multiplication test fixture.
845 class MatMulTest : public testing::TestWithParam<MatMulClampTestParams> {
846 private:
847 /// Unique ID: m, n, k, method_id.
848 using TestDataId = std::tuple<size_t, size_t, size_t, std::string_view, BiasMode, float>;
849
850 protected:
851 /// Cached test data that is shared between multiple test case.
852 663 struct TestData {
853 1326 Buffer lhs{}; ///< LHS operand.
854 1326 Buffer ref_packed_lhs{}; ///< Reference packed LHS.
855 1326 Buffer rhs{}; ///< RHS operand.
856 1326 Buffer rhs_scales{}; ///< RHS per-row quantization scales.
857 1326 Buffer bias{}; ///< Bias.
858 1326 Buffer rhs_t{}; ///< Transposed RHS matrix.
859 1326 Buffer ref_packed_rhs{}; ///< Reference packed RHS.
860 1326 Buffer ref_dst{}; ///< Reference output.
861 663 float clamp_min{}; ///< Minimum output value.
862 663 float clamp_max{}; ///< Maximum output value.
863 };
864
865 /// Gets the test data for the current test case.
866 4800 static const TestData& test_data() {
867 69417 const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam();
868 28800 const TestDataId data_id{info.m, info.n, info.k, method.name, bias_mode, clamp_keep_ratio};
869
870 // Creates a unique seed for the test data.
871
12/24
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4800 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4800 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4800 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4800 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4800 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4800 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4800 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4800 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 4800 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4800 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1296 times.
✗ Branch 23 not taken.
19200 const auto key = std::string(method.name) + "_" + std::to_string(info.m) + "x" + std::to_string(info.n) + "x" +
872
8/14
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4800 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4800 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3504 times.
✓ Branch 7 taken 1296 times.
✓ Branch 8 taken 4800 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4800 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4800 times.
✗ Branch 13 not taken.
19200 std::to_string(info.k) + "_" + (bias_mode == BiasMode::INTERNAL ? "internal" : "provided") + "_" +
873
2/4
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4800 times.
✗ Branch 3 not taken.
9600 std::to_string(clamp_keep_ratio);
874
1/2
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
4800 auto& feed = seed_stream(key);
875
876 // If the test data is already available, returns it.
877
1/2
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
4800 const auto data_it = _data.find(data_id);
878
879
4/4
✓ Branch 0 taken 4584 times.
✓ Branch 1 taken 216 times.
✓ Branch 2 taken 3057 times.
✓ Branch 3 taken 447 times.
4800 if (data_it != _data.end()) {
880
1/2
✓ Branch 0 taken 3057 times.
✗ Branch 1 not taken.
4137 return data_it->second;
881 }
882
883 // Generates the test data.
884
2/4
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
1326 const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN;
885
2/4
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
1326 const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN;
886
2/4
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
1326 const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN;
887 1326 const bool null_bias_mode = bias_mode == BiasMode::INTERNAL;
888
889 1326 const auto lhs_h = info.m;
890 1326 const auto lhs_w = info.k;
891
2/4
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
1326 auto lhs = method.fn_generate_lhs(lhs_h, lhs_w, feed);
892
893 663 Buffer ref_packed_lhs;
894
2/2
✓ Branch 0 taken 591 times.
✓ Branch 1 taken 72 times.
663 if (has_lhs_pack) {
895 72 ref_packed_lhs =
896
3/6
✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 72 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 72 times.
✗ Branch 5 not taken.
144 pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w);
897 72 }
898
899 1326 const auto rhs_h = info.k;
900 1326 const auto rhs_w = info.n;
901
2/4
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
1326 auto rhs = method.fn_generate_rhs(rhs_h, rhs_w, feed);
902
903 KAI_ASSUME_ALWAYS(method.rhs_format.is_raw());
904
3/6
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
663 auto rhs_t = transpose(rhs.data(), method.rhs_format.data_type(), rhs_h, rhs_w);
905
906 663 Buffer rhs_scales;
907
3/8
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 663 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
663 if (data_type_is_quantized(method.rhs_format.data_type()) &&
908 method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) {
909 rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), feed());
910 }
911
912 663 const auto bias_h = 1;
913 1326 const auto bias_w = info.n;
914 663 Buffer bias;
915
916
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 663 times.
663 if (has_bias) {
917
2/4
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
1326 bias = method.fn_generate_bias(bias_h, bias_w, feed, null_bias_mode);
918 663 }
919
920 663 Buffer packed_rhs;
921
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 663 times.
663 if (has_rhs_pack) {
922
1/2
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
1326 packed_rhs = matmul_pack_rhs(
923
3/6
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
663 rhs.data(), rhs_scales.data(), bias.data(), method.rhs_format, method.packed_rhs_format, info.n, info.k,
924 true);
925 663 }
926
927 KAI_ASSUME_ALWAYS(method.lhs_format.is_raw());
928 KAI_ASSUME_ALWAYS(method.rhs_format.is_raw());
929 KAI_ASSUME_ALWAYS(method.dst_format.is_raw());
930
1/2
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
1326 auto ref_dst = matmul(
931
2/4
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
663 lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), //
932
3/6
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
663 rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), //
933
2/4
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
663 bias.data(), nullptr, nullptr, method.bias_format.data_type(), //
934
1/2
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
663 method.dst_format.data_type(), //
935 1989 info.m, info.n, info.k, false, false);
936
937 3978 const auto [clamp_min, clamp_max] =
938
5/10
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 663 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 663 times.
✗ Branch 9 not taken.
663 find_clamp_range(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_keep_ratio);
939
7/14
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 663 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 663 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 663 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 663 times.
✗ Branch 13 not taken.
1326 ref_dst = clamp(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_min, clamp_max);
940
941
9/18
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 663 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 663 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 663 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 663 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 663 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 663 times.
✗ Branch 17 not taken.
5967 auto& data = _data[data_id] = {};
942 663 data.lhs = std::move(lhs);
943 663 data.ref_packed_lhs = std::move(ref_packed_lhs);
944 663 data.rhs = std::move(rhs);
945 663 data.rhs_scales = std::move(rhs_scales);
946 663 data.bias = std::move(bias);
947 663 data.rhs_t = std::move(rhs_t);
948 663 data.ref_packed_rhs = std::move(packed_rhs);
949 663 data.ref_dst = std::move(ref_dst);
950 663 data.clamp_min = clamp_min;
951 663 data.clamp_max = clamp_max;
952
953 663 return data;
954 4800 }
955
956 private:
957 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
958 static std::map<TestDataId, TestData> _data;
959 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
960 };
961
962 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
963 3 std::map<MatMulTest::TestDataId, MatMulTest::TestData> MatMulTest::_data;
964 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
965
966 /// Tests the LHS packing micro-kernel.
967
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
8064 TEST_P(MatMulTest, PackedLhs) {
968 6474 const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam();
969
970
3/4
✓ Branch 0 taken 3234 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✓ Branch 3 taken 1050 times.
3234 if (method.fn_is_supported && !method.fn_is_supported()) {
971
3/6
✓ Branch 0 taken 1050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1050 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1050 times.
✗ Branch 5 not taken.
1050 GTEST_SKIP() << "Unsupported CPU feature";
972 }
973
974
2/2
✓ Branch 0 taken 216 times.
✓ Branch 1 taken 1968 times.
2184 if (!method.is_pack_lhs_needed()) {
975
3/6
✓ Branch 0 taken 1968 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1968 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1968 times.
✗ Branch 5 not taken.
1968 GTEST_SKIP() << "Test not valid w/o LHS pack";
976 }
977
978 216 const auto& data = test_data();
979 432 const auto lhs_h = info.m;
980 432 const auto lhs_w = info.k;
981
982 432 const auto rect = portion.compute_portion(
983 432 lhs_h, lhs_w, method.packed_lhs_format.scheduler_block_height(lhs_h),
984 216 lhs_w); // LHS packing micro-kernel API doesn't support scheduling over K dimension.
985
986
2/4
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 216 times.
216 if (rect.height() == 0 || rect.width() == 0) {
987 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
988 }
989
990 432 const auto mr = method.fn_get_mr();
991 432 const auto kr = method.fn_get_kr();
992 432 const auto sr = method.fn_get_sr();
993 432 const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(lhs_w);
994
995 864 const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k, mr, kr, sr);
996 432 const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(lhs_h, lhs_w);
997
3/14
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
216 ASSERT_EQ(packed_lhs_size, ref_packed_lhs_size);
998
999 432 const auto lhs_offset = method.fn_get_lhs_offset(rect.start_row(), ref_lhs_row_stride);
1000 432 const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), lhs_w);
1001
3/14
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
216 ASSERT_EQ(lhs_offset, ref_lhs_offset);
1002
1003 648 const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect.start_row(), info.k);
1004 432 const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w);
1005
3/14
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
216 ASSERT_EQ(packed_lhs_offset, ref_packed_lhs_offset);
1006
1007 216 Buffer packed_lhs(packed_lhs_size, 0);
1008
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 abi_check(
1009
3/6
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
216 method.fn_pack_lhs, rect.height(), rect.width(), mr, kr, sr, 0, data.lhs.data() + lhs_offset,
1010
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 ref_lhs_row_stride, packed_lhs.data() + packed_lhs_offset);
1011
1012
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 DefaultMismatchHandler handler(0, 0.0001, 0, 0.001);
1013 432 const auto success =
1014
3/6
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
216 compare(packed_lhs.data(), data.ref_packed_lhs.data(), method.packed_lhs_format, lhs_h, lhs_w, rect, handler);
1015
4/16
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 216 times.
216 ASSERT_TRUE(success);
1016 3234 }
1017
1018 /// Tests the RHS packing micro-kernel.
1019
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
8064 TEST_P(MatMulTest, PackedRhs) {
1020 31626 const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam();
1021
1022
3/4
✓ Branch 0 taken 3234 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✓ Branch 3 taken 1050 times.
3234 if (method.fn_is_supported && !method.fn_is_supported()) {
1023
3/6
✓ Branch 0 taken 1050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1050 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1050 times.
✗ Branch 5 not taken.
1050 GTEST_SKIP() << "Unsupported CPU feature";
1024 }
1025
1026
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
2184 if (!method.is_pack_rhs_needed()) {
1027 GTEST_SKIP() << "Test not valid w/o RHS pack";
1028 }
1029
1030 2184 const auto& data = test_data();
1031 4368 const auto rhs_full_width = info.n;
1032 4368 const auto rhs_full_height = info.k;
1033
1034 4368 const auto block_height = method.packed_rhs_format.scheduler_block_height(rhs_full_width);
1035 4368 const auto block_width = method.packed_rhs_format.scheduler_block_width(rhs_full_height);
1036
1037 4368 const auto null_bias_mode = bias_mode == BiasMode::INTERNAL;
1038
1039 4368 const Rect rect = portion.compute_portion(rhs_full_width, rhs_full_height, block_height, block_width);
1040
1041
2/4
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2184 times.
2184 if (rect.height() == 0 || rect.width() == 0) {
1042 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
1043 }
1044
1045 2184 const auto rhs_start_row = rect.start_row();
1046 2184 const auto rhs_start_col = rect.start_col();
1047 2184 const auto width = rect.width();
1048 2184 const auto height = rect.height();
1049 4368 const auto rhs_row_stride = method.rhs_format.default_row_stride(rhs_full_width);
1050
1051 /** Ensure that all relevant parameters are sane **/
1052 4368 const auto n_step = method.fn_get_pack_rhs_n_step();
1053 2184 const auto ref_n_step = block_height;
1054
3/14
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
2184 ASSERT_EQ(n_step, ref_n_step);
1055
1056 4368 const auto rhs_offset = method.fn_get_rhs_offset(rhs_start_row);
1057 4368 const auto ref_rhs_offset =
1058 2184 method.rhs_format.default_offset_in_bytes(rhs_start_col, rhs_start_row, rhs_full_height);
1059
3/14
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
2184 ASSERT_EQ(rhs_offset, ref_rhs_offset);
1060
1061 4368 const auto packed_rhs_size = method.fn_get_packed_rhs_size(rhs_full_width, rhs_full_height);
1062 4368 const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(rhs_full_width, rhs_full_height);
1063
3/14
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
2184 ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size);
1064
1065 4368 const auto packed_rhs_offset = method.fn_get_pack_rhs_packed_rhs_offset(rhs_start_row, rhs_full_height);
1066 4368 const auto ref_packed_rhs_offset =
1067 2184 method.packed_rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_full_height);
1068
3/14
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
2184 ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset);
1069
1070 4368 const auto scale_type = method.packed_rhs_format.scale_data_type();
1071 2184 const auto ref_rhs_scales_offset = rhs_start_row * data_type_size_in_bits(scale_type) / 8;
1072
1073 4368 const auto bias_offset = method.fn_get_bias_offset(rhs_start_row);
1074 4368 const auto ref_bias_offset =
1075
2/2
✓ Branch 0 taken 432 times.
✓ Branch 1 taken 1752 times.
2184 !null_bias_mode ? method.bias_format.default_offset_in_bytes(0, rhs_start_row, rhs_full_height) : bias_offset;
1076
3/14
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
2184 ASSERT_EQ(bias_offset, ref_bias_offset);
1077
1078 /** Perform RHS packing, and compare with reference result **/
1079 2184 Buffer packed_rhs(packed_rhs_size, 0);
1080
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
2184 abi_check(
1081
2/4
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
4368 &MatMulMethod::pack_rhs, method, height, width, data.rhs.data() + rhs_offset, rhs_row_stride,
1082
3/4
✓ Branch 0 taken 432 times.
✓ Branch 1 taken 1752 times.
✓ Branch 2 taken 1752 times.
✗ Branch 3 not taken.
2184 !null_bias_mode ? data.bias.data() + bias_offset : nullptr,
1083
2/6
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2184 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
2184 data.rhs_scales.data() != nullptr ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr,
1084
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
2184 packed_rhs.data() + packed_rhs_offset);
1085
1086
2/4
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
4368 const bool exact = method.packed_rhs_format.pack_format() != DataFormat::PackFormat::QUANTIZE_PER_ROW;
1087
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
2184 DefaultMismatchHandler handler(0, exact ? 0 : 0.0001, 0, exact ? 0 : 0.001);
1088
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
4368 const auto success = compare(
1089
2/4
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
2184 packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, rhs_full_width, rhs_full_height, rect,
1090 handler);
1091
4/16
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2184 times.
2184 ASSERT_TRUE(success);
1092 3234 }
1093
1094 /// Tests the transposed RHS packing micro-kernel.
1095
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
8064 TEST_P(MatMulTest, PackedTransposedRhs) {
1096 9282 const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam();
1097
1098
3/4
✓ Branch 0 taken 3234 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✓ Branch 3 taken 1050 times.
3234 if (method.fn_is_supported && !method.fn_is_supported()) {
1099
3/6
✓ Branch 0 taken 1050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1050 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1050 times.
✗ Branch 5 not taken.
1050 GTEST_SKIP() << "Unsupported CPU feature";
1100 }
1101
1102
2/2
✓ Branch 0 taken 216 times.
✓ Branch 1 taken 1968 times.
2184 if (!method.is_pack_rhs_nxk_needed()) {
1103
3/6
✓ Branch 0 taken 1968 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1968 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1968 times.
✗ Branch 5 not taken.
1968 GTEST_SKIP() << "Test not valid w/o pre-processing of transposed RHS matrix";
1104 }
1105
1106 216 const auto& data = test_data();
1107 432 const bool null_bias_mode = bias_mode == BiasMode::INTERNAL;
1108
1109 432 const auto n_step = method.fn_pack_rhs_nxk_get_n_step();
1110 648 const auto ref_n_step = method.packed_rhs_format.scheduler_block_height(info.n);
1111
3/14
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
216 ASSERT_EQ(n_step, ref_n_step);
1112
1113 432 const auto rect = portion.compute_portion(
1114 864 info.n, info.k, method.packed_rhs_format.scheduler_block_height(info.n),
1115 432 method.packed_rhs_format.scheduler_block_width(info.k));
1116
1117
2/4
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 216 times.
216 if (rect.height() == 0 || rect.width() == 0) {
1118 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
1119 }
1120
1121 648 const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(info.k);
1122
1123 432 const auto rhs_offset = method.fn_pack_rhs_nxk_get_rhs_offset(rect.start_row(), ref_rhs_row_stride);
1124 648 const auto ref_rhs_offset = method.rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k);
1125
3/14
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
216 ASSERT_EQ(rhs_offset, ref_rhs_offset);
1126
1127 864 const auto packed_rhs_size = method.fn_pack_rhs_nxk_get_packed_rhs_size(info.n, info.k);
1128 864 const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(info.n, info.k);
1129
3/14
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
216 ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size);
1130
1131 648 const auto packed_rhs_offset = method.fn_pack_rhs_nxk_get_packed_rhs_offset(rect.start_row(), info.k);
1132 432 const auto ref_packed_rhs_offset =
1133 432 method.packed_rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k);
1134
3/14
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
216 ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset);
1135
1136 432 const auto ref_rhs_scales_offset =
1137 432 rect.start_row() * data_type_size_in_bits(method.packed_rhs_format.scale_data_type()) / 8;
1138
1139 432 const auto bias_offset = method.fn_get_bias_offset(rect.start_row());
1140 432 const auto ref_bias_offset =
1141
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 216 times.
216 !null_bias_mode ? method.bias_format.default_offset_in_bytes(0, rect.start_row(), info.n) : bias_offset;
1142
3/14
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
216 ASSERT_EQ(bias_offset, ref_bias_offset);
1143
1144 216 Buffer packed_rhs(packed_rhs_size, 0);
1145
1146
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 abi_check(
1147
4/8
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
432 &MatMulMethod::pack_rhs_nxk, method, rect.height(), rect.width(), data.rhs_t.data() + rhs_offset,
1148
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 216 times.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
216 ref_rhs_row_stride, null_bias_mode ? nullptr : data.bias.data() + bias_offset,
1149
2/6
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 216 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
216 data.rhs_scales.data() != nullptr ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr,
1150
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 packed_rhs.data() + packed_rhs_offset);
1151
1152
2/4
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
432 const auto exact = method.packed_rhs_format.pack_format() != DataFormat::PackFormat::QUANTIZE_PER_ROW;
1153
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 DefaultMismatchHandler handler(0, exact ? 0 : 0.0001, 0, exact ? 0 : 0.001);
1154 432 const auto success =
1155
5/10
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 216 times.
✗ Branch 9 not taken.
216 compare(packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, info.n, info.k, rect, handler);
1156
4/16
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 216 times.
216 ASSERT_TRUE(success);
1157 3234 }
1158
1159 /// Tests the output.
1160
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
8064 TEST_P(MatMulTest, Output) {
1161 63066 const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam();
1162
1163
3/4
✓ Branch 0 taken 3234 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✓ Branch 3 taken 1050 times.
3234 if (method.fn_is_supported && !method.fn_is_supported()) {
1164
3/6
✓ Branch 0 taken 1050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1050 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1050 times.
✗ Branch 5 not taken.
1050 GTEST_SKIP() << "Unsupported CPU feature";
1165 }
1166
1167
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
2184 if (!method.has_main_kernel()) {
1168 GTEST_SKIP() << "No main kernel available";
1169 }
1170
1171 2184 const auto& data = test_data();
1172 4368 const auto m_step = method.fn_get_main_m_step();
1173
4/16
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2184 times.
4368 ASSERT_EQ(m_step, method.m0);
1174
1175 4368 const auto n_step = method.fn_get_main_n_step();
1176
4/16
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2184 times.
4368 ASSERT_EQ(n_step, method.n0);
1177
1178 10920 const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0);
1179
1180
2/4
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2184 times.
2184 if (rect.height() == 0 || rect.width() == 0) {
1181 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
1182 }
1183
1184 4368 const auto lhs_w = info.k;
1185 4368 const auto rhs_w = info.n;
1186 4368 const auto bias_w = info.n;
1187 4368 const auto dst_w = info.n;
1188
1189 4368 const bool null_bias_mode = bias_mode == BiasMode::INTERNAL;
1190
1191 2184 const auto lhs_start_row = rect.start_row();
1192 2184 const auto lhs_start_col = 0;
1193 4368 const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w);
1194
1195 2184 const std::byte* lhs_data = nullptr;
1196 2184 uintptr_t lhs_offset = 0;
1197
1198
2/2
✓ Branch 0 taken 216 times.
✓ Branch 1 taken 1968 times.
2184 if (method.is_pack_lhs_needed()) {
1199 216 lhs_data = data.ref_packed_lhs.data();
1200
1201 432 const auto ref_packed_lhs_offset =
1202 432 method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k);
1203 432 lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k);
1204
3/14
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
216 ASSERT_EQ(lhs_offset, ref_packed_lhs_offset);
1205
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 216 times.
216 } else {
1206 1968 lhs_data = data.lhs.data();
1207
1208 1968 lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride);
1209 3936 const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, lhs_w);
1210
3/14
✓ Branch 0 taken 1968 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1968 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1968 times.
1968 ASSERT_EQ(lhs_offset, ref_lhs_offset);
1211 1968 }
1212
1213 4368 const auto rhs_stride = method.rhs_format.default_row_stride(rhs_w);
1214
1215 2184 const std::byte* rhs_data = nullptr;
1216 2184 uintptr_t rhs_offset = 0;
1217
1218
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
2184 if (method.is_pack_rhs_needed()) {
1219 2184 const auto packed_rhs_start_row = rect.start_col();
1220 2184 const auto packed_rhs_start_col = 0;
1221
1222 2184 rhs_data = data.ref_packed_rhs.data();
1223
1224 4368 rhs_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k);
1225 4368 const auto ref_rhs_offset =
1226 4368 method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k);
1227
3/14
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
2184 ASSERT_EQ(rhs_offset, ref_rhs_offset);
1228
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2184 times.
2184 } else {
1229 const auto rhs_start_row = 0;
1230 const auto rhs_start_col = rect.start_col();
1231
1232 rhs_data = data.rhs.data();
1233 rhs_offset = method.rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_w);
1234 }
1235
1236 2184 const std::byte* bias_data = nullptr;
1237
2/2
✓ Branch 0 taken 432 times.
✓ Branch 1 taken 1752 times.
2184 if (!null_bias_mode) {
1238 3504 const auto bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), bias_w);
1239 1752 bias_data = data.bias.data() + bias_offset;
1240 1752 }
1241
1242 4368 const auto dst_stride = method.dst_format.default_row_stride(dst_w);
1243 4368 const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
1244 4368 const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w);
1245
3/14
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
2184 ASSERT_EQ(dst_offset, ref_dst_offset);
1246
1247 8736 const auto dst_size = method.fn_get_dst_size(info.m, info.n);
1248 8736 const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n);
1249
3/14
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
2184 ASSERT_EQ(dst_size, ref_dst_size);
1250
1251 2184 Buffer dst(dst_size, 0);
1252
1253
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
2184 abi_check(
1254
3/6
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
4368 &MatMulMethod::main_kernel, method, rect.height(), rect.width(), info.k, lhs_data + lhs_offset,
1255
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
2184 rhs_data + rhs_offset, bias_data, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, data.clamp_min,
1256 2184 data.clamp_max);
1257
1258
1/2
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
2184 DefaultMismatchHandler handler(0, 0.1, 0, 0.1);
1259
5/10
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2184 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2184 times.
✗ Branch 9 not taken.
2184 const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler);
1260
4/16
✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2184 times.
2184 ASSERT_TRUE(success);
1261 3234 }
1262
1263
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
3 const std::vector<MatrixPortion> MatrixPortions = {
1264 {0, 0, 1, 1},
1265 {0, 0, 0.25, 0.25},
1266 {0.75, 0.75, 1, 1},
1267 };
1268
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
3 const std::vector<MatMulShape> MatMulShapes = {
1269 {1, 16, 16}, //
1270 {20, 1, 20}, //
1271 {6, 16, 32}, //
1272 {12, 32, 17}, //
1273 {13, 33, 23}, //
1274 {87, 93, 56}, //
1275 };
1276
1277
20/64
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✓ Branch 18 taken 8 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✓ Branch 22 taken 1512 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 3024 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
4551 INSTANTIATE_TEST_SUITE_P(
1278 MatMul, MatMulTest,
1279 testing::Combine(
1280 testing::ValuesIn(get_matmul_methods()), //
1281 testing::ValuesIn(MatMulShapes), //
1282 testing::ValuesIn(MatrixPortions), //
1283 testing::Values(BiasMode::PROVIDED), //
1284 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio
1285 testing::PrintToStringParamName());
1286
1287
20/64
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✓ Branch 18 taken 8 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✓ Branch 22 taken 1728 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 3456 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
5199 INSTANTIATE_TEST_SUITE_P(
1288 NullBiasMatMul, MatMulTest,
1289 testing::Combine(
1290 testing::ValuesIn(get_nullbias_matmul_methods()), //
1291 testing::ValuesIn(MatMulShapes), //
1292 testing::ValuesIn(MatrixPortions), //
1293 testing::Values(BiasMode::INTERNAL, BiasMode::PROVIDED), //
1294 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio
1295 testing::PrintToStringParamName());
1296
1297
28/96
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✓ Branch 18 taken 8 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✓ Branch 22 taken 8 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 4 times.
✓ Branch 24 taken 8 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✓ Branch 26 taken 8 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✓ Branch 28 taken 8 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 8 times.
✓ Branch 30 taken 3120 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✓ Branch 32 taken 6240 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
9375 INSTANTIATE_TEST_SUITE_P(
1298 VecMul, MatMulTest,
1299 testing::Combine(
1300 testing::ValuesIn(get_vecmul_methods()),
1301 testing::Values(
1302 MatMulShape{1, 16, 16}, //
1303 MatMulShape{1, 1, 20}, //
1304 MatMulShape{1, 16, 32}, //
1305 MatMulShape{1, 32, 17}, //
1306 MatMulShape{1, 33, 23}, //
1307 MatMulShape{1, 1500, 20}, //
1308 MatMulShape{1, 93, 56}, //
1309 MatMulShape{1, 1, 1}, //
1310 MatMulShape{1, 16, 1}, //
1311 MatMulShape{1, 32, 64}, //
1312 MatMulShape{1, 7, 74}, //
1313 MatMulShape{1, 800, 64}, //
1314 MatMulShape{1, 512, 130} //
1315 ),
1316 testing::Values(
1317 MatrixPortion(0, 0, 1, 1), // Full row.
1318 MatrixPortion(0, 0, 1, 0.5), // First half
1319 MatrixPortion(0, .4, 1, 0.3), // mid row-section.
1320 MatrixPortion(0, 0.75, 1, .25) // right row section
1321 ),
1322 testing::Values(BiasMode::PROVIDED), //
1323 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio
1324 testing::PrintToStringParamName());
1325
1326 } // namespace kai::test
1327