KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_f16_qsi8d32p_qai4c32p_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.7% 148 0 150
Functions: 100.0% 15 0 15
Branches: 43.3% 231 0 534

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <sstream>
14 #include <string>
15 #include <tuple>
16
17 #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h"
18 #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h"
19 #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h"
20 #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h"
23 #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p_qai4c32p_interface.h"
24 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f16_neon.h"
25 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.h"
26 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.h"
27 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon.h"
28 #include "test/common/buffer.hpp"
29 #include "test/common/compare.hpp"
30 #include "test/common/cpu_info.hpp"
31 #include "test/common/data_format.hpp"
32 #include "test/common/float16.hpp"
33 #include "test/common/int4.hpp"
34 #include "test/common/matmul_test_common.hpp"
35 #include "test/common/matrix_portion.hpp"
36 #include "test/common/memory.hpp"
37 #include "test/common/round.hpp"
38 #include "test/common/test_suite.hpp"
39 #include "test/reference/cast.hpp"
40 #include "test/reference/clamp.hpp"
41 #include "test/reference/fill.hpp"
42 #include "test/reference/matmul.hpp"
43 #include "test/reference/pack.hpp"
44 #include "test/reference/quantize.hpp"
45
46 namespace kai::test {
47
48 // Interface for the LHS and RHS packed size and packing micro-kernels
49 using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f16_neon);
50 using kai_get_rhs_packed_size_func_t =
51 decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon);
52 using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f16_neon);
53 using kai_get_rhs_packed_offset_func_t =
54 decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon);
55 using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f16_neon);
56 using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon);
57 using kai_run_lhs_pack_func_f16_t = decltype(&kai_run_lhs_quant_pack_qsi8d32pscalef32_f16_neon);
58 using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon);
59
60 // Micro-kernel interface
61 struct kai_qai4c32p_pack_functions {
62 kai_get_rhs_packed_size_func_t packed_size;
63 kai_get_rhs_packed_offset_func_t get_packed_offset;
64 kai_get_rhs_offset_func_t get_offset;
65 kai_run_rhs_pack_func_t run_pack;
66 };
67
68 struct kai_qsi8d32p_f16_pack_functions {
69 kai_get_lhs_packed_size_func_t packed_size;
70 kai_get_lhs_packed_offset_func_t get_packed_offset;
71 kai_get_lhs_offset_func_t get_offset;
72 kai_run_lhs_pack_func_f16_t run_pack;
73 };
74
75 static const std::array<
76 UkernelMatmulPackVariant<
77 kai_matmul_clamp_f16_qsi8d32p_qai4c32p_ukernel, kai_qsi8d32p_f16_pack_functions, kai_qai4c32p_pack_functions>,
78 8>
79 variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p = {
80 {UKERNEL_MATMUL_PACK_VARIANT(
81 clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod, cpu_has_dotprod,
82 lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true),
83 UKERNEL_MATMUL_PACK_VARIANT(
84 clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32pscalef32_f16_neon,
85 rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true),
86 UKERNEL_MATMUL_PACK_VARIANT(
87 clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod, cpu_has_dotprod,
88 lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true),
89 UKERNEL_MATMUL_PACK_VARIANT(
90 clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod,
91 lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true),
92 UKERNEL_MATMUL_PACK_VARIANT(
93 clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f16_neon,
94 rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon, false),
95 UKERNEL_MATMUL_PACK_VARIANT(
96 clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2,
97 lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon,
98 false),
99 UKERNEL_MATMUL_PACK_VARIANT(
100 clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f16_neon,
101 rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon, true),
102 UKERNEL_MATMUL_PACK_VARIANT(
103 clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2,
104 lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon,
105 true)}};
106
107 static const auto test_matmul_shapes = testing::Values(
108 MatMulShape{1, 64, 32}, //
109 MatMulShape{1, 63, 32}, //
110 MatMulShape{1, 65, 32}, //
111 MatMulShape{1, 64, 64}, //
112 MatMulShape{1, 64, 128}, //
113 MatMulShape{1, 128, 32}, //
114 MatMulShape{1, 128, 128}, //
115 MatMulShape{1, 2, 32}, //
116 MatMulShape{1, 3, 32}, //
117 MatMulShape{1, 4, 32}, //
118 MatMulShape{1, 5, 32}, //
119 MatMulShape{3, 3, 32}, //
120 MatMulShape{4, 4, 32}, //
121 MatMulShape{5, 5, 32}, //
122 MatMulShape{32, 128, 32}, //
123 MatMulShape{15, 64, 64}, //
124 MatMulShape{17, 64, 64}, //
125 MatMulShape{16, 63, 64}, //
126 MatMulShape{16, 64, 64}, //
127 MatMulShape{16, 65, 64}, //
128 MatMulShape{32, 64, 64}, //
129 MatMulShape{16, 32, 64}, //
130 MatMulShape{8, 32, 64}, //
131 MatMulShape{15, 32, 32}, //
132 MatMulShape{77, 99, 64} //
133 );
134
135 static const auto test_portions = testing::Values(
136 MatrixPortion(0, 0, 1, 1), // Full matrix.
137 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
138 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
139 MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle
140 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
141 MatrixPortion(0.75, 0, 1, 1), // Partial rows
142 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
143 );
144
145 static const auto test_block_lengths = testing::Values(32, 64);
146
147 // Executes the LHS packing micro-kernel.
148 10904 static inline Buffer pack_lhs_qsi8d32p_f16(
149 const kai_qsi8d32p_f16_pack_functions& pack_interface, size_t M, size_t K, size_t bl, size_t mr, size_t kr,
150 size_t sr, const Buffer& lhs_f16, size_t stride, size_t rect_start_row, size_t rect_height) {
151 10904 const auto imp_packed_lhs_size = pack_interface.packed_size(M, K, bl, mr, kr, sr);
152 10904 Buffer imp_packed_lhs(imp_packed_lhs_size, 0);
153
154
1/2
✓ Branch 0 taken 10904 times.
✗ Branch 1 not taken.
10904 auto lhs_offset = pack_interface.get_offset(rect_start_row, stride);
155
1/2
✓ Branch 0 taken 10904 times.
✗ Branch 1 not taken.
10904 auto lhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, bl, mr, kr, sr);
156
157
2/4
✓ Branch 0 taken 10904 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 10904 times.
✗ Branch 3 not taken.
21808 pack_interface.run_pack(
158 10904 rect_height, K, bl, mr, kr, sr, 0, reinterpret_cast<const uint8_t*>(lhs_f16.data() + lhs_offset), stride,
159 10904 imp_packed_lhs.data() + lhs_packed_offset);
160
161 10904 return (imp_packed_lhs);
162 10904 }
163
164 // Executes the RHS packing micro-kernel.
165 2616 static inline Buffer pack_rhs_qai4c32p(
166 const kai_qai4c32p_pack_functions& pack_interface, size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr,
167 const Buffer& rhs_values_qai4, const bool has_bias, const Buffer& biases, const Buffer& rhs_scales,
168 const Buffer& rhs_zp, bool s0s1_input) {
169 // Cast to unsigned int
170 2616 auto rhs_qau4s1s0 = cast_qsu4_qsi4(rhs_values_qai4.data(), N * K);
171
172
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 const auto imp_packed_rhs_size = pack_interface.packed_size(N, K, nr, kr, bl);
173
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 Buffer imp_packed_rhs(imp_packed_rhs_size);
174
175 // Runs the RHS packing micro-kernel.
176 2616 kai_rhs_pack_nxk_qai4c32p_params params{};
177 2616 params.lhs_zero_point = 1;
178 2616 params.rhs_zero_point = 8;
179
180
5/10
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 654 times.
✓ Branch 3 taken 1962 times.
✓ Branch 4 taken 654 times.
✓ Branch 5 taken 1962 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
5232 pack_interface.run_pack(
181 2616 1, N, K, nr, kr, sr, bl,
182
3/4
✓ Branch 0 taken 1962 times.
✓ Branch 1 taken 654 times.
✓ Branch 2 taken 1962 times.
✗ Branch 3 not taken.
2616 reinterpret_cast<const uint8_t*>(s0s1_input ? convert_s0s1_s1s0(rhs_qau4s1s0).data() : rhs_qau4s1s0.data()),
183
2/2
✓ Branch 0 taken 1308 times.
✓ Branch 1 taken 1308 times.
2616 rhs_zp.data(), has_bias ? biases.data() : nullptr, rhs_scales.data(), imp_packed_rhs.data(), 0, &params);
184
185 2616 return (imp_packed_rhs);
186 2616 }
187
188 class MatMulTest_f16_qsi8d32p_qai4c32p : public ::testing::TestWithParam<MatMulTestPortionedParamsWithBias_WithBL> {};
189
190
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
16802 TEST_P(MatMulTest_f16_qsi8d32p_qai4c32p, LhsPackedWithSameBlockdepth) {
191 // Verify LHS quant and pack int8 kernel behaves same for int4 and int8 matmul kernels,
192 // when the block-depth is same for different values of kr, sr.
193
194 4801776 const auto& [variant_index, matmul_shape, bl, portion, has_bias] = GetParam();
195 11200 const auto& ukernel_variant = variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.at(variant_index);
196
197
2/4
✓ Branch 0 taken 5600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5600 times.
✗ Branch 3 not taken.
5600 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
198 GTEST_SKIP() << "Unsupported CPU feature";
199 }
200
201 5600 const std::uint32_t seed = 0;
202
203 11200 const size_t M = matmul_shape.m;
204 11200 const size_t N = matmul_shape.n;
205 11200 const size_t K = matmul_shape.k;
206
207
4/4
✓ Branch 0 taken 1456 times.
✓ Branch 1 taken 4144 times.
✓ Branch 2 taken 1456 times.
✓ Branch 3 taken 4144 times.
11200 if (K % bl != 0) {
208
3/6
✓ Branch 0 taken 1456 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1456 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1456 times.
✗ Branch 5 not taken.
1456 GTEST_SKIP() << "K must be a multiple of bl";
209 }
210
211 4144 const auto mr = ukernel_variant.ukernel.interface.get_mr();
212 4144 const auto nr = ukernel_variant.ukernel.interface.get_nr();
213 4144 const auto kr = ukernel_variant.ukernel.interface.get_kr();
214 4144 const auto sr = ukernel_variant.ukernel.interface.get_sr();
215
216 4144 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
217
3/14
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 4144 times.
4144 ASSERT_TRUE(m_step % mr == 0);
218
219 4144 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
220
3/14
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 4144 times.
4144 ASSERT_TRUE(n_step % nr == 0);
221
222 8288 const auto rect = portion.compute_portion(M, N, m_step, n_step);
223
224 // Generates input data.
225 4144 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
226
227 // Runs the reference implementation.
228 // * Quantizes the LHS matrix using 8-bit symmetric quantization.
229 8288 const auto [ref_lhs_qvalues, ref_lhs_scales] =
230
3/6
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4144 times.
✗ Branch 5 not taken.
4144 quantize_symmetric_per_block_dynamic<float, int8_t, float>(ref_lhs.data(), M, K, bl);
231
232 // Runs the LHS packing micro-kernel.
233
1/2
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
4144 const auto lhs_start_row = rect.start_row();
234 4144 auto lhs_stride = K * sizeof(uint16_t);
235
236
1/2
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
4144 auto imp_packed_lhs = pack_lhs_qsi8d32p_f16(
237
2/4
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
8288 ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs, lhs_stride, lhs_start_row, rect.height());
238
2/4
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
8288 auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr);
239
240 4144 const size_t kr_qsi8 = kr / sr;
241 4144 const size_t sr_qsi8 = 1;
242
243
1/2
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
4144 auto imp_packed_lhs_qsi8 = pack_lhs_qsi8d32p_f16(
244 8288 ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr_qsi8, sr_qsi8, ref_lhs, lhs_stride, lhs_start_row,
245
1/2
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
4144 rect.height());
246 4144 auto lhs_qsi8_packed_offset =
247
2/4
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
8288 ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr_qsi8, sr_qsi8);
248
249
4/16
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4144 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 4144 times.
4144 ASSERT_EQ(lhs_qsi8_packed_offset, lhs_packed_offset);
250
251
1/2
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
4144 auto* imp_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs.data());
252
1/2
✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
4144 auto* imp_packed_lhs_qsi8_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs_qsi8.data());
253
5/8
✗ Branch 0 not taken.
✓ Branch 1 taken 4751600 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4751600 times.
✓ Branch 4 taken 4144 times.
✓ Branch 5 taken 4747456 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4144 times.
4751600 for (size_t i = 0; i < ukernel_variant.lhs_pack_interface.packed_size(M, K, bl, mr, kr, sr); i++) {
254
4/16
✓ Branch 0 taken 4747456 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4747456 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4747456 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 4747456 times.
4747456 ASSERT_EQ(imp_packed_lhs_ptr[i], imp_packed_lhs_qsi8_ptr[i]);
255 4747456 }
256 5600 }
257
258
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
16802 TEST_P(MatMulTest_f16_qsi8d32p_qai4c32p, EndToEnd) {
259 67848 const auto& [variant_index, matmul_shape, bl, portion, has_bias] = GetParam();
260 11200 const auto& ukernel_variant = variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.at(variant_index);
261
262
2/4
✓ Branch 0 taken 5600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5600 times.
✗ Branch 3 not taken.
5600 if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) {
263 GTEST_SKIP() << "Unsupported CPU feature";
264 }
265
266 5600 const std::uint32_t seed = 0;
267
268 11200 const size_t M = matmul_shape.m;
269 11200 const size_t N = matmul_shape.n;
270 11200 const size_t K = matmul_shape.k;
271
272
4/4
✓ Branch 0 taken 1456 times.
✓ Branch 1 taken 4144 times.
✓ Branch 2 taken 1456 times.
✓ Branch 3 taken 4144 times.
11200 if (K % bl != 0) {
273
3/6
✓ Branch 0 taken 1456 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1456 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1456 times.
✗ Branch 5 not taken.
1456 GTEST_SKIP() << "K must be a multiple of bl";
274 }
275
276 4144 const auto mr = ukernel_variant.ukernel.interface.get_mr();
277 4144 const auto nr = ukernel_variant.ukernel.interface.get_nr();
278 4144 const auto kr = ukernel_variant.ukernel.interface.get_kr();
279 4144 const auto sr = ukernel_variant.ukernel.interface.get_sr();
280
281
4/4
✓ Branch 0 taken 2072 times.
✓ Branch 1 taken 2072 times.
✓ Branch 2 taken 784 times.
✓ Branch 3 taken 1288 times.
4144 if (mr == 1 && M > 1) {
282
3/6
✓ Branch 0 taken 1288 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1288 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1288 times.
✗ Branch 5 not taken.
1288 GTEST_SKIP() << "Kernel does not support M != 1";
283 }
284
285 2856 auto m_step = ukernel_variant.ukernel.interface.get_m_step();
286
3/14
✓ Branch 0 taken 2856 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2856 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2856 times.
2856 ASSERT_TRUE(m_step % mr == 0);
287
288 2856 auto n_step = ukernel_variant.ukernel.interface.get_n_step();
289
3/14
✓ Branch 0 taken 2856 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2856 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2856 times.
2856 ASSERT_TRUE(n_step % nr == 0);
290
291 5712 const auto rect = portion.compute_portion(M, N, m_step, n_step);
292
4/4
✓ Branch 0 taken 2632 times.
✓ Branch 1 taken 224 times.
✓ Branch 2 taken 16 times.
✓ Branch 3 taken 2616 times.
2856 if (rect.height() == 0 || rect.width() == 0) {
293
9/18
✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 240 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 240 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 240 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 240 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 240 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 240 times.
✗ Branch 17 not taken.
240 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
294 }
295
296 // Generates input data.
297 2616 const auto ref_lhs_f16 = fill_random<Float16>(M * K, seed + 0);
298
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 const auto ref_rhs = fill_random<float>(N * K, seed + 1);
299 2616 Buffer ref_biases;
300
301
2/2
✓ Branch 0 taken 1308 times.
✓ Branch 1 taken 1308 times.
2616 if (has_bias) {
302
1/2
✓ Branch 0 taken 1308 times.
✗ Branch 1 not taken.
1308 ref_biases = fill_random<float>(N, seed + 2);
303 1308 }
304 // For reference implementation, Casting FP16 input to FP32 type and FP32 output back to FP16 because the matmul
305 // implementation works with FP32 accumulation and casts the result to FP16
306
3/6
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
2616 const auto ref_lhs = cast<float, Float16>(ref_lhs_f16.data(), ref_lhs_f16.size() * 8 / size_in_bits<Float16>);
307
308 // Runs the reference implementation.
309 // * Quantizes the LHS matrix using 8-bit symmetric quantization.
310 // * Quantizes the RHS matrix using 8-bit asymmetric quantization.
311 // * Performs GEMM.
312 7848 const auto [ref_lhs_qvalues, ref_lhs_scales] =
313
3/6
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
2616 quantize_symmetric_per_block_dynamic<float, int8_t, float>(ref_lhs.data(), M, K, bl);
314 275664 const auto [ref_rhs_qai4, ref_rhs_scales, ref_rhs_zero_points] =
315
3/6
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
2616 quantize_asymmetric_per_block_dynamic<float, Int4, float, int32_t>(ref_rhs.data(), N, K, bl);
316
317 2616 const auto ref_dst_no_clamp =
318
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>(
319
5/10
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2616 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2616 times.
✗ Branch 9 not taken.
5232 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, 1, bl, ref_rhs_qai4.data(),
320
7/10
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1308 times.
✓ Branch 5 taken 1308 times.
✓ Branch 6 taken 1308 times.
✓ Branch 7 taken 1308 times.
✓ Branch 8 taken 1308 times.
✗ Branch 9 not taken.
2616 ref_rhs_scales.data(), ref_rhs_zero_points.data(), 1, bl, has_bias ? ref_biases.data() : nullptr, nullptr,
321 nullptr, 1);
322
323 // Clamps the reference output.
324 2616 const auto clamp_ratio = 0.8F;
325
2/4
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
7848 const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_no_clamp.data(), M * N, clamp_ratio);
326
4/8
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2616 times.
✗ Branch 7 not taken.
2616 const auto ref_dst_float = clamp<float>(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max);
327
328 // Cast the reference output to F16
329
3/6
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
2616 auto ref_dst = cast<Float16, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>);
330
331 // Runs the LHS packing micro-kernel.
332
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 const auto lhs_start_row = rect.start_row();
333
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 auto imp_packed_lhs = pack_lhs_qsi8d32p_f16(
334 5232 ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs_f16, K * sizeof(uint16_t), lhs_start_row,
335
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 rect.height());
336
2/4
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
5232 auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr);
337
2/4
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
5232 auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl);
338
339
4/16
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2616 times.
2616 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
340
341 // Prepare the offsets as the RHS packing micro-kernel expects the scaled zero-points in float.
342
2/4
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
5232 const size_t num_blocks_per_row = round_up_division(K, bl);
343 2616 const size_t ref_zp_size = N * num_blocks_per_row;
344 2616 const size_t ref_zp_size_in_bytes = ref_zp_size * sizeof(float);
345
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 Buffer ref_rhs_zp_f32(ref_zp_size_in_bytes);
346
2/2
✓ Branch 0 taken 262584 times.
✓ Branch 1 taken 2616 times.
265200 for (size_t i = 0; i < ref_zp_size; ++i) {
347
1/2
✓ Branch 0 taken 262584 times.
✗ Branch 1 not taken.
262584 reinterpret_cast<float*>(ref_rhs_zp_f32.data())[i] =
348
1/2
✓ Branch 0 taken 262584 times.
✗ Branch 1 not taken.
262584 -reinterpret_cast<const int32_t*>(ref_rhs_zero_points.data())[i] *
349
1/2
✓ Branch 0 taken 262584 times.
✗ Branch 1 not taken.
262584 reinterpret_cast<const float*>(ref_rhs_scales.data())[i];
350 262584 }
351
352
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 const auto rhs_start_row = rect.start_col();
353
2/4
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
5232 auto imp_packed_rhs = pack_rhs_qai4c32p(
354 7848 ukernel_variant.rhs_pack_interface, N, K, bl, nr, kr, sr, ref_rhs_qai4, has_bias, ref_biases, ref_rhs_scales,
355 2616 ref_rhs_zp_f32, ukernel_variant.rhs_s0s1_input);
356
2/4
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
5232 auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl);
357
2/4
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
5232 auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl);
358
4/16
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2616 times.
2616 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
359
360 2616 const auto dst_stride_row = N * sizeof(uint16_t);
361 2616 const auto dst_stride_col = sizeof(uint16_t);
362 5232 const auto dst_offset =
363
3/6
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
2616 ukernel_variant.ukernel.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
364
2/4
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
2616 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
365
4/16
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2616 times.
2616 ASSERT_EQ(dst_offset, ref_dst_offset);
366
367 // Runs the GEMM micro-kernel.
368
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 const auto imp_dst_size = ukernel_variant.ukernel.interface.get_dst_size(M, N);
369
5/18
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2616 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2616 times.
2616 ASSERT_EQ(imp_dst_size, ref_dst.size());
370
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 Buffer imp_dst(imp_dst_size);
371
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
5232 ukernel_variant.ukernel.interface.run_matmul(
372
4/8
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2616 times.
✗ Branch 7 not taken.
2616 rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset,
373
2/4
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
2616 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
374 7848 dst_stride_row, dst_stride_col, clamp_min, clamp_max);
375
376 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
377 // tested.
378
1/2
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
2616 DefaultMismatchHandler handler(0, 0.1, 0, 0.05);
379 2616 DataFormat dst_format = DataFormat(DataType::FP16);
380
3/6
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
2616 const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler);
381
4/16
✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2616 times.
2616 ASSERT_TRUE(success);
382 5600 }
383
384
27/56
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11200 times.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 11200 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 11200 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 11200 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 11200 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 11200 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 11200 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 11200 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 5600 times.
✓ Branch 39 taken 5600 times.
✓ Branch 40 taken 5600 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 5600 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 11200 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 8400 times.
✓ Branch 47 taken 2800 times.
✓ Branch 48 taken 8400 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 2800 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 11200 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 11200 times.
✗ Branch 55 not taken.
30803 INSTANTIATE_TEST_SUITE_P(
385 MatMul, MatMulTest_f16_qsi8d32p_qai4c32p,
386 testing::Combine(
387 testing::Range<size_t>(0, variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.size()), test_matmul_shapes,
388 test_block_lengths, test_portions, testing::Bool()),
389 [](const auto& info) {
390 const auto variant_idx = std::get<0>(info.param);
391 const std::string name{variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.at(variant_idx).ukernel.name};
392 const auto shape = std::get<MatMulShape>(info.param);
393 const auto bl = std::get<2>(info.param);
394 const auto portion = std::get<3>(info.param);
395 const auto has_bias = std::get<4>(info.param);
396
397 std::ostringstream sstream;
398 sstream << name << "__";
399 PrintTo(shape, &sstream);
400 sstream << "__BL_" << bl << "_";
401 if (has_bias) {
402 sstream << "_withBias_";
403 } else {
404 sstream << "_noBias_";
405 }
406 if (variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.at(variant_idx).rhs_s0s1_input) {
407 sstream << "_RHS_s0s1__";
408 } else {
409 sstream << "_RHS_s1s0__";
410 }
411 PrintTo(portion, &sstream);
412
413 return sstream.str();
414 });
415
416 } // namespace kai::test
417