KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.0% 121 0 126
Functions: 100.0% 17 0 17
Branches: 39.1% 180 0 460

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <limits>
14 #include <sstream>
15 #include <string>
16 #include <tuple>
17
18 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
19 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
20 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
23 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
24 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h"
25 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
26 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
27 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.h"
28 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
29 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
30 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
31 #include "test/common/buffer.hpp"
32 #include "test/common/cpu_info.hpp"
33 #include "test/common/float16.hpp"
34 #include "test/common/int4.hpp"
35 #include "test/common/matmul_test_common.hpp"
36 #include "test/common/matrix_portion.hpp"
37 #include "test/common/memory.hpp"
38 #include "test/common/round.hpp"
39 #include "test/common/test_suite.hpp"
40 #include "test/reference/cast.hpp"
41 #include "test/reference/fill.hpp"
42 #include "test/reference/matmul.hpp"
43 #include "test/reference/pack.hpp"
44 #include "test/reference/quantize.hpp"
45
46 namespace kai::test {
47
48 // Interface for the LHS and RHS packed size and packing micro-kernels
49 using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32);
50 using kai_get_rhs_packed_size_func_t = decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0);
51 using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32);
52 using kai_get_rhs_packed_offset_func_t =
53 decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0);
54 using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32);
55 using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0);
56 using kai_run_lhs_pack_func_t = decltype(&kai_run_lhs_quant_pack_qsi8d32p_f32);
57 using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0);
58
59 // Micro-kernel interface
60 struct kai_qsi8d32p_pack_functions {
61 kai_get_lhs_packed_size_func_t packed_size;
62 kai_get_lhs_packed_offset_func_t get_packed_offset;
63 kai_get_lhs_offset_func_t get_offset;
64 kai_run_lhs_pack_func_t run_pack;
65 };
66 struct kai_qsi4c32p_pack_functions {
67 kai_get_rhs_packed_size_func_t packed_size;
68 kai_get_rhs_packed_offset_func_t get_packed_offset;
69 kai_get_rhs_offset_func_t get_offset;
70 kai_run_rhs_pack_func_t run_pack;
71 };
72
73 static const std::array<
74 UkernelMatmulPackVariant<
75 kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_ukernel, kai_qsi8d32p_pack_functions, kai_qsi4c32p_pack_functions>,
76 8>
77 variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p = {
78 {UKERNEL_MATMUL_PACK_VARIANT(
79 clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32,
80 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false),
81 UKERNEL_MATMUL_PACK_VARIANT(
82 clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32,
83 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false),
84 UKERNEL_MATMUL_PACK_VARIANT(
85 4x8sb_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
86 cpu_has_i8mm, lhs_quant_pack_qsi8d32p4x8sb_f32_neon, rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false),
87 UKERNEL_MATMUL_PACK_VARIANT(
88 clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32,
89 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false),
90 UKERNEL_MATMUL_PACK_VARIANT(
91 clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32,
92 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false),
93 UKERNEL_MATMUL_PACK_VARIANT(
94 clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32,
95 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false),
96 UKERNEL_MATMUL_PACK_VARIANT(
97 clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, lhs_quant_pack_qsi8d32p_f32_neon,
98 rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, false),
99 UKERNEL_MATMUL_PACK_VARIANT(
100 clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, cpu_has_sme2, lhs_quant_pack_qsi8d32p_f32_neon,
101 rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, false)}};
102
103 class MatMulTest_f32_qsi8d32p_qsi4c32p : public ::testing::TestWithParam<MatMulTestPortionedParams> {};
104
105
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
578 TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, Offset_RHS) {
106 1152 const auto& [variant_index, matmul_shape, portion] = GetParam();
107 384 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index);
108
109
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
110 GTEST_SKIP() << "Unsupported CPU feature";
111 }
112
113 192 const size_t bl = 32;
114 384 const size_t M = matmul_shape.m;
115 384 const size_t N = matmul_shape.n;
116 384 const size_t K = matmul_shape.k;
117
118 192 const auto nr = ukernel_variant.ukernel.interface.get_nr();
119 192 const auto kr = ukernel_variant.ukernel.interface.get_kr();
120
121 192 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
122 192 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
123
124 384 const auto rect = portion.compute_portion(M, N, m_step, n_step);
125
3/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 184 times.
192 if (rect.height() == 0 || rect.width() == 0) {
126
9/18
✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
8 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
127 }
128
129 184 const auto rhs_start_row = rect.start_col();
130 184 auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl);
131 184 auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl);
132
3/14
✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 184 times.
184 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
133 192 }
134
135
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
578 TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, Offset_LHS) {
136 1152 const auto& [variant_index, matmul_shape, portion] = GetParam();
137 384 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index);
138
139
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
140 GTEST_SKIP() << "Unsupported CPU feature";
141 }
142
143 192 const size_t bl = 32;
144 384 const size_t M = matmul_shape.m;
145 384 const size_t N = matmul_shape.n;
146 384 const size_t K = matmul_shape.k;
147
148 192 const auto mr = ukernel_variant.ukernel.interface.get_mr();
149 192 const auto kr = ukernel_variant.ukernel.interface.get_kr();
150 192 const auto sr = ukernel_variant.ukernel.interface.get_sr();
151
152 192 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
153 192 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
154
155 384 const auto rect = portion.compute_portion(M, N, m_step, n_step);
156
3/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 184 times.
192 if (rect.height() == 0 || rect.width() == 0) {
157
9/18
✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
8 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
158 }
159
160 184 const auto lhs_start_row = rect.start_row();
161 184 auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr);
162 184 auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl);
163
164
3/14
✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 184 times.
184 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
165 192 }
166
167
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
578 TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, EndToEnd) {
168 1092 const auto& [variant_index, matmul_shape, portion] = GetParam();
169 384 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index);
170
171
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
172 GTEST_SKIP() << "Unsupported CPU feature";
173 }
174
175 192 const std::uint32_t seed = 0;
176
177 384 const size_t M = matmul_shape.m;
178 384 const size_t N = matmul_shape.n;
179 384 const size_t K = matmul_shape.k;
180 192 const size_t bl = 32;
181
182 192 const auto mr = ukernel_variant.ukernel.interface.get_mr();
183 192 const auto nr = ukernel_variant.ukernel.interface.get_nr();
184 192 const auto kr = ukernel_variant.ukernel.interface.get_kr();
185 192 const auto sr = ukernel_variant.ukernel.interface.get_sr();
186
187
4/4
✓ Branch 0 taken 72 times.
✓ Branch 1 taken 120 times.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 60 times.
192 if (mr == 1 && M > 1) {
188
3/6
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
60 GTEST_SKIP() << "Kernel does not support M != 1";
189 }
190
191 132 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
192
3/14
✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 132 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 132 times.
132 ASSERT_TRUE(m_step % mr == 0);
193
194 132 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
195
3/14
✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 132 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 132 times.
132 ASSERT_TRUE(n_step % nr == 0);
196
197 264 const auto rect = portion.compute_portion(M, N, m_step, n_step);
198
3/4
✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 124 times.
132 if (rect.height() == 0 || rect.width() == 0) {
199
9/18
✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
8 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
200 }
201 // Generates input data.
202 124 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
203
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 const auto ref_rhs = fill_random<float>(N * K, seed + 1);
204
205 // Runs the reference implementation.
206 372 const auto [ref_lhs_qvalues, ref_lhs_scales] =
207
2/4
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
124 quantize_symmetric_per_block_dynamic<float, int8_t, Float16>(ref_lhs.data(), M, K, bl);
208 620 const auto [ref_rhs_qsi4, ref_rhs_scales] =
209
2/4
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
124 quantize_symmetric_per_block_dynamic<float, Int4, Float16>(ref_rhs.data(), N, K, bl);
210
211
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
248 const auto ref_dst = matmul_clamp_nt_t<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>(
212
5/10
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 124 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 124 times.
✗ Branch 9 not taken.
248 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, bl, ref_rhs_qsi4.data(), ref_rhs_scales.data(),
213 124 nullptr, bl, nullptr, std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
214
215 // Runs the LHS packing micro-kernel.
216
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 const auto lhs_start_row = rect.start_row();
217
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 const auto imp_packed_lhs_size = ukernel_variant.lhs_pack_interface.packed_size(M, K, bl, mr, kr, sr);
218
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 Buffer imp_packed_lhs(imp_packed_lhs_size);
219
220 124 auto lhs_stride = K * sizeof(float);
221
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 auto lhs_offset = ukernel_variant.lhs_pack_interface.get_offset(lhs_start_row, lhs_stride);
222
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr);
223
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl);
224
225
4/16
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 124 times.
✗ Branch 15 not taken.
124 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
226
227
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
248 ukernel_variant.lhs_pack_interface.run_pack(
228
2/4
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
124 rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset),
229
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 lhs_stride, imp_packed_lhs.data() + lhs_packed_offset);
230
231 // Runs the RHS packing micro-kernel.
232
3/6
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
248 const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K);
233 124 const auto ref_rhs_qsu4_scale_f16 =
234
3/6
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
124 pack_data_scales_interleave_block<UInt4, Float16>(ref_rhs_qsu4.data(), ref_rhs_scales.data(), N, K, bl);
235
236
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 const auto imp_packed_rhs_size = ukernel_variant.rhs_pack_interface.packed_size(N, K, nr, kr, bl);
237
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 Buffer imp_packed_rhs(imp_packed_rhs_size);
238
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 const auto rhs_start_row = rect.start_col();
239
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl);
240
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl);
241
4/16
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 124 times.
124 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
242
243 124 const kai_rhs_pack_qs4cxs1s0_param params{.lhs_zero_point = 1, .rhs_zero_point = 8};
244
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
248 ukernel_variant.rhs_pack_interface.run_pack(
245
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 1, N, K, nr, kr, sr, bl, reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_scale_f16.data()), nullptr,
246
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 imp_packed_rhs.data(), 0, &params);
247
248 124 const auto dst_stride_row = N * sizeof(float);
249 124 const auto dst_stride_col = sizeof(float);
250 248 const auto dst_offset =
251
3/6
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
124 ukernel_variant.ukernel.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
252
2/4
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
124 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
253
4/16
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 124 times.
124 ASSERT_EQ(dst_offset, ref_dst_offset);
254
255 // Runs the GEMM micro-kernel.
256
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 const auto imp_dst_size = ukernel_variant.ukernel.interface.get_dst_size(M, N);
257
5/18
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 124 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 124 times.
124 ASSERT_EQ(imp_dst_size, ref_dst.size());
258
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
124 Buffer imp_dst(imp_dst_size);
259
1/2
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
248 ukernel_variant.ukernel.interface.run_matmul(
260
3/6
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
124 rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset,
261
2/4
✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
124 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
262 124 dst_stride_row, dst_stride_col, std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
263
264 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
265 // tested.
266
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 3108 times.
✓ Branch 2 taken 2984 times.
✓ Branch 3 taken 124 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 124 times.
3108 for (size_t y = 0; y < rect.height(); ++y) {
267
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 128593 times.
✓ Branch 2 taken 125609 times.
✓ Branch 3 taken 2984 times.
✓ Branch 4 taken 2984 times.
✗ Branch 5 not taken.
128593 for (size_t x = 0; x < rect.width(); ++x) {
268 251218 const auto imp_value =
269
4/8
✓ Branch 0 taken 125609 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125609 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125609 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125609 times.
✗ Branch 7 not taken.
125609 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
270 251218 const auto ref_value =
271
4/8
✓ Branch 0 taken 125609 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125609 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125609 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125609 times.
✗ Branch 7 not taken.
125609 read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
272
1/2
✓ Branch 0 taken 125609 times.
✗ Branch 1 not taken.
125609 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value;
273
274
1/2
✓ Branch 0 taken 125609 times.
✗ Branch 1 not taken.
125609 if (rel_error > 0.0001F) {
275 ASSERT_EQ(imp_value, ref_value);
276 }
277 125609 }
278 2984 }
279 192 }
280
281
15/46
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 576 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✓ Branch 44 taken 576 times.
✗ Branch 45 not taken.
1156 INSTANTIATE_TEST_SUITE_P(
282 MatMul, MatMulTest_f32_qsi8d32p_qsi4c32p,
283 testing::Combine(
284 testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.size()),
285 testing::Values(
286 MatMulShape{1, 2, 32}, //
287 MatMulShape{32, 64, 64}, //
288 MatMulShape{16, 32, 64}, //
289 MatMulShape{8, 32, 64}, //
290 MatMulShape{15, 32, 32}, //
291 MatMulShape{77, 99, 64}),
292 testing::Values(
293 MatrixPortion(0, 0, 1, 1), // Full matrix.
294 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
295 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
296 MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle
297 )),
298 [](const auto& info) {
299 const auto variant_idx = std::get<0>(info.param);
300 const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_idx).ukernel.name};
301 const auto shape = std::get<MatMulShape>(info.param);
302 const auto portion = std::get<2>(info.param);
303
304 return test_description(name, shape, portion, true);
305 });
306
307 } // namespace kai::test
308