KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.9% 178 0 180
Functions: 100.0% 13 0 13
Branches: 41.0% 281 0 686

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <limits>
14 #include <sstream>
15 #include <string>
16 #include <tuple>
17 #include <vector>
18
19 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h"
20 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h"
21 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h"
22 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h"
23 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h"
24 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h"
25 #include "test/common/bfloat16.hpp"
26 #include "test/common/buffer.hpp"
27 #include "test/common/compare.hpp"
28 #include "test/common/cpu_info.hpp"
29 #include "test/common/data_format.hpp"
30 #include "test/common/int4.hpp"
31 #include "test/common/matmul_test_common.hpp"
32 #include "test/common/matrix_portion.hpp"
33 #include "test/common/memory.hpp"
34 #include "test/common/round.hpp"
35 #include "test/common/test_suite.hpp"
36 #include "test/reference/cast.hpp"
37 #include "test/reference/clamp.hpp"
38 #include "test/reference/fill.hpp"
39 #include "test/reference/matmul.hpp"
40 #include "test/reference/pack.hpp"
41 #include "test/reference/pad.hpp"
42 #include "test/reference/quantize.hpp"
43 #include "test/reference/transpose.hpp"
44
45 // Using BFloat truncate implementation (BFloat16<false>) to match existing packing/inference
46
47 namespace kai::test {
48
49 static const std::array<UkernelVariant<kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel>, 2>
50 variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp = {{
51 {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod),
52 "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", cpu_has_dotprod_and_bf16},
53 {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm),
54 "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", cpu_has_i8mm_and_bf16},
55 }};
56
57 class MatMulTest_bf16_qai8dxp_qsi4cxp : public ::testing::TestWithParam<MatMulTestPortionedParamsWithBias> {};
58
59
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1178 TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_NxK) {
60 3468 const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam();
61 784 const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index);
62
63
2/4
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
392 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
64 GTEST_SKIP() << "CPU features are not supported by current CPU";
65 }
66
67 392 const std::uint32_t seed = 0;
68
69 784 const size_t M = matmul_shape.m;
70 784 const size_t N = matmul_shape.n;
71 784 const size_t K = matmul_shape.k;
72
73 392 const auto mr = ukernel_variant.interface.get_mr();
74 392 const auto nr = ukernel_variant.interface.get_nr();
75 392 const auto kr = ukernel_variant.interface.get_kr();
76 392 const auto sr = ukernel_variant.interface.get_sr();
77
78 392 auto m_step = ukernel_variant.interface.get_m_step();
79
3/14
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 392 times.
392 ASSERT_TRUE(m_step % mr == 0);
80
81 392 auto n_step = ukernel_variant.interface.get_n_step();
82
3/14
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 392 times.
392 ASSERT_TRUE(n_step % nr == 0);
83
84 784 const auto rect = portion.compute_portion(M, N, m_step, n_step);
85
4/4
✓ Branch 0 taken 376 times.
✓ Branch 1 taken 16 times.
✓ Branch 2 taken 4 times.
✓ Branch 3 taken 372 times.
392 if (rect.height() == 0 || rect.width() == 0) {
86
9/18
✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 20 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 20 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 20 times.
✗ Branch 17 not taken.
20 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
87 }
88
89 // Generates input data.
90 372 const auto ref_lhs_bf16 = fill_random<BFloat16<false>>(M * K, seed + 0);
91
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto ref_rhs = fill_random<float>(N * K, seed + 1);
92
93 372 Buffer ref_biases_buf;
94
2/2
✓ Branch 0 taken 186 times.
✓ Branch 1 taken 186 times.
372 if (has_bias) {
95
1/2
✓ Branch 0 taken 186 times.
✗ Branch 1 not taken.
186 ref_biases_buf = Buffer(fill_random<float>(N, seed + 2));
96 186 }
97
98 // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BFP16 because the matmul
99 // implementation works with FP32 accumulation and casts the result to BFP16
100 372 const auto ref_lhs =
101
3/6
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
372 cast<float, BFloat16<false>>(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits<BFloat16<false>>);
102
103 // Runs the reference implementation.
104 // * Quantizes the LHS matrix using 8-bit symmetric quantization.
105 // * Quantizes the RHS matrix using 8-bit asymmetric quantization.
106 // * Performs GEMM.
107 1116 const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
108
2/4
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
372 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K);
109 1860 const auto [ref_rhs_qsi4, ref_rhs_scales] =
110
2/4
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
372 quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K);
111 372 const auto ref_dst_no_clamp =
112
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>(
113
4/8
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
744 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K,
114
7/10
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 186 times.
✓ Branch 5 taken 186 times.
✓ Branch 6 taken 186 times.
✓ Branch 7 taken 186 times.
✓ Branch 8 taken 186 times.
✗ Branch 9 not taken.
372 ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, K, has_bias ? ref_biases_buf.data() : nullptr,
115 nullptr, nullptr, 1);
116
117 // Clamps the reference output.
118 372 const auto clamp_ratio = 0.8F;
119
2/4
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
1116 const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_no_clamp.data(), M * N, clamp_ratio);
120
4/8
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
372 const auto ref_dst_float = clamp<float>(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max);
121
122 // Cast the reference output to BF16
123
3/6
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
372 auto ref_dst = cast<BFloat16<false>, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>);
124
125 // Runs the LHS packing micro-kernel.
126
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto lhs_start_row = rect.start_row();
127
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr);
128
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size);
129
130 372 auto lhs_stride = K * sizeof(uint16_t);
131
132
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride);
133
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr);
134
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
135
136
4/16
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 372 times.
✗ Branch 15 not taken.
372 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
137
138
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 kai_run_lhs_quant_pack_qai8dxp_bf16_neon(
139
2/4
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
372 rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride,
140
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 reinterpret_cast<uint8_t*>(imp_packed_lhs_buf.data()) + lhs_packed_offset);
141
142
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
744 const auto ref_rhs_qsi4_padded = pad_row<Int4>(
143
4/8
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
372 ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2));
144
145
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr);
146
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size);
147
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto rhs_start_row = rect.start_col();
148
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr);
149
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
150
4/16
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
372 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
151 // Runs the RHS packing micro-kernel.
152 372 kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{};
153 372 params.lhs_zero_point = 1;
154 372 params.rhs_zero_point = 0;
155
156
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(
157
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()),
158
3/4
✓ Branch 0 taken 186 times.
✓ Branch 1 taken 186 times.
✓ Branch 2 taken 186 times.
✗ Branch 3 not taken.
372 has_bias ? reinterpret_cast<const float*>(ref_biases_buf.data()) : nullptr,
159
2/4
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
372 reinterpret_cast<const float*>(ref_rhs_scales.data()), reinterpret_cast<uint8_t*>(imp_packed_rhs_buf.data()), 0,
160 &params);
161
162 372 const auto dst_stride_row = N * sizeof(uint16_t);
163 372 const auto dst_stride_col = sizeof(uint16_t);
164 744 const auto dst_offset =
165
3/6
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
372 ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
166
2/4
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
372 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
167
4/16
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
372 ASSERT_EQ(dst_offset, ref_dst_offset);
168
169 // Runs the GEMM micro-kernel.
170
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
171
5/18
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 372 times.
372 ASSERT_EQ(imp_dst_size, ref_dst.size());
172
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 Buffer imp_dst_buf = Buffer(imp_dst_size);
173
174
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
744 ukernel_variant.interface.run_matmul(
175
3/6
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
372 rect.height(), rect.width(), K, reinterpret_cast<const uint8_t*>(imp_packed_lhs_buf.data()) + lhs_matmul_offset,
176
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 reinterpret_cast<const uint8_t*>(imp_packed_rhs_buf.data()) + rhs_matmul_offset,
177
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 reinterpret_cast<uint8_t*>(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min,
178 372 clamp_max);
179
180 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
181 // tested.
182
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
183 372 DataFormat dst_format = DataFormat(DataType::BF16);
184 744 const auto success =
185
3/6
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
372 compare(reinterpret_cast<const uint8_t*>(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler);
186
4/16
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
372 ASSERT_TRUE(success);
187 392 }
188
189
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1178 TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_KxN) {
190 3116 const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam();
191 784 const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index);
192
193
2/4
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
392 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
194 GTEST_SKIP() << "CPU features are not supported by current CPU";
195 }
196
197 392 const uint32_t seed = 0;
198
199 784 const size_t M = matmul_shape.m;
200 784 const size_t N = matmul_shape.n;
201 784 const size_t K = matmul_shape.k;
202
203 392 const auto mr = ukernel_variant.interface.get_mr();
204 392 const auto nr = ukernel_variant.interface.get_nr();
205 392 const auto kr = ukernel_variant.interface.get_kr();
206 392 const auto sr = ukernel_variant.interface.get_sr();
207
208 // Generates input data.
209 392 const auto ref_lhs_bf16 = fill_random<BFloat16<false>>(M * K, seed + 0);
210
1/2
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
392 const auto ref_rhs = fill_random<float>(N * K, seed + 1);
211 392 Buffer ref_biases_buf;
212
2/2
✓ Branch 0 taken 196 times.
✓ Branch 1 taken 196 times.
392 if (has_bias) {
213
1/2
✓ Branch 0 taken 196 times.
✗ Branch 1 not taken.
196 ref_biases_buf = Buffer(fill_random<float>(N, seed + 2));
214 196 }
215
216 392 const auto ref_lhs =
217
3/6
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
392 cast<float, BFloat16<false>>(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits<BFloat16<false>>);
218
219 // Transposed(nxk) RHS dimensions
220 392 const size_t ref_rhs_qsi4_nxk_stride = K;
221
222 // Non-Transposed(kxn) RHS dimensions
223
1/2
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
392 const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2);
224
1/2
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
392 const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(K * ref_rhs_qsi4_kxn_stride, 2);
225
226 // Runs the reference implementation.
227 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
228 // * Quantizes the RHS matrix using 4-bit symmetric quantization.
229 // * Performs GEMM.
230 1176 const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
231
2/4
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
392 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K);
232 1940 const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] =
233
2/4
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
392 quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K);
234
235
1/2
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
392 const auto ref_rhs_qsi4 = transpose_with_padding<Int4>(
236
1/2
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
392 ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride,
237 392 ref_rhs_qsi4_kxn_size_bytes);
238
239 392 const auto ref_dst_fp32_clamp =
240
1/2
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
392 matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
241
5/10
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 392 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 392 times.
✗ Branch 9 not taken.
784 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(),
242
6/8
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 196 times.
✓ Branch 3 taken 196 times.
✓ Branch 4 taken 196 times.
✓ Branch 5 taken 196 times.
✓ Branch 6 taken 196 times.
✗ Branch 7 not taken.
392 ref_rhs_scales.data(), nullptr, K, has_bias ? ref_biases_buf.data() : nullptr,
243 392 std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
244
245 // Clamps the reference output.
246 392 const auto clamp_ratio = 0.8F;
247
2/4
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
1136 const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_fp32_clamp.data(), M * N, clamp_ratio);
248
4/8
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 392 times.
✗ Branch 7 not taken.
392 const auto ref_dst_float = clamp<float>(ref_dst_fp32_clamp.data(), M * N, clamp_min, clamp_max);
249
250 // Cast the reference output to BF16
251
3/6
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
392 auto ref_dst = cast<BFloat16<false>, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>);
252
253
1/2
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
392 auto m_step = ukernel_variant.interface.get_m_step();
254
4/16
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 392 times.
392 ASSERT_TRUE(m_step % mr == 0);
255
256
1/2
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
392 auto n_step = ukernel_variant.interface.get_n_step();
257
4/16
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 392 times.
392 ASSERT_TRUE(n_step % nr == 0);
258
259
2/4
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
784 const auto rect = portion.compute_portion(M, N, m_step, n_step);
260
6/8
✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 376 times.
✓ Branch 3 taken 16 times.
✓ Branch 4 taken 376 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✓ Branch 7 taken 372 times.
392 if (rect.height() == 0 || rect.width() == 0) {
261
10/20
✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 20 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 20 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 20 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 20 times.
✗ Branch 19 not taken.
20 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
262 }
263
264
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto lhs_start_row = rect.start_row();
265 372 size_t lhs_stride = K * sizeof(uint16_t);
266
267 // Runs the LHS packing micro-kernel.
268
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr);
269
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size);
270
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride);
271
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr);
272
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
273
4/16
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 372 times.
✗ Branch 15 not taken.
372 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
274
275
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 kai_run_lhs_quant_pack_qai8dxp_bf16_neon(
276
2/4
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
372 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, ref_lhs_bf16.data() + lhs_offset, lhs_stride,
277
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 reinterpret_cast<uint8_t*>(imp_packed_lhs_buf.data()) + lhs_packed_offset);
278
279 // Runs the RHS packing micro-kernel.
280 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
281 // * Packs the RHS matrix.
282
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
744 const auto ref_rhs_qsi4_padded = pad_row<Int4>(
283
4/8
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
372 ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2));
284
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr);
285
286
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto rhs_start_row = rect.start_col();
287
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr);
288
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
289
4/16
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
372 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
290
291
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size);
292 372 kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{};
293 372 params.lhs_zero_point = 1;
294 372 params.rhs_zero_point = 0;
295
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(
296
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()),
297
3/4
✓ Branch 0 taken 186 times.
✓ Branch 1 taken 186 times.
✓ Branch 2 taken 186 times.
✗ Branch 3 not taken.
372 has_bias ? reinterpret_cast<const float*>(ref_biases_buf.data()) : nullptr,
298
2/4
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
372 reinterpret_cast<const float*>(ref_rhs_scales.data()), imp_packed_rhs_buf.data(), 0, &params);
299
300 372 const auto dst_stride = N * sizeof(uint16_t);
301
3/6
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
372 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
302
2/4
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
372 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(uint16_t);
303
4/16
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
372 ASSERT_EQ(dst_offset, ref_dst_offset);
304
305 // Runs the GEMM micro-kernel.
306
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
307
5/18
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 372 times.
372 ASSERT_EQ(imp_dst_size, ref_dst.size());
308
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 Buffer imp_dst_buf = Buffer(imp_dst_size);
309
310 372 const auto dst_stride_row = N * sizeof(uint16_t);
311 372 const auto dst_stride_col = sizeof(uint16_t);
312
313
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
744 ukernel_variant.interface.run_matmul(
314
3/6
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
372 rect.height(), rect.width(), K, reinterpret_cast<const uint8_t*>(imp_packed_lhs_buf.data()) + lhs_matmul_offset,
315
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 reinterpret_cast<const uint8_t*>(imp_packed_rhs_buf.data()) + rhs_matmul_offset,
316
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 reinterpret_cast<uint8_t*>(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min,
317 372 clamp_max);
318
319 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
320 // tested.
321
1/2
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
372 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
322 372 DataFormat dst_format = DataFormat(DataType::BF16);
323 744 const auto success =
324
3/6
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
372 compare(reinterpret_cast<const uint8_t*>(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler);
325
4/16
✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
372 ASSERT_TRUE(success);
326 392 }
327
328
18/60
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 784 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
1571 INSTANTIATE_TEST_SUITE_P(
329 MatMul, MatMulTest_bf16_qai8dxp_qsi4cxp,
330 testing::Combine(
331 testing::Range<size_t>(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.size()),
332 testing::Values(
333 MatMulShape{1, 2, 32}, //
334 MatMulShape{1, 3, 32}, //
335 MatMulShape{1, 4, 32}, //
336 MatMulShape{1, 5, 32}, //
337 MatMulShape{3, 3, 32}, //
338 MatMulShape{4, 4, 32}, //
339 MatMulShape{5, 5, 32}, //
340 MatMulShape{32, 64, 64}, //
341 MatMulShape{16, 32, 64}, //
342 MatMulShape{8, 32, 64}, //
343 MatMulShape{15, 32, 32}, //
344 MatMulShape{77, 99, 64}, //
345 MatMulShape{77, 99, 66}, //
346 MatMulShape{77, 99, 31}),
347 testing::Values(
348 MatrixPortion(0, 0, 1, 1), // Full matrix.
349 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
350 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
351 MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle
352 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
353 MatrixPortion(0.75, 0, 1, 1), // Partial rows
354 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
355 ),
356 testing::Bool()),
357 [](const auto& info) -> std::string {
358 const auto variant_idx = std::get<0>(info.param);
359 const auto& name = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp[variant_idx].name;
360 return test_description(
361 name, std::get<MatMulShape>(info.param), std::get<2>(info.param), std::get<3>(info.param));
362 });
363
364 } // namespace kai::test
365