KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 96.0% 243 / 0 / 253
Functions: 75.9% 22 / 0 / 29
Branches: 47.2% 109 / 0 / 231

test/nextgen/operators/matmul/matmul_tb.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/nextgen/operators/matmul/matmul_tb.hpp"
8
9 #include <algorithm>
10 #include <array>
11 #include <cstddef>
12 #include <tuple>
13 #include <utility>
14 #include <vector>
15
16 #include "test/common/assert.hpp"
17 #include "test/common/buffer.hpp"
18 #include "test/common/compare.hpp"
19 #include "test/common/data_type.hpp"
20 #include "test/nextgen/common/poly.hpp"
21 #include "test/nextgen/common/random.hpp"
22 #include "test/nextgen/format/format.hpp"
23 #include "test/nextgen/format/plain_format.hpp"
24 #include "test/nextgen/harness/kernel_wrapper.hpp"
25 #include "test/nextgen/operators/matmul/matmul_config.hpp"
26 #include "test/nextgen/operators/matmul/matmul_slots.hpp"
27 #include "test/nextgen/quantization/quantizer.hpp"
28 #include "test/nextgen/reference/binary_elementwise.hpp"
29 #include "test/nextgen/reference/clamp.hpp"
30 #include "test/nextgen/reference/matmul.hpp"
31 #include "test/nextgen/reference/reduce.hpp"
32 #include "test/nextgen/reference/unary_elementwise.hpp"
33 #include "test/reference/transpose.hpp"
34
35 namespace kai::test {
36
37 600 MatMulTb::MatMulTb(
38 size_t shape_m, size_t shape_n, size_t shape_k, MatMulBiasMode bias_mode, float clamp_ratio,
39 const MatMulOperator* op) :
40 200 m_shape_m(shape_m),
41 200 m_shape_n(shape_n),
42 200 m_shape_k(shape_k),
43 200 m_bias_mode(bias_mode),
44 200 m_clamp_ratio(clamp_ratio),
45 200 m_op(op),
46 400 m_tensors_required() {
47
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 std::fill(m_tensors_required.begin(), m_tensors_required.end(), false);
48 400 }
49
50 200 void MatMulTb::generate_test_data(Rng& rng) {
51 200 populate_config();
52 200 determine_required_tensors();
53
54 // Populates the constant information.
55 200 m_op->matmul->populate_constant_info(m_tensors);
56
57
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_op->pack_lhs.has_value()) {
58 200 const KernelWrapper& pack_lhs = *m_op->pack_lhs.value();
59 200 pack_lhs.populate_constant_info(m_tensors);
60 200 }
61
62
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_op->pack_rhs.has_value()) {
63 200 const KernelWrapper& pack_rhs = *m_op->pack_rhs.value();
64 200 pack_rhs.populate_constant_info(m_tensors);
65 200 }
66
67 // Generates the raw test data.
68 200 generate_lhs_raw(rng);
69 200 generate_rhs_raw(rng);
70 200 generate_bias_raw(rng);
71
72 200 compute_rhs_t_raw(); // The transposed RHS data is always needed for reference packing.
73
74 // Quantizes the input data.
75
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_op->lhs_quant.has_value()) {
76 200 quantize_lhs();
77 200 }
78
79
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_op->rhs_quant.has_value()) {
80 200 quantize_rhs_t();
81 200 }
82
83
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 if (m_op->bias_quant.has_value()) {
84 quantize_bias();
85 }
86
87
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_tensors_required.at(MATMUL_SLOT_LHS_QZP_NEG)) {
88 200 compute_lhs_qzp_neg();
89 200 }
90
91
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_tensors_required.at(MATMUL_SLOT_RHS_T_QDATA_SIGN)) {
92 200 compute_rhs_t_qdata_sign();
93 200 }
94
95
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_tensors_required.at(MATMUL_SLOT_RHS_T_QDATA_SIGN_SUM)) {
96 200 compute_rhs_t_qdata_sign_sum();
97 200 }
98
99 // Generates reference output.
100
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_op->pack_lhs.has_value()) {
101 200 compute_ref_packed_lhs();
102 200 }
103
104
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_op->pack_rhs.has_value()) {
105 200 compute_ref_packed_rhs();
106 200 }
107
108 200 compute_ref_matmul();
109 200 }
110
111 200 void MatMulTb::populate_config() {
112 200 m_tensors.at(MATMUL_SLOT_CONFIG).set_value(MatMulConfig{m_bias_mode});
113 200 }
114
115 200 void MatMulTb::determine_required_tensors() {
116
0/2
✗ Branch 0 not taken.
✗ Branch 1 not taken.
200 std::vector<const KernelWrapper*> kernels{m_op->matmul.get()};
117
118
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_op->pack_lhs.has_value()) {
119
2/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
200 kernels.emplace_back(m_op->pack_lhs.value().get());
120 200 }
121
122
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
200 if (m_op->pack_rhs.has_value()) {
123
2/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
200 kernels.emplace_back(m_op->pack_rhs.value().get());
124 200 }
125
126
2/2
✓ Branch 0 taken 200 times.
✓ Branch 1 taken 600 times.
800 for (const KernelWrapper* kernel : kernels) {
127
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
600 if (kernel != nullptr) {
128
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
600 const std::vector<size_t> run_inputs = kernel->run_inputs(m_tensors);
129
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
600 const std::vector<size_t> ref_inputs = kernel->ref_inputs(m_tensors);
130
131
2/2
✓ Branch 0 taken 1345 times.
✓ Branch 1 taken 600 times.
1945 for (const size_t id : run_inputs) {
132
1/2
✓ Branch 0 taken 1345 times.
✗ Branch 1 not taken.
1345 m_tensors_required.at(id) = true;
133 1345 }
134
135
2/2
✓ Branch 0 taken 600 times.
✓ Branch 1 taken 2545 times.
3145 for (const size_t id : ref_inputs) {
136
1/2
✓ Branch 0 taken 2545 times.
✗ Branch 1 not taken.
2545 m_tensors_required.at(id) = true;
137 2545 }
138 600 }
139 600 }
140 200 }
141
142 200 void MatMulTb::generate_lhs_raw(Rng& rng) {
143 200 const std::array shape{m_shape_m, m_shape_k};
144 200 const Poly<Format> format(std::in_place_type<PlainFormat>, DataType::FP32);
145
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 Tensor& tensor = m_tensors.at(MATMUL_SLOT_LHS_RAW);
146
147
5/10
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
200 tensor.set_shape(shape).set_format(format).set_data(format->generate_random(shape, rng));
148 200 }
149
150 200 void MatMulTb::generate_rhs_raw(Rng& rng) {
151 200 const std::array shape{m_shape_k, m_shape_n};
152 200 const Poly<Format> format(std::in_place_type<PlainFormat>, DataType::FP32);
153
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 Tensor& tensor = m_tensors.at(MATMUL_SLOT_RHS_RAW);
154
155
5/10
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
200 tensor.set_shape(shape).set_format(format).set_data(format->generate_random(shape, rng));
156 200 }
157
158 200 void MatMulTb::generate_bias_raw(Rng& rng) {
159 200 const std::array shape{m_shape_n};
160 200 const Poly<Format> format(std::in_place_type<PlainFormat>, DataType::FP32);
161
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 Tensor& tensor = m_tensors.at(MATMUL_SLOT_BIAS_RAW);
162
163
5/10
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
200 tensor.set_shape(shape).set_format(format).set_data(format->generate_random(shape, rng));
164 200 }
165
166 200 void MatMulTb::compute_rhs_t_raw() {
167 200 const std::array shape{m_shape_n, m_shape_k};
168 200 const Poly<Format> format(std::in_place_type<PlainFormat>, DataType::FP32);
169
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 Tensor& rhs_t_raw = m_tensors.at(MATMUL_SLOT_RHS_T_RAW);
170
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 const Tensor& rhs_raw = m_tensors.at(MATMUL_SLOT_RHS_RAW);
171
172
5/10
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
200 rhs_t_raw.set_shape(shape).set_format(format).set_data(transpose<float>(rhs_raw.data_ptr(), m_shape_k, m_shape_n));
173 200 }
174
175 200 void MatMulTb::quantize_lhs() {
176 200 const Quantizer& lhs_quant = *m_op->lhs_quant.value();
177
178 200 const std::array lhs_shape{m_shape_m, m_shape_k};
179 200 const Tensor& lhs_raw = m_tensors.at(MATMUL_SLOT_LHS_RAW);
180 200 Tensor& lhs_qdata = m_tensors.at(MATMUL_SLOT_LHS_QDATA);
181 200 Tensor& lhs_qscale = m_tensors.at(MATMUL_SLOT_LHS_QSCALE);
182 200 Tensor& lhs_qzp = m_tensors.at(MATMUL_SLOT_LHS_QZP);
183
184 200 lhs_quant.dynamic_quantize(DataType::FP32, lhs_shape, lhs_raw.data(), lhs_qdata, lhs_qscale, lhs_qzp);
185 200 }
186
187 200 void MatMulTb::quantize_rhs_t() {
188 200 const Quantizer& rhs_quant = *m_op->rhs_quant.value();
189
190 200 const std::array rhs_t_shape{m_shape_n, m_shape_k};
191 200 const Tensor& rhs_t_raw = m_tensors.at(MATMUL_SLOT_RHS_T_RAW);
192 200 Tensor& rhs_t_qdata = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA);
193 200 Tensor& rhs_t_qscale = m_tensors.at(MATMUL_SLOT_RHS_T_QSCALE);
194 200 Tensor& rhs_t_qzp = m_tensors.at(MATMUL_SLOT_RHS_T_QZP);
195
196 200 rhs_quant.dynamic_quantize(DataType::FP32, rhs_t_shape, rhs_t_raw.data(), rhs_t_qdata, rhs_t_qscale, rhs_t_qzp);
197 200 }
198
199 void MatMulTb::quantize_bias() {
200 KAI_TEST_ERROR("Not supported.");
201 }
202
203 200 void MatMulTb::compute_lhs_qzp_neg() {
204 200 const Tensor& lhs_qzp = m_tensors.at(MATMUL_SLOT_LHS_QZP);
205 200 Tensor& lhs_qzp_neg = m_tensors.at(MATMUL_SLOT_LHS_QZP_NEG);
206
207 200 const Span<const size_t> shape = lhs_qzp.shape();
208 200 const Poly<Format>& format = lhs_qzp.format();
209
210 200 const UnaryElementwiseFn fn = make_negate(format->dtype());
211 200 Buffer data = fn(shape, lhs_qzp.data());
212
213
3/6
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
200 lhs_qzp_neg.set_shape(shape).set_format(format).set_data(std::move(data));
214 200 }
215
216 200 void MatMulTb::compute_rhs_t_qdata_sign() {
217 200 const Tensor& rhs_t_qdata = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA);
218 200 Tensor& rhs_t_qdata_sign = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA_SIGN);
219
220 200 const Span<const size_t> shape = rhs_t_qdata.shape();
221 200 const Poly<Format>& format = rhs_t_qdata.format();
222
223 200 const UnaryElementwiseFn fn = make_change_signedness(format->dtype());
224 200 Buffer data = fn(shape, rhs_t_qdata.data());
225
226
3/6
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
200 rhs_t_qdata_sign.set_shape(shape).set_format(format).set_data(std::move(data));
227 200 }
228
229 200 void MatMulTb::compute_rhs_t_qdata_sign_sum() {
230 200 const Tensor& rhs_t_qdata_sign = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA_SIGN);
231 200 Tensor& rhs_t_qdata_sign_sum = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA_SIGN_SUM);
232
233 200 const std::array rhs_t_shape = {m_shape_n, m_shape_k};
234 200 const std::array rhs_t_rowsum_shape = {m_shape_n};
235 200 const DataType src_dtype = rhs_t_qdata_sign.format()->dtype();
236 200 const DataType dst_dtype = rhs_t_qdata_sign_sum.format()->dtype();
237
238 200 const ReduceFn fn = make_reduce_add(src_dtype, dst_dtype);
239 200 Buffer data = fn(0, rhs_t_shape, rhs_t_qdata_sign.data());
240
241
2/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
200 rhs_t_qdata_sign_sum.set_shape(rhs_t_rowsum_shape).set_data(std::move(data));
242 200 }
243
244 200 void MatMulTb::compute_ref_packed_lhs() {
245 200 const KernelWrapper& pack_lhs = *m_op->pack_lhs.value();
246 200 const std::array lhs_shape{m_shape_m, m_shape_k};
247 200 pack_lhs.compute_reference(lhs_shape, m_tensors);
248 200 }
249
250 200 void MatMulTb::compute_ref_packed_rhs() {
251 200 const KernelWrapper& pack_rhs = *m_op->pack_rhs.value();
252 200 const std::array rhs_t_shape{m_shape_n, m_shape_k};
253 200 pack_rhs.compute_reference(rhs_t_shape, m_tensors);
254 200 }
255
256 200 void MatMulTb::compute_ref_matmul() {
257 200 const MatMulConfig& config = m_tensors.at(MATMUL_SLOT_CONFIG).value<MatMulConfig>();
258 200 const Tensor& lhs_qdata = m_tensors.at(MATMUL_SLOT_LHS_QDATA);
259 200 const Tensor& lhs_qscale = m_tensors.at(MATMUL_SLOT_LHS_QSCALE);
260 200 const Tensor& lhs_qzp = m_tensors.at(MATMUL_SLOT_LHS_QZP);
261 200 const Tensor& rhs_t_qdata = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA);
262 200 const Tensor& rhs_t_qscale = m_tensors.at(MATMUL_SLOT_RHS_T_QSCALE);
263 200 const Tensor& bias_raw = m_tensors.at(MATMUL_SLOT_BIAS_RAW);
264 200 Tensor& kernel_args = m_tensors.at(MATMUL_SLOT_MATMUL_ARGS);
265 200 Tensor& ref_dst_data = m_tensors.at(MATMUL_SLOT_REF_DST_DATA);
266
267
2/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
200 ref_dst_data.set_shape({m_shape_m, m_shape_n}).set_format(make_poly<PlainFormat>(m_op->dst_dtype));
268
269 // REVISIT: Assumes that the LHS and RHS are both quantized.
270 200 const Quantizer& lhs_quant = *m_op->lhs_quant.value();
271 200 const Quantizer& rhs_quant = *m_op->rhs_quant.value();
272
273 1200 const Buffer lhs_data = lhs_quant.dequantize(
274 1000 m_op->acc_dtype, {m_shape_m, m_shape_k}, lhs_qdata.data(), lhs_qscale.data(), lhs_qzp.data());
275 200 const Buffer rhs_t_data =
276
3/6
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
200 rhs_quant.dequantize(m_op->acc_dtype, {m_shape_n, m_shape_k}, rhs_t_qdata.data(), rhs_t_qscale.data(), {});
277
278
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 const MatMulFn matmul_fn = make_matmul_nt_t(m_op->acc_dtype);
279
3/6
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
200 Buffer dst = matmul_fn(m_shape_m, m_shape_n, m_shape_k, lhs_data, rhs_t_data);
280
281
2/3
✓ Branch 0 taken 55 times.
✓ Branch 1 taken 145 times.
✗ Branch 2 not taken.
200 switch (config.bias_mode) {
282 case MatMulBiasMode::NO_BIAS:
283 break;
284
285 case MatMulBiasMode::PER_N: {
286
1/2
✓ Branch 0 taken 145 times.
✗ Branch 1 not taken.
145 const BinaryElementwiseFn add_fn = make_add_2d(m_op->acc_dtype);
287
3/6
✓ Branch 0 taken 145 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 145 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 145 times.
✗ Branch 5 not taken.
145 dst = add_fn(m_shape_m, m_shape_n, dst, 1, m_shape_n, bias_raw.data());
288 break;
289 145 }
290
291 default:
292 KAI_TEST_ERROR("Not supported.");
293 }
294
295
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 const DynamicClampFn dynamic_clamp_fn = make_dynamic_clamp(m_op->acc_dtype);
296
2/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
600 auto [clamp_args, clampped_dst] = dynamic_clamp_fn(m_clamp_ratio, {m_shape_m, m_shape_n}, dst);
297
298
5/10
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
400 kernel_args.set_shape({clamp_args.size()}).set_data(std::move(clamp_args));
299
300
1/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
200 KAI_TEST_ASSERT_MSG(
301 m_op->dst_dtype == m_op->acc_dtype, "Only support the accumulator and output type being the same.");
302
2/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
400 ref_dst_data.set_data(std::move(clampped_dst));
303 200 }
304
305 bool MatMulTb::has_lhs_packing() const {
306 return m_op->pack_lhs != nullptr;
307 }
308
309 600 std::tuple<size_t, size_t> MatMulTb::lhs_packing_steps() const {
310 600 const KernelWrapper& pack_lhs = *m_op->pack_lhs.value();
311 600 const std::vector<size_t> steps = pack_lhs.steps({m_shape_m, m_shape_k}, m_tensors);
312
2/4
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
600 return {steps.at(0), steps.at(1)};
313 600 }
314
315 600 void MatMulTb::test_lhs_packing(size_t start_m, size_t start_k, size_t size_m, size_t size_k) {
316 600 const KernelWrapper& pack_lhs = *m_op->pack_lhs.value();
317
318 600 const std::array full_shape{m_shape_m, m_shape_k};
319 600 const std::array tile_coords{start_m, start_k};
320 600 const std::array tile_shape{size_m, size_k};
321
322 600 pack_lhs.run(full_shape, tile_coords, tile_shape, m_tensors);
323
324 600 const Tensor& ref_packed_lhs = m_tensors.at(MATMUL_SLOT_REF_LHS_PACKED);
325 600 const Tensor& imp_packed_lhs = m_tensors.at(MATMUL_SLOT_IMP_LHS_PACKED);
326 600 const Format& format = *ref_packed_lhs.format();
327
328 600 DefaultMismatchHandler handler(0.0F, 0.0F, 0, 0.0F);
329 1200 const bool ok =
330
3/6
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
600 format.compare(full_shape, tile_coords, tile_shape, imp_packed_lhs.data(), ref_packed_lhs.data(), handler);
331
1/6
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
600 KAI_TEST_ASSERT(ok);
332 600 }
333
334 bool MatMulTb::has_rhs_packing() const {
335 return m_op->pack_rhs.has_value();
336 }
337
338 600 std::tuple<size_t, size_t> MatMulTb::rhs_packing_steps() const {
339 600 const KernelWrapper& pack_rhs = *m_op->pack_rhs.value();
340 600 const std::vector<size_t> steps = pack_rhs.steps({m_shape_n, m_shape_k}, m_tensors);
341
2/4
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
600 return {steps.at(0), steps.at(1)};
342 600 }
343
344 600 void MatMulTb::test_rhs_packing(size_t start_n, size_t start_k, size_t size_n, size_t size_k) {
345 600 const KernelWrapper& pack_rhs = *m_op->pack_rhs.value();
346
347 600 const std::array full_shape{m_shape_n, m_shape_k};
348 600 const std::array tile_coords{start_n, start_k};
349 600 const std::array tile_shape{size_n, size_k};
350
351 600 pack_rhs.run(full_shape, tile_coords, tile_shape, m_tensors);
352
353 600 const Tensor& ref_packed_rhs = m_tensors.at(MATMUL_SLOT_REF_RHS_PACKED);
354 600 const Tensor& imp_packed_rhs = m_tensors.at(MATMUL_SLOT_IMP_RHS_PACKED);
355 600 const Format& format = *ref_packed_rhs.format();
356
357 600 DefaultMismatchHandler handler(0.0F, 0.0F, 0, 0.0F);
358 1200 const bool ok =
359
3/6
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
600 format.compare(full_shape, tile_coords, tile_shape, imp_packed_rhs.data(), ref_packed_rhs.data(), handler);
360
1/6
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
600 KAI_TEST_ASSERT(ok);
361 600 }
362
363 600 std::tuple<size_t, size_t> MatMulTb::matmul_steps() const {
364 600 const std::vector<size_t> steps = m_op->matmul->steps({m_shape_m, m_shape_n, m_shape_k}, m_tensors);
365
2/4
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
600 return {steps.at(0), steps.at(1)};
366 600 }
367
368 600 void MatMulTb::test_matmul(size_t start_m, size_t start_n, size_t size_m, size_t size_n) {
369 600 const std::array matmul_full_shape{m_shape_m, m_shape_n, m_shape_k};
370 600 const std::array matmul_tile_coords{start_m, start_n, static_cast<size_t>(0)};
371 600 const std::array matmul_tile_shape{size_m, size_n, m_shape_k};
372
373 600 const std::array dst_full_shape{m_shape_m, m_shape_n};
374 600 const std::array dst_tile_coords{start_m, start_n};
375 600 const std::array dst_tile_shape{size_m, size_n};
376
377 600 m_op->matmul->run(matmul_full_shape, matmul_tile_coords, matmul_tile_shape, m_tensors);
378
379 600 const Tensor& ref_dst_data = m_tensors.at(MATMUL_SLOT_REF_DST_DATA);
380 600 const Tensor& imp_dst_data = m_tensors.at(MATMUL_SLOT_IMP_DST_DATA);
381 600 const Format& format = *ref_dst_data.format();
382
383 600 DefaultMismatchHandler handler(1e-3, 1e-3, 0, 0.0F);
384
1/2
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
2400 const bool ok = format.compare(
385
6/12
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 600 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 600 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 600 times.
✗ Branch 11 not taken.
2400 dst_full_shape, dst_tile_coords, dst_tile_shape, imp_dst_data.data(), ref_dst_data.data(), handler);
386
1/6
✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
600 KAI_TEST_ASSERT(ok);
387 600 }
388
389 } // namespace kai::test
390