Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <string_view> | ||
12 | #include <tuple> | ||
13 | #include <unordered_map> | ||
14 | |||
15 | #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" | ||
16 | #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h" | ||
17 | #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h" | ||
18 | #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" | ||
19 | #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" | ||
20 | #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h" | ||
21 | #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" | ||
22 | #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" | ||
23 | #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" | ||
24 | #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" | ||
25 | #include "test/common/buffer.hpp" | ||
26 | #include "test/common/compare.hpp" | ||
27 | #include "test/common/cpu_info.hpp" | ||
28 | #include "test/common/matmul_test_common.hpp" | ||
29 | #include "test/common/matrix_portion.hpp" | ||
30 | #include "test/common/memory.hpp" | ||
31 | #include "test/common/round.hpp" | ||
32 | #include "test/common/sme.hpp" | ||
33 | #include "test/reference/clamp.hpp" | ||
34 | #include "test/reference/fill.hpp" | ||
35 | #include "test/reference/matmul.hpp" | ||
36 | #include "test/reference/reorder.hpp" | ||
37 | |||
38 | namespace kai::test { | ||
39 | |||
40 | // Ensure static linkage for all functionality local to this test file | ||
41 | namespace { | ||
42 | |||
43 | /// Convenience wrapper for K-chunk handling | ||
44 | struct KChunk { | ||
45 | size_t count; | ||
46 | size_t length; | ||
47 | }; | ||
48 | |||
49 | /// Interface for indirect matmul LHS packing micro-kernel | ||
50 |
2/4✓ Branch 0 taken 134648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 134648 times.
✗ Branch 3 not taken.
|
134648 | struct LhsPackIndirectKernel { |
51 | std::function<size_t()> get_m_step; | ||
52 | std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset; | ||
53 | std::function<size_t(size_t m, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_size; | ||
54 | std::function<void( | ||
55 | size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, | ||
56 | const void* zero, void* lhs_packed)> | ||
57 | pack; | ||
58 | }; | ||
59 | |||
60 | /// Interface for indirect matmul RHS packing micro-kernel | ||
61 |
4/8✓ Branch 0 taken 134648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 134648 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 134648 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 134648 times.
✗ Branch 7 not taken.
|
134648 | struct RhsPackIndirectKernel { |
62 | std::function<size_t()> get_n_step; | ||
63 | std::function<size_t(size_t n_idx)> get_rhs_offset; | ||
64 | std::function<size_t(size_t n_idx)> get_bias_offset; | ||
65 | std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset; | ||
66 | std::function<size_t(size_t n, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_size; | ||
67 | std::function<void( | ||
68 | size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, | ||
69 | void* rhs_packed)> | ||
70 | pack; | ||
71 | }; | ||
72 | |||
73 | /// Interface for indirect matmul kernel | ||
74 |
8/16✓ Branch 0 taken 134648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 134648 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 134648 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 134648 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 134648 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 134648 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 134648 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 134648 times.
✗ Branch 15 not taken.
|
134648 | struct MatMulIndirectKernel { |
75 | std::function<size_t(void)> get_m_step; | ||
76 | std::function<size_t(void)> get_n_step; | ||
77 | std::function<size_t(void)> get_mr; | ||
78 | std::function<size_t(void)> get_nr; | ||
79 | std::function<size_t(void)> get_kr; | ||
80 | std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset; | ||
81 | std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset; | ||
82 | std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride_row)> get_dst_offset; | ||
83 | std::function<size_t(size_t m, size_t n)> get_dst_size; | ||
84 | std::function<void( | ||
85 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, | ||
86 | void* dst, size_t dst_stride_row, float clamp_min, float clamp_max)> | ||
87 | imatmul; | ||
88 | }; | ||
89 | |||
90 | /// Description of a Indirect Matmul kernel set | ||
91 |
2/4✓ Branch 0 taken 134648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 134648 times.
✗ Branch 3 not taken.
|
134648 | struct IndirectMatMul { |
92 | std::string_view name; | ||
93 | std::function<bool(void)> is_supported; | ||
94 | |||
95 | MatMulShape pack_shape; | ||
96 | struct Format { | ||
97 | DataFormat lhs; | ||
98 | DataFormat rhs; | ||
99 | DataFormat bias; | ||
100 | DataFormat out; | ||
101 | |||
102 | struct Hash { | ||
103 | 29172 | size_t operator()(const Format& format) const { | |
104 | 29172 | return // | |
105 | 58344 | (DataFormat::Hash{}(format.lhs) << 0) ^ // | |
106 | 58344 | (DataFormat::Hash{}(format.rhs) << 1) ^ // | |
107 | 58344 | (DataFormat::Hash{}(format.bias) << 2) ^ // | |
108 | 29172 | (DataFormat::Hash{}(format.out) << 3); | |
109 | } | ||
110 | }; | ||
111 | |||
112 | private: | ||
113 | 24684 | friend bool operator==(const Format& lhs, const Format& rhs) { | |
114 | 24684 | return // | |
115 |
1/2✓ Branch 0 taken 24684 times.
✗ Branch 1 not taken.
|
24684 | lhs.lhs == rhs.lhs && // |
116 |
1/2✓ Branch 0 taken 24684 times.
✗ Branch 1 not taken.
|
24684 | lhs.rhs == rhs.rhs && // |
117 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 24684 times.
|
24684 | lhs.bias == rhs.bias && // |
118 | 24684 | lhs.out == rhs.out; | |
119 | } | ||
120 | } format; | ||
121 | |||
122 | LhsPackIndirectKernel lhs; | ||
123 | RhsPackIndirectKernel rhs; | ||
124 | MatMulIndirectKernel imatmul; | ||
125 | }; | ||
126 | |||
127 | /// Convenience type for test list | ||
128 | using IndirectMatMulArray = std::array<IndirectMatMul, 4>; | ||
129 | |||
130 | /// Test parameter bundle type | ||
131 | using IndirectMatMulTestParams = std::tuple<IndirectMatMul, MatMulShape, size_t, MatrixPortion, float>; | ||
132 | |||
133 | /// Test type | ||
134 | using IndirectMatMulTest = testing::TestWithParam<IndirectMatMulTestParams>; | ||
135 | |||
136 | /// Use interface for matmul kernel | ||
137 | 1 | const kai_imatmul_clamp_f16_f16p_f16p_ukernel& get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() { | |
138 | static kai_imatmul_clamp_f16_f16p_f16p_ukernel ukernel; | ||
139 | 1 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
140 | 1 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
141 | 1 | ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
142 | 1 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
143 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
144 | 1 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
145 | 1 | ukernel.run_imatmul = kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
146 | 1 | return ukernel; | |
147 | } | ||
148 | |||
149 | 1 | const kai_imatmul_clamp_f16_f16p_f16p_ukernel& get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa() { | |
150 | static kai_imatmul_clamp_f16_f16p_f16p_ukernel ukernel; | ||
151 | 1 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
152 | 1 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
153 | 1 | ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
154 | 1 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
155 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
156 | 1 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
157 | 1 | ukernel.run_imatmul = kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
158 | 1 | return ukernel; | |
159 | } | ||
160 | |||
161 | /// Use interface for matmul kernel | ||
162 | 1 | const kai_imatmul_clamp_f32_f32p_f32p_ukernel& get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() { | |
163 | static kai_imatmul_clamp_f32_f32p_f32p_ukernel ukernel; | ||
164 | 1 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
165 | 1 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
166 | 1 | ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
167 | 1 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
168 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
169 | 1 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
170 | 1 | ukernel.run_imatmul = kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
171 | 1 | return ukernel; | |
172 | } | ||
173 | |||
174 | 1 | const kai_imatmul_clamp_f32_f32p_f32p_ukernel& get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa() { | |
175 | static kai_imatmul_clamp_f32_f32p_f32p_ukernel ukernel; | ||
176 | 1 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
177 | 1 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
178 | 1 | ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
179 | 1 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
180 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
181 | 1 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
182 | 1 | ukernel.run_imatmul = kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
183 | 1 | return ukernel; | |
184 | } | ||
185 | |||
186 | /// Retreive the test list | ||
187 | 1 | const IndirectMatMulArray& get_indirect_matmul_methods() { | |
188 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
1 | static IndirectMatMulArray indirect_matmul_methods{}; |
189 | |||
190 | // F16 IMATMUL SME2 /////////////////////////////////////////////////////// | ||
191 | 1 | indirect_matmul_methods[0].name = "imatmul_f16_f16p_f16p_2vlx2vl_sme2_mopa"; | |
192 | 1 | indirect_matmul_methods[0].is_supported = cpu_has_sme2; | |
193 | 1 | indirect_matmul_methods[0].pack_shape.m = 2 * get_sme_vector_length<int32_t>(); | |
194 | 1 | indirect_matmul_methods[0].pack_shape.n = 2 * get_sme_vector_length<int32_t>(); | |
195 | 1 | indirect_matmul_methods[0].pack_shape.k = sizeof(int32_t); | |
196 | 1 | indirect_matmul_methods[0].format.lhs = DataFormat(DataType::FP16); | |
197 | 1 | indirect_matmul_methods[0].format.rhs = DataFormat(DataType::FP16); | |
198 | 1 | indirect_matmul_methods[0].format.bias = DataFormat(DataType::FP16); | |
199 | 1 | indirect_matmul_methods[0].format.out = DataFormat(DataType::FP16); | |
200 | |||
201 | // LHS | ||
202 | 1 | indirect_matmul_methods[0].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
203 | 1 | indirect_matmul_methods[0].lhs.get_lhs_packed_offset = | |
204 | kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | ||
205 | 1 | indirect_matmul_methods[0].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
206 | 1 | indirect_matmul_methods[0].lhs.pack = kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
207 | |||
208 | // RHS | ||
209 | 1 | indirect_matmul_methods[0].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
210 | 1 | indirect_matmul_methods[0].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
211 | 1 | indirect_matmul_methods[0].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
212 | 1 | indirect_matmul_methods[0].rhs.get_rhs_packed_offset = | |
213 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
214 | 1 | indirect_matmul_methods[0].rhs.get_rhs_packed_size = | |
215 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
216 | 1 | indirect_matmul_methods[0].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
217 | |||
218 | // IMATMUL | ||
219 | 2 | const kai_imatmul_clamp_f16_f16p_f16p_ukernel& ukernel_f16_sme2 = | |
220 | 1 | get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(); | |
221 | 1 | indirect_matmul_methods[0].imatmul.get_m_step = ukernel_f16_sme2.get_m_step; | |
222 | 1 | indirect_matmul_methods[0].imatmul.get_n_step = ukernel_f16_sme2.get_n_step; | |
223 | 1 | indirect_matmul_methods[0].imatmul.get_lhs_packed_offset = ukernel_f16_sme2.get_lhs_packed_offset; | |
224 | 1 | indirect_matmul_methods[0].imatmul.get_rhs_packed_offset = ukernel_f16_sme2.get_rhs_packed_offset; | |
225 | 1 | indirect_matmul_methods[0].imatmul.get_dst_offset = ukernel_f16_sme2.get_dst_offset; | |
226 | 1 | indirect_matmul_methods[0].imatmul.get_dst_size = ukernel_f16_sme2.get_dst_size; | |
227 | 1 | indirect_matmul_methods[0].imatmul.imatmul = ukernel_f16_sme2.run_imatmul; | |
228 | |||
229 | // F32 IMATMUL SME2 /////////////////////////////////////////////////////// | ||
230 | 1 | indirect_matmul_methods[1].name = "imatmul_f32_f32p_f32p_2vlx2vl_sme2_mopa"; | |
231 | 1 | indirect_matmul_methods[1].is_supported = cpu_has_sme2; | |
232 | 1 | indirect_matmul_methods[1].pack_shape.m = 2 * get_sme_vector_length<int32_t>(); | |
233 | 1 | indirect_matmul_methods[1].pack_shape.n = 2 * get_sme_vector_length<int32_t>(); | |
234 | 1 | indirect_matmul_methods[1].pack_shape.k = sizeof(int32_t); | |
235 | 1 | indirect_matmul_methods[1].format.lhs = DataFormat(DataType::FP32); | |
236 | 1 | indirect_matmul_methods[1].format.rhs = DataFormat(DataType::FP32); | |
237 | 1 | indirect_matmul_methods[1].format.bias = DataFormat(DataType::FP32); | |
238 | 1 | indirect_matmul_methods[1].format.out = DataFormat(DataType::FP32); | |
239 | |||
240 | // LHS | ||
241 | 1 | indirect_matmul_methods[1].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
242 | 1 | indirect_matmul_methods[1].lhs.get_lhs_packed_offset = | |
243 | kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | ||
244 | 1 | indirect_matmul_methods[1].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
245 | 1 | indirect_matmul_methods[1].lhs.pack = kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
246 | |||
247 | // RHS | ||
248 | 1 | indirect_matmul_methods[1].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
249 | 1 | indirect_matmul_methods[1].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
250 | 1 | indirect_matmul_methods[1].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
251 | 1 | indirect_matmul_methods[1].rhs.get_rhs_packed_offset = | |
252 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | ||
253 | 1 | indirect_matmul_methods[1].rhs.get_rhs_packed_size = | |
254 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | ||
255 | 1 | indirect_matmul_methods[1].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
256 | |||
257 | // IMATMUL | ||
258 | 2 | const kai_imatmul_clamp_f32_f32p_f32p_ukernel& ukernel_f32_sme2 = | |
259 | 1 | get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); | |
260 | 1 | indirect_matmul_methods[1].imatmul.get_m_step = ukernel_f32_sme2.get_m_step; | |
261 | 1 | indirect_matmul_methods[1].imatmul.get_n_step = ukernel_f32_sme2.get_n_step; | |
262 | 1 | indirect_matmul_methods[1].imatmul.get_lhs_packed_offset = ukernel_f32_sme2.get_lhs_packed_offset; | |
263 | 1 | indirect_matmul_methods[1].imatmul.get_rhs_packed_offset = ukernel_f32_sme2.get_rhs_packed_offset; | |
264 | 1 | indirect_matmul_methods[1].imatmul.get_dst_offset = ukernel_f32_sme2.get_dst_offset; | |
265 | 1 | indirect_matmul_methods[1].imatmul.get_dst_size = ukernel_f32_sme2.get_dst_size; | |
266 | 1 | indirect_matmul_methods[1].imatmul.imatmul = ukernel_f32_sme2.run_imatmul; | |
267 | |||
268 | // F16 IMATMUL SME //////////////////////////////////////////////////////// | ||
269 | 1 | indirect_matmul_methods[2].name = "imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa"; | |
270 | 1 | indirect_matmul_methods[2].is_supported = cpu_has_sme; | |
271 | 1 | indirect_matmul_methods[2].pack_shape.m = 2 * get_sme_vector_length<int32_t>(); | |
272 | 1 | indirect_matmul_methods[2].pack_shape.n = 2 * get_sme_vector_length<int32_t>(); | |
273 | 1 | indirect_matmul_methods[2].pack_shape.k = sizeof(int32_t); | |
274 | 1 | indirect_matmul_methods[2].format.lhs = DataFormat(DataType::FP16); | |
275 | 1 | indirect_matmul_methods[2].format.rhs = DataFormat(DataType::FP16); | |
276 | 1 | indirect_matmul_methods[2].format.bias = DataFormat(DataType::FP16); | |
277 | 1 | indirect_matmul_methods[2].format.out = DataFormat(DataType::FP16); | |
278 | |||
279 | // LHS | ||
280 | 1 | indirect_matmul_methods[2].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
281 | 1 | indirect_matmul_methods[2].lhs.get_lhs_packed_offset = | |
282 | kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | ||
283 | 1 | indirect_matmul_methods[2].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
284 | 1 | indirect_matmul_methods[2].lhs.pack = kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
285 | |||
286 | // RHS | ||
287 | 1 | indirect_matmul_methods[2].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
288 | 1 | indirect_matmul_methods[2].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
289 | 1 | indirect_matmul_methods[2].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
290 | 1 | indirect_matmul_methods[2].rhs.get_rhs_packed_offset = | |
291 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
292 | 1 | indirect_matmul_methods[2].rhs.get_rhs_packed_size = | |
293 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
294 | 1 | indirect_matmul_methods[2].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
295 | |||
296 | // IMATMUL | ||
297 | 2 | const kai_imatmul_clamp_f16_f16p_f16p_ukernel& ukernel_f16_sme = | |
298 | 1 | get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(); | |
299 | 1 | indirect_matmul_methods[2].imatmul.get_m_step = ukernel_f16_sme.get_m_step; | |
300 | 1 | indirect_matmul_methods[2].imatmul.get_n_step = ukernel_f16_sme.get_n_step; | |
301 | 1 | indirect_matmul_methods[2].imatmul.get_lhs_packed_offset = ukernel_f16_sme.get_lhs_packed_offset; | |
302 | 1 | indirect_matmul_methods[2].imatmul.get_rhs_packed_offset = ukernel_f16_sme.get_rhs_packed_offset; | |
303 | 1 | indirect_matmul_methods[2].imatmul.get_dst_offset = ukernel_f16_sme.get_dst_offset; | |
304 | 1 | indirect_matmul_methods[2].imatmul.get_dst_size = ukernel_f16_sme.get_dst_size; | |
305 | 1 | indirect_matmul_methods[2].imatmul.imatmul = ukernel_f16_sme.run_imatmul; | |
306 | |||
307 | // F32 IMATMUL SME //////////////////////////////////////////////////////// | ||
308 | 1 | indirect_matmul_methods[3].name = "imatmul_f32_f32p_f32p_2vlx2vl_sme_mopa"; | |
309 | 1 | indirect_matmul_methods[3].is_supported = cpu_has_sme; | |
310 | 1 | indirect_matmul_methods[3].pack_shape.m = 2 * get_sme_vector_length<int32_t>(); | |
311 | 1 | indirect_matmul_methods[3].pack_shape.n = 2 * get_sme_vector_length<int32_t>(); | |
312 | 1 | indirect_matmul_methods[3].pack_shape.k = sizeof(int32_t); | |
313 | 1 | indirect_matmul_methods[3].format.lhs = DataFormat(DataType::FP32); | |
314 | 1 | indirect_matmul_methods[3].format.rhs = DataFormat(DataType::FP32); | |
315 | 1 | indirect_matmul_methods[3].format.bias = DataFormat(DataType::FP32); | |
316 | 1 | indirect_matmul_methods[3].format.out = DataFormat(DataType::FP32); | |
317 | |||
318 | // LHS | ||
319 | 1 | indirect_matmul_methods[3].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
320 | 1 | indirect_matmul_methods[3].lhs.get_lhs_packed_offset = | |
321 | kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | ||
322 | 1 | indirect_matmul_methods[3].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
323 | 1 | indirect_matmul_methods[3].lhs.pack = kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
324 | |||
325 | // RHS | ||
326 | 1 | indirect_matmul_methods[3].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
327 | 1 | indirect_matmul_methods[3].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
328 | 1 | indirect_matmul_methods[3].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
329 | 1 | indirect_matmul_methods[3].rhs.get_rhs_packed_offset = | |
330 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | ||
331 | 1 | indirect_matmul_methods[3].rhs.get_rhs_packed_size = | |
332 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | ||
333 | 1 | indirect_matmul_methods[3].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
334 | |||
335 | // IMATMUL | ||
336 | 2 | const kai_imatmul_clamp_f32_f32p_f32p_ukernel& ukernel_f32_sme = | |
337 | 1 | get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); | |
338 | 1 | indirect_matmul_methods[3].imatmul.get_m_step = ukernel_f32_sme.get_m_step; | |
339 | 1 | indirect_matmul_methods[3].imatmul.get_n_step = ukernel_f32_sme.get_n_step; | |
340 | 1 | indirect_matmul_methods[3].imatmul.get_lhs_packed_offset = ukernel_f32_sme.get_lhs_packed_offset; | |
341 | 1 | indirect_matmul_methods[3].imatmul.get_rhs_packed_offset = ukernel_f32_sme.get_rhs_packed_offset; | |
342 | 1 | indirect_matmul_methods[3].imatmul.get_dst_offset = ukernel_f32_sme.get_dst_offset; | |
343 | 1 | indirect_matmul_methods[3].imatmul.get_dst_size = ukernel_f32_sme.get_dst_size; | |
344 | 1 | indirect_matmul_methods[3].imatmul.imatmul = ukernel_f32_sme.run_imatmul; | |
345 | |||
346 | 1 | return indirect_matmul_methods; | |
347 | 1 | } | |
348 | |||
349 | /// Test reference identification | ||
350 | struct TestDataId { | ||
351 | MatMulShape shape; | ||
352 | MatMulShape pack_shape; | ||
353 | IndirectMatMul::Format format; | ||
354 | size_t k_chunk_length; | ||
355 | float clamp_rate; | ||
356 | |||
357 | struct Hash { | ||
358 | 29172 | size_t operator()(const TestDataId& test_id) const { | |
359 | 29172 | return // | |
360 | 58344 | (MatMulShape::Hash{}(test_id.shape) << 0) ^ // | |
361 | 58344 | (MatMulShape::Hash{}(test_id.pack_shape) << 1) ^ // | |
362 | 58344 | (IndirectMatMul::Format::Hash{}(test_id.format) << 2) ^ // | |
363 | 58344 | (std::hash<size_t>{}(test_id.k_chunk_length) << 3) ^ // | |
364 | 29172 | (std::hash<float>{}(test_id.clamp_rate) << 4); // | |
365 | } | ||
366 | }; | ||
367 | |||
368 | private: | ||
369 | 31212 | friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { | |
370 | 31212 | return // | |
371 |
2/2✓ Branch 0 taken 24684 times.
✓ Branch 1 taken 6528 times.
|
31212 | lhs.shape == rhs.shape && // |
372 |
1/2✓ Branch 0 taken 24684 times.
✗ Branch 1 not taken.
|
24684 | lhs.pack_shape == rhs.pack_shape && // |
373 |
1/2✓ Branch 0 taken 24684 times.
✗ Branch 1 not taken.
|
24684 | lhs.format == rhs.format && // |
374 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 24684 times.
|
24684 | lhs.k_chunk_length == rhs.k_chunk_length && // |
375 | 24684 | lhs.clamp_rate == rhs.clamp_rate; | |
376 | } | ||
377 | }; | ||
378 | |||
379 | /// Test reference data | ||
380 | struct TestData { | ||
381 | Buffer lhs; ///< LHS input matrix | ||
382 | Buffer rhs; ///< RHS input matrix | ||
383 | Buffer bias; ///< Bias vector | ||
384 | Buffer out; ///< Reference imatmul result | ||
385 | Buffer indirection; ///< LHS indirection buffer | ||
386 | uintptr_t indirection_offset; ///< LHS indirection buffer offset | ||
387 | Buffer padding; ///< Padding buffer | ||
388 | Range<float> clamp_range; ///< Clamp range | ||
389 | }; | ||
390 | |||
391 | /// Reference data generator | ||
392 | /// | ||
393 | /// Uses test id to generate reference data, and caches it. | ||
394 | struct ReferenceGenerator { | ||
395 | /// Retrieve reference data for the provided test identification | ||
396 | 26928 | static const TestData& get_test_reference(const TestDataId& test_id) { | |
397 |
3/4✓ Branch 0 taken 1 time.
✓ Branch 1 taken 26927 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
26928 | static std::unordered_map<TestDataId, TestData, TestDataId::Hash> m_data; |
398 |
4/5✓ Branch 0 taken 24684 times.
✓ Branch 1 taken 2244 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 24684 times.
✓ Branch 4 taken 2244 times.
|
51612 | if (const auto itr = m_data.find(test_id); itr != end(m_data)) { |
399 | 24684 | return itr->second; | |
400 | } | ||
401 | |||
402 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
2244 | return m_data[test_id] = generate_reference(test_id); |
403 | 26928 | } | |
404 | |||
405 | private: | ||
406 | /// Return incremented seed value | ||
407 | 6732 | static size_t get_seed() { | |
408 | static size_t seed = 0; | ||
409 | 6732 | return seed++; | |
410 | } | ||
411 | |||
412 | /// Generate reference data. Not intended to be called | ||
413 | /// directly, as this would bypass caching mechanism. | ||
414 | 2244 | static TestData generate_reference(const TestDataId& test_id) { | |
415 | 1097496 | const auto& [chunked_shape, pack_shape, format, k_chunk_length, clamp_rate] = test_id; | |
416 | |||
417 | // The LHS matrix will be split into several chunks in the K dimension | ||
418 | 4488 | const size_t k_chunk_count = chunked_shape.k; | |
419 | 8976 | MatMulShape shape = {chunked_shape.m, chunked_shape.n, k_chunk_count * k_chunk_length}; | |
420 | |||
421 | // Generate random input data | ||
422 | 4488 | Buffer lhs = fill_matrix_random(shape.m, shape.k, format.lhs, get_seed()); | |
423 |
3/6✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
|
4488 | Buffer rhs = fill_matrix_random(shape.k, shape.n, format.rhs, get_seed()); |
424 |
3/6✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
|
4488 | Buffer bias = fill_matrix_random(1, shape.n, format.bias, get_seed()); |
425 | |||
426 | // Data types used | ||
427 |
2/4✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
|
4488 | const DataType lhs_dt = format.lhs.data_type(); |
428 |
2/4✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
|
4488 | const DataType rhs_dt = format.rhs.data_type(); |
429 |
2/4✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
|
4488 | const DataType out_dt = format.out.data_type(); |
430 |
2/4✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
|
4488 | const DataType bias_dt = format.bias.data_type(); |
431 | |||
432 | // Create a padding chunk | ||
433 |
3/6✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
|
4488 | const size_t k_chunk_size = round_up_division(k_chunk_length * data_type_size_in_bits(lhs_dt), 8); |
434 | 2244 | const size_t row_size = k_chunk_count * k_chunk_size; | |
435 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
2244 | Buffer lhs_padding(k_chunk_size); |
436 |
4/4✓ Branch 0 taken 48756 times.
✓ Branch 1 taken 2244 times.
✓ Branch 2 taken 48756 times.
✓ Branch 3 taken 2244 times.
|
51000 | for (size_t i = 0; i < k_chunk_length; i += 1) { |
437 | static constexpr double padding_value = 0; | ||
438 |
2/4✓ Branch 0 taken 48756 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 48756 times.
✗ Branch 3 not taken.
|
48756 | write_array(lhs_dt, lhs_padding.data(), i, padding_value); |
439 | 48756 | } | |
440 | |||
441 | // Set up indirection buffer | ||
442 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
2244 | const uintptr_t indirection_offset = reinterpret_cast<uintptr_t>(lhs.data()); |
443 |
3/6✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
|
6732 | std::vector<const void*> indirection(chunked_shape.m * chunked_shape.k); |
444 |
4/4✓ Branch 0 taken 71478 times.
✓ Branch 1 taken 2244 times.
✓ Branch 2 taken 71478 times.
✓ Branch 3 taken 2244 times.
|
73722 | for (size_t i_m = 0; i_m < chunked_shape.m; i_m += 1) { |
445 |
4/4✓ Branch 0 taken 71478 times.
✓ Branch 1 taken 858660 times.
✓ Branch 2 taken 71478 times.
✓ Branch 3 taken 858660 times.
|
930138 | for (size_t i_k = 0; i_k < chunked_shape.k; i_k += 1) { |
446 | 1717320 | const size_t idx = i_m * chunked_shape.k + i_k; | |
447 | // Test padding pointers using first LHS row for shapes where M > 1 | ||
448 |
4/4✓ Branch 0 taken 839256 times.
✓ Branch 1 taken 19404 times.
✓ Branch 2 taken 815100 times.
✓ Branch 3 taken 24156 times.
|
858660 | if (chunked_shape.m > 1 && i_m == 0) { |
449 |
2/4✓ Branch 0 taken 24156 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24156 times.
✗ Branch 3 not taken.
|
24156 | indirection.at(idx) = lhs_padding.data(); |
450 | 24156 | } else { | |
451 | 834504 | uintptr_t offset = i_m * row_size + i_k * k_chunk_size; | |
452 |
1/2✓ Branch 0 taken 834504 times.
✗ Branch 1 not taken.
|
834504 | indirection.at(idx) = reinterpret_cast<const void*>(offset); |
453 | 834504 | } | |
454 | 858660 | } | |
455 | 71478 | } | |
456 | |||
457 | // Pack indirection buffer | ||
458 |
2/4✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
|
4488 | Buffer indirection_packed = reorder_block<const void*>( |
459 | 8976 | reinterpret_cast<const void* const*>(indirection.data()), chunked_shape.m, chunked_shape.k, pack_shape.m, | |
460 | 1); | ||
461 | |||
462 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
4488 | Buffer out = indirect_matmul( // |
463 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
2244 | indirection.data(), indirection_offset, lhs_padding.data(), nullptr, nullptr, lhs_dt, // LHS |
464 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
2244 | rhs.data(), nullptr, nullptr, rhs_dt, // RHS |
465 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
2244 | bias.data(), nullptr, nullptr, bias_dt, // Bias |
466 | 2244 | out_dt, // Out | |
467 | 8976 | chunked_shape.m, chunked_shape.n, chunked_shape.k, k_chunk_length); | |
468 | |||
469 | // Calculate clamping range based on full range of values, and then clamp values | ||
470 |
3/6✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
|
2244 | const auto [min, max] = find_clamp_range(out_dt, out.data(), shape.m * shape.n, 1.0F - clamp_rate); |
471 |
4/8✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2244 times.
✗ Branch 7 not taken.
|
2244 | Buffer out_clamped = clamp(out_dt, out.data(), shape.m * shape.n, min, max); |
472 | |||
473 | // Populate reference data | ||
474 | 2244 | TestData test_reference; | |
475 | 2244 | test_reference.lhs = std::move(lhs); | |
476 | 2244 | test_reference.rhs = std::move(rhs); | |
477 | 2244 | test_reference.bias = std::move(bias); | |
478 | 2244 | test_reference.padding = std::move(lhs_padding); | |
479 | 2244 | test_reference.out = std::move(out_clamped); | |
480 | 2244 | test_reference.indirection_offset = indirection_offset; | |
481 | 2244 | test_reference.indirection = std::move(indirection_packed); | |
482 | 6732 | test_reference.clamp_range = {min, max}; | |
483 | |||
484 | 2244 | return test_reference; | |
485 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
2244 | }; |
486 | }; | ||
487 | |||
488 | /// Perform LHS packing for indirect matmul | ||
489 | 22704 | Buffer pack_lhs( | |
490 | const LhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t m, | ||
491 | const KChunk& k_chunk) { | ||
492 | 22704 | const void* const* indirection_pointer = reinterpret_cast<const void* const*>(reference.indirection.data()); | |
493 | |||
494 | // Calculate size, and allocate buffer | ||
495 | 22704 | const size_t dst_size = kernel.get_lhs_packed_size(m, k_chunk.count, k_chunk.length); | |
496 | 22704 | Buffer dst(dst_size); | |
497 | |||
498 | // Calculate portion offsets | ||
499 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | const size_t input_offset = portion.start_row() * k_chunk.count; |
500 |
2/4✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
|
22704 | const size_t dst_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); |
501 | |||
502 | // Perform packing | ||
503 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
45408 | kernel.pack( |
504 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | portion.height(), k_chunk.count, k_chunk.length, // Dimensions |
505 | 22704 | indirection_pointer + input_offset, // Indirection input | |
506 | 22704 | reference.indirection_offset, // Chunk offset | |
507 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | reference.padding.data(), // Padding pointer |
508 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | dst.data() + dst_offset); |
509 | 22704 | return dst; | |
510 | 22704 | } | |
511 | |||
512 | /// Perform RHS packign for indirect matmul | ||
513 | 22704 | Buffer pack_rhs( | |
514 | const RhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t n, | ||
515 | const KChunk& k_chunk, DataType type) { | ||
516 | // Calculate size, and allocate buffer | ||
517 | 22704 | const size_t row_stride = round_up_division(n * data_type_size_in_bits(type), 8); | |
518 | 22704 | const size_t dst_size = kernel.get_rhs_packed_size(n, k_chunk.count, k_chunk.length); | |
519 | 22704 | Buffer dst(dst_size); | |
520 | |||
521 | // Calculate offsets | ||
522 |
2/4✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
|
22704 | const size_t rhs_offset = kernel.get_rhs_offset(portion.start_col()); |
523 |
2/4✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
|
22704 | const size_t bias_offset = kernel.get_bias_offset(portion.start_col()); |
524 |
2/4✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
|
22704 | const size_t dst_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); |
525 | |||
526 | // Perform actual packing | ||
527 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
45408 | kernel.pack( |
528 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | portion.width(), k_chunk.count, k_chunk.length, row_stride, // Dimensions |
529 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | reference.rhs.data() + rhs_offset, // RHS input |
530 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | reference.bias.data() + bias_offset, // Bias |
531 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | dst.data() + dst_offset); // Output |
532 | 22704 | return dst; | |
533 | 22704 | } | |
534 | |||
535 | /// Perform imatmul | ||
536 | /// | ||
537 | /// Note, this should not be aware of reference result, as to make it clear that | ||
538 | /// any produced result is strictly from the code under test | ||
539 | 22704 | Buffer imatmul( | |
540 | const MatMulIndirectKernel& kernel, const Rect& portion, const MatMulShape& shape, const KChunk& k_chunk, | ||
541 | const Buffer& lhs_packed, const Buffer& rhs_packed, Range<float> clamp_range, DataType type) { | ||
542 | // Calculate size, and allocate buffer | ||
543 | 22704 | const size_t dst_size = kernel.get_dst_size(shape.m, shape.n); | |
544 | 22704 | const size_t row_stride = round_up_division(shape.n * data_type_size_in_bits(type), 8); | |
545 | 22704 | Buffer dst(dst_size); | |
546 | |||
547 | // Calculate portion offsets | ||
548 |
2/4✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
|
22704 | const size_t lhs_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); |
549 |
2/4✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
|
22704 | const size_t rhs_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); |
550 |
3/6✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 22704 times.
✗ Branch 5 not taken.
|
22704 | const size_t dst_offset = kernel.get_dst_offset(portion.start_row(), portion.start_col(), row_stride); |
551 | |||
552 | // Call matmul kernel | ||
553 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
45408 | kernel.imatmul( |
554 |
2/4✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
|
22704 | portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions |
555 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | lhs_packed.data() + lhs_offset, // LHS |
556 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | rhs_packed.data() + rhs_offset, // RHS |
557 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | dst.data() + dst_offset, // DST |
558 | 22704 | row_stride, clamp_range.min, clamp_range.max); | |
559 | |||
560 | 22704 | return dst; | |
561 | 22704 | } | |
562 | |||
563 | } // namespace | ||
564 | |||
565 | /// End-to-end test for indirection matmul kernels | ||
566 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
80786 | TEST_P(IndirectMatMulTest, Output) { |
567 | 231264 | const auto& [method, shape, k_chunk_length, output_portion, clamp_rate] = GetParam(); | |
568 |
1/2✓ Branch 0 taken 26928 times.
✗ Branch 1 not taken.
|
26928 | if (not method.is_supported()) { |
569 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
570 | } | ||
571 | |||
572 | 80784 | const KChunk k_chunk{shape.k, k_chunk_length}; | |
573 | |||
574 | // Retrieve reference data | ||
575 | 161568 | const TestDataId test_id{shape, method.pack_shape, method.format, k_chunk_length, clamp_rate}; | |
576 | 26928 | const TestData& test_data = ReferenceGenerator::get_test_reference(test_id); | |
577 | 134640 | const Rect portion = output_portion.compute_portion(shape.m, shape.n, method.pack_shape.m, method.pack_shape.n); | |
578 | |||
579 |
4/4✓ Branch 0 taken 23364 times.
✓ Branch 1 taken 3564 times.
✓ Branch 2 taken 660 times.
✓ Branch 3 taken 22704 times.
|
26928 | if (portion.height() == 0 || portion.width() == 0) { |
580 |
9/18✓ Branch 0 taken 4224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4224 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4224 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4224 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4224 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4224 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4224 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4224 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4224 times.
✗ Branch 17 not taken.
|
4224 | GTEST_SKIP() << "Empty dimension of matrix(" << portion.width() << "," << portion.height() << ")"; |
581 | } | ||
582 | |||
583 | // Call packing micro-kernels, and then imatmul kernel | ||
584 | 68112 | Buffer lhs_packed = pack_lhs(method.lhs, portion, test_data, shape.m, k_chunk); | |
585 |
5/10✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 22704 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 22704 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 22704 times.
✗ Branch 9 not taken.
|
90816 | Buffer rhs_packed = pack_rhs(method.rhs, portion, test_data, shape.n, k_chunk, method.format.rhs.data_type()); |
586 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
45408 | Buffer out = imatmul( |
587 | 45408 | method.imatmul, portion, shape, k_chunk, lhs_packed, rhs_packed, test_data.clamp_range, | |
588 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | method.format.out.data_type()); |
589 | |||
590 | // Compare the actual result with the reference result | ||
591 |
1/2✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
|
22704 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); |
592 | 45408 | const auto success = | |
593 |
7/14✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 22704 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 22704 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 22704 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 22704 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 22704 times.
✗ Branch 13 not taken.
|
22704 | compare(out.data(), test_data.out.data(), method.format.out.data_type(), shape.m, shape.n, portion, handler); |
594 |
4/16✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 22704 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 22704 times.
|
22704 | ASSERT_TRUE(success); |
595 | 26928 | } | |
596 | |||
597 | /// Name generator for test case | ||
598 | 53856 | [[maybe_unused]] static void PrintTo(const IndirectMatMulTestParams& param, std::ostream* os) { | |
599 | 323136 | const auto& [method, shape, k_chunk_length, portion, clamp_rate] = param; | |
600 | 107712 | *os << method.name << "__"; | |
601 | 53856 | PrintTo(shape, os); | |
602 | 107712 | *os << "__K_chunk_length_" << k_chunk_length; | |
603 | 107712 | *os << "__clamp_rate_" << static_cast<int>(clamp_rate * 100) << "__"; | |
604 | 53856 | PrintTo(portion, os); | |
605 | 53856 | } | |
606 | |||
607 | /// Test parameter listing | ||
608 |
18/60✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 time.
✗ Branch 33 not taken.
✓ Branch 34 taken 26928 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
|
26930 | INSTANTIATE_TEST_SUITE_P( |
609 | IndirectMatMul, IndirectMatMulTest, | ||
610 | testing::Combine( | ||
611 | testing::ValuesIn(get_indirect_matmul_methods()), // | ||
612 | testing::ValuesIn({ | ||
613 | // clang-format off | ||
614 | MatMulShape{ 1, 1, 1}, // | ||
615 | MatMulShape{ 1, 17, 4}, // | ||
616 | MatMulShape{ 1, 19, 24}, // | ||
617 | MatMulShape{ 1, 32, 4}, // | ||
618 | MatMulShape{ 1, 32, 32}, // | ||
619 | MatMulShape{ 1, 33, 200}, // | ||
620 | MatMulShape{ 1, 49, 21}, // | ||
621 | MatMulShape{ 1, 64, 4}, // | ||
622 | MatMulShape{ 1, 65, 4}, // | ||
623 | MatMulShape{ 3, 6, 6}, // | ||
624 | MatMulShape{ 3, 28, 25}, // | ||
625 | MatMulShape{ 4, 16, 4}, // | ||
626 | MatMulShape{ 4, 16, 27}, // | ||
627 | MatMulShape{ 6, 18, 31}, // | ||
628 | MatMulShape{ 6, 28, 1}, // | ||
629 | MatMulShape{ 6, 29, 24}, // | ||
630 | MatMulShape{ 8, 16, 16}, // | ||
631 | MatMulShape{ 16, 16, 4}, // | ||
632 | MatMulShape{ 16, 16, 16}, // | ||
633 | MatMulShape{ 20, 30, 40}, // | ||
634 | MatMulShape{ 23, 1, 43}, // | ||
635 | MatMulShape{ 32, 14, 1}, // | ||
636 | MatMulShape{ 32, 16, 27}, // | ||
637 | MatMulShape{ 32, 32, 3}, // | ||
638 | MatMulShape{ 32, 32, 4}, // | ||
639 | MatMulShape{ 33, 29, 24}, // | ||
640 | MatMulShape{ 64, 64, 3}, // | ||
641 | MatMulShape{ 64, 64, 4}, // | ||
642 | MatMulShape{ 96, 96, 3}, // | ||
643 | MatMulShape{ 96, 97, 3}, // | ||
644 | MatMulShape{ 97, 96, 3}, // | ||
645 | MatMulShape{123, 85, 45}, // | ||
646 | MatMulShape{128, 128, 3}, // | ||
647 | MatMulShape{130, 130, 6}, // | ||
648 | // clang-format on | ||
649 | }), | ||
650 | testing::ValuesIn(std::initializer_list<size_t>{1, 2, 3, 4, 8, 11, 16, 32, 33, 64, 65}), // | ||
651 | testing::ValuesIn({ | ||
652 | // clang-format off | ||
653 | // (Start row , start col , height , width) | ||
654 | MatrixPortion( 0 , 0 , 1 , 1 ), // Full matrix. | ||
655 | MatrixPortion( 0 , 0 , 1 , 0.5 ), // Left half | ||
656 | MatrixPortion( 0 , 0 , 0.5 , 1 ), // Upper half | ||
657 | MatrixPortion( 0 , 0.5 , 1 , 0.5 ), // Right half | ||
658 | MatrixPortion( 0.5 , 0 , 0.5 , 1 ), // Bottom half | ||
659 | MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3 ), // Center ninth | ||
660 | // clang-format on | ||
661 | }), | ||
662 | testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})), // | ||
663 | testing::PrintToStringParamName()); | ||
664 | |||
665 | } // namespace kai::test | ||
666 |