KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.1% 101 / 0 / 104
Functions: 100.0% 16 / 0 / 16
Branches: 37.9% 194 / 0 / 512

test/tests/matmul_clamp_f16_qai8dxp_qsi4cxp_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <sstream>
14 #include <string>
15 #include <tuple>
16
17 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h"
18 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.h"
19 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod.h"
20 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp_qsi4cxp_interface.h"
22 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.h"
23 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h"
24 #include "test/common/buffer.hpp"
25 #include "test/common/cache.hpp"
26 #include "test/common/compare.hpp"
27 #include "test/common/cpu_info.hpp"
28 #include "test/common/data_format.hpp"
29 #include "test/common/float16.hpp"
30 #include "test/common/int4.hpp"
31 #include "test/common/matmul_test_common.hpp"
32 #include "test/common/matrix_portion.hpp"
33 #include "test/common/memory.hpp"
34 #include "test/common/round.hpp"
35 #include "test/common/seed.hpp"
36 #include "test/common/test_suite.hpp"
37 #include "test/reference/cast.hpp"
38 #include "test/reference/clamp.hpp"
39 #include "test/reference/fill.hpp"
40 #include "test/reference/matmul.hpp"
41 #include "test/reference/pad.hpp"
42 #include "test/reference/quantize.hpp"
43
44 namespace kai::test {
45
46 using F16Qai8Qsi4CacheDataId = std::tuple<
47 MatMulShape, //
48 DataFormat, // lhs format
49 DataFormat, // rhs format
50 DataFormat, // bias format
51 float // clamp_keep_ratio
52 >;
53
54 struct F16Qai8Qsi4CacheData {
55 Buffer ref_dst;
56 Buffer ref_rhs_qsi4;
57 Buffer ref_rhs_scales;
58 Buffer ref_lhs_f16;
59 Buffer ref_biases;
60 Range<float> clamp;
61 };
62
63 template <>
64 192 F16Qai8Qsi4CacheData ReferenceGenerator<F16Qai8Qsi4CacheDataId, F16Qai8Qsi4CacheData>::generate_reference(
65 const F16Qai8Qsi4CacheDataId& data_id) {
66 2496 auto [shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio] = data_id;
67
68 384 const size_t M = shape.m;
69 384 const size_t N = shape.n;
70 384 const size_t K = shape.k;
71
72 // Seed the random generator.
73
8/16
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 192 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 192 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 192 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 96 times.
✗ Branch 15 not taken.
576 const auto key = std::string("F16Qai8Qsi4_cache:") + std::to_string(M) + "x" + std::to_string(N) + "x" +
74
8/16
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 192 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 192 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 192 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 96 times.
✗ Branch 15 not taken.
576 std::to_string(K) + ":" + std::to_string(static_cast<uint32_t>(lhs_format.data_type())) + ":" +
75
5/10
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 96 times.
✗ Branch 9 not taken.
672 std::to_string(static_cast<uint32_t>(rhs_format.data_type())) + ":" +
76
7/14
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 192 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 192 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 96 times.
✗ Branch 13 not taken.
480 std::to_string(static_cast<uint32_t>(bias_format.data_type())) + ":" + std::to_string(clamp_keep_ratio);
77
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 auto& feed = seed_stream(key);
78
79
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
384 bool has_bias = bias_format.data_type() != DataType::UNKNOWN;
80
81
5/10
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 192 times.
✗ Branch 9 not taken.
768 Buffer lhs = fill_matrix_random(shape.m, shape.k, lhs_format, feed());
82
5/10
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 192 times.
✗ Branch 9 not taken.
768 Buffer rhs = fill_matrix_random(shape.n, shape.k, rhs_format, feed());
83
5/8
✓ Branch 0 taken 96 times.
✓ Branch 1 taken 96 times.
✓ Branch 2 taken 96 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 96 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 96 times.
✗ Branch 7 not taken.
192 Buffer bias = has_bias ? fill_matrix_random(1, shape.n, bias_format, feed()) : Buffer();
84
85
3/6
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
192 const auto ref_lhs = cast<float, Float16>(lhs.data(), lhs.size() * 8 / size_in_bits<Float16>);
86
87
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 QuantizationInfo lhs_qinfo{};
88 lhs_qinfo.quant_width = K;
89 lhs_qinfo.dst_type = DataType::QAI8;
90 lhs_qinfo.scale_type = DataType::FP32;
91 lhs_qinfo.zero_point_type = DataType::I32;
92 const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo);
93
94 QuantizationInfo rhs_qinfo{};
95 rhs_qinfo.quant_width = K;
96 rhs_qinfo.dst_type = DataType::QSI4;
97 rhs_qinfo.scale_type = DataType::FP32;
98 auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(rhs.data(), DataType::FP32, N, K, rhs_qinfo);
99
100 const auto ref_dst_no_clamp =
101 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>(
102 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), 1, K,
103 ref_rhs_quant.data(), rhs_qoutputs.scales.data(), nullptr, 1, K, has_bias ? bias.data() : nullptr, nullptr,
104 nullptr, 1);
105
106 const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_no_clamp.data(), M * N, clamp_keep_ratio);
107 const auto ref_dst_float = clamp<float>(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max);
108 auto ref_dst = cast<Float16, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>);
109
110 F16Qai8Qsi4CacheData out;
111 out.ref_dst = std::move(ref_dst);
112 out.ref_rhs_qsi4 = std::move(ref_rhs_quant);
113 out.ref_rhs_scales = std::move(rhs_qoutputs.scales);
114 out.ref_lhs_f16 = std::move(lhs);
115 out.ref_biases = std::move(bias);
116 out.clamp = {clamp_min, clamp_max};
117
118 return out;
119 }
120
121 3 static const std::array<UkernelVariant<kai_matmul_clamp_f16_qai8dxp_qsi4cxp_ukernel>, 4>
122
0/4
✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
7 variants_kai_matmul_clamp_f16_qai8dxp_qsi4cxp = {{
123
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod),
124
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f16_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod", cpu_has_dotprod_and_fp16},
125
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod),
126
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f16_qai8dxp4x4_qsi4cxp4x4_16x4_neon_dotprod", cpu_has_dotprod_and_fp16},
127
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod),
128
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod", cpu_has_dotprod_and_fp16},
129
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm),
130
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f16_qai8dxp4x8_qsi4cxp4x8_16x4_neon_i8mm", cpu_has_i8mm_and_fp16},
131 }};
132
133 class MatMulTest_f16_qai8dxp_qsi4cxp : public ::testing::TestWithParam<MatMulClampTestPortionedParamsWithBias> {};
134
135
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
13446 TEST_P(MatMulTest_f16_qai8dxp_qsi4cxp, EndToEnd) {
136 41664 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio, has_bias] = GetParam();
137 10752 const auto& ukernel_variant = variants_kai_matmul_clamp_f16_qai8dxp_qsi4cxp.at(variant_index);
138
139
2/4
✓ Branch 0 taken 5376 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5376 times.
✗ Branch 3 not taken.
5376 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
140 GTEST_SKIP() << "Unsupported CPU feature";
141 }
142
143 10752 const size_t M = matmul_shape.m;
144 10752 const size_t N = matmul_shape.n;
145 10752 const size_t K = matmul_shape.k;
146
147 5376 const auto mr = ukernel_variant.interface.get_mr();
148 5376 const auto nr = ukernel_variant.interface.get_nr();
149 5376 const auto kr = ukernel_variant.interface.get_kr();
150 5376 const auto sr = ukernel_variant.interface.get_sr();
151
152
4/4
✓ Branch 0 taken 2688 times.
✓ Branch 1 taken 2688 times.
✓ Branch 2 taken 1008 times.
✓ Branch 3 taken 1680 times.
5376 if (mr == 1 && M > 1) {
153
3/6
✓ Branch 0 taken 1680 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1680 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1680 times.
✗ Branch 5 not taken.
1680 GTEST_SKIP() << "Kernel does not support M != 1";
154 }
155
156 3696 auto m_step = ukernel_variant.interface.get_m_step();
157
3/14
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 3696 times.
3696 ASSERT_TRUE(m_step % mr == 0);
158
159 3696 auto n_step = ukernel_variant.interface.get_n_step();
160
3/14
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 3696 times.
3696 ASSERT_TRUE(n_step % nr == 0);
161
162 7392 const auto rect = portion.compute_portion(M, N, m_step, n_step);
163
2/4
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3696 times.
3696 if (rect.height() == 0 || rect.width() == 0) {
164 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
165 }
166
167 3696 const auto lhs_format = DataFormat(DataType::FP16);
168 3696 const auto rhs_format = DataFormat(DataType::FP32);
169
4/4
✓ Branch 0 taken 1848 times.
✓ Branch 1 taken 1848 times.
✓ Branch 2 taken 1848 times.
✓ Branch 3 taken 1848 times.
7392 const auto bias_format = has_bias ? DataFormat(DataType::FP32) : DataFormat(DataType::UNKNOWN);
170
171 7392 const F16Qai8Qsi4CacheDataId testdata_id = {matmul_shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio};
172 3696 const F16Qai8Qsi4CacheData& testdata = getV<F16Qai8Qsi4CacheDataId, F16Qai8Qsi4CacheData>(testdata_id);
173
174 3696 const auto& ref_lhs_f16 = testdata.ref_lhs_f16;
175 3696 const auto& ref_rhs_qsi4 = testdata.ref_rhs_qsi4;
176 3696 const auto& ref_biases = testdata.ref_biases;
177 3696 const auto& ref_rhs_scales = testdata.ref_rhs_scales;
178 3696 const auto& ref_dst = testdata.ref_dst;
179 11088 auto [clamp_min, clamp_max] = testdata.clamp;
180 // Runs the LHS packing micro-kernel.
181 3696 const auto lhs_start_row = rect.start_row();
182 3696 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(M, K, mr, kr, sr);
183 3696 Buffer imp_packed_lhs(imp_packed_lhs_size);
184
185 3696 auto lhs_stride = K * sizeof(uint16_t);
186
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, lhs_stride);
187
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, K, mr, kr, sr);
188
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
189
190
4/16
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3696 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 3696 times.
✗ Branch 15 not taken.
3696 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
191
192
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 kai_run_lhs_quant_pack_qai8dxp_f16_neon(
193
2/4
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
3696 rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_f16.data() + lhs_offset, lhs_stride,
194
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 imp_packed_lhs.data() + lhs_packed_offset);
195
196
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
7392 const auto ref_rhs_qsi4_padded = pad_row<Int4>(
197
4/8
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3696 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3696 times.
✗ Branch 7 not taken.
3696 ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2));
198
199
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr);
200
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 Buffer imp_packed_rhs(imp_packed_rhs_size);
201
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 const auto rhs_start_row = rect.start_col();
202
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr);
203
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
204
4/16
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3696 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 3696 times.
3696 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
205
206 // Runs the RHS packing micro-kernel.
207 3696 kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{};
208 3696 params.lhs_zero_point = 1;
209 3696 params.rhs_zero_point = 0;
210
211
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(
212
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()),
213
3/4
✓ Branch 0 taken 1848 times.
✓ Branch 1 taken 1848 times.
✓ Branch 2 taken 1848 times.
✗ Branch 3 not taken.
3696 has_bias ? reinterpret_cast<const float*>(ref_biases.data()) : nullptr,
214
2/4
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
3696 reinterpret_cast<const float*>(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, &params);
215
216 3696 const auto dst_stride_row = N * sizeof(uint16_t);
217 3696 const auto dst_stride_col = sizeof(uint16_t);
218 7392 const auto dst_offset =
219
3/6
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3696 times.
✗ Branch 5 not taken.
3696 ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
220
2/4
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
3696 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
221
4/16
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3696 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 3696 times.
3696 ASSERT_EQ(dst_offset, ref_dst_offset);
222
223 // Runs the GEMM micro-kernel.
224
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
225
5/18
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3696 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3696 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 3696 times.
3696 ASSERT_EQ(imp_dst_size, ref_dst.size());
226
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 Buffer imp_dst(imp_dst_size);
227
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
7392 ukernel_variant.interface.run_matmul(
228
3/6
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3696 times.
✗ Branch 5 not taken.
3696 rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
229
2/4
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
3696 imp_packed_rhs.data() + rhs_matmul_offset, imp_dst.data() + dst_offset, dst_stride_row, dst_stride_col,
230 7392 clamp_min, clamp_max);
231
232 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
233 // tested.
234
1/2
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
3696 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
235 3696 DataFormat dst_format = DataFormat(DataType::FP16);
236
3/6
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3696 times.
✗ Branch 5 not taken.
3696 const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler);
237
4/16
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3696 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 3696 times.
3696 ASSERT_TRUE(success);
238 5376 }
239
37/126
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✓ Branch 30 taken 2 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 time.
✓ Branch 32 taken 2 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 time.
✓ Branch 34 taken 2 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 2 times.
✓ Branch 36 taken 2688 times.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✓ Branch 38 taken 5376 times.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 59 not taken.
✗ Branch 60 not taken.
✗ Branch 60 not taken.
✗ Branch 61 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✗ Branch 63 not taken.
✗ Branch 64 not taken.
✓ Branch 64 taken 2688 times.
✗ Branch 65 not taken.
✗ Branch 65 not taken.
✓ Branch 66 taken 5376 times.
✗ Branch 67 not taken.
✓ Branch 68 taken 5376 times.
✗ Branch 69 not taken.
16134 INSTANTIATE_TEST_SUITE_P(
240 MatMul, MatMulTest_f16_qai8dxp_qsi4cxp,
241 testing::Combine(
242 testing::Range<size_t>(0, variants_kai_matmul_clamp_f16_qai8dxp_qsi4cxp.size()),
243 testing::Values(
244 MatMulShape{1, 2, 32}, //
245 MatMulShape{1, 3, 32}, //
246 MatMulShape{1, 4, 32}, //
247 MatMulShape{1, 5, 31}, //
248 MatMulShape{3, 3, 32}, //
249 MatMulShape{4, 4, 32}, //
250 MatMulShape{5, 5, 31}, //
251 MatMulShape{16, 32, 64}, //
252 MatMulShape{16, 32, 36}, //
253 MatMulShape{15, 35, 65}, //
254 MatMulShape{8, 32, 64}, //
255 MatMulShape{15, 31, 45}, //
256 MatMulShape{1, 35, 65}, //
257 MatMulShape{1, 128, 32}, //
258 MatMulShape{64, 128, 32}, //
259 MatMulShape{77, 99, 64}),
260 testing::Values(
261 MatrixPortion(0, 0, 1, 1), // Full matrix.
262 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
263 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
264 MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle
265 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
266 MatrixPortion(0.75, 0, 1, 1), // Partial rows
267 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
268 ),
269 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f})), // clamp_keep_ratio
270 testing::Bool()),
271 [](const auto& info) {
272 const auto variant_idx = std::get<0>(info.param);
273 const std::string name{variants_kai_matmul_clamp_f16_qai8dxp_qsi4cxp.at(variant_idx).name};
274 const auto shape = std::get<MatMulShape>(info.param);
275 const auto portion = std::get<2>(info.param);
276 const auto clamp_keep_ratio = std::get<3>(info.param);
277 const auto has_bias = std::get<4>(info.param);
278
279 return test_description(name, shape, portion, has_bias, clamp_keep_ratio);
280 });
281
282 } // namespace kai::test
283