KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_f16_qai8dxp_qsi8cxp_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.8% 81 0 82
Functions: 100.0% 9 0 9
Branches: 40.4% 147 0 364

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <limits>
14 #include <sstream>
15 #include <string>
16 #include <tuple>
17
18 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
19 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
20 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h"
23 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.h"
24 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"
25 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
26 #include "test/common/buffer.hpp"
27 #include "test/common/compare.hpp"
28 #include "test/common/cpu_info.hpp"
29 #include "test/common/data_format.hpp"
30 #include "test/common/float16.hpp"
31 #include "test/common/matmul_test_common.hpp"
32 #include "test/common/matrix_portion.hpp"
33 #include "test/common/memory.hpp"
34 #include "test/common/round.hpp"
35 #include "test/common/test_suite.hpp"
36 #include "test/reference/cast.hpp"
37 #include "test/reference/clamp.hpp"
38 #include "test/reference/fill.hpp"
39 #include "test/reference/matmul.hpp"
40 #include "test/reference/pack.hpp"
41 #include "test/reference/pad.hpp"
42 #include "test/reference/quantize.hpp"
43
44 namespace kai::test {
45
46 static const std::array<UkernelVariant<kai_matmul_clamp_f16_qai8dxp_qsi8cxp_ukernel>, 4>
47 variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp = {{
48 {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod),
49 "kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", cpu_has_dotprod_and_fp16},
50 {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod),
51 "kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", cpu_has_dotprod_and_fp16},
52 {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod),
53 "kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod_and_fp16},
54 {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm),
55 "kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", cpu_has_i8mm_and_fp16},
56 }};
57
58 class MatMulTest_f16_qai8dxp_qsi8cxp : public ::testing::TestWithParam<MatMulTestPortionedParamsWithBias> {};
59
60
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
2690 TEST_P(MatMulTest_f16_qai8dxp_qsi8cxp, EndToEnd) {
61 6776 const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam();
62 1792 const auto& ukernel_variant = variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp.at(variant_index);
63
64
2/4
✓ Branch 0 taken 896 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 896 times.
✗ Branch 3 not taken.
896 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
65 GTEST_SKIP() << "Unsupported CPU feature";
66 }
67
68 896 const std::uint32_t seed = 0;
69
70 1792 const size_t M = matmul_shape.m;
71 1792 const size_t N = matmul_shape.n;
72 1792 const size_t K = matmul_shape.k;
73
74 896 const auto mr = ukernel_variant.interface.get_mr();
75 896 const auto nr = ukernel_variant.interface.get_nr();
76 896 const auto kr = ukernel_variant.interface.get_kr();
77 896 const auto sr = ukernel_variant.interface.get_sr();
78
79
4/4
✓ Branch 0 taken 448 times.
✓ Branch 1 taken 448 times.
✓ Branch 2 taken 168 times.
✓ Branch 3 taken 280 times.
896 if (mr == 1 && M > 1) {
80
3/6
✓ Branch 0 taken 280 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 280 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 280 times.
✗ Branch 5 not taken.
280 GTEST_SKIP() << "Kernel does not support M != 1";
81 }
82
83 616 auto m_step = ukernel_variant.interface.get_m_step();
84
3/14
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
616 ASSERT_TRUE(m_step % mr == 0);
85
86 616 auto n_step = ukernel_variant.interface.get_n_step();
87
3/14
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
616 ASSERT_TRUE(n_step % nr == 0);
88
89 1232 const auto rect = portion.compute_portion(M, N, m_step, n_step);
90
4/4
✓ Branch 0 taken 568 times.
✓ Branch 1 taken 48 times.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 560 times.
616 if (rect.height() == 0 || rect.width() == 0) {
91
9/18
✓ Branch 0 taken 56 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 56 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 56 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 56 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 56 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 56 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 56 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 56 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 56 times.
✗ Branch 17 not taken.
56 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
92 }
93
94 // Generates input data.
95 560 const auto ref_lhs_f16 = fill_random<Float16>(M * K, seed + 0);
96
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 const auto ref_rhs = fill_random<float>(N * K, seed + 1);
97 560 Buffer ref_biases;
98
99
2/2
✓ Branch 0 taken 280 times.
✓ Branch 1 taken 280 times.
560 if (has_bias) {
100
1/2
✓ Branch 0 taken 280 times.
✗ Branch 1 not taken.
280 ref_biases = fill_random<float>(N, seed + 2);
101 280 }
102 // For reference implementation, Casting FP16 input to FP32 type and FP32 output back to FP16 because the matmul
103 // implementation works with FP32 accumulation and casts the result to FP16
104
3/6
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
560 const auto ref_lhs = cast<float, Float16>(ref_lhs_f16.data(), ref_lhs_f16.size() * 8 / size_in_bits<Float16>);
105
106 // Runs the reference implementation.
107 // * Quantizes the LHS matrix using 8-bit symmetric quantization.
108 // * Quantizes the RHS matrix using 8-bit asymmetric quantization.
109 // * Performs GEMM.
110 1680 const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
111
2/4
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
560 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K);
112 2800 const auto [ref_rhs_qsi8, ref_rhs_scales] =
113
2/4
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
560 quantize_symmetric_per_block_dynamic<float, int8_t, float>(ref_rhs.data(), N, K, K);
114
115 560 const auto ref_dst_no_clamp =
116
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, float, float, int32_t, float>(
117
4/8
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 560 times.
✗ Branch 7 not taken.
1120 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K,
118
7/10
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 280 times.
✓ Branch 5 taken 280 times.
✓ Branch 6 taken 280 times.
✓ Branch 7 taken 280 times.
✓ Branch 8 taken 280 times.
✗ Branch 9 not taken.
560 ref_rhs_qsi8.data(), ref_rhs_scales.data(), nullptr, 1, K, has_bias ? ref_biases.data() : nullptr, nullptr,
119 nullptr, 1);
120
121 // Clamps the reference output.
122 560 const auto clamp_ratio = 0.8F;
123
2/4
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
1680 const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_no_clamp.data(), M * N, clamp_ratio);
124
4/8
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 560 times.
✗ Branch 7 not taken.
560 const auto ref_dst_float = clamp<float>(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max);
125
126 // Cast the reference output to F16
127
3/6
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
560 auto ref_dst = cast<Float16, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>);
128
129 // Runs the LHS packing micro-kernel.
130
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 const auto lhs_start_row = rect.start_row();
131
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(M, K, mr, kr, sr);
132
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 Buffer imp_packed_lhs(imp_packed_lhs_size);
133
134 560 auto lhs_stride = K * sizeof(uint16_t);
135
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, lhs_stride);
136
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, K, mr, kr, sr);
137
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
138
139
4/16
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 560 times.
✗ Branch 15 not taken.
560 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
140
141
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 kai_run_lhs_quant_pack_qai8dxp_f16_neon(
142
2/4
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
560 rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_f16.data() + lhs_offset, lhs_stride,
143
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 imp_packed_lhs.data() + lhs_packed_offset);
144
145
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr);
146
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 Buffer imp_packed_rhs(imp_packed_rhs_size);
147
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 const auto rhs_start_row = rect.start_col();
148
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(rhs_start_row, K, nr, kr, sr);
149
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
150
4/16
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 560 times.
560 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
151
152 // Runs the RHS packirng micro-kernel.
153 560 const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f};
154
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(
155
2/4
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
1120 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()),
156
3/4
✓ Branch 0 taken 280 times.
✓ Branch 1 taken 280 times.
✓ Branch 2 taken 280 times.
✗ Branch 3 not taken.
560 has_bias ? reinterpret_cast<const float*>(ref_biases.data()) : nullptr,
157
2/4
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
560 reinterpret_cast<const float*>(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, &params);
158
159 560 const auto dst_stride_row = N * sizeof(uint16_t);
160 560 const auto dst_stride_col = sizeof(uint16_t);
161 1120 const auto dst_offset =
162
3/6
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
560 ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
163
2/4
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
560 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
164
4/16
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 560 times.
560 ASSERT_EQ(dst_offset, ref_dst_offset);
165
166 // Runs the GEMM micro-kernel.
167
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
168
5/18
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 560 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 560 times.
560 ASSERT_EQ(imp_dst_size, ref_dst.size());
169
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 Buffer imp_dst(imp_dst_size);
170
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
1120 ukernel_variant.interface.run_matmul(
171
3/6
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
560 rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
172
2/4
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
560 imp_packed_rhs.data() + rhs_matmul_offset, imp_dst.data() + dst_offset, dst_stride_row, dst_stride_col,
173 1120 clamp_min, clamp_max);
174
175 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
176 // tested.
177
1/2
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
560 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
178 560 DataFormat dst_format = DataFormat(DataType::FP16);
179
3/6
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
560 const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler);
180
4/16
✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 560 times.
560 ASSERT_TRUE(success);
181 896 }
182
19/62
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 time.
✗ Branch 33 not taken.
✓ Branch 34 taken 896 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✓ Branch 60 taken 896 times.
✗ Branch 61 not taken.
1794 INSTANTIATE_TEST_SUITE_P(
183 MatMul, MatMulTest_f16_qai8dxp_qsi8cxp,
184 testing::Combine(
185 testing::Range<size_t>(0, variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp.size()),
186 testing::Values(
187 MatMulShape{1, 2, 32}, //
188 MatMulShape{1, 3, 32}, //
189 MatMulShape{1, 4, 32}, //
190 MatMulShape{1, 5, 31}, //
191 MatMulShape{3, 3, 32}, //
192 MatMulShape{4, 4, 32}, //
193 MatMulShape{5, 5, 31}, //
194 MatMulShape{16, 32, 64}, //
195 MatMulShape{16, 32, 36}, //
196 MatMulShape{15, 35, 65}, //
197 MatMulShape{8, 32, 64}, //
198 MatMulShape{15, 31, 45}, //
199 MatMulShape{1, 35, 65}, //
200 MatMulShape{1, 128, 32}, //
201 MatMulShape{64, 128, 32}, //
202 MatMulShape{77, 99, 64}),
203 testing::Values(
204 MatrixPortion(0, 0, 1, 1), // Full matrix.
205 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
206 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
207 MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle
208 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
209 MatrixPortion(0.75, 0, 1, 1), // Partial rows
210 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
211 ),
212 testing::Bool()),
213 [](const auto& info) {
214 const auto variant_idx = std::get<0>(info.param);
215 const std::string name{variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp.at(variant_idx).name};
216 const auto shape = std::get<MatMulShape>(info.param);
217 const auto portion = std::get<2>(info.param);
218 const auto has_bias = std::get<3>(info.param);
219
220 return test_description(name, shape, portion, has_bias);
221 });
222
223 } // namespace kai::test
224