Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <functional> | ||
13 | #include <iosfwd> | ||
14 | #include <limits> | ||
15 | #include <map> | ||
16 | #include <string_view> | ||
17 | #include <tuple> | ||
18 | #include <utility> | ||
19 | |||
20 | #include "kai/kai_common.h" | ||
21 | #include "test/common/buffer.hpp" | ||
22 | #include "test/common/compare.hpp" | ||
23 | #include "test/common/cpu_info.hpp" | ||
24 | #include "test/common/data_format.hpp" | ||
25 | #include "test/common/data_type.hpp" | ||
26 | #include "test/common/matmul_test_common.hpp" | ||
27 | #include "test/common/matrix_portion.hpp" | ||
28 | #include "test/common/printer.hpp" | ||
29 | #include "test/common/sme.hpp" | ||
30 | #include "test/reference/cast.hpp" | ||
31 | #include "test/reference/fill.hpp" | ||
32 | #include "test/reference/matmul.hpp" | ||
33 | #include "test/reference/pack.hpp" | ||
34 | |||
35 | // matmul_clamp_f32_bf16p_bf16p | ||
36 | #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h" | ||
37 | #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" | ||
38 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h" | ||
39 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.h" | ||
40 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.h" | ||
41 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.h" | ||
42 | #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h" | ||
43 | |||
44 | // SME files here. | ||
45 | #include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" | ||
46 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" | ||
47 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" | ||
48 | |||
49 | namespace kai::test { | ||
50 | |||
51 | /// List of supported matrix multiplication methods. | ||
52 | namespace { | ||
53 | |||
54 | 1 | static const std::array<MatMulMethod, 5>& get_gemm_methods() { | |
55 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
1 | static std::array<MatMulMethod, 5> gemm_methods{}; |
56 | gemm_methods[0].name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa"; | ||
57 | gemm_methods[0].m0 = 2 * get_sme_vector_length<float>(); | ||
58 | gemm_methods[0].n0 = 2 * get_sme_vector_length<float>(); | ||
59 | gemm_methods[0].k0 = 2; | ||
60 | gemm_methods[0].dst_format = DataFormat(DataType::FP32); | ||
61 | gemm_methods[0].lhs_format = DataFormat(DataType::FP32); | ||
62 | gemm_methods[0].packed_lhs_format = DataFormat( | ||
63 | DataType::BF16, 2 * get_sme_vector_length<float>(), 2, DataFormat::PackFormat::NONE, DataType::FP32, | ||
64 | DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 2); | ||
65 | gemm_methods[0].rhs_format = DataFormat(DataType::FP32); | ||
66 | gemm_methods[0].packed_rhs_format = DataFormat( | ||
67 | DataType::BF16, 2 * get_sme_vector_length<float>(), 2, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, | ||
68 | DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 2); | ||
69 | gemm_methods[0].bias_format = DataFormat(DataType::FP32); | ||
70 | gemm_methods[0].fn_is_supported = cpu_has_sme2; | ||
71 | gemm_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
72 | gemm_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
73 | gemm_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
74 | gemm_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
75 | gemm_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
76 | gemm_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
77 | gemm_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
78 | gemm_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme; | ||
79 | gemm_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme; | ||
80 | gemm_methods[0].fn_get_packed_lhs_offset = | ||
81 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
82 | gemm_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme; | ||
83 | gemm_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
84 | gemm_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
85 | gemm_methods[0].fn_get_main_packed_rhs_offset = | ||
86 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
87 | gemm_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
88 | gemm_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
89 | gemm_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
90 | gemm_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
91 | gemm_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
92 | |||
93 | gemm_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla"; | ||
94 | gemm_methods[1].m0 = 8; | ||
95 | gemm_methods[1].n0 = 12; | ||
96 | gemm_methods[1].k0 = 4; | ||
97 | gemm_methods[1].dst_format = DataFormat(DataType::FP32); | ||
98 | gemm_methods[1].lhs_format = DataFormat(DataType::FP32); | ||
99 | gemm_methods[1].packed_lhs_format = | ||
100 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); | ||
101 | gemm_methods[1].rhs_format = DataFormat(DataType::FP32); | ||
102 | gemm_methods[1].packed_rhs_format = DataFormat( | ||
103 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
104 | gemm_methods[1].bias_format = DataFormat(DataType::FP32); | ||
105 | gemm_methods[1].fn_is_supported = cpu_has_bf16; | ||
106 | gemm_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
107 | gemm_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
108 | gemm_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
109 | gemm_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
110 | gemm_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
111 | gemm_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
112 | gemm_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
113 | gemm_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; | ||
114 | gemm_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; | ||
115 | gemm_methods[1].fn_get_packed_lhs_offset = | ||
116 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
117 | gemm_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; | ||
118 | gemm_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
119 | gemm_methods[1].fn_get_packed_rhs_size_generic_block_size = | ||
120 | kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
121 | gemm_methods[1].fn_get_main_packed_rhs_offset = | ||
122 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
123 | gemm_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
124 | gemm_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
125 | gemm_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
126 | gemm_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
127 | gemm_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
128 | |||
129 | gemm_methods[2].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output"; | ||
130 | gemm_methods[2].m0 = 8; | ||
131 | gemm_methods[2].n0 = 12; | ||
132 | gemm_methods[2].k0 = 4; | ||
133 | gemm_methods[2].dst_format = DataFormat(DataType::FP32); | ||
134 | gemm_methods[2].lhs_format = DataFormat(DataType::FP16); | ||
135 | gemm_methods[2].packed_lhs_format = | ||
136 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); | ||
137 | gemm_methods[2].rhs_format = DataFormat(DataType::FP16); | ||
138 | gemm_methods[2].packed_rhs_format = DataFormat( | ||
139 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
140 | gemm_methods[2].bias_format = DataFormat(DataType::FP32); | ||
141 | gemm_methods[2].fn_is_supported = cpu_has_bf16; | ||
142 | gemm_methods[2].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
143 | gemm_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
144 | gemm_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
145 | gemm_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
146 | gemm_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
147 | gemm_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
148 | gemm_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
149 | gemm_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; | ||
150 | gemm_methods[2].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; | ||
151 | gemm_methods[2].fn_get_packed_lhs_offset = | ||
152 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
153 | gemm_methods[2].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; | ||
154 | gemm_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
155 | gemm_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
156 | gemm_methods[2].fn_get_main_packed_rhs_offset = | ||
157 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
158 | gemm_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
159 | gemm_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
160 | gemm_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
161 | gemm_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
162 | gemm_methods[2].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
163 | |||
164 | gemm_methods[3].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias"; | ||
165 | gemm_methods[3].m0 = 8; | ||
166 | gemm_methods[3].n0 = 12; | ||
167 | gemm_methods[3].k0 = 4; | ||
168 | gemm_methods[3].dst_format = DataFormat(DataType::FP32); | ||
169 | gemm_methods[3].lhs_format = DataFormat(DataType::FP16); | ||
170 | gemm_methods[3].packed_lhs_format = | ||
171 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); | ||
172 | gemm_methods[3].rhs_format = DataFormat(DataType::FP16); | ||
173 | gemm_methods[3].packed_rhs_format = DataFormat( | ||
174 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
175 | gemm_methods[3].bias_format = DataFormat(DataType::UNKNOWN); | ||
176 | gemm_methods[3].fn_is_supported = cpu_has_bf16; | ||
177 | gemm_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
178 | gemm_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
179 | gemm_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
180 | gemm_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
181 | gemm_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
182 | gemm_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
183 | gemm_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
184 | gemm_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; | ||
185 | gemm_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; | ||
186 | gemm_methods[3].fn_get_packed_lhs_offset = | ||
187 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
188 | gemm_methods[3].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; | ||
189 | gemm_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
190 | gemm_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
191 | gemm_methods[3].fn_get_main_packed_rhs_offset = | ||
192 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
193 | gemm_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
194 | gemm_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
195 | gemm_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
196 | gemm_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
197 | gemm_methods[3].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
198 | |||
199 | gemm_methods[4].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias"; | ||
200 | gemm_methods[4].m0 = 8; | ||
201 | gemm_methods[4].n0 = 12; | ||
202 | gemm_methods[4].k0 = 4; | ||
203 | gemm_methods[4].dst_format = DataFormat(DataType::FP32); | ||
204 | gemm_methods[4].lhs_format = DataFormat(DataType::FP32); | ||
205 | gemm_methods[4].packed_lhs_format = | ||
206 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); | ||
207 | gemm_methods[4].rhs_format = DataFormat(DataType::FP32); | ||
208 | gemm_methods[4].packed_rhs_format = DataFormat( | ||
209 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
210 | gemm_methods[4].bias_format = DataFormat(DataType::UNKNOWN); | ||
211 | gemm_methods[4].fn_is_supported = cpu_has_bf16; | ||
212 | gemm_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
213 | gemm_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
214 | gemm_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
215 | gemm_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
216 | gemm_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
217 | gemm_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
218 | gemm_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
219 | gemm_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; | ||
220 | gemm_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; | ||
221 | gemm_methods[4].fn_get_packed_lhs_offset = | ||
222 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
223 | gemm_methods[4].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; | ||
224 | gemm_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
225 | gemm_methods[4].fn_get_packed_rhs_size_generic_block_size = | ||
226 | kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
227 | gemm_methods[4].fn_get_main_packed_rhs_offset = | ||
228 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
229 | gemm_methods[4].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
230 | gemm_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
231 | gemm_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
232 | gemm_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
233 | gemm_methods[4].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
234 | |||
235 | return gemm_methods; | ||
236 | } | ||
237 | |||
238 | 1 | static const std::array<MatMulMethod, 2>& get_gemv_methods() { | |
239 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
1 | static std::array<MatMulMethod, 2> gemv_methods{}; |
240 | gemv_methods[0].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot"; | ||
241 | gemv_methods[0].m0 = 1; | ||
242 | gemv_methods[0].n0 = 12; | ||
243 | gemv_methods[0].k0 = 4; | ||
244 | gemv_methods[0].dst_format = DataFormat(DataType::FP32); | ||
245 | gemv_methods[0].lhs_format = DataFormat(DataType::FP32); | ||
246 | gemv_methods[0].packed_lhs_format = | ||
247 | DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); | ||
248 | gemv_methods[0].rhs_format = DataFormat(DataType::FP32); | ||
249 | gemv_methods[0].packed_rhs_format = DataFormat( | ||
250 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
251 | gemv_methods[0].bias_format = DataFormat(DataType::FP32); | ||
252 | gemv_methods[0].fn_is_supported = cpu_has_bf16; | ||
253 | gemv_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
254 | gemv_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
255 | gemv_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
256 | gemv_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
257 | gemv_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
258 | gemv_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
259 | gemv_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
260 | gemv_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; | ||
261 | gemv_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; | ||
262 | gemv_methods[0].fn_get_packed_lhs_offset = | ||
263 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
264 | gemv_methods[0].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; | ||
265 | gemv_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
266 | gemv_methods[0].fn_get_packed_rhs_size_generic_block_size = | ||
267 | kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
268 | gemv_methods[0].fn_get_main_packed_rhs_offset = | ||
269 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
270 | gemv_methods[0].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
271 | gemv_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
272 | gemv_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
273 | gemv_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
274 | gemv_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
275 | |||
276 | gemv_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias"; | ||
277 | gemv_methods[1].m0 = 1; | ||
278 | gemv_methods[1].n0 = 12; | ||
279 | gemv_methods[1].k0 = 4; | ||
280 | gemv_methods[1].dst_format = DataFormat(DataType::FP32); | ||
281 | gemv_methods[1].lhs_format = DataFormat(DataType::FP32); | ||
282 | gemv_methods[1].packed_lhs_format = | ||
283 | DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); | ||
284 | gemv_methods[1].rhs_format = DataFormat(DataType::FP32); | ||
285 | gemv_methods[1].packed_rhs_format = DataFormat( | ||
286 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
287 | gemv_methods[1].bias_format = DataFormat(DataType::UNKNOWN); | ||
288 | gemv_methods[1].fn_is_supported = cpu_has_bf16; | ||
289 | gemv_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
290 | gemv_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
291 | gemv_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
292 | gemv_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
293 | gemv_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
294 | gemv_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
295 | gemv_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
296 | gemv_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; | ||
297 | gemv_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; | ||
298 | gemv_methods[1].fn_get_packed_lhs_offset = | ||
299 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
300 | gemv_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; | ||
301 | gemv_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
302 | gemv_methods[1].fn_get_packed_rhs_size_generic_block_size = | ||
303 | kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
304 | gemv_methods[1].fn_get_main_packed_rhs_offset = | ||
305 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
306 | gemv_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
307 | gemv_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
308 | gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
309 | gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
310 | gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
311 | |||
312 | return gemv_methods; | ||
313 | } | ||
314 | |||
315 | } // namespace | ||
316 | |||
317 | /// Matrix multiplication test fixture. | ||
318 | class MatMulTestBf16 : public testing::TestWithParam<MatMulTestParams> { | ||
319 | private: | ||
320 | /// Unique ID: m, n, k | ||
321 | using TestDataId = std::tuple<size_t, size_t, size_t, std::string_view>; | ||
322 | |||
323 | protected: | ||
324 | /// Cached test data that is shared between multiple test case. | ||
325 | 74 | struct TestData { | |
326 | 148 | Buffer lhs{}; ///< LHS operand. | |
327 | 148 | Buffer ref_packed_lhs{}; ///< Reference packed LHS. | |
328 | 148 | Buffer rhs{}; ///< RHS operand. | |
329 | 148 | Buffer rhs_scales{}; ///< RHS per-row quantization scales. | |
330 | 148 | Buffer bias{}; ///< Bias. | |
331 | 148 | Buffer ref_packed_rhs{}; ///< Reference packed RHS. | |
332 | 148 | Buffer ref_dst{}; ///< Reference output. | |
333 | }; | ||
334 | |||
335 | /// Gets the test data for the current test case. | ||
336 | 346 | static const TestData& test_data() { | |
337 | 3070 | const auto& [method, info, portion] = GetParam(); | |
338 | 1730 | const TestDataId data_id{info.m, info.n, info.k, method.name}; | |
339 | |||
340 | // If the test data is already available, returns it. | ||
341 | 346 | const auto data_it = _data.find(data_id); | |
342 | |||
343 |
2/2✓ Branch 0 taken 272 times.
✓ Branch 1 taken 74 times.
|
346 | if (data_it != _data.end()) { |
344 | 272 | return data_it->second; | |
345 | } | ||
346 | |||
347 | // Generates the test data. | ||
348 | 148 | const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; | |
349 | 148 | const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; | |
350 | 148 | const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; | |
351 | |||
352 | 148 | const auto lhs_h = info.m; | |
353 | 148 | const auto lhs_w = info.k; | |
354 | 148 | auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); | |
355 | 74 | Buffer ref_packed_lhs; | |
356 | |||
357 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 74 times.
|
74 | if (has_lhs_pack) { |
358 | 74 | ref_packed_lhs = | |
359 |
3/6✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 74 times.
✗ Branch 5 not taken.
|
148 | pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); |
360 | 74 | } | |
361 | |||
362 | 148 | const auto rhs_h = info.k; | |
363 | 148 | const auto rhs_w = info.n; | |
364 |
2/4✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
|
148 | auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); |
365 | |||
366 | 74 | Buffer rhs_scales; | |
367 |
3/8✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 74 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
74 | if (data_type_is_quantized(method.rhs_format.data_type()) && |
368 | ✗ | method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { | |
369 | ✗ | rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); | |
370 | ✗ | } | |
371 | |||
372 | 74 | const auto bias_h = 1; | |
373 | 148 | const auto bias_w = info.n; | |
374 | 74 | Buffer bias; | |
375 | |||
376 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 42 times.
|
74 | if (has_bias) { |
377 |
2/4✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✗ Branch 3 not taken.
|
84 | bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3); |
378 | 42 | } | |
379 | |||
380 | 74 | constexpr size_t nr = 12; | |
381 | 74 | constexpr size_t kr = 4; | |
382 | |||
383 | 74 | size_t packed_rhs_size = 0; | |
384 | |||
385 |
2/2✓ Branch 0 taken 30 times.
✓ Branch 1 taken 44 times.
|
74 | if (method.fn_get_packed_rhs_size) { |
386 |
1/2✓ Branch 0 taken 30 times.
✗ Branch 1 not taken.
|
30 | packed_rhs_size = method.fn_get_packed_rhs_size(rhs_w, rhs_h); |
387 |
1/2✓ Branch 0 taken 44 times.
✗ Branch 1 not taken.
|
74 | } else if (method.fn_get_packed_rhs_size_generic_block_size) { |
388 |
1/2✓ Branch 0 taken 44 times.
✗ Branch 1 not taken.
|
44 | packed_rhs_size = method.fn_get_packed_rhs_size_generic_block_size(rhs_w, rhs_h, nr, kr); |
389 | 44 | } else { | |
390 | − | KAI_ERROR("No function to calculate Packed Rhs Matrix Size"); | |
391 | } | ||
392 | |||
393 |
1/2✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
|
74 | Buffer packed_rhs(packed_rhs_size); |
394 | |||
395 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 74 times.
|
74 | if (has_rhs_pack) { |
396 |
2/4✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
|
148 | const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); |
397 |
1/2✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
|
74 | method.pack_rhs( |
398 |
5/8✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 42 times.
✓ Branch 5 taken 32 times.
✓ Branch 6 taken 42 times.
✗ Branch 7 not taken.
|
148 | info.n, info.k, rhs.data(), ref_rhs_row_stride, has_bias ? bias.data() : nullptr, nullptr, |
399 |
1/2✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
|
74 | packed_rhs.data()); |
400 | 74 | } | |
401 | |||
402 | − | KAI_ASSUME(method.lhs_format.is_raw()); | |
403 | − | KAI_ASSUME(method.rhs_format.is_raw()); | |
404 | − | KAI_ASSUME(method.dst_format.is_raw()); | |
405 | |||
406 | 74 | Buffer tmp_lhs; | |
407 | 74 | Buffer tmp_rhs; | |
408 |
1/2✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
|
74 | const void* p_lhs_buff = lhs.data(); |
409 |
1/2✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
|
74 | const void* p_rhs_buff = rhs.data(); |
410 | |||
411 |
5/8✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✓ Branch 3 taken 54 times.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
|
74 | if (method.lhs_format.data_type() == DataType::FP32 || method.lhs_format.data_type() == DataType::FP16) { |
412 |
3/6✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 74 times.
✗ Branch 5 not taken.
|
148 | tmp_lhs = cast(p_lhs_buff, method.lhs_format.data_type(), DataType::BF16, lhs_h, lhs_w); |
413 |
1/2✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
|
74 | p_lhs_buff = tmp_lhs.data(); |
414 | 74 | } | |
415 |
5/8✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✓ Branch 3 taken 54 times.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
|
74 | if (method.rhs_format.data_type() == DataType::FP32 || method.rhs_format.data_type() == DataType::FP16) { |
416 |
3/6✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 74 times.
✗ Branch 5 not taken.
|
148 | tmp_rhs = cast(p_rhs_buff, method.rhs_format.data_type(), DataType::BF16, rhs_h, rhs_w); |
417 |
1/2✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
|
74 | p_rhs_buff = tmp_rhs.data(); |
418 | 74 | } | |
419 | |||
420 | 74 | auto ref_dst = | |
421 |
1/2✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
|
74 | matmul_nt_nt_quantized<BFloat16<>, float, float, BFloat16<>, float, float, float, float, float, float>( |
422 | 296 | info.m, info.n, info.k, p_lhs_buff, nullptr, nullptr, 1, info.k, p_rhs_buff, nullptr, nullptr, 1, | |
423 |
1/2✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
|
74 | info.k, bias.data(), nullptr, nullptr, info.k); |
424 | |||
425 |
8/16✓ Branch 0 taken 74 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 74 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 74 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 74 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 74 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 74 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 74 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 74 times.
✗ Branch 15 not taken.
|
592 | auto& data = _data[data_id] = {}; |
426 | 74 | data.lhs = std::move(lhs); | |
427 | 74 | data.ref_packed_lhs = std::move(ref_packed_lhs); | |
428 | 74 | data.rhs = std::move(rhs); | |
429 | 74 | data.rhs_scales = std::move(rhs_scales); | |
430 | 74 | data.bias = std::move(bias); | |
431 | 74 | data.ref_packed_rhs = std::move(packed_rhs); | |
432 | 74 | data.ref_dst = std::move(ref_dst); | |
433 | |||
434 | 74 | return data; | |
435 | 346 | } | |
436 | |||
437 | private: | ||
438 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
439 | static std::map<TestDataId, TestData> _data; | ||
440 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
441 | }; | ||
442 | |||
443 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
444 | 1 | std::map<MatMulTestBf16::TestDataId, MatMulTestBf16::TestData> MatMulTestBf16::_data; | |
445 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
446 | |||
447 | /// Tests the output. | ||
448 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1040 | TEST_P(MatMulTestBf16, Output) { |
449 | 11252 | const auto& [method, info, portion] = GetParam(); | |
450 | |||
451 |
2/4✓ Branch 0 taken 346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 346 times.
✗ Branch 3 not taken.
|
346 | if (method.fn_is_supported && !method.fn_is_supported()) { |
452 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
453 | } | ||
454 | |||
455 |
1/2✓ Branch 0 taken 346 times.
✗ Branch 1 not taken.
|
346 | if (!method.has_main_kernel()) { |
456 | ✗ | GTEST_SKIP() << "No main kernel available"; | |
457 | } | ||
458 | |||
459 | 346 | const auto& data = test_data(); | |
460 | 692 | const auto m_step = method.fn_get_main_m_step(); | |
461 |
4/16✓ Branch 0 taken 346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 346 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 346 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 346 times.
|
692 | ASSERT_TRUE(m_step % method.m0 == 0); |
462 | |||
463 | 692 | const auto n_step = method.fn_get_main_n_step(); | |
464 |
4/16✓ Branch 0 taken 346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 346 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 346 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 346 times.
|
692 | ASSERT_TRUE(n_step % method.n0 == 0); |
465 | |||
466 | 1038 | const auto rect = portion.compute_portion(info.m, info.n, m_step, n_step); | |
467 | |||
468 |
4/4✓ Branch 0 taken 331 times.
✓ Branch 1 taken 15 times.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 322 times.
|
346 | if (rect.height() == 0 || rect.width() == 0) { |
469 |
9/18✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 24 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 24 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 24 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 24 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 24 times.
✗ Branch 17 not taken.
|
24 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
470 | } | ||
471 | |||
472 | 644 | const size_t lhs_w = info.k; | |
473 | 322 | const size_t rhs_w = rect.width(); | |
474 | 644 | const size_t bias_w = info.n; | |
475 | 644 | const size_t dst_w = info.n; | |
476 | 322 | const bool has_bias = (data.bias.size() > 0); | |
477 | |||
478 | 322 | const auto lhs_start_row = rect.start_row(); | |
479 | 644 | const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); | |
480 | |||
481 | 1932 | const size_t lhs_packed_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, method.k0, 1 /* sr */); | |
482 | 322 | Buffer lhs_data(lhs_packed_size); | |
483 | |||
484 |
2/4✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
|
644 | uintptr_t lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); |
485 |
3/6✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
|
966 | uintptr_t lhs_packed_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); |
486 | |||
487 | KAI_UNUSED(lhs_offset); | ||
488 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
644 | method.fn_pack_lhs( |
489 |
4/8✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
|
322 | rect.height(), info.k, method.m0, method.k0, 1 /* sr */, 0 /* m_idx_start */, data.lhs.data() + lhs_offset, |
490 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
322 | lhs_stride, lhs_data.data() + lhs_packed_offset); |
491 | |||
492 |
3/6✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
|
966 | const auto rhs_stride = method.rhs_format.default_row_stride(info.n); |
493 | |||
494 | 322 | size_t rhs_packed_size = 0; | |
495 | |||
496 |
2/2✓ Branch 0 taken 184 times.
✓ Branch 1 taken 138 times.
|
322 | if (method.fn_get_packed_rhs_size_generic_block_size) { |
497 |
5/10✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 184 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 184 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 184 times.
✗ Branch 9 not taken.
|
920 | rhs_packed_size = method.fn_get_packed_rhs_size_generic_block_size(info.n, info.k, method.n0, method.k0); |
498 |
1/2✓ Branch 0 taken 138 times.
✗ Branch 1 not taken.
|
322 | } else if (method.fn_get_packed_rhs_size) { |
499 |
3/6✓ Branch 0 taken 138 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 138 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 138 times.
✗ Branch 5 not taken.
|
414 | rhs_packed_size = method.fn_get_packed_rhs_size(info.n, info.k); |
500 | 138 | } | |
501 | |||
502 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
322 | Buffer rhs_data(rhs_packed_size); |
503 | |||
504 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
322 | const auto packed_rhs_start_row = rect.start_col(); |
505 | 322 | const auto packed_rhs_start_col = 0; | |
506 | |||
507 |
3/6✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
|
644 | uintptr_t rhs_offset = method.fn_get_rhs_offset(rect.start_col()); |
508 |
3/6✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
|
966 | uintptr_t rhs_packed_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); |
509 | 322 | const auto ref_rhs_packed_offset = | |
510 |
2/4✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
|
644 | method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); |
511 | |||
512 |
4/16✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 322 times.
|
322 | ASSERT_EQ(rhs_packed_offset, ref_rhs_packed_offset); |
513 | |||
514 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
322 | uintptr_t bias_offset = sizeof(float) * rect.start_col(); |
515 | |||
516 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
644 | method.fn_pack_rhs( |
517 | 1, // num_groups | ||
518 | 1288 | rhs_w, info.k, method.n0, method.k0, | |
519 | 1, // sr | ||
520 |
4/6✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✓ Branch 3 taken 138 times.
✓ Branch 4 taken 184 times.
✗ Branch 5 not taken.
|
322 | rhs_stride, data.rhs.data() + rhs_offset, has_bias ? data.bias.data() + bias_offset : nullptr, |
521 | NULL, // Scale | ||
522 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
322 | rhs_data.data() + rhs_packed_offset, 0, NULL); |
523 | |||
524 |
2/2✓ Branch 0 taken 138 times.
✓ Branch 1 taken 184 times.
|
322 | if (has_bias) { |
525 |
3/6✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 184 times.
✗ Branch 5 not taken.
|
368 | const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_col(), bias_w); |
526 |
4/16✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 184 times.
|
184 | ASSERT_EQ(ref_bias_offset, bias_offset); |
527 | 184 | } | |
528 | |||
529 |
2/4✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
|
644 | const auto dst_stride = method.dst_format.default_row_stride(dst_w); |
530 |
4/8✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
|
644 | const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
531 |
4/8✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
|
644 | const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); |
532 |
4/16✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 322 times.
|
322 | ASSERT_EQ(dst_offset, ref_dst_offset); |
533 | |||
534 |
4/8✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
|
1288 | const auto dst_size = method.fn_get_dst_size(info.m, info.n); |
535 |
4/8✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
|
1288 | const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n); |
536 |
4/16✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 322 times.
|
322 | ASSERT_EQ(dst_size, ref_dst_size); |
537 | |||
538 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
322 | Buffer dst(dst_size); |
539 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
322 | method.main_kernel( |
540 |
4/8✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
|
322 | rect.height(), rect.width(), info.k, lhs_data.data() + lhs_packed_offset, rhs_data.data() + rhs_packed_offset, |
541 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
322 | NULL, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, -std::numeric_limits<float>::infinity(), |
542 | 322 | std::numeric_limits<float>::infinity()); | |
543 | |||
544 |
1/2✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
|
322 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
545 |
5/10✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 322 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 322 times.
✗ Branch 9 not taken.
|
322 | const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); |
546 | |||
547 |
4/16✓ Branch 0 taken 322 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 322 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 322 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 322 times.
|
322 | ASSERT_TRUE(success); |
548 | 346 | } | |
549 | |||
550 |
15/48✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 250 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
|
252 | INSTANTIATE_TEST_SUITE_P( |
551 | MatMulGemm, MatMulTestBf16, | ||
552 | testing::Combine( | ||
553 | testing::ValuesIn(get_gemm_methods()), | ||
554 | testing::Values( | ||
555 | MatMulShape{1, 1, 1}, // Smallest Possible Shape | ||
556 | MatMulShape{3, 7, 3}, // Smaller than block size | ||
557 | MatMulShape{12, 8, 4}, // Same block size | ||
558 | MatMulShape{1, 1, 1023}, // Long K | ||
559 | MatMulShape{1013, 1, 5}, // Long M | ||
560 | MatMulShape{2, 1013, 6}, // Long N | ||
561 | MatMulShape{13, 33, 23}, // | ||
562 | MatMulShape{93, 57, 89}, // | ||
563 | MatMulShape{256, 256, 256}, // Nice shapes | ||
564 | MatMulShape{257, 113, 373} // Prime numbers | ||
565 | ), | ||
566 | testing::Values( | ||
567 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
568 | MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. | ||
569 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
570 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
571 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
572 | )), | ||
573 | testing::PrintToStringParamName()); | ||
574 | |||
575 |
14/44✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 96 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
|
98 | INSTANTIATE_TEST_SUITE_P( |
576 | MatMulGemv, MatMulTestBf16, | ||
577 | testing::Combine( | ||
578 | testing::ValuesIn(get_gemv_methods()), | ||
579 | testing::Values( | ||
580 | MatMulShape{1, 1, 1}, // Smallest Possible Shape | ||
581 | MatMulShape{1, 1, 1023}, // Long K | ||
582 | MatMulShape{1, 1023, 1}, // Long N | ||
583 | MatMulShape{1, 1013, 1023}, // Large Rhs | ||
584 | MatMulShape{1, 37, 23}, // | ||
585 | MatMulShape{1, 57, 89}, // | ||
586 | MatMulShape{1, 36, 89}, // | ||
587 | MatMulShape{1, 98, 23}, // | ||
588 | MatMulShape{1, 64, 1024}, // Nice shapes - Long Rhs Rect | ||
589 | MatMulShape{1, 1024, 64}, // Nice shapes - Wide Rhs Rect | ||
590 | MatMulShape{1, 256, 256}, // Nice shapes - Square | ||
591 | MatMulShape{1, 113, 373} // Prime numbers | ||
592 | ), | ||
593 | testing::Values( | ||
594 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
595 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
596 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
597 | MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle | ||
598 | )), | ||
599 | testing::PrintToStringParamName()); | ||
600 | } // namespace kai::test | ||
601 |