Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/matmul.hpp" | ||
8 | |||
9 | #include <gtest/gtest.h> | ||
10 | |||
11 | #include <array> | ||
12 | #include <cstddef> | ||
13 | #include <cstdint> | ||
14 | #include <functional> | ||
15 | #include <map> | ||
16 | #include <string_view> | ||
17 | #include <tuple> | ||
18 | #include <utility> | ||
19 | |||
20 | #include "kai/kai_common.h" | ||
21 | #include "test/common/buffer.hpp" | ||
22 | #include "test/common/compare.hpp" | ||
23 | #include "test/common/cpu_info.hpp" | ||
24 | #include "test/common/data_format.hpp" | ||
25 | #include "test/common/data_type.hpp" | ||
26 | #include "test/common/float16.hpp" | ||
27 | #include "test/common/matmul_test_common.hpp" | ||
28 | #include "test/common/matrix_portion.hpp" | ||
29 | #include "test/common/sme.hpp" | ||
30 | #include "test/reference/clamp.hpp" | ||
31 | #include "test/reference/fill.hpp" | ||
32 | #include "test/reference/pack.hpp" | ||
33 | #include "test/reference/transpose.hpp" | ||
34 | |||
35 | // matmul_clamp_f16_f16_f16p | ||
36 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" | ||
37 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h" | ||
38 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla.h" | ||
39 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla.h" | ||
40 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55.h" | ||
41 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" | ||
42 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p32x1b_x16_x16_neon.h" | ||
43 | |||
44 | // matmul_clamp_f16_f16p_f16p | ||
45 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" | ||
46 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h" | ||
47 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h" | ||
48 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h" | ||
49 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h" | ||
50 | |||
51 | // matmul_clamp_f32_f32_f32p | ||
52 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" | ||
53 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla.h" | ||
54 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55.h" | ||
55 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" | ||
56 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h" | ||
57 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" | ||
58 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h" | ||
59 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" | ||
60 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x32p16x1b_x32_x32_neon.h" | ||
61 | |||
62 | // matmul_clamp_f32_f32p_f32p | ||
63 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" | ||
64 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" | ||
65 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" | ||
66 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" | ||
67 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" | ||
68 | |||
69 | namespace kai::test { | ||
70 | |||
71 | 4 | static const auto& get_matmul_methods() { | |
72 | // List of supported matrix multiplication methods. | ||
73 |
3/4✓ Branch 0 taken 1 time.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
4 | static std::array<MatMulMethod, 10> matmul_methods{}; |
74 | |||
75 | matmul_methods[0].name = "matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla"; | ||
76 | matmul_methods[0].m0 = 6; | ||
77 | matmul_methods[0].n0 = 16; | ||
78 | matmul_methods[0].dst_format = DataFormat(DataType::FP16); | ||
79 | matmul_methods[0].lhs_format = DataFormat(DataType::FP16); | ||
80 | matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
81 | matmul_methods[0].rhs_format = DataFormat(DataType::FP16); | ||
82 | matmul_methods[0].packed_rhs_format = DataFormat( | ||
83 | DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1); | ||
84 | matmul_methods[0].bias_format = DataFormat(DataType::FP16); | ||
85 | matmul_methods[0].fn_is_supported = cpu_has_fp16; | ||
86 | matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
87 | matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
88 | matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
89 | matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
90 | matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
91 | matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
92 | matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
93 | matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
94 | matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
95 | matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = | ||
96 | kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
97 | matmul_methods[0].fn_get_main_packed_rhs_offset = | ||
98 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
99 | matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
100 | matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
101 | matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
102 | matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
103 | matmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
104 | |||
105 | matmul_methods[1].name = "matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa"; | ||
106 | matmul_methods[1].m0 = 2 * get_sme_vector_length<float>(); | ||
107 | matmul_methods[1].n0 = 2 * get_sme_vector_length<float>(); | ||
108 | matmul_methods[1].dst_format = DataFormat(DataType::FP16); | ||
109 | matmul_methods[1].lhs_format = DataFormat(DataType::FP16); | ||
110 | matmul_methods[1].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length<float>(), 2); | ||
111 | matmul_methods[1].rhs_format = DataFormat(DataType::FP16); | ||
112 | matmul_methods[1].packed_rhs_format = DataFormat( | ||
113 | DataType::FP16, // Output type | ||
114 | 2 * get_sme_vector_length<float>(), 2, // Block size | ||
115 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
116 | DataType::FP16, // Bias format | ||
117 | DataType::UNKNOWN, // Scaling type | ||
118 | 2 * get_sme_vector_length<float>(), 2); // Sub-block | ||
119 | matmul_methods[1].bias_format = DataFormat(DataType::FP16); | ||
120 | matmul_methods[1].fn_is_supported = cpu_has_sme2; | ||
121 | matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
122 | matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
123 | matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
124 | matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
125 | matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
126 | matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
127 | matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
128 | matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; | ||
129 | matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme; | ||
130 | matmul_methods[1].fn_get_packed_lhs_offset = | ||
131 | kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
132 | matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme; | ||
133 | matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
134 | matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
135 | matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
136 | matmul_methods[1].fn_get_main_packed_rhs_offset = | ||
137 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
138 | matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
139 | matmul_methods[1].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
140 | matmul_methods[1].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
141 | matmul_methods[1].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
142 | matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_offset = | ||
143 | kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
144 | matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
145 | matmul_methods[1].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
146 | matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
147 | matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
148 | matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
149 | matmul_methods[1].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
150 | |||
151 | matmul_methods[2].name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla"; | ||
152 | matmul_methods[2].m0 = 6; | ||
153 | matmul_methods[2].n0 = 8; | ||
154 | matmul_methods[2].dst_format = DataFormat(DataType::FP32); | ||
155 | matmul_methods[2].lhs_format = DataFormat(DataType::FP32); | ||
156 | matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
157 | matmul_methods[2].rhs_format = DataFormat(DataType::FP32); | ||
158 | matmul_methods[2].packed_rhs_format = | ||
159 | DataFormat(DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1); | ||
160 | matmul_methods[2].bias_format = DataFormat(DataType::FP32); | ||
161 | matmul_methods[2].fn_is_supported = cpu_has_advsimd; | ||
162 | matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
163 | matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
164 | matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
165 | matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
166 | matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
167 | matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
168 | matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
169 | matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
170 | matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
171 | matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = | ||
172 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
173 | matmul_methods[2].fn_get_main_packed_rhs_offset = | ||
174 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
175 | matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
176 | matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
177 | matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
178 | matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
179 | matmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
180 | |||
181 | matmul_methods[3].name = "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"; | ||
182 | matmul_methods[3].m0 = 2 * get_sme_vector_length<float>(); | ||
183 | matmul_methods[3].n0 = 2 * get_sme_vector_length<float>(); | ||
184 | matmul_methods[3].dst_format = DataFormat(DataType::FP32); | ||
185 | matmul_methods[3].lhs_format = DataFormat(DataType::FP32); | ||
186 | matmul_methods[3].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length<float>(), 1); | ||
187 | matmul_methods[3].rhs_format = DataFormat(DataType::FP32); | ||
188 | matmul_methods[3].packed_rhs_format = DataFormat( | ||
189 | DataType::FP32, 2 * get_sme_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, | ||
190 | DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 1); | ||
191 | matmul_methods[3].bias_format = DataFormat(DataType::FP32); | ||
192 | matmul_methods[3].fn_is_supported = cpu_has_sme2; | ||
193 | matmul_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
194 | matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
195 | matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
196 | matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
197 | matmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
198 | matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
199 | matmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
200 | matmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme; | ||
201 | matmul_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme; | ||
202 | matmul_methods[3].fn_get_packed_lhs_offset = | ||
203 | kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
204 | matmul_methods[3].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme; | ||
205 | matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
206 | matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
207 | matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset = | ||
208 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
209 | matmul_methods[3].fn_get_main_packed_rhs_offset = | ||
210 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
211 | matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
212 | matmul_methods[3].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
213 | matmul_methods[3].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
214 | matmul_methods[3].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
215 | matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_offset = | ||
216 | kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
217 | matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_size = | ||
218 | kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
219 | matmul_methods[3].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
220 | matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
221 | matmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
222 | matmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
223 | matmul_methods[3].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
224 | |||
225 | matmul_methods[4].name = "matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa"; | ||
226 | matmul_methods[4].m0 = 2 * get_sme_vector_length<float>(); | ||
227 | matmul_methods[4].n0 = 2 * get_sme_vector_length<float>(); | ||
228 | matmul_methods[4].dst_format = DataFormat(DataType::FP32); | ||
229 | matmul_methods[4].lhs_format = DataFormat(DataType::FP32); | ||
230 | matmul_methods[4].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length<float>(), 1); | ||
231 | matmul_methods[4].rhs_format = DataFormat(DataType::FP32); | ||
232 | matmul_methods[4].packed_rhs_format = DataFormat( | ||
233 | DataType::FP32, 2 * get_sme_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, | ||
234 | DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 1); | ||
235 | matmul_methods[4].bias_format = DataFormat(DataType::FP32); | ||
236 | matmul_methods[4].fn_is_supported = cpu_has_sme; | ||
237 | matmul_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
238 | matmul_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
239 | matmul_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
240 | matmul_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
241 | matmul_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
242 | matmul_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
243 | matmul_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
244 | matmul_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme; | ||
245 | matmul_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme; | ||
246 | matmul_methods[4].fn_get_packed_lhs_offset = | ||
247 | kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
248 | matmul_methods[4].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme; | ||
249 | matmul_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
250 | matmul_methods[4].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
251 | matmul_methods[4].fn_get_pack_rhs_packed_rhs_offset = | ||
252 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
253 | matmul_methods[4].fn_get_main_packed_rhs_offset = | ||
254 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
255 | matmul_methods[4].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
256 | matmul_methods[4].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
257 | matmul_methods[4].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
258 | matmul_methods[4].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
259 | matmul_methods[4].fn_pack_rhs_nxk_get_packed_rhs_offset = | ||
260 | kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
261 | matmul_methods[4].fn_pack_rhs_nxk_get_packed_rhs_size = | ||
262 | kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
263 | matmul_methods[4].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
264 | matmul_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
265 | matmul_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
266 | matmul_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
267 | matmul_methods[4].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
268 | |||
269 | matmul_methods[5].name = "matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa"; | ||
270 | matmul_methods[5].m0 = 2 * get_sme_vector_length<float>(); | ||
271 | matmul_methods[5].n0 = 2 * get_sme_vector_length<float>(); | ||
272 | matmul_methods[5].dst_format = DataFormat(DataType::FP16); | ||
273 | matmul_methods[5].lhs_format = DataFormat(DataType::FP16); | ||
274 | matmul_methods[5].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length<float>(), 2); | ||
275 | matmul_methods[5].rhs_format = DataFormat(DataType::FP16); | ||
276 | matmul_methods[5].packed_rhs_format = DataFormat( | ||
277 | DataType::FP16, // Output type | ||
278 | 2 * get_sme_vector_length<float>(), 2, // Block size | ||
279 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
280 | DataType::FP16, // Bias format | ||
281 | DataType::UNKNOWN, // Scaling type | ||
282 | 2 * get_sme_vector_length<float>(), 2); // Sub-block | ||
283 | matmul_methods[5].bias_format = DataFormat(DataType::FP16); | ||
284 | matmul_methods[5].fn_is_supported = cpu_has_sme; | ||
285 | matmul_methods[5].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
286 | matmul_methods[5].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
287 | matmul_methods[5].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
288 | matmul_methods[5].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
289 | matmul_methods[5].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
290 | matmul_methods[5].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
291 | matmul_methods[5].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
292 | matmul_methods[5].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; | ||
293 | matmul_methods[5].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme; | ||
294 | matmul_methods[5].fn_get_packed_lhs_offset = | ||
295 | kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
296 | matmul_methods[5].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme; | ||
297 | matmul_methods[5].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
298 | matmul_methods[5].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
299 | matmul_methods[5].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
300 | matmul_methods[5].fn_get_main_packed_rhs_offset = | ||
301 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
302 | matmul_methods[5].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
303 | matmul_methods[5].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
304 | matmul_methods[5].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
305 | matmul_methods[5].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
306 | matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_offset = | ||
307 | kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
308 | matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
309 | matmul_methods[5].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
310 | matmul_methods[5].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
311 | matmul_methods[5].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
312 | matmul_methods[5].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
313 | matmul_methods[5].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
314 | |||
315 | matmul_methods[6].name = "matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla"; | ||
316 | matmul_methods[6].m0 = 6; | ||
317 | matmul_methods[6].n0 = 32; | ||
318 | matmul_methods[6].dst_format = DataFormat(DataType::FP16); | ||
319 | matmul_methods[6].lhs_format = DataFormat(DataType::FP16); | ||
320 | matmul_methods[6].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
321 | matmul_methods[6].rhs_format = DataFormat(DataType::FP16); | ||
322 | matmul_methods[6].packed_rhs_format = DataFormat( | ||
323 | DataType::FP16, 32, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 32, 1); | ||
324 | matmul_methods[6].bias_format = DataFormat(DataType::FP16); | ||
325 | matmul_methods[6].fn_is_supported = cpu_has_fp16; | ||
326 | matmul_methods[6].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
327 | matmul_methods[6].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
328 | matmul_methods[6].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
329 | matmul_methods[6].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
330 | matmul_methods[6].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
331 | matmul_methods[6].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
332 | matmul_methods[6].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
333 | matmul_methods[6].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
334 | matmul_methods[6].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
335 | matmul_methods[6].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
336 | matmul_methods[6].fn_get_main_packed_rhs_offset = | ||
337 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
338 | matmul_methods[6].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
339 | matmul_methods[6].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
340 | matmul_methods[6].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
341 | matmul_methods[6].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
342 | matmul_methods[6].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
343 | |||
344 | matmul_methods[7].name = "matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55"; | ||
345 | matmul_methods[7].m0 = 6; | ||
346 | matmul_methods[7].n0 = 32; | ||
347 | matmul_methods[7].dst_format = DataFormat(DataType::FP16); | ||
348 | matmul_methods[7].lhs_format = DataFormat(DataType::FP16); | ||
349 | matmul_methods[7].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
350 | matmul_methods[7].rhs_format = DataFormat(DataType::FP16); | ||
351 | matmul_methods[7].packed_rhs_format = DataFormat( | ||
352 | DataType::FP16, 32, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 32, 1); | ||
353 | matmul_methods[7].bias_format = DataFormat(DataType::FP16); | ||
354 | matmul_methods[7].fn_is_supported = cpu_has_fp16; | ||
355 | matmul_methods[7].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
356 | matmul_methods[7].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
357 | matmul_methods[7].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
358 | matmul_methods[7].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
359 | matmul_methods[7].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
360 | matmul_methods[7].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
361 | matmul_methods[7].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
362 | matmul_methods[7].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
363 | matmul_methods[7].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
364 | matmul_methods[7].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
365 | matmul_methods[7].fn_get_main_packed_rhs_offset = | ||
366 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
367 | matmul_methods[7].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
368 | matmul_methods[7].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
369 | matmul_methods[7].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
370 | matmul_methods[7].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
371 | matmul_methods[7].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
372 | |||
373 | matmul_methods[8].name = "matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla"; | ||
374 | matmul_methods[8].m0 = 6; | ||
375 | matmul_methods[8].n0 = 16; | ||
376 | matmul_methods[8].dst_format = DataFormat(DataType::FP32); | ||
377 | matmul_methods[8].lhs_format = DataFormat(DataType::FP32); | ||
378 | matmul_methods[8].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
379 | matmul_methods[8].rhs_format = DataFormat(DataType::FP32); | ||
380 | matmul_methods[8].packed_rhs_format = DataFormat( | ||
381 | DataType::FP32, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 16, 1); | ||
382 | matmul_methods[8].bias_format = DataFormat(DataType::FP32); | ||
383 | matmul_methods[8].fn_is_supported = cpu_has_advsimd; | ||
384 | matmul_methods[8].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
385 | matmul_methods[8].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
386 | matmul_methods[8].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
387 | matmul_methods[8].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
388 | matmul_methods[8].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
389 | matmul_methods[8].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
390 | matmul_methods[8].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
391 | matmul_methods[8].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
392 | matmul_methods[8].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
393 | matmul_methods[8].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
394 | matmul_methods[8].fn_get_main_packed_rhs_offset = | ||
395 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
396 | matmul_methods[8].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
397 | matmul_methods[8].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
398 | matmul_methods[8].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
399 | matmul_methods[8].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
400 | matmul_methods[8].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
401 | |||
402 | matmul_methods[9].name = "matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55"; | ||
403 | matmul_methods[9].m0 = 6; | ||
404 | matmul_methods[9].n0 = 16; | ||
405 | matmul_methods[9].dst_format = DataFormat(DataType::FP32); | ||
406 | matmul_methods[9].lhs_format = DataFormat(DataType::FP32); | ||
407 | matmul_methods[9].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
408 | matmul_methods[9].rhs_format = DataFormat(DataType::FP32); | ||
409 | matmul_methods[9].packed_rhs_format = DataFormat( | ||
410 | DataType::FP32, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 16, 1); | ||
411 | matmul_methods[9].bias_format = DataFormat(DataType::FP32); | ||
412 | matmul_methods[9].fn_is_supported = cpu_has_advsimd; | ||
413 | matmul_methods[9].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
414 | matmul_methods[9].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
415 | matmul_methods[9].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
416 | matmul_methods[9].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
417 | matmul_methods[9].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
418 | matmul_methods[9].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
419 | matmul_methods[9].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
420 | matmul_methods[9].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
421 | matmul_methods[9].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
422 | matmul_methods[9].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
423 | matmul_methods[9].fn_get_main_packed_rhs_offset = | ||
424 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
425 | matmul_methods[9].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
426 | matmul_methods[9].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
427 | matmul_methods[9].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
428 | matmul_methods[9].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
429 | matmul_methods[9].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
430 | |||
431 | return matmul_methods; | ||
432 | } | ||
433 | |||
434 | 4 | static const auto& get_vecmul_methods() { | |
435 | // List of supported vector by matrix multiplication methods | ||
436 |
3/4✓ Branch 0 taken 1 time.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
4 | static std::array<MatMulMethod, 5> vecmul_methods{}; |
437 | |||
438 | vecmul_methods[0].name = "matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot"; | ||
439 | vecmul_methods[0].m0 = 1; | ||
440 | vecmul_methods[0].n0 = 16 * get_sme_vector_length<float>(); | ||
441 | vecmul_methods[0].dst_format = DataFormat(DataType::FP16); | ||
442 | vecmul_methods[0].lhs_format = DataFormat(DataType::FP16); | ||
443 | vecmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
444 | vecmul_methods[0].rhs_format = DataFormat(DataType::FP16); | ||
445 | vecmul_methods[0].packed_rhs_format = DataFormat( | ||
446 | DataType::FP16, // Output type | ||
447 | 2 * get_sme_vector_length<float>(), 2, // Block size | ||
448 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
449 | DataType::FP16, // Bias format | ||
450 | DataType::UNKNOWN, // Scaling type | ||
451 | 2 * get_sme_vector_length<float>(), 2); // Sub-block | ||
452 | vecmul_methods[0].bias_format = DataFormat(DataType::FP16); | ||
453 | vecmul_methods[0].fn_is_supported = cpu_has_sme2; | ||
454 | vecmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
455 | vecmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
456 | vecmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
457 | vecmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
458 | vecmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
459 | vecmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
460 | vecmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; | ||
461 | vecmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
462 | vecmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
463 | vecmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
464 | vecmul_methods[0].fn_get_main_packed_rhs_offset = | ||
465 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
466 | vecmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
467 | vecmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
468 | vecmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
469 | vecmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
470 | vecmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
471 | |||
472 | vecmul_methods[1].name = "matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla"; | ||
473 | vecmul_methods[1].m0 = 1; | ||
474 | vecmul_methods[1].n0 = 8 * get_sme_vector_length<float>(); | ||
475 | vecmul_methods[1].dst_format = DataFormat(DataType::FP16); | ||
476 | vecmul_methods[1].lhs_format = DataFormat(DataType::FP16); | ||
477 | vecmul_methods[1].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
478 | vecmul_methods[1].rhs_format = DataFormat(DataType::FP16); | ||
479 | vecmul_methods[1].packed_rhs_format = DataFormat( | ||
480 | DataType::FP16, // Output type | ||
481 | 2 * get_sme_vector_length<float>(), 2, // Block size | ||
482 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
483 | DataType::FP16, // Bias format | ||
484 | DataType::UNKNOWN, // Scaling type | ||
485 | 2 * get_sme_vector_length<float>(), 2); // Sub-block | ||
486 | vecmul_methods[1].bias_format = DataFormat(DataType::FP16); | ||
487 | vecmul_methods[1].fn_is_supported = cpu_has_sme; | ||
488 | vecmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
489 | vecmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
490 | vecmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
491 | vecmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
492 | vecmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
493 | vecmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
494 | vecmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; | ||
495 | vecmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
496 | vecmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
497 | vecmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
498 | vecmul_methods[1].fn_get_main_packed_rhs_offset = | ||
499 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
500 | vecmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
501 | vecmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
502 | vecmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
503 | vecmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
504 | vecmul_methods[1].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
505 | |||
506 | vecmul_methods[2].name = "matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla"; | ||
507 | vecmul_methods[2].m0 = 1; | ||
508 | vecmul_methods[2].n0 = 8 * get_sme_vector_length<float>(); | ||
509 | vecmul_methods[2].dst_format = DataFormat(DataType::FP32); | ||
510 | vecmul_methods[2].lhs_format = DataFormat(DataType::FP32); | ||
511 | vecmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
512 | vecmul_methods[2].rhs_format = DataFormat(DataType::FP32); | ||
513 | vecmul_methods[2].packed_rhs_format = DataFormat( | ||
514 | DataType::FP32, // Output type | ||
515 | 2 * get_sme_vector_length<float>(), 1, // Block size | ||
516 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
517 | DataType::FP32, // Bias format | ||
518 | DataType::UNKNOWN, // Scaling type | ||
519 | 2 * get_sme_vector_length<float>(), 1); // Sub-block | ||
520 | vecmul_methods[2].bias_format = DataFormat(DataType::FP32); | ||
521 | vecmul_methods[2].fn_is_supported = cpu_has_sme; | ||
522 | vecmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
523 | vecmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
524 | vecmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
525 | vecmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
526 | vecmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
527 | vecmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
528 | vecmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
529 | vecmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
530 | vecmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
531 | vecmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = | ||
532 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
533 | vecmul_methods[2].fn_get_main_packed_rhs_offset = | ||
534 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
535 | vecmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
536 | vecmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
537 | vecmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
538 | vecmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
539 | vecmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
540 | |||
541 | vecmul_methods[3].name = "matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla"; | ||
542 | vecmul_methods[3].m0 = 1; | ||
543 | vecmul_methods[3].n0 = 16 * get_sme_vector_length<float>(); | ||
544 | vecmul_methods[3].dst_format = DataFormat(DataType::FP32); | ||
545 | vecmul_methods[3].lhs_format = DataFormat(DataType::FP32); | ||
546 | vecmul_methods[3].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
547 | vecmul_methods[3].rhs_format = DataFormat(DataType::FP32); | ||
548 | vecmul_methods[3].packed_rhs_format = DataFormat( | ||
549 | DataType::FP32, // Output type | ||
550 | 2 * get_sme_vector_length<float>(), 1, // Block size | ||
551 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
552 | DataType::FP32, // Bias format | ||
553 | DataType::UNKNOWN, // Scaling type | ||
554 | 2 * get_sme_vector_length<float>(), 1); // Sub-block | ||
555 | vecmul_methods[3].bias_format = DataFormat(DataType::FP32); | ||
556 | vecmul_methods[3].fn_is_supported = cpu_has_sme2; | ||
557 | vecmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
558 | vecmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
559 | vecmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
560 | vecmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
561 | vecmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
562 | vecmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
563 | vecmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
564 | vecmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
565 | vecmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
566 | vecmul_methods[3].fn_get_pack_rhs_packed_rhs_offset = | ||
567 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
568 | vecmul_methods[3].fn_get_main_packed_rhs_offset = | ||
569 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
570 | vecmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
571 | vecmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
572 | vecmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
573 | vecmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
574 | vecmul_methods[3].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
575 | |||
576 | vecmul_methods[4].name = "matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla"; | ||
577 | vecmul_methods[4].m0 = 1; | ||
578 | vecmul_methods[4].n0 = 16 * get_sme_vector_length<float>(); | ||
579 | vecmul_methods[4].dst_format = DataFormat(DataType::FP32); | ||
580 | vecmul_methods[4].lhs_format = DataFormat(DataType::FP32); | ||
581 | vecmul_methods[4].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
582 | vecmul_methods[4].rhs_format = DataFormat(DataType::FP32); | ||
583 | vecmul_methods[4].packed_rhs_format = DataFormat( | ||
584 | DataType::FP32, // Output type | ||
585 | 16 * get_sme_vector_length<float>(), 1, // Block size | ||
586 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
587 | DataType::FP32, // Bias format | ||
588 | DataType::UNKNOWN, // Scaling type | ||
589 | 16 * get_sme_vector_length<float>(), 1); // Sub-block | ||
590 | vecmul_methods[4].bias_format = DataFormat(DataType::FP32); | ||
591 | vecmul_methods[4].fn_is_supported = cpu_has_sme2; | ||
592 | vecmul_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
593 | vecmul_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
594 | vecmul_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
595 | vecmul_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
596 | vecmul_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
597 | vecmul_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
598 | vecmul_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
599 | vecmul_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
600 | vecmul_methods[4].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
601 | vecmul_methods[4].fn_get_pack_rhs_packed_rhs_offset = | ||
602 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
603 | vecmul_methods[4].fn_get_main_packed_rhs_offset = | ||
604 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
605 | vecmul_methods[4].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
606 | vecmul_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
607 | vecmul_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
608 | vecmul_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
609 | vecmul_methods[4].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
610 | |||
611 | return vecmul_methods; | ||
612 | } | ||
613 | |||
614 | /// Matrix multiplication test fixture. | ||
615 | class MatMulTest : public testing::TestWithParam<MatMulTestParams> { | ||
616 | private: | ||
617 | /// Unique ID: m, n, k, method_id. | ||
618 | using TestDataId = std::tuple<size_t, size_t, size_t, std::string_view>; | ||
619 | |||
620 | protected: | ||
621 | /// Cached test data that is shared between multiple test case. | ||
622 | 125 | struct TestData { | |
623 | 250 | Buffer lhs{}; ///< LHS operand. | |
624 | 250 | Buffer ref_packed_lhs{}; ///< Reference packed LHS. | |
625 | 250 | Buffer rhs{}; ///< RHS operand. | |
626 | 250 | Buffer rhs_scales{}; ///< RHS per-row quantization scales. | |
627 | 250 | Buffer bias{}; ///< Bias. | |
628 | 250 | Buffer rhs_t{}; ///< Transposed RHS matrix. | |
629 | 250 | Buffer ref_packed_rhs{}; ///< Reference packed RHS. | |
630 | 250 | Buffer ref_dst{}; ///< Reference output. | |
631 | 125 | float clamp_min{}; ///< Minimum output value. | |
632 | 125 | float clamp_max{}; ///< Maximum output value. | |
633 | }; | ||
634 | |||
635 | /// Gets the test data for the current test case. | ||
636 | 1024 | static const TestData& test_data() { | |
637 | 7442 | const auto& [method, info, portion] = GetParam(); | |
638 | 5120 | const TestDataId data_id{info.m, info.n, info.k, method.name}; | |
639 | |||
640 | // If the test data is already available, returns it. | ||
641 | 1024 | const auto data_it = _data.find(data_id); | |
642 | |||
643 |
2/2✓ Branch 0 taken 899 times.
✓ Branch 1 taken 125 times.
|
1024 | if (data_it != _data.end()) { |
644 | 899 | return data_it->second; | |
645 | } | ||
646 | |||
647 | // Generates the test data. | ||
648 | 250 | const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; | |
649 | 250 | const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; | |
650 | 250 | const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; | |
651 | |||
652 | 250 | const auto lhs_h = info.m; | |
653 | 250 | const auto lhs_w = info.k; | |
654 | 250 | auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, 0); | |
655 | 125 | Buffer ref_packed_lhs; | |
656 | |||
657 |
2/2✓ Branch 0 taken 101 times.
✓ Branch 1 taken 24 times.
|
125 | if (has_lhs_pack) { |
658 | 24 | ref_packed_lhs = | |
659 |
3/6✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
|
48 | pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); |
660 | 24 | } | |
661 | |||
662 | 250 | const auto rhs_h = info.k; | |
663 | 250 | const auto rhs_w = info.n; | |
664 |
2/4✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
|
250 | auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, 1); |
665 | |||
666 | − | KAI_ASSUME(method.rhs_format.is_raw()); | |
667 |
3/6✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
|
125 | auto rhs_t = transpose(rhs.data(), method.rhs_format.data_type(), rhs_h, rhs_w); |
668 | |||
669 | 125 | Buffer rhs_scales; | |
670 |
3/8✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 125 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
125 | if (data_type_is_quantized(method.rhs_format.data_type()) && |
671 | ✗ | method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { | |
672 | ✗ | rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), 2); | |
673 | ✗ | } | |
674 | |||
675 | 125 | const auto bias_h = 1; | |
676 | 250 | const auto bias_w = info.n; | |
677 | 125 | Buffer bias; | |
678 | |||
679 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 125 times.
|
125 | if (has_bias) { |
680 |
2/4✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
|
250 | bias = fill_matrix_random(bias_h, bias_w, method.bias_format, 3); |
681 | 125 | } | |
682 | |||
683 | 125 | Buffer packed_rhs; | |
684 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 125 times.
|
125 | if (has_rhs_pack) { |
685 |
1/2✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
|
250 | packed_rhs = matmul_pack_rhs( |
686 |
3/6✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
|
125 | rhs.data(), rhs_scales.data(), bias.data(), method.rhs_format, method.packed_rhs_format, info.n, info.k, |
687 | true); | ||
688 | 125 | } | |
689 | |||
690 | − | KAI_ASSUME(method.lhs_format.is_raw()); | |
691 | − | KAI_ASSUME(method.rhs_format.is_raw()); | |
692 | − | KAI_ASSUME(method.dst_format.is_raw()); | |
693 |
1/2✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
|
250 | auto ref_dst = matmul( |
694 |
2/4✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
|
125 | lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // |
695 |
3/6✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
|
125 | rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // |
696 |
2/4✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
|
125 | bias.data(), nullptr, nullptr, method.bias_format.data_type(), // |
697 |
1/2✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
|
125 | method.dst_format.data_type(), // |
698 | 375 | info.m, info.n, info.k, false, false); | |
699 | |||
700 | static constexpr float clamp_ratio = 0.8F; | ||
701 | 750 | const auto [clamp_min, clamp_max] = | |
702 |
4/8✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125 times.
✗ Branch 7 not taken.
|
125 | find_clamp_range(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_ratio); |
703 |
7/14✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 125 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 125 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 125 times.
✗ Branch 13 not taken.
|
250 | ref_dst = clamp(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_min, clamp_max); |
704 | |||
705 |
9/18✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 125 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 125 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 125 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 125 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 125 times.
✗ Branch 17 not taken.
|
1125 | auto& data = _data[data_id] = {}; |
706 | 125 | data.lhs = std::move(lhs); | |
707 | 125 | data.ref_packed_lhs = std::move(ref_packed_lhs); | |
708 | 125 | data.rhs = std::move(rhs); | |
709 | 125 | data.rhs_scales = std::move(rhs_scales); | |
710 | 125 | data.bias = std::move(bias); | |
711 | 125 | data.rhs_t = std::move(rhs_t); | |
712 | 125 | data.ref_packed_rhs = std::move(packed_rhs); | |
713 | 125 | data.ref_dst = std::move(ref_dst); | |
714 | 125 | data.clamp_min = clamp_min; | |
715 | 125 | data.clamp_max = clamp_max; | |
716 | |||
717 | 125 | return data; | |
718 | 1024 | } | |
719 | |||
720 | private: | ||
721 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
722 | static std::map<TestDataId, TestData> _data; | ||
723 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
724 | }; | ||
725 | |||
726 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
727 | 1 | std::map<MatMulTest::TestDataId, MatMulTest::TestData> MatMulTest::_data; | |
728 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
729 | |||
730 | /// Tests the LHS packing micro-kernel. | ||
731 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1322 | TEST_P(MatMulTest, PackedLhs) { |
732 | 1460 | const auto& [method, info, portion] = GetParam(); | |
733 | |||
734 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | if (method.fn_is_supported && !method.fn_is_supported()) { |
735 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
736 | } | ||
737 | |||
738 |
2/2✓ Branch 0 taken 72 times.
✓ Branch 1 taken 368 times.
|
440 | if (!method.is_pack_lhs_needed()) { |
739 |
3/6✓ Branch 0 taken 368 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 368 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 368 times.
✗ Branch 5 not taken.
|
368 | GTEST_SKIP() << "Test not valid w/o LHS pack"; |
740 | } | ||
741 | |||
742 | 72 | const auto& data = test_data(); | |
743 | 144 | const auto lhs_h = info.m; | |
744 | 144 | const auto lhs_w = info.k; | |
745 | |||
746 | 144 | const auto rect = portion.compute_portion( | |
747 | 144 | lhs_h, lhs_w, method.packed_lhs_format.scheduler_block_height(lhs_h), | |
748 | 72 | lhs_w); // LHS packing micro-kernel API doesn't support scheduling over K dimension. | |
749 | |||
750 |
3/4✓ Branch 0 taken 68 times.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 68 times.
|
72 | if (rect.height() == 0 || rect.width() == 0) { |
751 |
9/18✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
|
4 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
752 | } | ||
753 | |||
754 | 136 | const auto mr = method.fn_get_mr(); | |
755 | 136 | const auto kr = method.fn_get_kr(); | |
756 | 136 | const auto sr = method.fn_get_sr(); | |
757 | 136 | const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(lhs_w); | |
758 | |||
759 | 272 | const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k, mr, kr, sr); | |
760 | 136 | const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(lhs_h, lhs_w); | |
761 |
3/14✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
|
68 | ASSERT_EQ(packed_lhs_size, ref_packed_lhs_size); |
762 | |||
763 | 136 | const auto lhs_offset = method.fn_get_lhs_offset(rect.start_row(), ref_lhs_row_stride); | |
764 | 136 | const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), lhs_w); | |
765 |
3/14✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
|
68 | ASSERT_EQ(lhs_offset, ref_lhs_offset); |
766 | |||
767 | 204 | const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect.start_row(), info.k); | |
768 | 136 | const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w); | |
769 |
3/14✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
|
68 | ASSERT_EQ(packed_lhs_offset, ref_packed_lhs_offset); |
770 | |||
771 | 68 | Buffer packed_lhs(packed_lhs_size, 0); | |
772 |
1/2✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
|
136 | method.fn_pack_lhs( |
773 |
3/6✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
|
68 | rect.height(), rect.width(), mr, kr, sr, 0, data.lhs.data() + lhs_offset, ref_lhs_row_stride, |
774 |
1/2✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
|
68 | packed_lhs.data() + packed_lhs_offset); |
775 | |||
776 |
1/2✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
|
68 | DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); |
777 | 136 | const auto success = | |
778 |
3/6✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
|
68 | compare(packed_lhs.data(), data.ref_packed_lhs.data(), method.packed_lhs_format, lhs_h, lhs_w, rect, handler); |
779 |
4/16✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 68 times.
|
68 | ASSERT_TRUE(success); |
780 | 440 | } | |
781 | |||
782 | /// Tests the RHS packing micro-kernel. | ||
783 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1322 | TEST_P(MatMulTest, PackedRhs) { |
784 | 6040 | const auto& [method, info, portion] = GetParam(); | |
785 | |||
786 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | if (method.fn_is_supported && !method.fn_is_supported()) { |
787 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
788 | } | ||
789 | |||
790 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | if (!method.is_pack_rhs_needed()) { |
791 | ✗ | GTEST_SKIP() << "Test not valid w/o RHS pack"; | |
792 | } | ||
793 | |||
794 | 440 | const auto& data = test_data(); | |
795 | 880 | const auto rhs_full_width = info.n; | |
796 | 880 | const auto rhs_full_height = info.k; | |
797 | |||
798 | 880 | const auto block_height = method.packed_rhs_format.scheduler_block_height(rhs_full_width); | |
799 | 880 | const auto block_width = method.packed_rhs_format.scheduler_block_width(rhs_full_height); | |
800 | |||
801 | 880 | const Rect rect = portion.compute_portion(rhs_full_width, rhs_full_height, block_height, block_width); | |
802 | |||
803 |
4/4✓ Branch 0 taken 430 times.
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 30 times.
✓ Branch 3 taken 400 times.
|
440 | if (rect.height() == 0 || rect.width() == 0) { |
804 |
9/18✓ Branch 0 taken 40 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 40 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 40 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 40 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 40 times.
✗ Branch 17 not taken.
|
40 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
805 | } | ||
806 | |||
807 | 400 | const auto rhs_start_row = rect.start_row(); | |
808 | 400 | const auto rhs_start_col = rect.start_col(); | |
809 | 400 | const auto width = rect.width(); | |
810 | 400 | const auto height = rect.height(); | |
811 | 800 | const auto rhs_row_stride = method.rhs_format.default_row_stride(rhs_full_width); | |
812 | |||
813 | /** Ensure that all relevant parameters are sane **/ | ||
814 | 800 | const auto n_step = method.fn_get_pack_rhs_n_step(); | |
815 | 400 | const auto ref_n_step = block_height; | |
816 |
3/14✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
|
400 | ASSERT_EQ(n_step, ref_n_step); |
817 | |||
818 | 800 | const auto rhs_offset = method.fn_get_rhs_offset(rhs_start_row); | |
819 | 800 | const auto ref_rhs_offset = | |
820 | 400 | method.rhs_format.default_offset_in_bytes(rhs_start_col, rhs_start_row, rhs_full_height); | |
821 |
3/14✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
|
400 | ASSERT_EQ(rhs_offset, ref_rhs_offset); |
822 | |||
823 | 800 | const auto packed_rhs_size = method.fn_get_packed_rhs_size(rhs_full_width, rhs_full_height); | |
824 | 800 | const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(rhs_full_width, rhs_full_height); | |
825 |
3/14✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
|
400 | ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size); |
826 | |||
827 | 800 | const auto packed_rhs_offset = method.fn_get_pack_rhs_packed_rhs_offset(rhs_start_row, rhs_full_height); | |
828 | 800 | const auto ref_packed_rhs_offset = | |
829 | 400 | method.packed_rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_full_height); | |
830 |
3/14✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
|
400 | ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset); |
831 | |||
832 | 800 | const auto scale_type = method.packed_rhs_format.scale_data_type(); | |
833 | 400 | const auto ref_rhs_scales_offset = rhs_start_row * data_type_size_in_bits(scale_type) / 8; | |
834 | |||
835 | 800 | const auto bias_offset = method.fn_get_bias_offset(rhs_start_row); | |
836 | 800 | const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rhs_start_row, rhs_full_height); | |
837 |
3/14✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 400 times.
|
400 | ASSERT_EQ(bias_offset, ref_bias_offset); |
838 | |||
839 | /** Perform RHS packing, and compare with reference result **/ | ||
840 | 400 | Buffer packed_rhs(packed_rhs_size, 0); | |
841 |
1/2✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
|
400 | method.pack_rhs( |
842 |
2/4✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
|
400 | height, width, data.rhs.data() + rhs_offset, rhs_row_stride, data.bias.data() + bias_offset, |
843 |
2/6✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 400 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
400 | data.rhs_scales.data() != nullptr ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr, |
844 |
1/2✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
|
400 | packed_rhs.data() + packed_rhs_offset); |
845 | |||
846 |
2/4✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
|
800 | const bool exact = method.packed_rhs_format.pack_format() != DataFormat::PackFormat::QUANTIZE_PER_ROW; |
847 |
1/2✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
|
400 | DefaultMismatchHandler handler(0, exact ? 0 : 0.0001, 0, exact ? 0 : 0.001); |
848 |
1/2✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
|
800 | const auto success = compare( |
849 |
2/4✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
|
400 | packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, rhs_full_width, rhs_full_height, rect, |
850 | handler); | ||
851 |
4/16✓ Branch 0 taken 400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 400 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 400 times.
|
400 | ASSERT_TRUE(success); |
852 | 440 | } | |
853 | |||
854 | /// Tests the transposed RHS packing micro-kernel. | ||
855 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1322 | TEST_P(MatMulTest, PackedTransposedRhs) { |
856 | 2504 | const auto& [method, info, portion] = GetParam(); | |
857 | |||
858 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | if (method.fn_is_supported && !method.fn_is_supported()) { |
859 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
860 | } | ||
861 | |||
862 |
2/2✓ Branch 0 taken 72 times.
✓ Branch 1 taken 368 times.
|
440 | if (!method.is_pack_rhs_nxk_needed()) { |
863 |
3/6✓ Branch 0 taken 368 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 368 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 368 times.
✗ Branch 5 not taken.
|
368 | GTEST_SKIP() << "Test not valid w/o pre-processing of transposed RHS matrix"; |
864 | } | ||
865 | |||
866 | 72 | const auto& data = test_data(); | |
867 | 144 | const auto n_step = method.fn_pack_rhs_nxk_get_n_step(); | |
868 | 216 | const auto ref_n_step = method.packed_rhs_format.scheduler_block_height(info.n); | |
869 |
3/14✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 72 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 72 times.
|
72 | ASSERT_EQ(n_step, ref_n_step); |
870 | |||
871 | 144 | const auto rect = portion.compute_portion( | |
872 | 288 | info.n, info.k, method.packed_rhs_format.scheduler_block_height(info.n), | |
873 | 144 | method.packed_rhs_format.scheduler_block_width(info.k)); | |
874 | |||
875 |
3/4✓ Branch 0 taken 68 times.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 68 times.
|
72 | if (rect.height() == 0 || rect.width() == 0) { |
876 |
9/18✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
|
4 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
877 | } | ||
878 | |||
879 | 204 | const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(info.k); | |
880 | |||
881 | 136 | const auto rhs_offset = method.fn_pack_rhs_nxk_get_rhs_offset(rect.start_row(), ref_rhs_row_stride); | |
882 | 204 | const auto ref_rhs_offset = method.rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k); | |
883 |
3/14✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
|
68 | ASSERT_EQ(rhs_offset, ref_rhs_offset); |
884 | |||
885 | 272 | const auto packed_rhs_size = method.fn_pack_rhs_nxk_get_packed_rhs_size(info.n, info.k); | |
886 | 272 | const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(info.n, info.k); | |
887 |
3/14✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
|
68 | ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size); |
888 | |||
889 | 204 | const auto packed_rhs_offset = method.fn_pack_rhs_nxk_get_packed_rhs_offset(rect.start_row(), info.k); | |
890 | 136 | const auto ref_packed_rhs_offset = | |
891 | 136 | method.packed_rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k); | |
892 |
3/14✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
|
68 | ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset); |
893 | |||
894 | 136 | const auto ref_rhs_scales_offset = | |
895 | 136 | rect.start_row() * data_type_size_in_bits(method.packed_rhs_format.scale_data_type()) / 8; | |
896 | |||
897 | 136 | const auto bias_offset = method.fn_get_bias_offset(rect.start_row()); | |
898 | 204 | const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), info.n); | |
899 |
3/14✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 68 times.
|
68 | ASSERT_EQ(bias_offset, ref_bias_offset); |
900 | |||
901 | 68 | Buffer packed_rhs(packed_rhs_size, 0); | |
902 | |||
903 |
1/2✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
|
68 | method.pack_rhs_nxk( |
904 |
4/8✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 68 times.
✗ Branch 7 not taken.
|
68 | rect.height(), rect.width(), data.rhs_t.data() + rhs_offset, ref_rhs_row_stride, data.bias.data() + bias_offset, |
905 |
2/6✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 68 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
68 | data.rhs_scales.data() != nullptr ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr, |
906 |
1/2✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
|
68 | packed_rhs.data() + packed_rhs_offset); |
907 | |||
908 |
2/4✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
|
136 | const auto exact = method.packed_rhs_format.pack_format() != DataFormat::PackFormat::QUANTIZE_PER_ROW; |
909 |
1/2✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
|
68 | DefaultMismatchHandler handler(0, exact ? 0 : 0.0001, 0, exact ? 0 : 0.001); |
910 | 136 | const auto success = | |
911 |
5/10✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 68 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 68 times.
✗ Branch 9 not taken.
|
68 | compare(packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, info.n, info.k, rect, handler); |
912 |
4/16✓ Branch 0 taken 68 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 68 times.
|
68 | ASSERT_TRUE(success); |
913 | 440 | } | |
914 | |||
915 | /// Tests the output. | ||
916 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1322 | TEST_P(MatMulTest, Output) { |
917 | 12050 | const auto& [method, info, portion] = GetParam(); | |
918 | |||
919 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | if (method.fn_is_supported && !method.fn_is_supported()) { |
920 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
921 | } | ||
922 | |||
923 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | if (!method.has_main_kernel()) { |
924 | ✗ | GTEST_SKIP() << "No main kernel available"; | |
925 | } | ||
926 | |||
927 | 440 | const auto& data = test_data(); | |
928 | 880 | const auto m_step = method.fn_get_main_m_step(); | |
929 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
880 | ASSERT_EQ(m_step, method.m0); |
930 | |||
931 | 880 | const auto n_step = method.fn_get_main_n_step(); | |
932 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
880 | ASSERT_EQ(n_step, method.n0); |
933 | |||
934 | 2200 | const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); | |
935 | |||
936 |
4/4✓ Branch 0 taken 430 times.
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 40 times.
✓ Branch 3 taken 390 times.
|
440 | if (rect.height() == 0 || rect.width() == 0) { |
937 |
9/18✓ Branch 0 taken 50 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 50 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 50 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 50 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 50 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 50 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 50 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 50 times.
✗ Branch 17 not taken.
|
50 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
938 | } | ||
939 | |||
940 | 780 | const auto lhs_w = info.k; | |
941 | 780 | const auto rhs_w = info.n; | |
942 | 780 | const auto bias_w = info.n; | |
943 | 780 | const auto dst_w = info.n; | |
944 | |||
945 | 390 | const auto lhs_start_row = rect.start_row(); | |
946 | 390 | const auto lhs_start_col = 0; | |
947 | 780 | const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); | |
948 | |||
949 | 390 | const std::byte* lhs_data = nullptr; | |
950 | 390 | uintptr_t lhs_offset = 0; | |
951 | |||
952 |
2/2✓ Branch 0 taken 64 times.
✓ Branch 1 taken 326 times.
|
390 | if (method.is_pack_lhs_needed()) { |
953 | 64 | lhs_data = data.ref_packed_lhs.data(); | |
954 | |||
955 | 128 | const auto ref_packed_lhs_offset = | |
956 | 128 | method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); | |
957 | 128 | lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); | |
958 |
3/14✓ Branch 0 taken 64 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 64 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 64 times.
|
64 | ASSERT_EQ(lhs_offset, ref_packed_lhs_offset); |
959 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 64 times.
|
64 | } else { |
960 | 326 | lhs_data = data.lhs.data(); | |
961 | |||
962 | 326 | lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); | |
963 | 652 | const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, lhs_w); | |
964 |
3/14✓ Branch 0 taken 326 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 326 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 326 times.
|
326 | ASSERT_EQ(lhs_offset, ref_lhs_offset); |
965 | 326 | } | |
966 | |||
967 | 780 | const auto rhs_stride = method.rhs_format.default_row_stride(rhs_w); | |
968 | |||
969 | 390 | const std::byte* rhs_data = nullptr; | |
970 | 390 | uintptr_t rhs_offset = 0; | |
971 | |||
972 |
1/2✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
|
390 | if (method.is_pack_rhs_needed()) { |
973 | 390 | const auto packed_rhs_start_row = rect.start_col(); | |
974 | 390 | const auto packed_rhs_start_col = 0; | |
975 | |||
976 | 390 | rhs_data = data.ref_packed_rhs.data(); | |
977 | |||
978 | 780 | rhs_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); | |
979 | 780 | const auto ref_rhs_offset = | |
980 | 780 | method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); | |
981 |
3/14✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 390 times.
|
390 | ASSERT_EQ(rhs_offset, ref_rhs_offset); |
982 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 390 times.
|
390 | } else { |
983 | ✗ | const auto rhs_start_row = 0; | |
984 | ✗ | const auto rhs_start_col = rect.start_col(); | |
985 | |||
986 | ✗ | rhs_data = data.rhs.data(); | |
987 | ✗ | rhs_offset = method.rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_w); | |
988 | ✗ | } | |
989 | |||
990 | 390 | const auto* bias_data = data.bias.data(); | |
991 | 780 | const auto bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), bias_w); | |
992 | |||
993 | 780 | const auto dst_stride = method.dst_format.default_row_stride(dst_w); | |
994 | 780 | const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); | |
995 | 780 | const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); | |
996 |
3/14✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 390 times.
|
390 | ASSERT_EQ(dst_offset, ref_dst_offset); |
997 | |||
998 | 1560 | const auto dst_size = method.fn_get_dst_size(info.m, info.n); | |
999 | 1560 | const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n); | |
1000 |
3/14✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 390 times.
|
390 | ASSERT_EQ(dst_size, ref_dst_size); |
1001 | |||
1002 | 390 | Buffer dst(dst_size, 0); | |
1003 | |||
1004 |
1/2✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
|
390 | method.main_kernel( |
1005 |
2/4✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
|
390 | rect.height(), rect.width(), info.k, lhs_data + lhs_offset, rhs_data + rhs_offset, bias_data + bias_offset, |
1006 |
1/2✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
|
390 | dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, data.clamp_min, data.clamp_max); |
1007 | |||
1008 |
1/2✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
|
390 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); |
1009 |
5/10✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 390 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 390 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 390 times.
✗ Branch 9 not taken.
|
390 | const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); |
1010 |
4/16✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 390 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 390 times.
|
390 | ASSERT_TRUE(success); |
1011 | 440 | } | |
1012 | |||
1013 |
13/40✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 720 times.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
|
725 | INSTANTIATE_TEST_SUITE_P( |
1014 | MatMul, MatMulTest, | ||
1015 | testing::Combine( | ||
1016 | testing::ValuesIn(get_matmul_methods()), | ||
1017 | testing::Values( | ||
1018 | MatMulShape{1, 16, 16}, // | ||
1019 | MatMulShape{20, 1, 20}, // | ||
1020 | MatMulShape{6, 16, 32}, // | ||
1021 | MatMulShape{12, 32, 17}, // | ||
1022 | MatMulShape{13, 33, 23}, // | ||
1023 | MatMulShape{87, 93, 56} // | ||
1024 | ), | ||
1025 | testing::Values( | ||
1026 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
1027 | MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. | ||
1028 | MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner. | ||
1029 | )), | ||
1030 | testing::PrintToStringParamName()); | ||
1031 | |||
1032 |
14/44✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 4 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1040 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
|
1045 | INSTANTIATE_TEST_SUITE_P( |
1033 | VecMul, MatMulTest, | ||
1034 | testing::Combine( | ||
1035 | testing::ValuesIn(get_vecmul_methods()), | ||
1036 | testing::Values( | ||
1037 | MatMulShape{1, 16, 16}, // | ||
1038 | MatMulShape{1, 1, 20}, // | ||
1039 | MatMulShape{1, 16, 32}, // | ||
1040 | MatMulShape{1, 32, 17}, // | ||
1041 | MatMulShape{1, 33, 23}, // | ||
1042 | MatMulShape{1, 1500, 20}, // | ||
1043 | MatMulShape{1, 93, 56}, // | ||
1044 | MatMulShape{1, 1, 1}, // | ||
1045 | MatMulShape{1, 16, 1}, // | ||
1046 | MatMulShape{1, 32, 64}, // | ||
1047 | MatMulShape{1, 7, 74}, // | ||
1048 | MatMulShape{1, 800, 64}, // | ||
1049 | MatMulShape{1, 512, 130} // | ||
1050 | ), | ||
1051 | testing::Values( | ||
1052 | MatrixPortion(0, 0, 1, 1), // Full row. | ||
1053 | MatrixPortion(0, 0, 1, 0.5), // First half | ||
1054 | MatrixPortion(0, .4, 1, 0.3), // mid row-section. | ||
1055 | MatrixPortion(0, 0.75, 1, .25) // right row section | ||
1056 | )), | ||
1057 | testing::PrintToStringParamName()); | ||
1058 | |||
1059 | } // namespace kai::test | ||
1060 |