KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.1% 101 / 0 / 104
Functions: 100.0% 17 / 1 / 18
Branches: 41.1% 217 / 0 / 528

test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <sstream>
14 #include <string>
15 #include <tuple>
16
17 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h"
18 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h"
19 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h"
20 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h"
23 #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p_qai4c32p_interface.h"
24 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h"
25 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.h"
26 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.h"
27 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon.h"
28 #include "test/common/abi_checker.hpp"
29 #include "test/common/buffer.hpp"
30 #include "test/common/compare.hpp"
31 #include "test/common/cpu_info.hpp"
32 #include "test/common/int4.hpp"
33 #include "test/common/matmul_test_common.hpp"
34 #include "test/common/matrix_portion.hpp"
35 #include "test/common/memory.hpp"
36 #include "test/common/round.hpp"
37 #include "test/common/seed.hpp"
38 #include "test/common/test_suite.hpp"
39 #include "test/reference/cast.hpp"
40 #include "test/reference/clamp.hpp"
41 #include "test/reference/fill.hpp"
42 #include "test/reference/matmul.hpp"
43 #include "test/reference/pack.hpp"
44 #include "test/reference/quantize.hpp"
45
46 namespace kai::test {
47 // Interface for the LHS and RHS packed size and packing micro-kernels
48 using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon);
49 using kai_get_rhs_packed_size_func_t =
50 decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon);
51 using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon);
52 using kai_get_rhs_packed_offset_func_t =
53 decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon);
54 using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon);
55 using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon);
56 using kai_run_lhs_pack_func_t = decltype(&kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon);
57 using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon);
58
59 // Micro-kernel interface
60 struct kai_qai4c32p_pack_functions {
61 kai_get_rhs_packed_size_func_t packed_size;
62 kai_get_rhs_packed_offset_func_t get_packed_offset;
63 kai_get_rhs_offset_func_t get_offset;
64 kai_run_rhs_pack_func_t run_pack;
65 };
66
67 struct kai_qsi8d32p_pack_functions {
68 kai_get_lhs_packed_size_func_t packed_size;
69 kai_get_lhs_packed_offset_func_t get_packed_offset;
70 kai_get_lhs_offset_func_t get_offset;
71 kai_run_lhs_pack_func_t run_pack;
72 };
73
74 static const std::array<
75 UkernelMatmulPackVariant<
76 kai_matmul_clamp_f32_qsi8d32p_qai4c32p_ukernel, kai_qsi8d32p_pack_functions, kai_qai4c32p_pack_functions>,
77 8>
78
0/4
✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
3 variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p = {
79
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
10 {UKERNEL_MATMUL_PACK_VARIANT(
80 clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod, cpu_has_dotprod,
81 lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true),
82
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
83 clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32pscalef32_f32_neon,
84 rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true),
85
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
86 clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod, cpu_has_dotprod,
87 lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true),
88
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
89 clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod,
90 lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true),
91
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
92 clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f32_neon,
93 rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon, false),
94
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
95 clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2,
96 lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon,
97 false),
98
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
99 clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f32_neon,
100 rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon, true),
101
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
102 clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2,
103 lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon,
104 true)}};
105
106 // Executes the LHS packing micro-kernel.
107 50148 static inline std::tuple<Buffer, size_t> pack_lhs_qsi8d32p(
108 const kai_qsi8d32p_pack_functions& pack_interface, size_t M, size_t K, size_t bl, size_t mr, size_t kr, size_t sr,
109 const Buffer& lhs_values_qsi8, size_t stride, size_t rect_start_row, size_t rect_height) {
110 50148 const auto imp_packed_lhs_size = pack_interface.packed_size(M, K, bl, mr, kr, sr);
111 50148 Buffer imp_packed_lhs(imp_packed_lhs_size, 0);
112
113
1/2
✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
50148 auto lhs_offset = pack_interface.get_offset(rect_start_row, stride);
114
1/2
✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
50148 auto lhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, bl, mr, kr, sr);
115
116
1/2
✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
50148 abi_check(
117 50148 pack_interface.run_pack, rect_height, K, bl, mr, kr, sr, 0,
118
1/2
✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
50148 reinterpret_cast<const float*>(lhs_values_qsi8.data() + lhs_offset), stride,
119
1/2
✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
50148 imp_packed_lhs.data() + lhs_packed_offset);
120
121 50148 return {std::move(imp_packed_lhs), lhs_packed_offset};
122 50148 }
123
124 // Executes the RHS packing micro-kernel.
125 12852 static inline std::tuple<Buffer, size_t> pack_rhs_qai4c32p(
126 const kai_qai4c32p_pack_functions& pack_interface, size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr,
127 const Buffer& rhs_values_qai4, const bool has_bias, const Buffer& biases, const Buffer& rhs_scales,
128 const Buffer& rhs_zp, bool s0s1_input, size_t rect_start_row) {
129 // Cast to unsigned int
130 12852 auto rhs_qau4s1s0 = cast_qsu4_qsi4(rhs_values_qai4.data(), N * K);
131
132
1/2
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
12852 const auto imp_packed_rhs_size = pack_interface.packed_size(N, K, nr, kr, bl);
133
1/2
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
12852 Buffer imp_packed_rhs(imp_packed_rhs_size);
134
1/2
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
12852 auto rhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, nr, kr, bl);
135
136 // Runs the RHS packing micro-kernel.
137 12852 kai_rhs_pack_nxk_qai4c32p_params params{};
138 12852 params.lhs_zero_point = 1;
139 12852 params.rhs_zero_point = 8;
140
141
5/10
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2142 times.
✓ Branch 3 taken 10710 times.
✓ Branch 4 taken 2142 times.
✓ Branch 5 taken 10710 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
12852 abi_check(
142 12852 pack_interface.run_pack, 1, N, K, nr, kr, sr, bl,
143
3/4
✓ Branch 0 taken 10710 times.
✓ Branch 1 taken 2142 times.
✓ Branch 2 taken 10710 times.
✗ Branch 3 not taken.
12852 reinterpret_cast<const uint8_t*>(s0s1_input ? convert_s0s1_s1s0(rhs_qau4s1s0).data() : rhs_qau4s1s0.data()),
144
2/2
✓ Branch 0 taken 6426 times.
✓ Branch 1 taken 6426 times.
12852 rhs_zp.data(), has_bias ? biases.data() : nullptr, rhs_scales.data(), imp_packed_rhs.data(), 0, &params);
145
146 12852 return {std::move(imp_packed_rhs), rhs_packed_offset};
147 12852 }
148
149 using MatMulTestClampPortionedParamsWithBias_WithBL =
150 std::tuple<size_t, MatMulShape, size_t, MatrixPortion, float, bool>;
151 class MatMulTest_f32_qsi8d32p_qai4c32p
152 : public ::testing::TestWithParam<MatMulTestClampPortionedParamsWithBias_WithBL> {};
153
154
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
84006 TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, LhsPackedWithSameBlockdepth) {
155 // Verify LHS quant and pack int8 kernel behaves same for int4 and int8 matmul kernels,
156 // when the block-depth is same for different values of kr, sr.
157
158 20880384 const auto& [variant_index, matmul_shape, bl, portion, clamp_keep_ratio, has_bias] = GetParam();
159 67200 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index);
160
161
3/4
✓ Branch 0 taken 33600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25200 times.
✓ Branch 3 taken 8400 times.
33600 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
162
3/6
✓ Branch 0 taken 8400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8400 times.
✗ Branch 5 not taken.
8400 GTEST_SKIP() << "Unsupported CPU feature";
163 }
164
165 50400 const size_t M = matmul_shape.m;
166 50400 const size_t N = matmul_shape.n;
167 50400 const size_t K = matmul_shape.k;
168
169
4/4
✓ Branch 0 taken 6552 times.
✓ Branch 1 taken 18648 times.
✓ Branch 2 taken 6552 times.
✓ Branch 3 taken 18648 times.
50400 if (K % bl != 0) {
170
3/6
✓ Branch 0 taken 6552 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6552 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6552 times.
✗ Branch 5 not taken.
6552 GTEST_SKIP() << "K must be a multiple of bl";
171 }
172
173 18648 const auto mr = ukernel_variant.ukernel.interface.get_mr();
174 18648 const auto nr = ukernel_variant.ukernel.interface.get_nr();
175 18648 const auto kr = ukernel_variant.ukernel.interface.get_kr();
176 18648 const auto sr = ukernel_variant.ukernel.interface.get_sr();
177
178 18648 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
179
3/14
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 18648 times.
18648 ASSERT_TRUE(m_step % mr == 0);
180
181 18648 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
182
3/14
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 18648 times.
18648 ASSERT_TRUE(n_step % nr == 0);
183
184 37296 const auto rect = portion.compute_portion(M, N, m_step, n_step);
185
186 // Generates input data.
187
3/6
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18648 times.
✗ Branch 5 not taken.
18648 const auto ref_lhs = fill_random<float>(M * K, seed_stream(current_test_key())());
188
189 // Runs the LHS packing micro-kernel.
190
1/2
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
18648 const auto lhs_start_row = rect.start_row();
191 18648 auto lhs_stride = K * sizeof(float);
192
193
1/2
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
55944 auto [imp_packed_lhs, lhs_packed_offset] = pack_lhs_qsi8d32p(
194
2/4
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
37296 ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs, lhs_stride, lhs_start_row, rect.height());
195
196 18648 const size_t kr_qsi8 = kr / sr;
197 18648 const size_t sr_qsi8 = 1;
198
199
1/2
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
37296 auto [imp_packed_lhs_qsi8, lhs_qsi8_packed_offset] = pack_lhs_qsi8d32p(
200 37296 ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr_qsi8, sr_qsi8, ref_lhs, lhs_stride, lhs_start_row,
201
1/2
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
18648 rect.height());
202
203
5/18
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18648 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 18648 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 18648 times.
37296 ASSERT_EQ(lhs_qsi8_packed_offset, lhs_packed_offset);
204
205
2/4
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
37296 auto* imp_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs.data());
206
2/4
✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
37296 auto* imp_packed_lhs_qsi8_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs_qsi8.data());
207
5/8
✗ Branch 0 not taken.
✓ Branch 1 taken 20656440 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 20656440 times.
✓ Branch 4 taken 18648 times.
✓ Branch 5 taken 20637792 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18648 times.
20656440 for (size_t i = 0; i < ukernel_variant.lhs_pack_interface.packed_size(M, K, bl, mr, kr, sr); i++) {
208
4/16
✓ Branch 0 taken 20637792 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20637792 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20637792 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 20637792 times.
20637792 ASSERT_EQ(imp_packed_lhs_ptr[i], imp_packed_lhs_qsi8_ptr[i]);
209 20637792 }
210 33600 }
211
212
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
84006 TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd) {
213 193704 const auto& [variant_index, matmul_shape, bl, portion, clamp_keep_ratio, has_bias] = GetParam();
214 67200 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index);
215
216
3/4
✓ Branch 0 taken 33600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25200 times.
✓ Branch 3 taken 8400 times.
33600 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
217
3/6
✓ Branch 0 taken 8400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8400 times.
✗ Branch 5 not taken.
8400 GTEST_SKIP() << "Unsupported CPU feature";
218 }
219
220 50400 const size_t M = matmul_shape.m;
221 50400 const size_t N = matmul_shape.n;
222 50400 const size_t K = matmul_shape.k;
223
224
4/4
✓ Branch 0 taken 6552 times.
✓ Branch 1 taken 18648 times.
✓ Branch 2 taken 6552 times.
✓ Branch 3 taken 18648 times.
50400 if (K % bl != 0) {
225
3/6
✓ Branch 0 taken 6552 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6552 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6552 times.
✗ Branch 5 not taken.
6552 GTEST_SKIP() << "K must be a multiple of bl";
226 }
227
228 18648 const auto mr = ukernel_variant.ukernel.interface.get_mr();
229 18648 const auto nr = ukernel_variant.ukernel.interface.get_nr();
230 18648 const auto kr = ukernel_variant.ukernel.interface.get_kr();
231 18648 const auto sr = ukernel_variant.ukernel.interface.get_sr();
232
233
4/4
✓ Branch 0 taken 9324 times.
✓ Branch 1 taken 9324 times.
✓ Branch 2 taken 3528 times.
✓ Branch 3 taken 5796 times.
18648 if (mr == 1 && M > 1) {
234
3/6
✓ Branch 0 taken 5796 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5796 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5796 times.
✗ Branch 5 not taken.
5796 GTEST_SKIP() << "Kernel does not support M != 1";
235 }
236
237 12852 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
238
3/14
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12852 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 12852 times.
12852 ASSERT_TRUE(m_step % mr == 0);
239
240 12852 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
241
3/14
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12852 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 12852 times.
12852 ASSERT_TRUE(n_step % nr == 0);
242
243 25704 const auto rect = portion.compute_portion(M, N, m_step, n_step);
244
2/4
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 12852 times.
12852 if (rect.height() == 0 || rect.width() == 0) {
245 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
246 }
247
248 // Seed the random generator.
249
1/2
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
12852 auto& feed = seed_stream(current_test_key());
250
251 // Generates input data.
252 12852 const auto ref_lhs = fill_random<float>(M * K, feed());
253
2/4
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12852 times.
✗ Branch 3 not taken.
12852 const auto ref_rhs = fill_random<float>(N * K, feed());
254 12852 Buffer ref_biases;
255
256
2/2
✓ Branch 0 taken 6426 times.
✓ Branch 1 taken 6426 times.
12852 if (has_bias) {
257
2/4
✓ Branch 0 taken 6426 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6426 times.
✗ Branch 3 not taken.
6426 ref_biases = fill_random<float>(N, feed());
258 6426 }
259 // Runs the reference implementation.
260 // * Quantizes the LHS matrix using 8-bit symmetric quantization.
261 // * Quantizes the RHS matrix using 4-bit asymmetric quantization.
262 // * Performs GEMM.
263
1/2
✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
12852 QuantizationInfo lhs_qinfo{};
264 lhs_qinfo.quant_width = bl;
265 lhs_qinfo.dst_type = DataType::QSI8;
266 lhs_qinfo.scale_type = DataType::FP32;
267 const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo);
268
269 QuantizationInfo rhs_qinfo{};
270 rhs_qinfo.quant_width = bl;
271 rhs_qinfo.dst_type = DataType::QAI4;
272 rhs_qinfo.scale_type = DataType::FP32;
273 rhs_qinfo.zero_point_type = DataType::I32;
274 const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo);
275
276 const auto ref_dst_no_clamp =
277 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>(
278 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), nullptr, 1, bl, ref_rhs_quant.data(),
279 rhs_qoutputs.scales.data(), rhs_qoutputs.zero_points.data(), 1, bl, has_bias ? ref_biases.data() : nullptr,
280 nullptr, nullptr, 1);
281
282 // Clamps the reference output.
283 const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_no_clamp.data(), M * N, clamp_keep_ratio);
284 const auto ref_dst = clamp<float>(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max);
285
286 // Runs the LHS packing micro-kernel.
287 const auto lhs_start_row = rect.start_row();
288 auto [imp_packed_lhs, lhs_packed_offset] = pack_lhs_qsi8d32p(
289 ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs, K * sizeof(float), lhs_start_row,
290 rect.height());
291 auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl);
292
293 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
294
295 // Prepare the offsets as the RHS packing micro-kernel expects the scaled zero-points in float.
296 const size_t num_blocks_per_row = round_up_division(K, bl);
297 const size_t ref_zp_size = N * num_blocks_per_row;
298 const size_t ref_zp_size_in_bytes = ref_zp_size * sizeof(float);
299 Buffer ref_rhs_zp_f32(ref_zp_size_in_bytes);
300 for (size_t i = 0; i < ref_zp_size; ++i) {
301 reinterpret_cast<float*>(ref_rhs_zp_f32.data())[i] =
302 -reinterpret_cast<const int32_t*>(rhs_qoutputs.zero_points.data())[i] *
303 reinterpret_cast<const float*>(rhs_qoutputs.scales.data())[i];
304 }
305
306 const auto rhs_start_row = rect.start_col();
307 auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qai4c32p(
308 ukernel_variant.rhs_pack_interface, N, K, bl, nr, kr, sr, ref_rhs_quant, has_bias, ref_biases,
309 rhs_qoutputs.scales, ref_rhs_zp_f32, ukernel_variant.rhs_s0s1_input, rhs_start_row);
310
311 auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl);
312 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
313
314 const auto dst_stride_row = N * sizeof(float);
315 const auto dst_stride_col = sizeof(float);
316 const auto dst_offset =
317 ukernel_variant.ukernel.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
318 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
319 ASSERT_EQ(dst_offset, ref_dst_offset);
320
321 // Runs the GEMM micro-kernel.
322 const auto imp_dst_size = ukernel_variant.ukernel.interface.get_dst_size(M, N);
323 ASSERT_EQ(imp_dst_size, ref_dst.size());
324 Buffer imp_dst(imp_dst_size);
325 abi_check(
326 ukernel_variant.ukernel.interface.run_matmul, rect.height(), rect.width(), K, bl,
327 imp_packed_lhs.data() + lhs_matmul_offset, imp_packed_rhs.data() + rhs_matmul_offset,
328 reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
329
330 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
331 // tested.
332 DefaultMismatchHandler handler(0, 0.1, 0, 0.05);
333 DataFormat dst_format = DataFormat(DataType::FP32);
334 const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler);
335 ASSERT_TRUE(success);
336 }
337
77/202
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✓ Branch 24 taken 4 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 4 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✓ Branch 28 taken 4 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✓ Branch 30 taken 4 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✓ Branch 32 taken 4 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✓ Branch 34 taken 4 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 2 times.
✓ Branch 36 taken 4 times.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✓ Branch 38 taken 4 times.
✓ Branch 38 taken 33600 times.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✓ Branch 40 taken 67200 times.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 59 not taken.
✗ Branch 60 not taken.
✗ Branch 60 not taken.
✗ Branch 61 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✗ Branch 63 not taken.
✗ Branch 64 not taken.
✗ Branch 64 not taken.
✗ Branch 65 not taken.
✗ Branch 65 not taken.
✗ Branch 66 not taken.
✗ Branch 66 not taken.
✗ Branch 67 not taken.
✗ Branch 67 not taken.
✗ Branch 68 not taken.
✓ Branch 68 taken 33600 times.
✗ Branch 69 not taken.
✗ Branch 69 not taken.
✓ Branch 70 taken 33600 times.
✓ Branch 70 taken 67200 times.
✗ Branch 71 not taken.
✗ Branch 71 not taken.
✓ Branch 72 taken 33600 times.
✓ Branch 72 taken 67200 times.
✗ Branch 73 not taken.
✗ Branch 73 not taken.
✓ Branch 74 taken 33600 times.
✓ Branch 74 taken 67200 times.
✗ Branch 75 not taken.
✗ Branch 75 not taken.
✓ Branch 76 taken 33600 times.
✓ Branch 76 taken 67200 times.
✗ Branch 77 not taken.
✗ Branch 77 not taken.
✓ Branch 78 taken 33600 times.
✓ Branch 78 taken 67200 times.
✗ Branch 79 not taken.
✗ Branch 79 not taken.
✓ Branch 80 taken 33600 times.
✓ Branch 80 taken 67200 times.
✗ Branch 81 not taken.
✗ Branch 81 not taken.
✓ Branch 82 taken 33600 times.
✓ Branch 82 taken 67200 times.
✗ Branch 83 not taken.
✗ Branch 83 not taken.
✓ Branch 84 taken 33600 times.
✓ Branch 84 taken 67200 times.
✗ Branch 85 not taken.
✗ Branch 85 not taken.
✓ Branch 86 taken 16800 times.
✓ Branch 86 taken 67200 times.
✓ Branch 87 taken 16800 times.
✗ Branch 87 not taken.
✓ Branch 88 taken 16800 times.
✓ Branch 88 taken 67200 times.
✗ Branch 89 not taken.
✗ Branch 89 not taken.
✓ Branch 90 taken 16800 times.
✓ Branch 90 taken 33600 times.
✗ Branch 91 not taken.
✓ Branch 91 taken 33600 times.
✓ Branch 92 taken 33600 times.
✓ Branch 92 taken 33600 times.
✗ Branch 93 not taken.
✗ Branch 93 not taken.
✓ Branch 94 taken 25200 times.
✓ Branch 94 taken 33600 times.
✓ Branch 95 taken 8400 times.
✗ Branch 95 not taken.
✓ Branch 96 taken 25200 times.
✓ Branch 96 taken 67200 times.
✗ Branch 97 not taken.
✗ Branch 97 not taken.
✓ Branch 98 taken 8400 times.
✓ Branch 98 taken 50400 times.
✗ Branch 99 not taken.
✓ Branch 99 taken 16800 times.
✓ Branch 100 taken 33600 times.
✓ Branch 100 taken 50400 times.
✗ Branch 101 not taken.
✗ Branch 101 not taken.
✓ Branch 102 taken 33600 times.
✓ Branch 102 taken 16800 times.
✗ Branch 103 not taken.
✗ Branch 103 not taken.
✓ Branch 104 taken 67200 times.
✗ Branch 105 not taken.
✓ Branch 106 taken 67200 times.
✗ Branch 107 not taken.
277209 INSTANTIATE_TEST_SUITE_P(
338 MatMul, MatMulTest_f32_qsi8d32p_qai4c32p,
339 testing::Combine(
340 testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.size()),
341 testing::Values(
342 MatMulShape{1, 64, 32}, //
343 MatMulShape{1, 63, 32}, //
344 MatMulShape{1, 65, 32}, //
345 MatMulShape{1, 64, 64}, //
346 MatMulShape{1, 64, 128}, //
347 MatMulShape{1, 128, 32}, //
348 MatMulShape{1, 128, 128}, //
349 MatMulShape{1, 2, 32}, //
350 MatMulShape{1, 3, 32}, //
351 MatMulShape{1, 4, 32}, //
352 MatMulShape{1, 5, 32}, //
353 MatMulShape{3, 3, 32}, //
354 MatMulShape{4, 4, 32}, //
355 MatMulShape{5, 5, 32}, //
356 MatMulShape{32, 128, 32}, //
357 MatMulShape{15, 64, 64}, //
358 MatMulShape{17, 64, 64}, //
359 MatMulShape{16, 63, 64}, //
360 MatMulShape{16, 64, 64}, //
361 MatMulShape{16, 65, 64}, //
362 MatMulShape{32, 64, 64}, //
363 MatMulShape{16, 32, 64}, //
364 MatMulShape{8, 32, 64}, //
365 MatMulShape{15, 32, 32}, //
366 MatMulShape{77, 99, 64}),
367 testing::Values(32, 64),
368 testing::Values(
369 MatrixPortion(0, 0, 1, 1), // Full matrix.
370 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
371 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
372 MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle
373 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
374 MatrixPortion(0.75, 0, 1, 1), // Partial rows
375 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
376 ),
377 testing::ValuesIn(std::initializer_list<float>{1.0f, 0.9f, 0.5f}), //
378 testing::Bool()),
379 [](const auto& info) {
380 const auto variant_idx = std::get<0>(info.param);
381 const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_idx).ukernel.name};
382 const auto shape = std::get<MatMulShape>(info.param);
383 const auto bl = std::get<2>(info.param);
384 const auto portion = std::get<3>(info.param);
385 const auto clamp_keep_ratio = std::get<4>(info.param);
386 const auto has_bias = std::get<5>(info.param);
387
388 std::ostringstream sstream;
389 sstream << name << "__";
390 PrintTo(shape, &sstream);
391 sstream << "__BL_" << bl << "_";
392 sstream << "__clamp_keep_ratio_" << static_cast<int>(clamp_keep_ratio * 100);
393
394 if (has_bias) {
395 sstream << "_withBias_";
396 } else {
397 sstream << "_noBias_";
398 }
399 if (variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_idx).rhs_s0s1_input) {
400 sstream << "_RHS_s0s1__";
401 } else {
402 sstream << "_RHS_s1s0__";
403 }
404 PrintTo(portion, &sstream);
405
406 return sstream.str();
407 });
408
409 } // namespace kai::test
410