KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.3% 177 / 0 / 182
Functions: 100.0% 20 / 0 / 20
Branches: 36.1% 262 / 0 / 726

test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <limits>
14 #include <sstream>
15 #include <string>
16 #include <tuple>
17 #include <vector>
18
19 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h"
20 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h"
21 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h"
22 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h"
23 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h"
24 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h"
25 #include "test/common/bfloat16.hpp"
26 #include "test/common/buffer.hpp"
27 #include "test/common/cache.hpp"
28 #include "test/common/compare.hpp"
29 #include "test/common/cpu_info.hpp"
30 #include "test/common/data_format.hpp"
31 #include "test/common/int4.hpp"
32 #include "test/common/matmul_test_common.hpp"
33 #include "test/common/matrix_portion.hpp"
34 #include "test/common/memory.hpp"
35 #include "test/common/round.hpp"
36 #include "test/common/seed.hpp"
37 #include "test/common/test_suite.hpp"
38 #include "test/reference/cast.hpp"
39 #include "test/reference/clamp.hpp"
40 #include "test/reference/fill.hpp"
41 #include "test/reference/matmul.hpp"
42 #include "test/reference/pack.hpp"
43 #include "test/reference/pad.hpp"
44 #include "test/reference/quantize.hpp"
45 #include "test/reference/transpose.hpp"
46
47 // Using BFloat truncate implementation (BFloat16<false>) to match existing packing/inference
48
49 namespace kai::test {
50
51 using Bf16Qai8Qsi4CacheDataId = std::tuple< //
52 MatMulShape, //
53 DataFormat, // lhs format
54 DataFormat, // rhs format
55 DataFormat, // bias format
56 float // clamp_keep_ratio
57 >;
58
59 struct Bf16Qai8Qsi4CacheData {
60 Buffer ref_dst_nt_t;
61 Buffer ref_dst_nt_nt;
62 Buffer ref_rhs_qsi4_nt_t;
63 Buffer ref_rhs_qsi4_nt_nt;
64 Buffer ref_rhs_scales;
65 Buffer ref_lhs_bf16;
66 Buffer ref_biases_buf;
67 Range<float> clamp_nt_nt;
68 Range<float> clamp_nt_t;
69 };
70
71 template <>
72 168 Bf16Qai8Qsi4CacheData ReferenceGenerator<Bf16Qai8Qsi4CacheDataId, Bf16Qai8Qsi4CacheData>::generate_reference(
73 const Bf16Qai8Qsi4CacheDataId& data_id) {
74 2184 auto [shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio] = data_id;
75
76 336 size_t M = shape.m;
77 336 size_t N = shape.n;
78 336 size_t K = shape.k;
79
80 // Seed the random generator.
81
8/16
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 168 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 168 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 84 times.
✗ Branch 15 not taken.
504 const auto key = std::string("Bf16Qai8Qsi4_cache:") + std::to_string(M) + "x" + std::to_string(N) + "x" +
82
8/16
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 168 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 168 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 84 times.
✗ Branch 15 not taken.
504 std::to_string(K) + ":" + std::to_string(static_cast<uint32_t>(lhs_format.data_type())) + ":" +
83
5/10
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 84 times.
✗ Branch 9 not taken.
588 std::to_string(static_cast<uint32_t>(rhs_format.data_type())) + ":" +
84
7/14
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 168 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 84 times.
✗ Branch 13 not taken.
420 std::to_string(static_cast<uint32_t>(bias_format.data_type())) + ":" + std::to_string(clamp_keep_ratio);
85
1/2
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
168 auto& feed = seed_stream(key);
86
87
2/4
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
336 bool has_bias = bias_format.data_type() != DataType::UNKNOWN;
88
5/10
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
672 Buffer lhs = fill_matrix_random(shape.m, shape.k, lhs_format, feed());
89
5/10
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
672 Buffer ref_rhs = fill_matrix_random(shape.n, shape.k, rhs_format, feed());
90
5/8
✓ Branch 0 taken 84 times.
✓ Branch 1 taken 84 times.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
168 Buffer bias = has_bias ? fill_matrix_random(1, shape.n, bias_format, feed()) : Buffer();
91
92 168 Bf16Qai8Qsi4CacheData out;
93 // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BFP16 because the matmul
94 // implementation works with FP32 accumulation and casts the result to BFP16
95
1/2
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
336 const auto ref_lhs = cast<float, BFloat16<false>>(
96
1/2
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
168 lhs.data(), //
97
1/2
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
168 lhs.size() * 8 / size_in_bits<BFloat16<false>>);
98
99 // Transposed(nxk) RHS dimensions
100 168 const size_t ref_rhs_qsi4_nxk_stride = K;
101
102 // Non-Transposed(kxn) RHS dimensions
103
1/2
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
168 const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2);
104
1/2
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
168 const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(K * ref_rhs_qsi4_kxn_stride, 2);
105
106
1/2
✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
168 QuantizationInfo lhs_qinfo{};
107 lhs_qinfo.quant_width = K;
108 lhs_qinfo.dst_type = DataType::QAI8;
109 lhs_qinfo.scale_type = DataType::FP32;
110 lhs_qinfo.zero_point_type = DataType::I32;
111 const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo);
112
113 QuantizationInfo rhs_qinfo{};
114 rhs_qinfo.quant_width = K;
115 rhs_qinfo.dst_type = DataType::QSI4;
116 rhs_qinfo.scale_type = DataType::FP32;
117 auto [ref_rhs_quant_t, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo);
118
119 auto ref_rhs_qsi4 = transpose_with_padding<Int4>(
120 ref_rhs_quant_t.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes);
121
122 const auto ref_dst_nt_nt = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
123 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K,
124 ref_rhs_qsi4.data(), rhs_qoutputs.scales.data(), nullptr, K, has_bias ? bias.data() : nullptr,
125 std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
126
127 const auto [clamp_min_nt_nt, clamp_max_nt_nt] =
128 find_clamp_range<float>(ref_dst_nt_nt.data(), M * N, clamp_keep_ratio);
129 out.ref_rhs_qsi4_nt_nt = std::move(ref_rhs_qsi4);
130
131 const auto ref_dst_float_nt_nt = clamp<float>(ref_dst_nt_nt.data(), M * N, clamp_min_nt_nt, clamp_max_nt_nt);
132
133 auto ref_dst_nt_nt_bf16 =
134 cast<BFloat16<false>, float>(ref_dst_float_nt_nt.data(), ref_dst_float_nt_nt.size() * 8 / size_in_bits<float>);
135 out.ref_dst_nt_nt = std::move(ref_dst_nt_nt_bf16);
136
137 out.clamp_nt_nt = {clamp_min_nt_nt, clamp_max_nt_nt};
138
139 const auto ref_dst_nt_t =
140 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>(
141 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), 1, K,
142 ref_rhs_quant_t.data(), rhs_qoutputs.scales.data(), nullptr, 1, K, has_bias ? bias.data() : nullptr,
143 nullptr, nullptr, 1);
144
145 const auto [clamp_min_nt_t, clamp_max_nt_t] = find_clamp_range<float>(ref_dst_nt_t.data(), M * N, clamp_keep_ratio);
146 out.ref_rhs_qsi4_nt_t = std::move(ref_rhs_quant_t);
147 const auto ref_dst_nt_t_float = clamp<float>(ref_dst_nt_t.data(), M * N, clamp_min_nt_t, clamp_max_nt_t);
148
149 auto ref_dst_nt_t_bf16 =
150 cast<BFloat16<false>, float>(ref_dst_nt_t_float.data(), ref_dst_nt_t_float.size() * 8 / size_in_bits<float>);
151 out.ref_dst_nt_t = std::move(ref_dst_nt_t_bf16);
152 out.clamp_nt_t = {clamp_min_nt_t, clamp_max_nt_t};
153 out.ref_lhs_bf16 = std::move(lhs);
154 out.ref_biases_buf = std::move(bias);
155 out.ref_rhs_scales = std::move(rhs_qoutputs.scales);
156 return out;
157 }
158
159 3 static const std::array<UkernelVariant<kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel>, 2>
160
0/4
✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
5 variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp = {{
161
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod),
162
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", cpu_has_dotprod_and_bf16},
163
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm),
164
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", cpu_has_i8mm_and_bf16},
165 }};
166
167 using MatMulClampTestPortionedParamsWithBias = std::tuple<size_t, MatMulShape, MatrixPortion, float, bool>;
168 class MatMulTest_bf16_qai8dxp_qsi4cxp : public ::testing::TestWithParam<MatMulClampTestPortionedParamsWithBias> {};
169
170
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5886 TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_NxK) {
171 21168 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio, has_bias] = GetParam();
172 4704 const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index);
173
174
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
2352 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
175 GTEST_SKIP() << "CPU features are not supported by current CPU";
176 }
177
178 4704 const size_t M = matmul_shape.m;
179 4704 const size_t N = matmul_shape.n;
180 4704 const size_t K = matmul_shape.k;
181
182 2352 const auto mr = ukernel_variant.interface.get_mr();
183 2352 const auto nr = ukernel_variant.interface.get_nr();
184 2352 const auto kr = ukernel_variant.interface.get_kr();
185 2352 const auto sr = ukernel_variant.interface.get_sr();
186
187 2352 const auto lhs_format = DataFormat(DataType::BF16);
188 2352 const auto rhs_format = DataFormat(DataType::FP32);
189
4/4
✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 1176 times.
✓ Branch 2 taken 1176 times.
✓ Branch 3 taken 1176 times.
4704 const auto bias_format = has_bias ? DataFormat(DataType::FP32) : DataFormat(DataType::UNKNOWN);
190
191 2352 auto m_step = ukernel_variant.interface.get_m_step();
192
3/14
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2352 times.
2352 ASSERT_TRUE(m_step % mr == 0);
193
194 2352 auto n_step = ukernel_variant.interface.get_n_step();
195
3/14
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2352 times.
2352 ASSERT_TRUE(n_step % nr == 0);
196
197 4704 const auto rect = portion.compute_portion(M, N, m_step, n_step);
198
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2352 times.
2352 if (rect.height() == 0 || rect.width() == 0) {
199 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
200 }
201
202 4704 const Bf16Qai8Qsi4CacheDataId testdata_id = {matmul_shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio};
203 2352 const Bf16Qai8Qsi4CacheData& testdata = getV<Bf16Qai8Qsi4CacheDataId, Bf16Qai8Qsi4CacheData>(testdata_id);
204
205 2352 const auto& ref_lhs_bf16 = testdata.ref_lhs_bf16;
206 2352 const auto& ref_rhs_qsi4 = testdata.ref_rhs_qsi4_nt_t;
207 2352 const auto& ref_biases_buf = testdata.ref_biases_buf;
208 2352 const auto& ref_rhs_scales = testdata.ref_rhs_scales;
209 2352 const auto& ref_dst = testdata.ref_dst_nt_t;
210 7056 auto [clamp_min, clamp_max] = testdata.clamp_nt_t;
211 2352 const auto lhs_start_row = rect.start_row();
212 2352 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr);
213 2352 Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size);
214
215 2352 auto lhs_stride = K * sizeof(uint16_t);
216
217
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride);
218
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr);
219
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
220
221
4/16
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 2352 times.
✗ Branch 15 not taken.
2352 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
222
223
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 kai_run_lhs_quant_pack_qai8dxp_bf16_neon(
224
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
2352 rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride,
225
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 reinterpret_cast<uint8_t*>(imp_packed_lhs_buf.data()) + lhs_packed_offset);
226
227
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
4704 const auto ref_rhs_qsi4_padded = pad_row<Int4>(
228
4/8
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2352 times.
✗ Branch 7 not taken.
2352 ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2));
229
230
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr);
231
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size);
232
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 const auto rhs_start_row = rect.start_col();
233
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr);
234
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
235
4/16
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
2352 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
236 // Runs the RHS packing micro-kernel.
237 2352 kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{};
238 2352 params.lhs_zero_point = 1;
239 2352 params.rhs_zero_point = 0;
240
241
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(
242
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()),
243
3/4
✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 1176 times.
✓ Branch 2 taken 1176 times.
✗ Branch 3 not taken.
2352 has_bias ? reinterpret_cast<const float*>(ref_biases_buf.data()) : nullptr,
244
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
2352 reinterpret_cast<const float*>(ref_rhs_scales.data()), reinterpret_cast<uint8_t*>(imp_packed_rhs_buf.data()), 0,
245 &params);
246
247 2352 const auto dst_stride_row = N * sizeof(uint16_t);
248 2352 const auto dst_stride_col = sizeof(uint16_t);
249 4704 const auto dst_offset =
250
3/6
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
2352 ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
251
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
2352 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
252
4/16
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
2352 ASSERT_EQ(dst_offset, ref_dst_offset);
253
254 // Runs the GEMM micro-kernel.
255
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
256
5/18
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2352 times.
2352 ASSERT_EQ(imp_dst_size, ref_dst.size());
257
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 Buffer imp_dst_buf = Buffer(imp_dst_size);
258
259
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
4704 ukernel_variant.interface.run_matmul(
260
3/6
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
2352 rect.height(), rect.width(), K, reinterpret_cast<const uint8_t*>(imp_packed_lhs_buf.data()) + lhs_matmul_offset,
261
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 reinterpret_cast<const uint8_t*>(imp_packed_rhs_buf.data()) + rhs_matmul_offset,
262
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 reinterpret_cast<uint8_t*>(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min,
263 2352 clamp_max);
264
265 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
266 // tested.
267
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
268 2352 DataFormat dst_format = DataFormat(DataType::BF16);
269 4704 const auto success =
270
3/6
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
2352 compare(reinterpret_cast<const uint8_t*>(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler);
271
4/16
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
2352 ASSERT_TRUE(success);
272 2352 }
273
274
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5886 TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_KxN) {
275 21168 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio, has_bias] = GetParam();
276 4704 const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index);
277
278
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
2352 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
279 GTEST_SKIP() << "CPU features are not supported by current CPU";
280 }
281
282 4704 const size_t M = matmul_shape.m;
283 4704 const size_t N = matmul_shape.n;
284 4704 const size_t K = matmul_shape.k;
285
286 2352 const auto mr = ukernel_variant.interface.get_mr();
287 2352 const auto nr = ukernel_variant.interface.get_nr();
288 2352 const auto kr = ukernel_variant.interface.get_kr();
289 2352 const auto sr = ukernel_variant.interface.get_sr();
290
291 2352 const auto lhs_format = DataFormat(DataType::BF16);
292 2352 const auto rhs_format = DataFormat(DataType::FP32);
293
4/4
✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 1176 times.
✓ Branch 2 taken 1176 times.
✓ Branch 3 taken 1176 times.
4704 const auto bias_format = has_bias ? DataFormat(DataType::FP32) : DataFormat(DataType::UNKNOWN);
294
295 // Generates input data.
296 4704 const Bf16Qai8Qsi4CacheDataId testdata_id = {matmul_shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio};
297 2352 const Bf16Qai8Qsi4CacheData& testdata = getV<Bf16Qai8Qsi4CacheDataId, Bf16Qai8Qsi4CacheData>(testdata_id);
298
299 2352 const auto& ref_lhs_bf16 = testdata.ref_lhs_bf16;
300 2352 const auto& ref_rhs_qsi4 = testdata.ref_rhs_qsi4_nt_nt;
301 2352 const auto& ref_biases_buf = testdata.ref_biases_buf;
302 2352 const auto& ref_rhs_scales = testdata.ref_rhs_scales;
303 2352 const auto& ref_dst = testdata.ref_dst_nt_nt;
304 7056 auto [clamp_min, clamp_max] = testdata.clamp_nt_nt;
305 2352 auto m_step = ukernel_variant.interface.get_m_step();
306
3/14
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2352 times.
2352 ASSERT_TRUE(m_step % mr == 0);
307
308 2352 auto n_step = ukernel_variant.interface.get_n_step();
309
3/14
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2352 times.
2352 ASSERT_TRUE(n_step % nr == 0);
310
311 4704 const auto rect = portion.compute_portion(M, N, m_step, n_step);
312
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2352 times.
2352 if (rect.height() == 0 || rect.width() == 0) {
313 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
314 }
315
316 2352 const auto lhs_start_row = rect.start_row();
317 2352 size_t lhs_stride = K * sizeof(uint16_t);
318
319 // Runs the LHS packing micro-kernel.
320 2352 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr);
321 2352 Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size);
322
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride);
323
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr);
324
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
325
4/16
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 2352 times.
✗ Branch 15 not taken.
2352 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
326
327
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 kai_run_lhs_quant_pack_qai8dxp_bf16_neon(
328
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
2352 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, ref_lhs_bf16.data() + lhs_offset, lhs_stride,
329
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 reinterpret_cast<uint8_t*>(imp_packed_lhs_buf.data()) + lhs_packed_offset);
330
331 // Runs the RHS packing micro-kernel.
332 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
333 // * Packs the RHS matrix.
334
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
4704 const auto ref_rhs_qsi4_padded = pad_row<Int4>(
335
4/8
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2352 times.
✗ Branch 7 not taken.
2352 ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2));
336
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr);
337
338
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 const auto rhs_start_row = rect.start_col();
339
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr);
340
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
341
4/16
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
2352 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
342
343
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size);
344 2352 kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{};
345 2352 params.lhs_zero_point = 1;
346 2352 params.rhs_zero_point = 0;
347
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(
348
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()),
349
3/4
✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 1176 times.
✓ Branch 2 taken 1176 times.
✗ Branch 3 not taken.
2352 has_bias ? reinterpret_cast<const float*>(ref_biases_buf.data()) : nullptr,
350
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
2352 reinterpret_cast<const float*>(ref_rhs_scales.data()), imp_packed_rhs_buf.data(), 0, &params);
351
352 2352 const auto dst_stride = N * sizeof(uint16_t);
353
3/6
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
2352 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
354
2/4
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
2352 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(uint16_t);
355
4/16
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
2352 ASSERT_EQ(dst_offset, ref_dst_offset);
356
357 // Runs the GEMM micro-kernel.
358
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
359
5/18
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2352 times.
2352 ASSERT_EQ(imp_dst_size, ref_dst.size());
360
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 Buffer imp_dst_buf = Buffer(imp_dst_size);
361
362 2352 const auto dst_stride_row = N * sizeof(uint16_t);
363 2352 const auto dst_stride_col = sizeof(uint16_t);
364
365
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
4704 ukernel_variant.interface.run_matmul(
366
3/6
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
2352 rect.height(), rect.width(), K, reinterpret_cast<const uint8_t*>(imp_packed_lhs_buf.data()) + lhs_matmul_offset,
367
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 reinterpret_cast<const uint8_t*>(imp_packed_rhs_buf.data()) + rhs_matmul_offset,
368
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 reinterpret_cast<uint8_t*>(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min,
369 2352 clamp_max);
370
371 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
372 // tested.
373
1/2
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
2352 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
374 2352 DataFormat dst_format = DataFormat(DataType::BF16);
375 4704 const auto success =
376
3/6
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
2352 compare(reinterpret_cast<const uint8_t*>(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler);
377
4/16
✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
2352 ASSERT_TRUE(success);
378 2352 }
379
380
34/120
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✓ Branch 24 taken 4 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 4 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✓ Branch 28 taken 4 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✓ Branch 30 taken 4 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✓ Branch 32 taken 4 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✓ Branch 34 taken 4 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 4 times.
✓ Branch 36 taken 2352 times.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✓ Branch 38 taken 4704 times.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 59 not taken.
✗ Branch 60 not taken.
✗ Branch 60 not taken.
✗ Branch 61 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✗ Branch 63 not taken.
✗ Branch 64 not taken.
✗ Branch 65 not taken.
14121 INSTANTIATE_TEST_SUITE_P(
381 MatMul, MatMulTest_bf16_qai8dxp_qsi4cxp,
382 testing::Combine(
383 testing::Range<size_t>(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.size()),
384 testing::Values(
385 MatMulShape{1, 2, 32}, //
386 MatMulShape{1, 3, 32}, //
387 MatMulShape{1, 4, 32}, //
388 MatMulShape{1, 5, 32}, //
389 MatMulShape{3, 3, 32}, //
390 MatMulShape{4, 4, 32}, //
391 MatMulShape{5, 5, 32}, //
392 MatMulShape{32, 64, 64}, //
393 MatMulShape{16, 32, 64}, //
394 MatMulShape{8, 32, 64}, //
395 MatMulShape{15, 32, 32}, //
396 MatMulShape{77, 99, 64}, //
397 MatMulShape{77, 99, 66}, //
398 MatMulShape{77, 99, 31}),
399 testing::Values(
400 MatrixPortion(0, 0, 1, 1), // Full matrix.
401 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
402 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
403 MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle
404 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
405 MatrixPortion(0.75, 0, 1, 1), // Partial rows
406 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
407 ),
408 testing::ValuesIn(std::initializer_list<float>{1.0f, 0.9f, 0.5f}), //
409 testing::Bool()),
410 [](const auto& info) -> std::string {
411 const auto variant_idx = std::get<0>(info.param);
412 const auto& name = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp[variant_idx].name;
413 return test_description(
414 name, std::get<MatMulShape>(info.param), std::get<2>(info.param), std::get<4>(info.param),
415 std::get<3>(info.param));
416 });
417
418 } // namespace kai::test
419