KleidiAI Coverage Report


Directory: ./
File: test/tests/imatmul_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 99.7% 312 0 313
Functions: 100.0% 33 0 33
Branches: 48.5% 180 0 371

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <string_view>
12 #include <tuple>
13 #include <unordered_map>
14
15 #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
16 #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h"
17 #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h"
18 #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
19 #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
20 #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h"
21 #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h"
22 #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h"
23 #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
24 #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h"
25 #include "test/common/buffer.hpp"
26 #include "test/common/compare.hpp"
27 #include "test/common/cpu_info.hpp"
28 #include "test/common/matmul_test_common.hpp"
29 #include "test/common/matrix_portion.hpp"
30 #include "test/common/memory.hpp"
31 #include "test/common/round.hpp"
32 #include "test/common/sme.hpp"
33 #include "test/reference/clamp.hpp"
34 #include "test/reference/fill.hpp"
35 #include "test/reference/matmul.hpp"
36 #include "test/reference/reorder.hpp"
37
38 namespace kai::test {
39
40 // Ensure static linkage for all functionality local to this test file
41 namespace {
42
43 /// Convenience wrapper for K-chunk handling
44 struct KChunk {
45 size_t count;
46 size_t length;
47 };
48
49 /// Interface for indirect matmul LHS packing micro-kernel
50
2/4
✓ Branch 0 taken 134648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 134648 times.
✗ Branch 3 not taken.
134648 struct LhsPackIndirectKernel {
51 std::function<size_t()> get_m_step;
52 std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset;
53 std::function<size_t(size_t m, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_size;
54 std::function<void(
55 size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset,
56 const void* zero, void* lhs_packed)>
57 pack;
58 };
59
60 /// Interface for indirect matmul RHS packing micro-kernel
61
4/8
✓ Branch 0 taken 134648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 134648 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 134648 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 134648 times.
✗ Branch 7 not taken.
134648 struct RhsPackIndirectKernel {
62 std::function<size_t()> get_n_step;
63 std::function<size_t(size_t n_idx)> get_rhs_offset;
64 std::function<size_t(size_t n_idx)> get_bias_offset;
65 std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset;
66 std::function<size_t(size_t n, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_size;
67 std::function<void(
68 size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias,
69 void* rhs_packed)>
70 pack;
71 };
72
73 /// Interface for indirect matmul kernel
74
8/16
✓ Branch 0 taken 134648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 134648 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 134648 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 134648 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 134648 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 134648 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 134648 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 134648 times.
✗ Branch 15 not taken.
134648 struct MatMulIndirectKernel {
75 std::function<size_t(void)> get_m_step;
76 std::function<size_t(void)> get_n_step;
77 std::function<size_t(void)> get_mr;
78 std::function<size_t(void)> get_nr;
79 std::function<size_t(void)> get_kr;
80 std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset;
81 std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset;
82 std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride_row)> get_dst_offset;
83 std::function<size_t(size_t m, size_t n)> get_dst_size;
84 std::function<void(
85 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed,
86 void* dst, size_t dst_stride_row, float clamp_min, float clamp_max)>
87 imatmul;
88 };
89
90 /// Description of a Indirect Matmul kernel set
91
2/4
✓ Branch 0 taken 134648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 134648 times.
✗ Branch 3 not taken.
134648 struct IndirectMatMul {
92 std::string_view name;
93 std::function<bool(void)> is_supported;
94
95 MatMulShape pack_shape;
96 struct Format {
97 DataFormat lhs;
98 DataFormat rhs;
99 DataFormat bias;
100 DataFormat out;
101
102 struct Hash {
103 29172 size_t operator()(const Format& format) const {
104 29172 return //
105 58344 (DataFormat::Hash{}(format.lhs) << 0) ^ //
106 58344 (DataFormat::Hash{}(format.rhs) << 1) ^ //
107 58344 (DataFormat::Hash{}(format.bias) << 2) ^ //
108 29172 (DataFormat::Hash{}(format.out) << 3);
109 }
110 };
111
112 private:
113 24684 friend bool operator==(const Format& lhs, const Format& rhs) {
114 24684 return //
115
1/2
✓ Branch 0 taken 24684 times.
✗ Branch 1 not taken.
24684 lhs.lhs == rhs.lhs && //
116
1/2
✓ Branch 0 taken 24684 times.
✗ Branch 1 not taken.
24684 lhs.rhs == rhs.rhs && //
117
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 24684 times.
24684 lhs.bias == rhs.bias && //
118 24684 lhs.out == rhs.out;
119 }
120 } format;
121
122 LhsPackIndirectKernel lhs;
123 RhsPackIndirectKernel rhs;
124 MatMulIndirectKernel imatmul;
125 };
126
127 /// Convenience type for test list
128 using IndirectMatMulArray = std::array<IndirectMatMul, 4>;
129
130 /// Test parameter bundle type
131 using IndirectMatMulTestParams = std::tuple<IndirectMatMul, MatMulShape, size_t, MatrixPortion, float>;
132
133 /// Test type
134 using IndirectMatMulTest = testing::TestWithParam<IndirectMatMulTestParams>;
135
136 /// Use interface for matmul kernel
137 1 const kai_imatmul_clamp_f16_f16p_f16p_ukernel& get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() {
138 static kai_imatmul_clamp_f16_f16p_f16p_ukernel ukernel;
139 1 ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
140 1 ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
141 1 ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
142 1 ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
143 1 ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
144 1 ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
145 1 ukernel.run_imatmul = kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa;
146 1 return ukernel;
147 }
148
149 1 const kai_imatmul_clamp_f16_f16p_f16p_ukernel& get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa() {
150 static kai_imatmul_clamp_f16_f16p_f16p_ukernel ukernel;
151 1 ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
152 1 ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
153 1 ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
154 1 ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
155 1 ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
156 1 ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
157 1 ukernel.run_imatmul = kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa;
158 1 return ukernel;
159 }
160
161 /// Use interface for matmul kernel
162 1 const kai_imatmul_clamp_f32_f32p_f32p_ukernel& get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() {
163 static kai_imatmul_clamp_f32_f32p_f32p_ukernel ukernel;
164 1 ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa;
165 1 ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa;
166 1 ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa;
167 1 ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa;
168 1 ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa;
169 1 ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa;
170 1 ukernel.run_imatmul = kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa;
171 1 return ukernel;
172 }
173
174 1 const kai_imatmul_clamp_f32_f32p_f32p_ukernel& get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa() {
175 static kai_imatmul_clamp_f32_f32p_f32p_ukernel ukernel;
176 1 ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
177 1 ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
178 1 ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
179 1 ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
180 1 ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
181 1 ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
182 1 ukernel.run_imatmul = kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa;
183 1 return ukernel;
184 }
185
186 /// Retreive the test list
187 1 const IndirectMatMulArray& get_indirect_matmul_methods() {
188
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
1 static IndirectMatMulArray indirect_matmul_methods{};
189
190 // F16 IMATMUL SME2 ///////////////////////////////////////////////////////
191 1 indirect_matmul_methods[0].name = "imatmul_f16_f16p_f16p_2vlx2vl_sme2_mopa";
192 1 indirect_matmul_methods[0].is_supported = cpu_has_sme2;
193 1 indirect_matmul_methods[0].pack_shape.m = 2 * get_sme_vector_length<int32_t>();
194 1 indirect_matmul_methods[0].pack_shape.n = 2 * get_sme_vector_length<int32_t>();
195 1 indirect_matmul_methods[0].pack_shape.k = sizeof(int32_t);
196 1 indirect_matmul_methods[0].format.lhs = DataFormat(DataType::FP16);
197 1 indirect_matmul_methods[0].format.rhs = DataFormat(DataType::FP16);
198 1 indirect_matmul_methods[0].format.bias = DataFormat(DataType::FP16);
199 1 indirect_matmul_methods[0].format.out = DataFormat(DataType::FP16);
200
201 // LHS
202 1 indirect_matmul_methods[0].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme;
203 1 indirect_matmul_methods[0].lhs.get_lhs_packed_offset =
204 kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme;
205 1 indirect_matmul_methods[0].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme;
206 1 indirect_matmul_methods[0].lhs.pack = kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme;
207
208 // RHS
209 1 indirect_matmul_methods[0].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
210 1 indirect_matmul_methods[0].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
211 1 indirect_matmul_methods[0].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
212 1 indirect_matmul_methods[0].rhs.get_rhs_packed_offset =
213 kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
214 1 indirect_matmul_methods[0].rhs.get_rhs_packed_size =
215 kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
216 1 indirect_matmul_methods[0].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
217
218 // IMATMUL
219 2 const kai_imatmul_clamp_f16_f16p_f16p_ukernel& ukernel_f16_sme2 =
220 1 get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa();
221 1 indirect_matmul_methods[0].imatmul.get_m_step = ukernel_f16_sme2.get_m_step;
222 1 indirect_matmul_methods[0].imatmul.get_n_step = ukernel_f16_sme2.get_n_step;
223 1 indirect_matmul_methods[0].imatmul.get_lhs_packed_offset = ukernel_f16_sme2.get_lhs_packed_offset;
224 1 indirect_matmul_methods[0].imatmul.get_rhs_packed_offset = ukernel_f16_sme2.get_rhs_packed_offset;
225 1 indirect_matmul_methods[0].imatmul.get_dst_offset = ukernel_f16_sme2.get_dst_offset;
226 1 indirect_matmul_methods[0].imatmul.get_dst_size = ukernel_f16_sme2.get_dst_size;
227 1 indirect_matmul_methods[0].imatmul.imatmul = ukernel_f16_sme2.run_imatmul;
228
229 // F32 IMATMUL SME2 ///////////////////////////////////////////////////////
230 1 indirect_matmul_methods[1].name = "imatmul_f32_f32p_f32p_2vlx2vl_sme2_mopa";
231 1 indirect_matmul_methods[1].is_supported = cpu_has_sme2;
232 1 indirect_matmul_methods[1].pack_shape.m = 2 * get_sme_vector_length<int32_t>();
233 1 indirect_matmul_methods[1].pack_shape.n = 2 * get_sme_vector_length<int32_t>();
234 1 indirect_matmul_methods[1].pack_shape.k = sizeof(int32_t);
235 1 indirect_matmul_methods[1].format.lhs = DataFormat(DataType::FP32);
236 1 indirect_matmul_methods[1].format.rhs = DataFormat(DataType::FP32);
237 1 indirect_matmul_methods[1].format.bias = DataFormat(DataType::FP32);
238 1 indirect_matmul_methods[1].format.out = DataFormat(DataType::FP32);
239
240 // LHS
241 1 indirect_matmul_methods[1].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme;
242 1 indirect_matmul_methods[1].lhs.get_lhs_packed_offset =
243 kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme;
244 1 indirect_matmul_methods[1].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme;
245 1 indirect_matmul_methods[1].lhs.pack = kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme;
246
247 // RHS
248 1 indirect_matmul_methods[1].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
249 1 indirect_matmul_methods[1].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
250 1 indirect_matmul_methods[1].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
251 1 indirect_matmul_methods[1].rhs.get_rhs_packed_offset =
252 kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
253 1 indirect_matmul_methods[1].rhs.get_rhs_packed_size =
254 kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
255 1 indirect_matmul_methods[1].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
256
257 // IMATMUL
258 2 const kai_imatmul_clamp_f32_f32p_f32p_ukernel& ukernel_f32_sme2 =
259 1 get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
260 1 indirect_matmul_methods[1].imatmul.get_m_step = ukernel_f32_sme2.get_m_step;
261 1 indirect_matmul_methods[1].imatmul.get_n_step = ukernel_f32_sme2.get_n_step;
262 1 indirect_matmul_methods[1].imatmul.get_lhs_packed_offset = ukernel_f32_sme2.get_lhs_packed_offset;
263 1 indirect_matmul_methods[1].imatmul.get_rhs_packed_offset = ukernel_f32_sme2.get_rhs_packed_offset;
264 1 indirect_matmul_methods[1].imatmul.get_dst_offset = ukernel_f32_sme2.get_dst_offset;
265 1 indirect_matmul_methods[1].imatmul.get_dst_size = ukernel_f32_sme2.get_dst_size;
266 1 indirect_matmul_methods[1].imatmul.imatmul = ukernel_f32_sme2.run_imatmul;
267
268 // F16 IMATMUL SME ////////////////////////////////////////////////////////
269 1 indirect_matmul_methods[2].name = "imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa";
270 1 indirect_matmul_methods[2].is_supported = cpu_has_sme;
271 1 indirect_matmul_methods[2].pack_shape.m = 2 * get_sme_vector_length<int32_t>();
272 1 indirect_matmul_methods[2].pack_shape.n = 2 * get_sme_vector_length<int32_t>();
273 1 indirect_matmul_methods[2].pack_shape.k = sizeof(int32_t);
274 1 indirect_matmul_methods[2].format.lhs = DataFormat(DataType::FP16);
275 1 indirect_matmul_methods[2].format.rhs = DataFormat(DataType::FP16);
276 1 indirect_matmul_methods[2].format.bias = DataFormat(DataType::FP16);
277 1 indirect_matmul_methods[2].format.out = DataFormat(DataType::FP16);
278
279 // LHS
280 1 indirect_matmul_methods[2].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme;
281 1 indirect_matmul_methods[2].lhs.get_lhs_packed_offset =
282 kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme;
283 1 indirect_matmul_methods[2].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme;
284 1 indirect_matmul_methods[2].lhs.pack = kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme;
285
286 // RHS
287 1 indirect_matmul_methods[2].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
288 1 indirect_matmul_methods[2].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
289 1 indirect_matmul_methods[2].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
290 1 indirect_matmul_methods[2].rhs.get_rhs_packed_offset =
291 kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
292 1 indirect_matmul_methods[2].rhs.get_rhs_packed_size =
293 kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
294 1 indirect_matmul_methods[2].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme;
295
296 // IMATMUL
297 2 const kai_imatmul_clamp_f16_f16p_f16p_ukernel& ukernel_f16_sme =
298 1 get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa();
299 1 indirect_matmul_methods[2].imatmul.get_m_step = ukernel_f16_sme.get_m_step;
300 1 indirect_matmul_methods[2].imatmul.get_n_step = ukernel_f16_sme.get_n_step;
301 1 indirect_matmul_methods[2].imatmul.get_lhs_packed_offset = ukernel_f16_sme.get_lhs_packed_offset;
302 1 indirect_matmul_methods[2].imatmul.get_rhs_packed_offset = ukernel_f16_sme.get_rhs_packed_offset;
303 1 indirect_matmul_methods[2].imatmul.get_dst_offset = ukernel_f16_sme.get_dst_offset;
304 1 indirect_matmul_methods[2].imatmul.get_dst_size = ukernel_f16_sme.get_dst_size;
305 1 indirect_matmul_methods[2].imatmul.imatmul = ukernel_f16_sme.run_imatmul;
306
307 // F32 IMATMUL SME ////////////////////////////////////////////////////////
308 1 indirect_matmul_methods[3].name = "imatmul_f32_f32p_f32p_2vlx2vl_sme_mopa";
309 1 indirect_matmul_methods[3].is_supported = cpu_has_sme;
310 1 indirect_matmul_methods[3].pack_shape.m = 2 * get_sme_vector_length<int32_t>();
311 1 indirect_matmul_methods[3].pack_shape.n = 2 * get_sme_vector_length<int32_t>();
312 1 indirect_matmul_methods[3].pack_shape.k = sizeof(int32_t);
313 1 indirect_matmul_methods[3].format.lhs = DataFormat(DataType::FP32);
314 1 indirect_matmul_methods[3].format.rhs = DataFormat(DataType::FP32);
315 1 indirect_matmul_methods[3].format.bias = DataFormat(DataType::FP32);
316 1 indirect_matmul_methods[3].format.out = DataFormat(DataType::FP32);
317
318 // LHS
319 1 indirect_matmul_methods[3].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme;
320 1 indirect_matmul_methods[3].lhs.get_lhs_packed_offset =
321 kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme;
322 1 indirect_matmul_methods[3].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme;
323 1 indirect_matmul_methods[3].lhs.pack = kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme;
324
325 // RHS
326 1 indirect_matmul_methods[3].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
327 1 indirect_matmul_methods[3].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
328 1 indirect_matmul_methods[3].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
329 1 indirect_matmul_methods[3].rhs.get_rhs_packed_offset =
330 kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
331 1 indirect_matmul_methods[3].rhs.get_rhs_packed_size =
332 kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
333 1 indirect_matmul_methods[3].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme;
334
335 // IMATMUL
336 2 const kai_imatmul_clamp_f32_f32p_f32p_ukernel& ukernel_f32_sme =
337 1 get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
338 1 indirect_matmul_methods[3].imatmul.get_m_step = ukernel_f32_sme.get_m_step;
339 1 indirect_matmul_methods[3].imatmul.get_n_step = ukernel_f32_sme.get_n_step;
340 1 indirect_matmul_methods[3].imatmul.get_lhs_packed_offset = ukernel_f32_sme.get_lhs_packed_offset;
341 1 indirect_matmul_methods[3].imatmul.get_rhs_packed_offset = ukernel_f32_sme.get_rhs_packed_offset;
342 1 indirect_matmul_methods[3].imatmul.get_dst_offset = ukernel_f32_sme.get_dst_offset;
343 1 indirect_matmul_methods[3].imatmul.get_dst_size = ukernel_f32_sme.get_dst_size;
344 1 indirect_matmul_methods[3].imatmul.imatmul = ukernel_f32_sme.run_imatmul;
345
346 1 return indirect_matmul_methods;
347 1 }
348
349 /// Test reference identification
350 struct TestDataId {
351 MatMulShape shape;
352 MatMulShape pack_shape;
353 IndirectMatMul::Format format;
354 size_t k_chunk_length;
355 float clamp_rate;
356
357 struct Hash {
358 29172 size_t operator()(const TestDataId& test_id) const {
359 29172 return //
360 58344 (MatMulShape::Hash{}(test_id.shape) << 0) ^ //
361 58344 (MatMulShape::Hash{}(test_id.pack_shape) << 1) ^ //
362 58344 (IndirectMatMul::Format::Hash{}(test_id.format) << 2) ^ //
363 58344 (std::hash<size_t>{}(test_id.k_chunk_length) << 3) ^ //
364 29172 (std::hash<float>{}(test_id.clamp_rate) << 4); //
365 }
366 };
367
368 private:
369 31212 friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) {
370 31212 return //
371
2/2
✓ Branch 0 taken 24684 times.
✓ Branch 1 taken 6528 times.
31212 lhs.shape == rhs.shape && //
372
1/2
✓ Branch 0 taken 24684 times.
✗ Branch 1 not taken.
24684 lhs.pack_shape == rhs.pack_shape && //
373
1/2
✓ Branch 0 taken 24684 times.
✗ Branch 1 not taken.
24684 lhs.format == rhs.format && //
374
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 24684 times.
24684 lhs.k_chunk_length == rhs.k_chunk_length && //
375 24684 lhs.clamp_rate == rhs.clamp_rate;
376 }
377 };
378
379 /// Test reference data
380 struct TestData {
381 Buffer lhs; ///< LHS input matrix
382 Buffer rhs; ///< RHS input matrix
383 Buffer bias; ///< Bias vector
384 Buffer out; ///< Reference imatmul result
385 Buffer indirection; ///< LHS indirection buffer
386 uintptr_t indirection_offset; ///< LHS indirection buffer offset
387 Buffer padding; ///< Padding buffer
388 Range<float> clamp_range; ///< Clamp range
389 };
390
391 /// Reference data generator
392 ///
393 /// Uses test id to generate reference data, and caches it.
394 struct ReferenceGenerator {
395 /// Retrieve reference data for the provided test identification
396 26928 static const TestData& get_test_reference(const TestDataId& test_id) {
397
3/4
✓ Branch 0 taken 1 time.
✓ Branch 1 taken 26927 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
26928 static std::unordered_map<TestDataId, TestData, TestDataId::Hash> m_data;
398
4/5
✓ Branch 0 taken 24684 times.
✓ Branch 1 taken 2244 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 24684 times.
✓ Branch 4 taken 2244 times.
51612 if (const auto itr = m_data.find(test_id); itr != end(m_data)) {
399 24684 return itr->second;
400 }
401
402
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
2244 return m_data[test_id] = generate_reference(test_id);
403 26928 }
404
405 private:
406 /// Return incremented seed value
407 6732 static size_t get_seed() {
408 static size_t seed = 0;
409 6732 return seed++;
410 }
411
412 /// Generate reference data. Not intended to be called
413 /// directly, as this would bypass caching mechanism.
414 2244 static TestData generate_reference(const TestDataId& test_id) {
415 1097496 const auto& [chunked_shape, pack_shape, format, k_chunk_length, clamp_rate] = test_id;
416
417 // The LHS matrix will be split into several chunks in the K dimension
418 4488 const size_t k_chunk_count = chunked_shape.k;
419 8976 MatMulShape shape = {chunked_shape.m, chunked_shape.n, k_chunk_count * k_chunk_length};
420
421 // Generate random input data
422 4488 Buffer lhs = fill_matrix_random(shape.m, shape.k, format.lhs, get_seed());
423
3/6
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
4488 Buffer rhs = fill_matrix_random(shape.k, shape.n, format.rhs, get_seed());
424
3/6
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
4488 Buffer bias = fill_matrix_random(1, shape.n, format.bias, get_seed());
425
426 // Data types used
427
2/4
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
4488 const DataType lhs_dt = format.lhs.data_type();
428
2/4
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
4488 const DataType rhs_dt = format.rhs.data_type();
429
2/4
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
4488 const DataType out_dt = format.out.data_type();
430
2/4
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
4488 const DataType bias_dt = format.bias.data_type();
431
432 // Create a padding chunk
433
3/6
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
4488 const size_t k_chunk_size = round_up_division(k_chunk_length * data_type_size_in_bits(lhs_dt), 8);
434 2244 const size_t row_size = k_chunk_count * k_chunk_size;
435
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
2244 Buffer lhs_padding(k_chunk_size);
436
4/4
✓ Branch 0 taken 48756 times.
✓ Branch 1 taken 2244 times.
✓ Branch 2 taken 48756 times.
✓ Branch 3 taken 2244 times.
51000 for (size_t i = 0; i < k_chunk_length; i += 1) {
437 static constexpr double padding_value = 0;
438
2/4
✓ Branch 0 taken 48756 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 48756 times.
✗ Branch 3 not taken.
48756 write_array(lhs_dt, lhs_padding.data(), i, padding_value);
439 48756 }
440
441 // Set up indirection buffer
442
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
2244 const uintptr_t indirection_offset = reinterpret_cast<uintptr_t>(lhs.data());
443
3/6
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
6732 std::vector<const void*> indirection(chunked_shape.m * chunked_shape.k);
444
4/4
✓ Branch 0 taken 71478 times.
✓ Branch 1 taken 2244 times.
✓ Branch 2 taken 71478 times.
✓ Branch 3 taken 2244 times.
73722 for (size_t i_m = 0; i_m < chunked_shape.m; i_m += 1) {
445
4/4
✓ Branch 0 taken 71478 times.
✓ Branch 1 taken 858660 times.
✓ Branch 2 taken 71478 times.
✓ Branch 3 taken 858660 times.
930138 for (size_t i_k = 0; i_k < chunked_shape.k; i_k += 1) {
446 1717320 const size_t idx = i_m * chunked_shape.k + i_k;
447 // Test padding pointers using first LHS row for shapes where M > 1
448
4/4
✓ Branch 0 taken 839256 times.
✓ Branch 1 taken 19404 times.
✓ Branch 2 taken 815100 times.
✓ Branch 3 taken 24156 times.
858660 if (chunked_shape.m > 1 && i_m == 0) {
449
2/4
✓ Branch 0 taken 24156 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24156 times.
✗ Branch 3 not taken.
24156 indirection.at(idx) = lhs_padding.data();
450 24156 } else {
451 834504 uintptr_t offset = i_m * row_size + i_k * k_chunk_size;
452
1/2
✓ Branch 0 taken 834504 times.
✗ Branch 1 not taken.
834504 indirection.at(idx) = reinterpret_cast<const void*>(offset);
453 834504 }
454 858660 }
455 71478 }
456
457 // Pack indirection buffer
458
2/4
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
4488 Buffer indirection_packed = reorder_block<const void*>(
459 8976 reinterpret_cast<const void* const*>(indirection.data()), chunked_shape.m, chunked_shape.k, pack_shape.m,
460 1);
461
462
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
4488 Buffer out = indirect_matmul( //
463
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
2244 indirection.data(), indirection_offset, lhs_padding.data(), nullptr, nullptr, lhs_dt, // LHS
464
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
2244 rhs.data(), nullptr, nullptr, rhs_dt, // RHS
465
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
2244 bias.data(), nullptr, nullptr, bias_dt, // Bias
466 2244 out_dt, // Out
467 8976 chunked_shape.m, chunked_shape.n, chunked_shape.k, k_chunk_length);
468
469 // Calculate clamping range based on full range of values, and then clamp values
470
3/6
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
2244 const auto [min, max] = find_clamp_range(out_dt, out.data(), shape.m * shape.n, 1.0F - clamp_rate);
471
4/8
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2244 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2244 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2244 times.
✗ Branch 7 not taken.
2244 Buffer out_clamped = clamp(out_dt, out.data(), shape.m * shape.n, min, max);
472
473 // Populate reference data
474 2244 TestData test_reference;
475 2244 test_reference.lhs = std::move(lhs);
476 2244 test_reference.rhs = std::move(rhs);
477 2244 test_reference.bias = std::move(bias);
478 2244 test_reference.padding = std::move(lhs_padding);
479 2244 test_reference.out = std::move(out_clamped);
480 2244 test_reference.indirection_offset = indirection_offset;
481 2244 test_reference.indirection = std::move(indirection_packed);
482 6732 test_reference.clamp_range = {min, max};
483
484 2244 return test_reference;
485
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
2244 };
486 };
487
488 /// Perform LHS packing for indirect matmul
489 22704 Buffer pack_lhs(
490 const LhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t m,
491 const KChunk& k_chunk) {
492 22704 const void* const* indirection_pointer = reinterpret_cast<const void* const*>(reference.indirection.data());
493
494 // Calculate size, and allocate buffer
495 22704 const size_t dst_size = kernel.get_lhs_packed_size(m, k_chunk.count, k_chunk.length);
496 22704 Buffer dst(dst_size);
497
498 // Calculate portion offsets
499
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 const size_t input_offset = portion.start_row() * k_chunk.count;
500
2/4
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
22704 const size_t dst_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length);
501
502 // Perform packing
503
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
45408 kernel.pack(
504
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 portion.height(), k_chunk.count, k_chunk.length, // Dimensions
505 22704 indirection_pointer + input_offset, // Indirection input
506 22704 reference.indirection_offset, // Chunk offset
507
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 reference.padding.data(), // Padding pointer
508
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 dst.data() + dst_offset);
509 22704 return dst;
510 22704 }
511
512 /// Perform RHS packign for indirect matmul
513 22704 Buffer pack_rhs(
514 const RhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t n,
515 const KChunk& k_chunk, DataType type) {
516 // Calculate size, and allocate buffer
517 22704 const size_t row_stride = round_up_division(n * data_type_size_in_bits(type), 8);
518 22704 const size_t dst_size = kernel.get_rhs_packed_size(n, k_chunk.count, k_chunk.length);
519 22704 Buffer dst(dst_size);
520
521 // Calculate offsets
522
2/4
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
22704 const size_t rhs_offset = kernel.get_rhs_offset(portion.start_col());
523
2/4
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
22704 const size_t bias_offset = kernel.get_bias_offset(portion.start_col());
524
2/4
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
22704 const size_t dst_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length);
525
526 // Perform actual packing
527
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
45408 kernel.pack(
528
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 portion.width(), k_chunk.count, k_chunk.length, row_stride, // Dimensions
529
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 reference.rhs.data() + rhs_offset, // RHS input
530
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 reference.bias.data() + bias_offset, // Bias
531
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 dst.data() + dst_offset); // Output
532 22704 return dst;
533 22704 }
534
535 /// Perform imatmul
536 ///
537 /// Note, this should not be aware of reference result, as to make it clear that
538 /// any produced result is strictly from the code under test
539 22704 Buffer imatmul(
540 const MatMulIndirectKernel& kernel, const Rect& portion, const MatMulShape& shape, const KChunk& k_chunk,
541 const Buffer& lhs_packed, const Buffer& rhs_packed, Range<float> clamp_range, DataType type) {
542 // Calculate size, and allocate buffer
543 22704 const size_t dst_size = kernel.get_dst_size(shape.m, shape.n);
544 22704 const size_t row_stride = round_up_division(shape.n * data_type_size_in_bits(type), 8);
545 22704 Buffer dst(dst_size);
546
547 // Calculate portion offsets
548
2/4
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
22704 const size_t lhs_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length);
549
2/4
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
22704 const size_t rhs_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length);
550
3/6
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 22704 times.
✗ Branch 5 not taken.
22704 const size_t dst_offset = kernel.get_dst_offset(portion.start_row(), portion.start_col(), row_stride);
551
552 // Call matmul kernel
553
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
45408 kernel.imatmul(
554
2/4
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
22704 portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions
555
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 lhs_packed.data() + lhs_offset, // LHS
556
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 rhs_packed.data() + rhs_offset, // RHS
557
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 dst.data() + dst_offset, // DST
558 22704 row_stride, clamp_range.min, clamp_range.max);
559
560 22704 return dst;
561 22704 }
562
563 } // namespace
564
565 /// End-to-end test for indirection matmul kernels
566
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
80786 TEST_P(IndirectMatMulTest, Output) {
567 231264 const auto& [method, shape, k_chunk_length, output_portion, clamp_rate] = GetParam();
568
1/2
✓ Branch 0 taken 26928 times.
✗ Branch 1 not taken.
26928 if (not method.is_supported()) {
569 GTEST_SKIP() << "Unsupported CPU feature";
570 }
571
572 80784 const KChunk k_chunk{shape.k, k_chunk_length};
573
574 // Retrieve reference data
575 161568 const TestDataId test_id{shape, method.pack_shape, method.format, k_chunk_length, clamp_rate};
576 26928 const TestData& test_data = ReferenceGenerator::get_test_reference(test_id);
577 134640 const Rect portion = output_portion.compute_portion(shape.m, shape.n, method.pack_shape.m, method.pack_shape.n);
578
579
4/4
✓ Branch 0 taken 23364 times.
✓ Branch 1 taken 3564 times.
✓ Branch 2 taken 660 times.
✓ Branch 3 taken 22704 times.
26928 if (portion.height() == 0 || portion.width() == 0) {
580
9/18
✓ Branch 0 taken 4224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4224 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4224 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4224 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4224 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4224 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4224 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4224 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4224 times.
✗ Branch 17 not taken.
4224 GTEST_SKIP() << "Empty dimension of matrix(" << portion.width() << "," << portion.height() << ")";
581 }
582
583 // Call packing micro-kernels, and then imatmul kernel
584 68112 Buffer lhs_packed = pack_lhs(method.lhs, portion, test_data, shape.m, k_chunk);
585
5/10
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 22704 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 22704 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 22704 times.
✗ Branch 9 not taken.
90816 Buffer rhs_packed = pack_rhs(method.rhs, portion, test_data, shape.n, k_chunk, method.format.rhs.data_type());
586
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
45408 Buffer out = imatmul(
587 45408 method.imatmul, portion, shape, k_chunk, lhs_packed, rhs_packed, test_data.clamp_range,
588
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 method.format.out.data_type());
589
590 // Compare the actual result with the reference result
591
1/2
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
22704 DefaultMismatchHandler handler(0, 0.1, 0, 0.05);
592 45408 const auto success =
593
7/14
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 22704 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 22704 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 22704 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 22704 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 22704 times.
✗ Branch 13 not taken.
22704 compare(out.data(), test_data.out.data(), method.format.out.data_type(), shape.m, shape.n, portion, handler);
594
4/16
✓ Branch 0 taken 22704 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22704 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 22704 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 22704 times.
22704 ASSERT_TRUE(success);
595 26928 }
596
597 /// Name generator for test case
598 53856 [[maybe_unused]] static void PrintTo(const IndirectMatMulTestParams& param, std::ostream* os) {
599 323136 const auto& [method, shape, k_chunk_length, portion, clamp_rate] = param;
600 107712 *os << method.name << "__";
601 53856 PrintTo(shape, os);
602 107712 *os << "__K_chunk_length_" << k_chunk_length;
603 107712 *os << "__clamp_rate_" << static_cast<int>(clamp_rate * 100) << "__";
604 53856 PrintTo(portion, os);
605 53856 }
606
607 /// Test parameter listing
608
18/60
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 time.
✗ Branch 33 not taken.
✓ Branch 34 taken 26928 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
26930 INSTANTIATE_TEST_SUITE_P(
609 IndirectMatMul, IndirectMatMulTest,
610 testing::Combine(
611 testing::ValuesIn(get_indirect_matmul_methods()), //
612 testing::ValuesIn({
613 // clang-format off
614 MatMulShape{ 1, 1, 1}, //
615 MatMulShape{ 1, 17, 4}, //
616 MatMulShape{ 1, 19, 24}, //
617 MatMulShape{ 1, 32, 4}, //
618 MatMulShape{ 1, 32, 32}, //
619 MatMulShape{ 1, 33, 200}, //
620 MatMulShape{ 1, 49, 21}, //
621 MatMulShape{ 1, 64, 4}, //
622 MatMulShape{ 1, 65, 4}, //
623 MatMulShape{ 3, 6, 6}, //
624 MatMulShape{ 3, 28, 25}, //
625 MatMulShape{ 4, 16, 4}, //
626 MatMulShape{ 4, 16, 27}, //
627 MatMulShape{ 6, 18, 31}, //
628 MatMulShape{ 6, 28, 1}, //
629 MatMulShape{ 6, 29, 24}, //
630 MatMulShape{ 8, 16, 16}, //
631 MatMulShape{ 16, 16, 4}, //
632 MatMulShape{ 16, 16, 16}, //
633 MatMulShape{ 20, 30, 40}, //
634 MatMulShape{ 23, 1, 43}, //
635 MatMulShape{ 32, 14, 1}, //
636 MatMulShape{ 32, 16, 27}, //
637 MatMulShape{ 32, 32, 3}, //
638 MatMulShape{ 32, 32, 4}, //
639 MatMulShape{ 33, 29, 24}, //
640 MatMulShape{ 64, 64, 3}, //
641 MatMulShape{ 64, 64, 4}, //
642 MatMulShape{ 96, 96, 3}, //
643 MatMulShape{ 96, 97, 3}, //
644 MatMulShape{ 97, 96, 3}, //
645 MatMulShape{123, 85, 45}, //
646 MatMulShape{128, 128, 3}, //
647 MatMulShape{130, 130, 6}, //
648 // clang-format on
649 }),
650 testing::ValuesIn(std::initializer_list<size_t>{1, 2, 3, 4, 8, 11, 16, 32, 33, 64, 65}), //
651 testing::ValuesIn({
652 // clang-format off
653 // (Start row , start col , height , width)
654 MatrixPortion( 0 , 0 , 1 , 1 ), // Full matrix.
655 MatrixPortion( 0 , 0 , 1 , 0.5 ), // Left half
656 MatrixPortion( 0 , 0 , 0.5 , 1 ), // Upper half
657 MatrixPortion( 0 , 0.5 , 1 , 0.5 ), // Right half
658 MatrixPortion( 0.5 , 0 , 0.5 , 1 ), // Bottom half
659 MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3 ), // Center ninth
660 // clang-format on
661 }),
662 testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})), //
663 testing::PrintToStringParamName());
664
665 } // namespace kai::test
666