KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.0% 164 / 4 / 173
Functions: 100.0% 18 / 0 / 18
Branches: 40.0% 295 / 12 / 750

test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <functional>
13 #include <limits>
14 #include <map>
15 #include <string_view>
16 #include <tuple>
17 #include <utility>
18
19 #include "kai/kai_common.h"
20 #include "test/common/abi_checker.hpp"
21 #include "test/common/buffer.hpp"
22 #include "test/common/compare.hpp"
23 #include "test/common/cpu_info.hpp"
24 #include "test/common/data_format.hpp"
25 #include "test/common/data_type.hpp"
26 #include "test/common/matmul_test_common.hpp"
27 #include "test/common/matrix_portion.hpp"
28 #include "test/common/printer.hpp"
29 #include "test/common/seed.hpp"
30 #include "test/common/sme.hpp"
31 #include "test/reference/cast.hpp"
32 #include "test/reference/clamp.hpp"
33 #include "test/reference/fill.hpp"
34 #include "test/reference/matmul.hpp"
35 #include "test/reference/pack.hpp"
36
37 // matmul_clamp_f32_bf16p_bf16p
38 #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h"
39 #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h"
40 #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h"
41 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.h"
42 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.h"
43 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.h"
44 #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h"
45
46 // SME files here.
47 #include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
48 #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h"
49 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
50
51 namespace kai::test {
52
53 /// List of supported matrix multiplication methods.
54 namespace {
55
56 3 static const std::array<MatMulMethod, 5>& get_gemm_methods() {
57
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
3 static std::array<MatMulMethod, 5> gemm_methods{};
58 gemm_methods[0].name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa";
59 gemm_methods[0].m0 = 2 * get_sme_vector_length<float>();
60 gemm_methods[0].n0 = 2 * get_sme_vector_length<float>();
61 gemm_methods[0].k0 = 2;
62 gemm_methods[0].dst_format = DataFormat(DataType::FP32);
63 gemm_methods[0].lhs_format = DataFormat(DataType::FP32);
64 gemm_methods[0].packed_lhs_format = DataFormat(
65 DataType::BF16, 2 * get_sme_vector_length<float>(), 2, DataFormat::PackFormat::NONE, DataType::FP32,
66 DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 2);
67 gemm_methods[0].rhs_format = DataFormat(DataType::FP32);
68 gemm_methods[0].packed_rhs_format = DataFormat(
69 DataType::BF16, 2 * get_sme_vector_length<float>(), 2, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32,
70 DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 2);
71 gemm_methods[0].bias_format = DataFormat(DataType::FP32);
72 gemm_methods[0].fn_is_supported = cpu_has_sme2;
73 gemm_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
74 gemm_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
75 gemm_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
76 gemm_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
77 gemm_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
78 gemm_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
79 gemm_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
80 gemm_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme;
81 gemm_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme;
82 gemm_methods[0].fn_get_packed_lhs_offset =
83 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
84 gemm_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme;
85 gemm_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
86 gemm_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
87 gemm_methods[0].fn_get_main_packed_rhs_offset =
88 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
89 gemm_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
90 gemm_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
91 gemm_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
92 gemm_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
93 gemm_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
94
95 gemm_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla";
96 gemm_methods[1].m0 = 8;
97 gemm_methods[1].n0 = 12;
98 gemm_methods[1].k0 = 4;
99 gemm_methods[1].dst_format = DataFormat(DataType::FP32);
100 gemm_methods[1].lhs_format = DataFormat(DataType::FP32);
101 gemm_methods[1].packed_lhs_format =
102 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4);
103 gemm_methods[1].rhs_format = DataFormat(DataType::FP32);
104 gemm_methods[1].packed_rhs_format = DataFormat(
105 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
106 gemm_methods[1].bias_format = DataFormat(DataType::FP32);
107 gemm_methods[1].fn_is_supported = cpu_has_bf16;
108 gemm_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
109 gemm_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
110 gemm_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
111 gemm_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
112 gemm_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
113 gemm_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
114 gemm_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
115 gemm_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon;
116 gemm_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon;
117 gemm_methods[1].fn_get_packed_lhs_offset =
118 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
119 gemm_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon;
120 gemm_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
121 gemm_methods[1].fn_get_packed_rhs_size_generic_block_size =
122 kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
123 gemm_methods[1].fn_get_main_packed_rhs_offset =
124 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
125 gemm_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
126 gemm_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
127 gemm_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
128 gemm_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
129 gemm_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
130
131 gemm_methods[2].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output";
132 gemm_methods[2].m0 = 8;
133 gemm_methods[2].n0 = 12;
134 gemm_methods[2].k0 = 4;
135 gemm_methods[2].dst_format = DataFormat(DataType::FP32);
136 gemm_methods[2].lhs_format = DataFormat(DataType::FP16);
137 gemm_methods[2].packed_lhs_format =
138 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4);
139 gemm_methods[2].rhs_format = DataFormat(DataType::FP16);
140 gemm_methods[2].packed_rhs_format = DataFormat(
141 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
142 gemm_methods[2].bias_format = DataFormat(DataType::FP32);
143 gemm_methods[2].fn_is_supported = cpu_has_bf16;
144 gemm_methods[2].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
145 gemm_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
146 gemm_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
147 gemm_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
148 gemm_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
149 gemm_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
150 gemm_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
151 gemm_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon;
152 gemm_methods[2].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon;
153 gemm_methods[2].fn_get_packed_lhs_offset =
154 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
155 gemm_methods[2].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon;
156 gemm_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
157 gemm_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
158 gemm_methods[2].fn_get_main_packed_rhs_offset =
159 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
160 gemm_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
161 gemm_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
162 gemm_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
163 gemm_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
164 gemm_methods[2].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
165
166 gemm_methods[3].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias";
167 gemm_methods[3].m0 = 8;
168 gemm_methods[3].n0 = 12;
169 gemm_methods[3].k0 = 4;
170 gemm_methods[3].dst_format = DataFormat(DataType::FP32);
171 gemm_methods[3].lhs_format = DataFormat(DataType::FP16);
172 gemm_methods[3].packed_lhs_format =
173 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4);
174 gemm_methods[3].rhs_format = DataFormat(DataType::FP16);
175 gemm_methods[3].packed_rhs_format = DataFormat(
176 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
177 gemm_methods[3].bias_format = DataFormat(DataType::UNKNOWN);
178 gemm_methods[3].fn_is_supported = cpu_has_bf16;
179 gemm_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
180 gemm_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
181 gemm_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
182 gemm_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
183 gemm_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
184 gemm_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
185 gemm_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
186 gemm_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon;
187 gemm_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon;
188 gemm_methods[3].fn_get_packed_lhs_offset =
189 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
190 gemm_methods[3].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon;
191 gemm_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
192 gemm_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
193 gemm_methods[3].fn_get_main_packed_rhs_offset =
194 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
195 gemm_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
196 gemm_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
197 gemm_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
198 gemm_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
199 gemm_methods[3].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
200
201 gemm_methods[4].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias";
202 gemm_methods[4].m0 = 8;
203 gemm_methods[4].n0 = 12;
204 gemm_methods[4].k0 = 4;
205 gemm_methods[4].dst_format = DataFormat(DataType::FP32);
206 gemm_methods[4].lhs_format = DataFormat(DataType::FP32);
207 gemm_methods[4].packed_lhs_format =
208 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4);
209 gemm_methods[4].rhs_format = DataFormat(DataType::FP32);
210 gemm_methods[4].packed_rhs_format = DataFormat(
211 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
212 gemm_methods[4].bias_format = DataFormat(DataType::UNKNOWN);
213 gemm_methods[4].fn_is_supported = cpu_has_bf16;
214 gemm_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
215 gemm_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
216 gemm_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
217 gemm_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
218 gemm_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
219 gemm_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
220 gemm_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
221 gemm_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon;
222 gemm_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon;
223 gemm_methods[4].fn_get_packed_lhs_offset =
224 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
225 gemm_methods[4].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon;
226 gemm_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
227 gemm_methods[4].fn_get_packed_rhs_size_generic_block_size =
228 kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
229 gemm_methods[4].fn_get_main_packed_rhs_offset =
230 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
231 gemm_methods[4].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
232 gemm_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
233 gemm_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
234 gemm_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
235 gemm_methods[4].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
236
237 return gemm_methods;
238 }
239
240 3 static const std::array<MatMulMethod, 2>& get_gemv_methods() {
241
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
3 static std::array<MatMulMethod, 2> gemv_methods{};
242 gemv_methods[0].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot";
243 gemv_methods[0].m0 = 1;
244 gemv_methods[0].n0 = 12;
245 gemv_methods[0].k0 = 4;
246 gemv_methods[0].dst_format = DataFormat(DataType::FP32);
247 gemv_methods[0].lhs_format = DataFormat(DataType::FP32);
248 gemv_methods[0].packed_lhs_format =
249 DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4);
250 gemv_methods[0].rhs_format = DataFormat(DataType::FP32);
251 gemv_methods[0].packed_rhs_format = DataFormat(
252 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
253 gemv_methods[0].bias_format = DataFormat(DataType::FP32);
254 gemv_methods[0].fn_is_supported = cpu_has_bf16;
255 gemv_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
256 gemv_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
257 gemv_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
258 gemv_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
259 gemv_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
260 gemv_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
261 gemv_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
262 gemv_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon;
263 gemv_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon;
264 gemv_methods[0].fn_get_packed_lhs_offset =
265 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
266 gemv_methods[0].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon;
267 gemv_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
268 gemv_methods[0].fn_get_packed_rhs_size_generic_block_size =
269 kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
270 gemv_methods[0].fn_get_main_packed_rhs_offset =
271 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
272 gemv_methods[0].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
273 gemv_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
274 gemv_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
275 gemv_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
276 gemv_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
277
278 gemv_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias";
279 gemv_methods[1].m0 = 1;
280 gemv_methods[1].n0 = 12;
281 gemv_methods[1].k0 = 4;
282 gemv_methods[1].dst_format = DataFormat(DataType::FP32);
283 gemv_methods[1].lhs_format = DataFormat(DataType::FP32);
284 gemv_methods[1].packed_lhs_format =
285 DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4);
286 gemv_methods[1].rhs_format = DataFormat(DataType::FP32);
287 gemv_methods[1].packed_rhs_format = DataFormat(
288 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
289 gemv_methods[1].bias_format = DataFormat(DataType::UNKNOWN);
290 gemv_methods[1].fn_is_supported = cpu_has_bf16;
291 gemv_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
292 gemv_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
293 gemv_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
294 gemv_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
295 gemv_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
296 gemv_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
297 gemv_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
298 gemv_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon;
299 gemv_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon;
300 gemv_methods[1].fn_get_packed_lhs_offset =
301 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
302 gemv_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon;
303 gemv_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
304 gemv_methods[1].fn_get_packed_rhs_size_generic_block_size =
305 kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
306 gemv_methods[1].fn_get_main_packed_rhs_offset =
307 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
308 gemv_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
309 gemv_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
310 gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
311 gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
312 gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
313
314 return gemv_methods;
315 }
316
317 } // namespace
318
319 /// Matrix multiplication test fixture.
320 class MatMulTestBf16 : public testing::TestWithParam<MatMulClampTestParams> {
321 private:
322 /// Unique ID: m, n, k
323 using TestDataId = std::tuple<size_t, size_t, size_t, float, std::string_view>;
324
325 protected:
326 /// Cached test data that is shared between multiple test case.
327 414 struct TestData {
328 828 Buffer lhs{}; ///< LHS operand.
329 828 Buffer ref_packed_lhs{}; ///< Reference packed LHS.
330 828 Buffer rhs{}; ///< RHS operand.
331 828 Buffer rhs_scales{}; ///< RHS per-row quantization scales.
332 828 Buffer bias{}; ///< Bias.
333 828 Buffer ref_packed_rhs{}; ///< Reference packed RHS.
334 828 Buffer ref_dst{}; ///< Reference output.
335 828 Buffer ref_clamped{}; ///< Reference clamped.
336 Range<float> clamp_range; ///< Clamping Range.
337 };
338
339 /// Gets the test data for the current test case.
340 1926 static const TestData& test_data() {
341 30960 const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam();
342 9630 const TestDataId data_id{info.m, info.n, info.k, clamp_keep_ratio, method.name};
343
344 // Creates a unique seed for the test data.
345
12/24
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1926 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1926 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1926 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1926 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1926 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1926 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1926 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 888 times.
✗ Branch 23 not taken.
7704 const auto key = std::string(method.name) + "_" + std::to_string(info.m) + "x" + std::to_string(info.n) + "x" +
346
8/14
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1038 times.
✓ Branch 7 taken 888 times.
✓ Branch 8 taken 1926 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1926 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1926 times.
✗ Branch 13 not taken.
7704 std::to_string(info.k) + "_" + (bias_mode == BiasMode::INTERNAL ? "internal" : "provided") + ":" +
347
2/4
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
3852 std::to_string(clamp_keep_ratio);
348
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 auto& feed = seed_stream(key);
349
350 // If the test data is already available, returns it.
351
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 const auto data_it = _data.find(data_id);
352
353
4/4
✓ Branch 0 taken 1734 times.
✓ Branch 1 taken 192 times.
✓ Branch 2 taken 816 times.
✓ Branch 3 taken 222 times.
1926 if (data_it != _data.end()) {
354
1/2
✓ Branch 0 taken 816 times.
✗ Branch 1 not taken.
1512 return data_it->second;
355 }
356
357 // Generates the test data.
358
2/4
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
828 const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN;
359
2/4
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
828 const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN;
360
2/4
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
828 const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN;
361
362 828 const auto lhs_h = info.m;
363 828 const auto lhs_w = info.k;
364
3/6
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
828 auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, feed());
365 414 Buffer ref_packed_lhs;
366
367
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 414 times.
414 if (has_lhs_pack) {
368 414 ref_packed_lhs =
369
3/6
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
828 pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w);
370 414 }
371
372 828 const auto rhs_h = info.k;
373 828 const auto rhs_w = info.n;
374
3/6
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
828 auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, feed());
375
376 414 Buffer rhs_scales;
377
3/8
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 414 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
414 if (data_type_is_quantized(method.rhs_format.data_type()) &&
378 method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) {
379 rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), feed());
380 }
381
382 414 const auto bias_h = 1;
383 828 const auto bias_w = info.n;
384 414 Buffer bias;
385
386
2/2
✓ Branch 0 taken 192 times.
✓ Branch 1 taken 222 times.
414 if (has_bias) {
387
3/6
✓ Branch 0 taken 222 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 222 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 222 times.
✗ Branch 5 not taken.
444 bias = fill_matrix_random(bias_h, bias_w, method.bias_format, feed());
388 222 }
389
390 414 constexpr size_t nr = 12;
391 414 constexpr size_t kr = 4;
392
393 414 size_t packed_rhs_size = 0;
394
395
2/2
✓ Branch 0 taken 150 times.
✓ Branch 1 taken 264 times.
414 if (method.fn_get_packed_rhs_size) {
396
1/2
✓ Branch 0 taken 150 times.
✗ Branch 1 not taken.
150 packed_rhs_size = method.fn_get_packed_rhs_size(rhs_w, rhs_h);
397
1/2
✓ Branch 0 taken 264 times.
✗ Branch 1 not taken.
414 } else if (method.fn_get_packed_rhs_size_generic_block_size) {
398
1/2
✓ Branch 0 taken 264 times.
✗ Branch 1 not taken.
264 packed_rhs_size = method.fn_get_packed_rhs_size_generic_block_size(rhs_w, rhs_h, nr, kr);
399 264 } else {
400 KAI_ERROR("No function to calculate Packed Rhs Matrix Size");
401 }
402
403
1/2
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
414 Buffer packed_rhs(packed_rhs_size);
404
405
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 414 times.
414 if (has_rhs_pack) {
406
2/4
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
828 const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w);
407
1/2
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
414 method.pack_rhs(
408
5/8
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 222 times.
✓ Branch 5 taken 192 times.
✓ Branch 6 taken 222 times.
✗ Branch 7 not taken.
828 info.n, info.k, rhs.data(), ref_rhs_row_stride, has_bias ? bias.data() : nullptr, nullptr,
409
1/2
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
414 packed_rhs.data());
410 414 }
411
412 KAI_ASSUME_ALWAYS(method.lhs_format.is_raw());
413 KAI_ASSUME_ALWAYS(method.rhs_format.is_raw());
414 KAI_ASSUME_ALWAYS(method.dst_format.is_raw());
415
416 414 Buffer tmp_lhs;
417 414 Buffer tmp_rhs;
418
1/2
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
414 const void* p_lhs_buff = lhs.data();
419
1/2
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
414 const void* p_rhs_buff = rhs.data();
420
421
5/8
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 120 times.
✓ Branch 3 taken 294 times.
✓ Branch 4 taken 120 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 120 times.
✗ Branch 7 not taken.
414 if (method.lhs_format.data_type() == DataType::FP32 || method.lhs_format.data_type() == DataType::FP16) {
422
3/6
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
828 tmp_lhs = cast(p_lhs_buff, method.lhs_format.data_type(), DataType::BF16, lhs_h, lhs_w);
423
1/2
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
414 p_lhs_buff = tmp_lhs.data();
424 414 }
425
5/8
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 120 times.
✓ Branch 3 taken 294 times.
✓ Branch 4 taken 120 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 120 times.
✗ Branch 7 not taken.
414 if (method.rhs_format.data_type() == DataType::FP32 || method.rhs_format.data_type() == DataType::FP16) {
426
3/6
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
828 tmp_rhs = cast(p_rhs_buff, method.rhs_format.data_type(), DataType::BF16, rhs_h, rhs_w);
427
1/2
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
414 p_rhs_buff = tmp_rhs.data();
428 414 }
429
430 414 auto ref_dst =
431
1/2
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
414 matmul_nt_nt_quantized<BFloat16<>, float, float, BFloat16<>, float, float, float, float, float, float>(
432 1656 info.m, info.n, info.k, p_lhs_buff, nullptr, nullptr, 1, info.k, p_rhs_buff, nullptr, nullptr, 1,
433
1/2
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
414 info.k, bias.data(), nullptr, nullptr, info.k);
434
435 1656 const auto [min, max] =
436
5/10
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 414 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 414 times.
✗ Branch 9 not taken.
414 find_clamp_range(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_keep_ratio);
437
438
5/10
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 414 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 414 times.
✗ Branch 9 not taken.
414 auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), info.m * info.n, min, max);
439
440
9/18
✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 414 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 414 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 414 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 414 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 414 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 414 times.
✗ Branch 17 not taken.
3726 auto& data = _data[data_id] = {};
441 414 data.lhs = std::move(lhs);
442 414 data.ref_packed_lhs = std::move(ref_packed_lhs);
443 414 data.rhs = std::move(rhs);
444 414 data.rhs_scales = std::move(rhs_scales);
445 414 data.bias = std::move(bias);
446 414 data.ref_packed_rhs = std::move(packed_rhs);
447 414 data.ref_dst = std::move(ref_dst);
448 414 data.ref_clamped = std::move(ref_clamped);
449 1242 data.clamp_range = {min, max};
450
451 414 return data;
452 1926 }
453
454 private:
455 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
456 static std::map<TestDataId, TestData> _data;
457 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
458 };
459
460 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
461 3 std::map<MatMulTestBf16::TestDataId, MatMulTestBf16::TestData> MatMulTestBf16::_data;
462 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
463
464 /// Tests the output.
465
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5196 TEST_P(MatMulTestBf16, Output) {
466 66672 const auto& [method, info, portion, clamp_keep_ratio, bias_mode] = GetParam();
467
468
3/4
✓ Branch 0 taken 2076 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✓ Branch 3 taken 150 times.
2076 if (method.fn_is_supported && !method.fn_is_supported()) {
469
3/6
✓ Branch 0 taken 150 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 150 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 150 times.
✗ Branch 5 not taken.
150 GTEST_SKIP() << "Unsupported CPU feature";
470 }
471
472
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 if (!method.has_main_kernel()) {
473 GTEST_SKIP() << "No main kernel available";
474 }
475
476 1926 const auto& data = test_data();
477 3852 const auto m_step = method.fn_get_main_m_step();
478
4/16
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
3852 ASSERT_TRUE(m_step % method.m0 == 0);
479
480 3852 const auto n_step = method.fn_get_main_n_step();
481
4/16
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
3852 ASSERT_TRUE(n_step % method.n0 == 0);
482
483 5778 const auto rect = portion.compute_portion(info.m, info.n, m_step, n_step);
484
485
2/4
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1926 times.
1926 if (rect.height() == 0 || rect.width() == 0) {
486 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
487 }
488
489 3852 const size_t lhs_w = info.k;
490 1926 const size_t rhs_w = rect.width();
491 3852 const size_t bias_w = info.n;
492 3852 const size_t dst_w = info.n;
493 1926 const bool has_bias = (data.bias.size() > 0);
494
495 1926 const auto lhs_start_row = rect.start_row();
496 3852 const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w);
497
498 11556 const size_t lhs_packed_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, method.k0, 1 /* sr */);
499 1926 Buffer lhs_data(lhs_packed_size);
500
501
2/4
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
3852 uintptr_t lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride);
502
3/6
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
5778 uintptr_t lhs_packed_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k);
503
504 KAI_UNUSED(lhs_offset);
505
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 abi_check(
506
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 method.fn_pack_lhs, rect.height(), info.k, method.m0, method.k0, 1 /* sr */, 0 /* m_idx_start */,
507
2/4
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
1926 data.lhs.data() + lhs_offset, lhs_stride, lhs_data.data() + lhs_packed_offset);
508
509
3/6
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
5778 const auto rhs_stride = method.rhs_format.default_row_stride(info.n);
510
511 1926 size_t rhs_packed_size = 0;
512
513
2/2
✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 750 times.
1926 if (method.fn_get_packed_rhs_size_generic_block_size) {
514
5/10
✓ Branch 0 taken 1176 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1176 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1176 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1176 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1176 times.
✗ Branch 9 not taken.
5880 rhs_packed_size = method.fn_get_packed_rhs_size_generic_block_size(info.n, info.k, method.n0, method.k0);
515
1/2
✓ Branch 0 taken 750 times.
✗ Branch 1 not taken.
1926 } else if (method.fn_get_packed_rhs_size) {
516
3/6
✓ Branch 0 taken 750 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 750 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 750 times.
✗ Branch 5 not taken.
2250 rhs_packed_size = method.fn_get_packed_rhs_size(info.n, info.k);
517 750 }
518
519
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 Buffer rhs_data(rhs_packed_size);
520
521
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 const auto packed_rhs_start_row = rect.start_col();
522 1926 const auto packed_rhs_start_col = 0;
523
524
3/6
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
3852 uintptr_t rhs_offset = method.fn_get_rhs_offset(rect.start_col());
525
3/6
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
5778 uintptr_t rhs_packed_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k);
526 1926 const auto ref_rhs_packed_offset =
527
2/4
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
3852 method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k);
528
529
4/16
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
1926 ASSERT_EQ(rhs_packed_offset, ref_rhs_packed_offset);
530
531
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 uintptr_t bias_offset = sizeof(float) * rect.start_col();
532
533
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 abi_check(
534 1926 method.fn_pack_rhs,
535 1926 1, // num_groups
536 5778 rhs_w, info.k, method.n0, method.k0,
537 1926 1, // sr
538
4/6
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1038 times.
✓ Branch 3 taken 888 times.
✓ Branch 4 taken 1038 times.
✗ Branch 5 not taken.
1926 rhs_stride, data.rhs.data() + rhs_offset, has_bias ? data.bias.data() + bias_offset : nullptr,
539 1926 nullptr, // Scale
540
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 rhs_data.data() + rhs_packed_offset, 0, nullptr);
541
542
2/2
✓ Branch 0 taken 888 times.
✓ Branch 1 taken 1038 times.
1926 if (has_bias) {
543
3/6
✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1038 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1038 times.
✗ Branch 5 not taken.
2076 const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_col(), bias_w);
544
4/16
✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1038 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1038 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1038 times.
1038 ASSERT_EQ(ref_bias_offset, bias_offset);
545 1038 }
546
547
2/4
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
3852 const auto dst_stride = method.dst_format.default_row_stride(dst_w);
548
4/8
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
3852 const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
549
4/8
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
3852 const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w);
550
4/16
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
1926 ASSERT_EQ(dst_offset, ref_dst_offset);
551
552
4/8
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
7704 const auto dst_size = method.fn_get_dst_size(info.m, info.n);
553
4/8
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
7704 const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n);
554
4/16
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
1926 ASSERT_EQ(dst_size, ref_dst_size);
555
556
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 Buffer dst(dst_size);
557
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 abi_check(
558
5/10
✗ Branch 0 not taken.
✓ Branch 1 taken 1926 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1926 times.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1926 times.
✗ Branch 9 not taken.
3852 &MatMulMethod::main_kernel, method, rect.height(), rect.width(), info.k, lhs_data.data() + lhs_packed_offset,
559
2/4
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
1926 rhs_data.data() + rhs_packed_offset, nullptr, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride,
560 1926 data.clamp_range.min, data.clamp_range.max);
561
562
1/2
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
1926 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
563
5/10
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1926 times.
✗ Branch 9 not taken.
1926 const auto success = compare(dst.data(), data.ref_clamped.data(), method.dst_format, info.m, info.n, rect, handler);
564
565
4/16
✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
1926 ASSERT_TRUE(success);
566 2076 }
567
568
30/104
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✓ Branch 30 taken 2 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✓ Branch 32 taken 750 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✓ Branch 34 taken 1500 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
2256 INSTANTIATE_TEST_SUITE_P(
569 MatMulGemm, MatMulTestBf16,
570 testing::Combine(
571 testing::ValuesIn(get_gemm_methods()),
572 testing::Values(
573 MatMulShape{1, 1, 1}, // Smallest Possible Shape
574 MatMulShape{3, 7, 3}, // Smaller than block size
575 MatMulShape{12, 8, 4}, // Same block size
576 MatMulShape{1, 1, 1023}, // Long K
577 MatMulShape{1013, 1, 5}, // Long M
578 MatMulShape{2, 1013, 6}, // Long N
579 MatMulShape{13, 33, 23}, //
580 MatMulShape{93, 57, 89}, //
581 MatMulShape{256, 256, 256}, // Nice shapes
582 MatMulShape{257, 113, 373} // Prime numbers
583 ),
584 testing::Values(
585 MatrixPortion(0, 0, 1, 1), // Full matrix.
586 MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner.
587 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
588 MatrixPortion(0.75, 0, 1, 1), // Partial rows
589 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
590 ),
591 testing::Values(BiasMode::PROVIDED), //
592 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio
593 testing::PrintToStringParamName());
594
595
28/96
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✓ Branch 30 taken 288 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✓ Branch 32 taken 576 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
870 INSTANTIATE_TEST_SUITE_P(
596 MatMulGemv, MatMulTestBf16,
597 testing::Combine(
598 testing::ValuesIn(get_gemv_methods()),
599 testing::Values(
600 MatMulShape{1, 1, 1}, // Smallest Possible Shape
601 MatMulShape{1, 1, 1023}, // Long K
602 MatMulShape{1, 1023, 1}, // Long N
603 MatMulShape{1, 1013, 1023}, // Large Rhs
604 MatMulShape{1, 37, 23}, //
605 MatMulShape{1, 57, 89}, //
606 MatMulShape{1, 36, 89}, //
607 MatMulShape{1, 98, 23}, //
608 MatMulShape{1, 64, 1024}, // Nice shapes - Long Rhs Rect
609 MatMulShape{1, 1024, 64}, // Nice shapes - Wide Rhs Rect
610 MatMulShape{1, 256, 256}, // Nice shapes - Square
611 MatMulShape{1, 113, 373} // Prime numbers
612 ),
613 testing::Values(
614 MatrixPortion(0, 0, 1, 1), // Full matrix.
615 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
616 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
617 MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle
618 ),
619 testing::Values(BiasMode::PROVIDED), //
620 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio
621 testing::PrintToStringParamName());
622 } // namespace kai::test
623