KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.8% 150 4 159
Functions: 100.0% 16 0 16
Branches: 41.9% 223 20 552

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <functional>
13 #include <iosfwd>
14 #include <limits>
15 #include <map>
16 #include <string_view>
17 #include <tuple>
18 #include <utility>
19
20 #include "kai/kai_common.h"
21 #include "test/common/buffer.hpp"
22 #include "test/common/compare.hpp"
23 #include "test/common/cpu_info.hpp"
24 #include "test/common/data_format.hpp"
25 #include "test/common/data_type.hpp"
26 #include "test/common/matmul_test_common.hpp"
27 #include "test/common/matrix_portion.hpp"
28 #include "test/common/printer.hpp"
29 #include "test/common/sme.hpp"
30 #include "test/reference/cast.hpp"
31 #include "test/reference/fill.hpp"
32 #include "test/reference/matmul.hpp"
33 #include "test/reference/pack.hpp"
34
35 // matmul_clamp_f32_bf16p_bf16p
36 #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h"
37 #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h"
38 #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h"
39 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.h"
40 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.h"
41 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.h"
42 #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h"
43
44 // SME files here.
45 #include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
46 #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h"
47 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
48
49 namespace kai::test {
50
51 /// List of supported matrix multiplication methods.
52 namespace {
53
54 1 static const std::array<MatMulMethod, 5>& get_gemm_methods() {
55
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
1 static std::array<MatMulMethod, 5> gemm_methods{};
56 gemm_methods[0].name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa";
57 gemm_methods[0].m0 = 2 * get_sme_vector_length<float>();
58 gemm_methods[0].n0 = 2 * get_sme_vector_length<float>();
59 gemm_methods[0].k0 = 2;
60 gemm_methods[0].dst_format = DataFormat(DataType::FP32);
61 gemm_methods[0].lhs_format = DataFormat(DataType::FP32);
62 gemm_methods[0].packed_lhs_format = DataFormat(
63 DataType::BF16, 2 * get_sme_vector_length<float>(), 2, DataFormat::PackFormat::NONE, DataType::FP32,
64 DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 2);
65 gemm_methods[0].rhs_format = DataFormat(DataType::FP32);
66 gemm_methods[0].packed_rhs_format = DataFormat(
67 DataType::BF16, 2 * get_sme_vector_length<float>(), 2, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32,
68 DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 2);
69 gemm_methods[0].bias_format = DataFormat(DataType::FP32);
70 gemm_methods[0].fn_is_supported = cpu_has_sme2;
71 gemm_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
72 gemm_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
73 gemm_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
74 gemm_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
75 gemm_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
76 gemm_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
77 gemm_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
78 gemm_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme;
79 gemm_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme;
80 gemm_methods[0].fn_get_packed_lhs_offset =
81 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
82 gemm_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme;
83 gemm_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
84 gemm_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
85 gemm_methods[0].fn_get_main_packed_rhs_offset =
86 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
87 gemm_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
88 gemm_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme;
89 gemm_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
90 gemm_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
91 gemm_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa;
92
93 gemm_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla";
94 gemm_methods[1].m0 = 8;
95 gemm_methods[1].n0 = 12;
96 gemm_methods[1].k0 = 4;
97 gemm_methods[1].dst_format = DataFormat(DataType::FP32);
98 gemm_methods[1].lhs_format = DataFormat(DataType::FP32);
99 gemm_methods[1].packed_lhs_format =
100 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4);
101 gemm_methods[1].rhs_format = DataFormat(DataType::FP32);
102 gemm_methods[1].packed_rhs_format = DataFormat(
103 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
104 gemm_methods[1].bias_format = DataFormat(DataType::FP32);
105 gemm_methods[1].fn_is_supported = cpu_has_bf16;
106 gemm_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
107 gemm_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
108 gemm_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
109 gemm_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
110 gemm_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
111 gemm_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
112 gemm_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
113 gemm_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon;
114 gemm_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon;
115 gemm_methods[1].fn_get_packed_lhs_offset =
116 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
117 gemm_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon;
118 gemm_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
119 gemm_methods[1].fn_get_packed_rhs_size_generic_block_size =
120 kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
121 gemm_methods[1].fn_get_main_packed_rhs_offset =
122 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
123 gemm_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
124 gemm_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
125 gemm_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
126 gemm_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
127 gemm_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
128
129 gemm_methods[2].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output";
130 gemm_methods[2].m0 = 8;
131 gemm_methods[2].n0 = 12;
132 gemm_methods[2].k0 = 4;
133 gemm_methods[2].dst_format = DataFormat(DataType::FP32);
134 gemm_methods[2].lhs_format = DataFormat(DataType::FP16);
135 gemm_methods[2].packed_lhs_format =
136 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4);
137 gemm_methods[2].rhs_format = DataFormat(DataType::FP16);
138 gemm_methods[2].packed_rhs_format = DataFormat(
139 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
140 gemm_methods[2].bias_format = DataFormat(DataType::FP32);
141 gemm_methods[2].fn_is_supported = cpu_has_bf16;
142 gemm_methods[2].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
143 gemm_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
144 gemm_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
145 gemm_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
146 gemm_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
147 gemm_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
148 gemm_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
149 gemm_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon;
150 gemm_methods[2].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon;
151 gemm_methods[2].fn_get_packed_lhs_offset =
152 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
153 gemm_methods[2].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon;
154 gemm_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
155 gemm_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
156 gemm_methods[2].fn_get_main_packed_rhs_offset =
157 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
158 gemm_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
159 gemm_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
160 gemm_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
161 gemm_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
162 gemm_methods[2].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
163
164 gemm_methods[3].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias";
165 gemm_methods[3].m0 = 8;
166 gemm_methods[3].n0 = 12;
167 gemm_methods[3].k0 = 4;
168 gemm_methods[3].dst_format = DataFormat(DataType::FP32);
169 gemm_methods[3].lhs_format = DataFormat(DataType::FP16);
170 gemm_methods[3].packed_lhs_format =
171 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4);
172 gemm_methods[3].rhs_format = DataFormat(DataType::FP16);
173 gemm_methods[3].packed_rhs_format = DataFormat(
174 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
175 gemm_methods[3].bias_format = DataFormat(DataType::UNKNOWN);
176 gemm_methods[3].fn_is_supported = cpu_has_bf16;
177 gemm_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
178 gemm_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
179 gemm_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
180 gemm_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
181 gemm_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
182 gemm_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
183 gemm_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
184 gemm_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon;
185 gemm_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon;
186 gemm_methods[3].fn_get_packed_lhs_offset =
187 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
188 gemm_methods[3].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon;
189 gemm_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
190 gemm_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
191 gemm_methods[3].fn_get_main_packed_rhs_offset =
192 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
193 gemm_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
194 gemm_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon;
195 gemm_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
196 gemm_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
197 gemm_methods[3].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
198
199 gemm_methods[4].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias";
200 gemm_methods[4].m0 = 8;
201 gemm_methods[4].n0 = 12;
202 gemm_methods[4].k0 = 4;
203 gemm_methods[4].dst_format = DataFormat(DataType::FP32);
204 gemm_methods[4].lhs_format = DataFormat(DataType::FP32);
205 gemm_methods[4].packed_lhs_format =
206 DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4);
207 gemm_methods[4].rhs_format = DataFormat(DataType::FP32);
208 gemm_methods[4].packed_rhs_format = DataFormat(
209 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
210 gemm_methods[4].bias_format = DataFormat(DataType::UNKNOWN);
211 gemm_methods[4].fn_is_supported = cpu_has_bf16;
212 gemm_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
213 gemm_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
214 gemm_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
215 gemm_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
216 gemm_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
217 gemm_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
218 gemm_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
219 gemm_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon;
220 gemm_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon;
221 gemm_methods[4].fn_get_packed_lhs_offset =
222 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
223 gemm_methods[4].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon;
224 gemm_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
225 gemm_methods[4].fn_get_packed_rhs_size_generic_block_size =
226 kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
227 gemm_methods[4].fn_get_main_packed_rhs_offset =
228 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
229 gemm_methods[4].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
230 gemm_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
231 gemm_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
232 gemm_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
233 gemm_methods[4].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla;
234
235 return gemm_methods;
236 }
237
238 1 static const std::array<MatMulMethod, 2>& get_gemv_methods() {
239
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
1 static std::array<MatMulMethod, 2> gemv_methods{};
240 gemv_methods[0].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot";
241 gemv_methods[0].m0 = 1;
242 gemv_methods[0].n0 = 12;
243 gemv_methods[0].k0 = 4;
244 gemv_methods[0].dst_format = DataFormat(DataType::FP32);
245 gemv_methods[0].lhs_format = DataFormat(DataType::FP32);
246 gemv_methods[0].packed_lhs_format =
247 DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4);
248 gemv_methods[0].rhs_format = DataFormat(DataType::FP32);
249 gemv_methods[0].packed_rhs_format = DataFormat(
250 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
251 gemv_methods[0].bias_format = DataFormat(DataType::FP32);
252 gemv_methods[0].fn_is_supported = cpu_has_bf16;
253 gemv_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
254 gemv_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
255 gemv_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
256 gemv_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
257 gemv_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
258 gemv_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
259 gemv_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
260 gemv_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon;
261 gemv_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon;
262 gemv_methods[0].fn_get_packed_lhs_offset =
263 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
264 gemv_methods[0].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon;
265 gemv_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
266 gemv_methods[0].fn_get_packed_rhs_size_generic_block_size =
267 kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
268 gemv_methods[0].fn_get_main_packed_rhs_offset =
269 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
270 gemv_methods[0].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
271 gemv_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
272 gemv_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
273 gemv_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
274 gemv_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
275
276 gemv_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias";
277 gemv_methods[1].m0 = 1;
278 gemv_methods[1].n0 = 12;
279 gemv_methods[1].k0 = 4;
280 gemv_methods[1].dst_format = DataFormat(DataType::FP32);
281 gemv_methods[1].lhs_format = DataFormat(DataType::FP32);
282 gemv_methods[1].packed_lhs_format =
283 DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4);
284 gemv_methods[1].rhs_format = DataFormat(DataType::FP32);
285 gemv_methods[1].packed_rhs_format = DataFormat(
286 DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4);
287 gemv_methods[1].bias_format = DataFormat(DataType::UNKNOWN);
288 gemv_methods[1].fn_is_supported = cpu_has_bf16;
289 gemv_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
290 gemv_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
291 gemv_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
292 gemv_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
293 gemv_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
294 gemv_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
295 gemv_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
296 gemv_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon;
297 gemv_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon;
298 gemv_methods[1].fn_get_packed_lhs_offset =
299 kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
300 gemv_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon;
301 gemv_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
302 gemv_methods[1].fn_get_packed_rhs_size_generic_block_size =
303 kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
304 gemv_methods[1].fn_get_main_packed_rhs_offset =
305 kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
306 gemv_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
307 gemv_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon;
308 gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
309 gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
310 gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot;
311
312 return gemv_methods;
313 }
314
315 } // namespace
316
317 /// Matrix multiplication test fixture.
318 class MatMulTestBf16 : public testing::TestWithParam<MatMulTestParams> {
319 private:
320 /// Unique ID: m, n, k
321 using TestDataId = std::tuple<size_t, size_t, size_t, std::string_view>;
322
323 protected:
324 /// Cached test data that is shared between multiple test case.
325 74 struct TestData {
326 148 Buffer lhs{}; ///< LHS operand.
327 148 Buffer ref_packed_lhs{}; ///< Reference packed LHS.
328 148 Buffer rhs{}; ///< RHS operand.
329 148 Buffer rhs_scales{}; ///< RHS per-row quantization scales.
330 148 Buffer bias{}; ///< Bias.
331 148 Buffer ref_packed_rhs{}; ///< Reference packed RHS.
332 148 Buffer ref_dst{}; ///< Reference output.
333 };
334
335 /// Gets the test data for the current test case.
336 346 static const TestData& test_data() {
337 3070 const auto& [method, info, portion] = GetParam();
338 1730 const TestDataId data_id{info.m, info.n, info.k, method.name};
339
340 // If the test data is already available, returns it.
341 346 const auto data_it = _data.find(data_id);
342
343
2/2
✓ Branch 0 taken 272 times.
✓ Branch 1 taken 74 times.
346 if (data_it != _data.end()) {
344 272 return data_it->second;
345 }
346
347 // Generates the test data.
348 148 const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN;
349 148 const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN;
350 148 const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN;
351
352 148 const auto lhs_h = info.m;
353 148 const auto lhs_w = info.k;
354 148 auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0);
355 74 Buffer ref_packed_lhs;
356
357
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 74 times.
74 if (has_lhs_pack) {
358 74 ref_packed_lhs =
359
3/6
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 74 times.
✗ Branch 5 not taken.
148 pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w);
360 74 }
361
362 148 const auto rhs_h = info.k;
363 148 const auto rhs_w = info.n;
364
2/4
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
148 auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1);
365
366 74 Buffer rhs_scales;
367
3/8
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 74 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
74 if (data_type_is_quantized(method.rhs_format.data_type()) &&
368 method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) {
369 rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2);
370 }
371
372 74 const auto bias_h = 1;
373 148 const auto bias_w = info.n;
374 74 Buffer bias;
375
376
2/2
✓ Branch 0 taken 32 times.
✓ Branch 1 taken 42 times.
74 if (has_bias) {
377
2/4
✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✗ Branch 3 not taken.
84 bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3);
378 42 }
379
380 74 constexpr size_t nr = 12;
381 74 constexpr size_t kr = 4;
382
383 74 size_t packed_rhs_size = 0;
384
385
2/2
✓ Branch 0 taken 30 times.
✓ Branch 1 taken 44 times.
74 if (method.fn_get_packed_rhs_size) {
386
1/2
✓ Branch 0 taken 30 times.
✗ Branch 1 not taken.
30 packed_rhs_size = method.fn_get_packed_rhs_size(rhs_w, rhs_h);
387
1/2
✓ Branch 0 taken 44 times.
✗ Branch 1 not taken.
74 } else if (method.fn_get_packed_rhs_size_generic_block_size) {
388
1/2
✓ Branch 0 taken 44 times.
✗ Branch 1 not taken.
44 packed_rhs_size = method.fn_get_packed_rhs_size_generic_block_size(rhs_w, rhs_h, nr, kr);
389 44 } else {
390 KAI_ERROR("No function to calculate Packed Rhs Matrix Size");
391 }
392
393
1/2
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
74 Buffer packed_rhs(packed_rhs_size);
394
395
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 74 times.
74 if (has_rhs_pack) {
396
2/4
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
148 const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w);
397
1/2
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
74 method.pack_rhs(
398
5/8
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 42 times.
✓ Branch 5 taken 32 times.
✓ Branch 6 taken 42 times.
✗ Branch 7 not taken.
148 info.n, info.k, rhs.data(), ref_rhs_row_stride, has_bias ? bias.data() : nullptr, nullptr,
399
1/2
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
74 packed_rhs.data());
400 74 }
401
402 KAI_ASSUME(method.lhs_format.is_raw());
403 KAI_ASSUME(method.rhs_format.is_raw());
404 KAI_ASSUME(method.dst_format.is_raw());
405
406 74 Buffer tmp_lhs;
407 74 Buffer tmp_rhs;
408
1/2
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
74 const void* p_lhs_buff = lhs.data();
409
1/2
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
74 const void* p_rhs_buff = rhs.data();
410
411
5/8
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✓ Branch 3 taken 54 times.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
74 if (method.lhs_format.data_type() == DataType::FP32 || method.lhs_format.data_type() == DataType::FP16) {
412
3/6
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 74 times.
✗ Branch 5 not taken.
148 tmp_lhs = cast(p_lhs_buff, method.lhs_format.data_type(), DataType::BF16, lhs_h, lhs_w);
413
1/2
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
74 p_lhs_buff = tmp_lhs.data();
414 74 }
415
5/8
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✓ Branch 3 taken 54 times.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
74 if (method.rhs_format.data_type() == DataType::FP32 || method.rhs_format.data_type() == DataType::FP16) {
416
3/6
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 74 times.
✗ Branch 5 not taken.
148 tmp_rhs = cast(p_rhs_buff, method.rhs_format.data_type(), DataType::BF16, rhs_h, rhs_w);
417
1/2
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
74 p_rhs_buff = tmp_rhs.data();
418 74 }
419
420 74 auto ref_dst =
421
1/2
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
74 matmul_nt_nt_quantized<BFloat16<>, float, float, BFloat16<>, float, float, float, float, float, float>(
422 296 info.m, info.n, info.k, p_lhs_buff, nullptr, nullptr, 1, info.k, p_rhs_buff, nullptr, nullptr, 1,
423
1/2
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
74 info.k, bias.data(), nullptr, nullptr, info.k);
424
425
8/16
✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 74 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 74 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 74 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 74 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 74 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 74 times.
✗ Branch 15 not taken.
592 auto& data = _data[data_id] = {};
426 74 data.lhs = std::move(lhs);
427 74 data.ref_packed_lhs = std::move(ref_packed_lhs);
428 74 data.rhs = std::move(rhs);
429 74 data.rhs_scales = std::move(rhs_scales);
430 74 data.bias = std::move(bias);
431 74 data.ref_packed_rhs = std::move(packed_rhs);
432 74 data.ref_dst = std::move(ref_dst);
433
434 74 return data;
435 346 }
436
437 private:
438 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
439 static std::map<TestDataId, TestData> _data;
440 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
441 };
442
443 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
444 1 std::map<MatMulTestBf16::TestDataId, MatMulTestBf16::TestData> MatMulTestBf16::_data;
445 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
446
447 /// Tests the output.
448
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1040 TEST_P(MatMulTestBf16, Output) {
449 11252 const auto& [method, info, portion] = GetParam();
450
451
2/4
✓ Branch 0 taken 346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 346 times.
✗ Branch 3 not taken.
346 if (method.fn_is_supported && !method.fn_is_supported()) {
452 GTEST_SKIP() << "Unsupported CPU feature";
453 }
454
455
1/2
✓ Branch 0 taken 346 times.
✗ Branch 1 not taken.
346 if (!method.has_main_kernel()) {
456 GTEST_SKIP() << "No main kernel available";
457 }
458
459 346 const auto& data = test_data();
460 692 const auto m_step = method.fn_get_main_m_step();
461
4/16
✓ Branch 0 taken 346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 346 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 346 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 346 times.
692 ASSERT_TRUE(m_step % method.m0 == 0);
462
463 692 const auto n_step = method.fn_get_main_n_step();
464
4/16
✓ Branch 0 taken 346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 346 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 346 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 346 times.
692 ASSERT_TRUE(n_step % method.n0 == 0);
465
466 1038 const auto rect = portion.compute_portion(info.m, info.n, m_step, n_step);
467
468
4/4
✓ Branch 0 taken 331 times.
✓ Branch 1 taken 15 times.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 322 times.
346 if (rect.height() == 0 || rect.width() == 0) {
469
9/18
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 24 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 24 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 24 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 24 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 24 times.
✗ Branch 17 not taken.
24 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
470 }
471
472 644 const size_t lhs_w = info.k;
473 322 const size_t rhs_w = rect.width();
474 644 const size_t bias_w = info.n;
475 644 const size_t dst_w = info.n;
476 322 const bool has_bias = (data.bias.size() > 0);
477
478 322 const auto lhs_start_row = rect.start_row();
479 644 const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w);
480
481 1932 const size_t lhs_packed_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, method.k0, 1 /* sr */);
482 322 Buffer lhs_data(lhs_packed_size);
483
484
2/4
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
644 uintptr_t lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride);
485
3/6
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
966 uintptr_t lhs_packed_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k);
486
487 KAI_UNUSED(lhs_offset);
488
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
644 method.fn_pack_lhs(
489
4/8
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
322 rect.height(), info.k, method.m0, method.k0, 1 /* sr */, 0 /* m_idx_start */, data.lhs.data() + lhs_offset,
490
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
322 lhs_stride, lhs_data.data() + lhs_packed_offset);
491
492
3/6
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
966 const auto rhs_stride = method.rhs_format.default_row_stride(info.n);
493
494 322 size_t rhs_packed_size = 0;
495
496
2/2
✓ Branch 0 taken 184 times.
✓ Branch 1 taken 138 times.
322 if (method.fn_get_packed_rhs_size_generic_block_size) {
497
5/10
✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 184 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 184 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 184 times.
✗ Branch 9 not taken.
920 rhs_packed_size = method.fn_get_packed_rhs_size_generic_block_size(info.n, info.k, method.n0, method.k0);
498
1/2
✓ Branch 0 taken 138 times.
✗ Branch 1 not taken.
322 } else if (method.fn_get_packed_rhs_size) {
499
3/6
✓ Branch 0 taken 138 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 138 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
414 rhs_packed_size = method.fn_get_packed_rhs_size(info.n, info.k);
500 138 }
501
502
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
322 Buffer rhs_data(rhs_packed_size);
503
504
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
322 const auto packed_rhs_start_row = rect.start_col();
505 322 const auto packed_rhs_start_col = 0;
506
507
3/6
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
644 uintptr_t rhs_offset = method.fn_get_rhs_offset(rect.start_col());
508
3/6
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
966 uintptr_t rhs_packed_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k);
509 322 const auto ref_rhs_packed_offset =
510
2/4
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
644 method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k);
511
512
4/16
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 322 times.
322 ASSERT_EQ(rhs_packed_offset, ref_rhs_packed_offset);
513
514
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
322 uintptr_t bias_offset = sizeof(float) * rect.start_col();
515
516
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
644 method.fn_pack_rhs(
517 1, // num_groups
518 1288 rhs_w, info.k, method.n0, method.k0,
519 1, // sr
520
4/6
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✓ Branch 3 taken 138 times.
✓ Branch 4 taken 184 times.
✗ Branch 5 not taken.
322 rhs_stride, data.rhs.data() + rhs_offset, has_bias ? data.bias.data() + bias_offset : nullptr,
521 NULL, // Scale
522
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
322 rhs_data.data() + rhs_packed_offset, 0, NULL);
523
524
2/2
✓ Branch 0 taken 138 times.
✓ Branch 1 taken 184 times.
322 if (has_bias) {
525
3/6
✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 184 times.
✗ Branch 5 not taken.
368 const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_col(), bias_w);
526
4/16
✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 184 times.
184 ASSERT_EQ(ref_bias_offset, bias_offset);
527 184 }
528
529
2/4
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
644 const auto dst_stride = method.dst_format.default_row_stride(dst_w);
530
4/8
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
644 const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
531
4/8
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
644 const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w);
532
4/16
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 322 times.
322 ASSERT_EQ(dst_offset, ref_dst_offset);
533
534
4/8
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
1288 const auto dst_size = method.fn_get_dst_size(info.m, info.n);
535
4/8
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
1288 const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n);
536
4/16
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 322 times.
322 ASSERT_EQ(dst_size, ref_dst_size);
537
538
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
322 Buffer dst(dst_size);
539
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
322 method.main_kernel(
540
4/8
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
322 rect.height(), rect.width(), info.k, lhs_data.data() + lhs_packed_offset, rhs_data.data() + rhs_packed_offset,
541
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
322 NULL, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits<float>::infinity(),
542 322 std::numeric_limits<float>::infinity());
543
544
1/2
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
322 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
545
5/10
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 322 times.
✗ Branch 9 not taken.
322 const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler);
546
547
4/16
✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 322 times.
322 ASSERT_TRUE(success);
548 346 }
549
550
15/48
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 250 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
252 INSTANTIATE_TEST_SUITE_P(
551 MatMulGemm, MatMulTestBf16,
552 testing::Combine(
553 testing::ValuesIn(get_gemm_methods()),
554 testing::Values(
555 MatMulShape{1, 1, 1}, // Smallest Possible Shape
556 MatMulShape{3, 7, 3}, // Smaller than block size
557 MatMulShape{12, 8, 4}, // Same block size
558 MatMulShape{1, 1, 1023}, // Long K
559 MatMulShape{1013, 1, 5}, // Long M
560 MatMulShape{2, 1013, 6}, // Long N
561 MatMulShape{13, 33, 23}, //
562 MatMulShape{93, 57, 89}, //
563 MatMulShape{256, 256, 256}, // Nice shapes
564 MatMulShape{257, 113, 373} // Prime numbers
565 ),
566 testing::Values(
567 MatrixPortion(0, 0, 1, 1), // Full matrix.
568 MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner.
569 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
570 MatrixPortion(0.75, 0, 1, 1), // Partial rows
571 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
572 )),
573 testing::PrintToStringParamName());
574
575
14/44
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 96 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
98 INSTANTIATE_TEST_SUITE_P(
576 MatMulGemv, MatMulTestBf16,
577 testing::Combine(
578 testing::ValuesIn(get_gemv_methods()),
579 testing::Values(
580 MatMulShape{1, 1, 1}, // Smallest Possible Shape
581 MatMulShape{1, 1, 1023}, // Long K
582 MatMulShape{1, 1023, 1}, // Long N
583 MatMulShape{1, 1013, 1023}, // Large Rhs
584 MatMulShape{1, 37, 23}, //
585 MatMulShape{1, 57, 89}, //
586 MatMulShape{1, 36, 89}, //
587 MatMulShape{1, 98, 23}, //
588 MatMulShape{1, 64, 1024}, // Nice shapes - Long Rhs Rect
589 MatMulShape{1, 1024, 64}, // Nice shapes - Wide Rhs Rect
590 MatMulShape{1, 256, 256}, // Nice shapes - Square
591 MatMulShape{1, 113, 373} // Prime numbers
592 ),
593 testing::Values(
594 MatrixPortion(0, 0, 1, 1), // Full matrix.
595 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
596 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
597 MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle
598 )),
599 testing::PrintToStringParamName());
600 } // namespace kai::test
601