KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 94.6% 88 / 1 / 94
Functions: 100.0% 28 / 0 / 28
Branches: 34.4% 175 / 4 / 512

test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <limits>
14 #include <sstream>
15 #include <string>
16 #include <tuple>
17
18 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
19 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
20 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p8x4_1x8_sve_dotprod.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
23 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
24 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
25 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
26 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h"
27 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
28 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
29 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
30 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.h"
31 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
32 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
33 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
34 #include "test/common/abi_checker.hpp"
35 #include "test/common/buffer.hpp"
36 #include "test/common/compare.hpp"
37 #include "test/common/cpu_info.hpp"
38 #include "test/common/float16.hpp"
39 #include "test/common/int4.hpp"
40 #include "test/common/matmul_test_common.hpp"
41 #include "test/common/matrix_portion.hpp"
42 #include "test/common/memory.hpp"
43 #include "test/common/round.hpp"
44 #include "test/common/test_suite.hpp"
45 #include "test/reference/cast.hpp"
46 #include "test/reference/clamp.hpp"
47 #include "test/reference/fill.hpp"
48 #include "test/reference/matmul.hpp"
49 #include "test/reference/pack.hpp"
50 #include "test/reference/quantize.hpp"
51
52 namespace kai::test {
53
54 // Interface for the LHS and RHS packed size and packing micro-kernels
55 using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32);
56 using kai_get_rhs_packed_size_func_t = decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0);
57 using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32);
58 using kai_get_rhs_packed_offset_func_t =
59 decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0);
60 using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32);
61 using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0);
62 using kai_run_lhs_pack_func_t = decltype(&kai_run_lhs_quant_pack_qsi8d32p_f32);
63 using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0);
64
65 // Micro-kernel interface
66 struct kai_qsi8d32p_pack_functions {
67 kai_get_lhs_packed_size_func_t packed_size;
68 kai_get_lhs_packed_offset_func_t get_packed_offset;
69 kai_get_lhs_offset_func_t get_offset;
70 kai_run_lhs_pack_func_t run_pack;
71 };
72 struct kai_qsi4c32p_pack_functions {
73 kai_get_rhs_packed_size_func_t packed_size;
74 kai_get_rhs_packed_offset_func_t get_packed_offset;
75 kai_get_rhs_offset_func_t get_offset;
76 kai_run_rhs_pack_func_t run_pack;
77 };
78
79 struct UKernelVariants {
80 UkernelMatmulPackVariant<
81 kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_ukernel, kai_qsi8d32p_pack_functions, kai_qsi4c32p_pack_functions>
82 variant;
83 bool clamp_support;
84 };
85
86 // clang-format off
87 static const int num_non_clamping_kernels = 4;
88 static const std::array<UKernelVariants, 11>
89
0/4
✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
3 variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p = {
90 14 {
91 // NOTE: The following kernels do not support clamping despite their names.
92
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
93 clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32,
94 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), false},
95
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
96 clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32,
97 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), false},
98
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
99 clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, lhs_quant_pack_qsi8d32p_f32_neon,
100 rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, false), false},
101
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
102 clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, cpu_has_sme2, lhs_quant_pack_qsi8d32p_f32_neon,
103 rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, false), false},
104
105 // The kernels below this point will run clamping tests
106
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
107 clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32,
108 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true},
109
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
110 clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32,
111 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true},
112
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
113 4x8sb_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
114 cpu_has_i8mm, lhs_quant_pack_qsi8d32p4x8sb_f32_neon, rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true},
115
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
116 clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32,
117 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true},
118
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
119 clamp_f32_qsi8d32p1x4_qsi4c32p8x4_1x8_sve_dotprod, (cpu_check<cpu_has_sve_vl256, cpu_has_dotprod>), lhs_quant_pack_qsi8d32p_f32,
120 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true},
121
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
122 clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, (cpu_check<cpu_has_sve_vl256, cpu_has_dotprod>), lhs_quant_pack_qsi8d32p_f32,
123 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true},
124
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 {UKERNEL_MATMUL_PACK_VARIANT(
125 clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, (cpu_check<cpu_has_sve_vl256, cpu_has_i8mm>), lhs_quant_pack_qsi8d32p_f32,
126 3 rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true}}};
127 // clang-format on
128
129 class MatMulTest_f32_qsi8d32p_qsi4c32p
130 : public ::testing::TestWithParam<std::tuple<size_t, MatMulShape, MatrixPortion, float>> {};
131
132 // Ensure non-clamping tests are marked correctly.
133
9/18
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
12 TEST(KernelClampingCheck, SanityCheck) {
134
3/4
✓ Branch 0 taken 22 times.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
24 for (size_t i = 0; i < variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.size(); i++) {
135
3/14
✓ Branch 0 taken 22 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 22 times.
22 ASSERT_EQ(variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(i).clamp_support, !(i < num_non_clamping_kernels));
136 22 }
137 2 }
138
139
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
4418 TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, Offset_RHS) {
140 8456 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
141 3640 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index).variant;
142
143
3/4
✓ Branch 0 taken 1820 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✓ Branch 3 taken 616 times.
1820 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
144
3/6
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
616 GTEST_SKIP() << "Unsupported CPU feature";
145 }
146
147 1204 const size_t bl = 32;
148 2408 const size_t M = matmul_shape.m;
149 2408 const size_t N = matmul_shape.n;
150 2408 const size_t K = matmul_shape.k;
151
152 1204 const auto nr = ukernel_variant.ukernel.interface.get_nr();
153 1204 const auto kr = ukernel_variant.ukernel.interface.get_kr();
154
155 1204 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
156 1204 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
157
158 2408 const auto rect = portion.compute_portion(M, N, m_step, n_step);
159
2/4
✓ Branch 0 taken 1204 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1204 times.
1204 if (rect.height() == 0 || rect.width() == 0) {
160 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
161 }
162
163 1204 const auto rhs_start_row = rect.start_col();
164 1204 auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl);
165 1204 auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl);
166
3/14
✓ Branch 0 taken 1204 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1204 times.
1204 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
167 1820 }
168
169
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
4418 TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, Offset_LHS) {
170 8456 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
171 3640 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index).variant;
172
173
3/4
✓ Branch 0 taken 1820 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✓ Branch 3 taken 616 times.
1820 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
174
3/6
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
616 GTEST_SKIP() << "Unsupported CPU feature";
175 }
176
177 1204 const size_t bl = 32;
178 2408 const size_t M = matmul_shape.m;
179 2408 const size_t N = matmul_shape.n;
180 2408 const size_t K = matmul_shape.k;
181
182 1204 const auto mr = ukernel_variant.ukernel.interface.get_mr();
183 1204 const auto kr = ukernel_variant.ukernel.interface.get_kr();
184 1204 const auto sr = ukernel_variant.ukernel.interface.get_sr();
185
186 1204 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
187 1204 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
188
189 2408 const auto rect = portion.compute_portion(M, N, m_step, n_step);
190
2/4
✓ Branch 0 taken 1204 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1204 times.
1204 if (rect.height() == 0 || rect.width() == 0) {
191 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
192 }
193
194 1204 const auto lhs_start_row = rect.start_row();
195 1204 auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr);
196 1204 auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl);
197
198
3/14
✓ Branch 0 taken 1204 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1204 times.
1204 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
199 1820 }
200
201
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
4418 TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, EndToEnd) {
202 9376 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
203 3640 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index).variant;
204
205
3/4
✓ Branch 0 taken 1820 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✓ Branch 3 taken 616 times.
1820 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
206
3/6
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
616 GTEST_SKIP() << "Unsupported CPU feature";
207 }
208
209 // NOTE: Workaround - some kernels despite being called matmul_clamp do not support clamping.
210 2408 const bool clamp_support = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index).clamp_support;
211 1204 const std::uint32_t seed = 0;
212
213 2408 const size_t M = matmul_shape.m;
214 2408 const size_t N = matmul_shape.n;
215 2408 const size_t K = matmul_shape.k;
216 1204 const size_t bl = 32;
217
218 1204 const auto mr = ukernel_variant.ukernel.interface.get_mr();
219 1204 const auto nr = ukernel_variant.ukernel.interface.get_nr();
220 1204 const auto kr = ukernel_variant.ukernel.interface.get_kr();
221 1204 const auto sr = ukernel_variant.ukernel.interface.get_sr();
222
223 // Skip tests on clamping when kernel does not support it.
224 KAI_ASSERT_ALWAYS_IF_MSG(clamp_keep_ratio != 1.0F, clamp_support, "Clamping not supported by this kernel");
225
226
4/4
✓ Branch 0 taken 464 times.
✓ Branch 1 taken 740 times.
✓ Branch 2 taken 180 times.
✓ Branch 3 taken 284 times.
1204 if (mr == 1 && M > 1) {
227
3/6
✓ Branch 0 taken 284 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 284 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 284 times.
✗ Branch 5 not taken.
284 GTEST_SKIP() << "Kernel does not support M != 1";
228 }
229
230 920 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
231
3/14
✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 920 times.
920 ASSERT_TRUE(m_step % mr == 0);
232
233 920 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
234
3/14
✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 920 times.
920 ASSERT_TRUE(n_step % nr == 0);
235
236 1840 const auto rect = portion.compute_portion(M, N, m_step, n_step);
237
2/4
✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 920 times.
920 if (rect.height() == 0 || rect.width() == 0) {
238 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
239 }
240 // Generates input data.
241 920 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
242
1/2
✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
920 const auto ref_rhs = fill_random<float>(N * K, seed + 1);
243
244 // Runs the reference implementation.
245
1/2
✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
920 QuantizationInfo lhs_qinfo{};
246 lhs_qinfo.quant_width = bl;
247 lhs_qinfo.dst_type = DataType::QSI8;
248 lhs_qinfo.scale_type = DataType::FP16;
249 const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo);
250
251 QuantizationInfo rhs_qinfo{};
252 rhs_qinfo.quant_width = bl;
253 rhs_qinfo.dst_type = DataType::QSI4;
254 rhs_qinfo.scale_type = DataType::FP16;
255 const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo);
256
257 const auto ref_dst = matmul_clamp_nt_t<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>(
258 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), nullptr, bl, ref_rhs_quant.data(),
259 rhs_qoutputs.scales.data(), nullptr, bl, nullptr, std::numeric_limits<float>::lowest(),
260 std::numeric_limits<float>::max());
261
262 // Clamp reference output
263 const auto [min, max] = find_clamp_range<float>(ref_dst.data(), M * N, clamp_keep_ratio);
264 const auto out_clamped = clamp<float>(ref_dst.data(), M * N, min, max);
265
266 // Runs the LHS packing micro-kernel.
267 const auto lhs_start_row = rect.start_row();
268 const auto imp_packed_lhs_size = ukernel_variant.lhs_pack_interface.packed_size(M, K, bl, mr, kr, sr);
269 Buffer imp_packed_lhs(imp_packed_lhs_size);
270
271 auto lhs_stride = K * sizeof(float);
272 auto lhs_offset = ukernel_variant.lhs_pack_interface.get_offset(lhs_start_row, lhs_stride);
273 auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr);
274 auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl);
275
276 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
277
278 abi_check(
279 ukernel_variant.lhs_pack_interface.run_pack, rect.height() /* m */, K, bl, mr, kr, sr, 0,
280 reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
281 imp_packed_lhs.data() + lhs_packed_offset);
282
283 // Runs the RHS packing micro-kernel.
284 const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_quant.data(), N * K);
285 const auto ref_rhs_qsu4_scale_f16 =
286 pack_data_scales_interleave_block<UInt4, Float16>(ref_rhs_qsu4.data(), rhs_qoutputs.scales.data(), N, K, bl);
287
288 const auto imp_packed_rhs_size = ukernel_variant.rhs_pack_interface.packed_size(N, K, nr, kr, bl);
289 Buffer imp_packed_rhs(imp_packed_rhs_size);
290 const auto rhs_start_row = rect.start_col();
291 auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl);
292 auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl);
293 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
294
295 const kai_rhs_pack_qs4cxs1s0_param params{.lhs_zero_point = 1, .rhs_zero_point = 8};
296 abi_check(
297 ukernel_variant.rhs_pack_interface.run_pack, 1, N, K, nr, kr, sr, bl,
298 reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_scale_f16.data()), nullptr, imp_packed_rhs.data(), 0, &params);
299
300 const auto dst_stride_row = N * sizeof(float);
301 const auto dst_stride_col = sizeof(float);
302 const auto dst_offset =
303 ukernel_variant.ukernel.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
304
305 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
306 ASSERT_EQ(dst_offset, ref_dst_offset);
307
308 // Runs the GEMM micro-kernel.
309 const auto imp_dst_size = ukernel_variant.ukernel.interface.get_dst_size(M, N);
310 ASSERT_EQ(imp_dst_size, ref_dst.size());
311 Buffer imp_dst(imp_dst_size);
312 abi_check(
313 ukernel_variant.ukernel.interface.run_matmul, rect.height(), rect.width(), K, bl,
314 imp_packed_lhs.data() + lhs_matmul_offset, imp_packed_rhs.data() + rhs_matmul_offset,
315 reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, dst_stride_col, min, max);
316
317 DefaultMismatchHandler handler(0, 0.0001, 0, 0.0001);
318 const auto success = compare(imp_dst.data(), out_clamped.data(), DataType::FP32, M, N, rect, handler);
319
320 ASSERT_TRUE(success);
321 }
322
323 // Test all kernels without clamping
324
29/94
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✓ Branch 14 taken 6 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✓ Branch 18 taken 6 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✓ Branch 20 taken 6 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 22 taken 6 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✓ Branch 24 taken 6 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✓ Branch 26 taken 6 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✓ Branch 28 taken 1056 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✓ Branch 30 taken 2112 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 48 taken 1056 times.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✓ Branch 50 taken 2112 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 2112 times.
✗ Branch 53 not taken.
6348 INSTANTIATE_TEST_SUITE_P(
325 MatMul, MatMulTest_f32_qsi8d32p_qsi4c32p,
326 testing::Combine(
327 testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.size()),
328 testing::Values(
329 MatMulShape{1, 2, 32}, //
330 MatMulShape{1, 40, 32}, //
331 MatMulShape{1, 33, 32}, //
332 MatMulShape{32, 64, 64}, //
333 MatMulShape{16, 32, 64}, //
334 MatMulShape{8, 32, 64}, //
335 MatMulShape{15, 32, 32}, //
336 MatMulShape{77, 99, 64}),
337 testing::Values(
338 MatrixPortion(0, 0, 1, 1), // Full matrix.
339 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
340 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
341 MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle
342 ),
343 testing::ValuesIn(std::initializer_list<float>{1.0F})), // We keep 100% of values - no clamping
344 [](const auto& info) {
345 const auto variant_idx = std::get<0>(info.param);
346 const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_idx).variant.ukernel.name};
347 const auto shape = std::get<MatMulShape>(info.param);
348 const auto portion = std::get<2>(info.param);
349 const auto clamp_keep_ratio = std::get<3>(info.param);
350
351 return test_description(name, shape, portion, true, clamp_keep_ratio);
352 });
353
354 // Test supported matmul kernels with clamping support.
355
29/94
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✓ Branch 14 taken 6 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✓ Branch 18 taken 6 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✓ Branch 20 taken 6 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 22 taken 6 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✓ Branch 24 taken 6 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✓ Branch 26 taken 6 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✓ Branch 28 taken 1260 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✓ Branch 30 taken 2520 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 48 taken 1260 times.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✓ Branch 50 taken 2520 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 2520 times.
✗ Branch 53 not taken.
7572 INSTANTIATE_TEST_SUITE_P(
356 MatMulClamped, MatMulTest_f32_qsi8d32p_qsi4c32p,
357 testing::Combine(
358 testing::Range<size_t>(num_non_clamping_kernels, variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.size()),
359 testing::Values(
360 MatMulShape{1, 2, 32}, //
361 MatMulShape{1, 33, 32}, //
362 MatMulShape{17, 32, 64}, //
363 MatMulShape{32, 64, 64}, //
364 MatMulShape{77, 99, 64}),
365 testing::Values(
366 MatrixPortion(0, 0, 1, 1), // Full matrix.
367 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
368 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
369 MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle
370 ),
371 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio
372 [](const auto& info) {
373 const auto variant_idx = std::get<0>(info.param);
374 const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_idx).variant.ukernel.name};
375 const auto shape = std::get<MatMulShape>(info.param);
376 const auto portion = std::get<2>(info.param);
377 const auto clamp_keep_ratio = std::get<3>(info.param);
378
379 return test_description(name, shape, portion, true, clamp_keep_ratio);
380 });
381
382 } // namespace kai::test
383