KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 96.8% 211 / 0 / 218
Functions: 100.0% 28 / 0 / 28
Branches: 37.8% 333 / 0 / 882

test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <limits>
14 #include <string>
15 #include <tuple>
16
17 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
18 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h"
19 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
20 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
23 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
24 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
25 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
26 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h"
27 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"
28 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
29 #include "test/common/abi_checker.hpp"
30 #include "test/common/buffer.hpp"
31 #include "test/common/cache.hpp"
32 #include "test/common/cpu_info.hpp"
33 #include "test/common/matmul_test_common.hpp"
34 #include "test/common/matrix_portion.hpp"
35 #include "test/common/memory.hpp"
36 #include "test/common/printer.hpp"
37 #include "test/common/seed.hpp"
38 #include "test/common/test_suite.hpp"
39 #include "test/reference/clamp.hpp"
40 #include "test/reference/fill.hpp"
41 #include "test/reference/matmul.hpp"
42 #include "test/reference/quantize.hpp"
43 #include "test/reference/transpose.hpp"
44
45 namespace kai::test {
46
47 using F32Qai8Qsi8CacheDataId = std::tuple<
48 MatMulShape, //
49 DataFormat, // lhs format
50 DataFormat, // rhs format
51 DataFormat, // bias format
52 float>;
53
54 struct F32Qai8Qsi8CacheData {
55 Buffer ref_dst_nt_t;
56 Buffer ref_dst_nt_nt;
57 Buffer ref_rhs_qsi8_nt_t;
58 Buffer ref_rhs_qsi8_nt_nt;
59 Buffer ref_rhs_scales;
60 Buffer ref_lhs;
61 Buffer ref_bias;
62 Range<float> clamp_range;
63 };
64
65 template <>
66 66 F32Qai8Qsi8CacheData ReferenceGenerator<F32Qai8Qsi8CacheDataId, F32Qai8Qsi8CacheData>::generate_reference(
67 const F32Qai8Qsi8CacheDataId& data_id) {
68 1056 auto [shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio] = data_id;
69
70 132 const size_t M = shape.m;
71 132 const size_t N = shape.n;
72 132 const size_t K = shape.k;
73
74 // Seed the random generator.
75
8/16
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 66 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 66 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 33 times.
✗ Branch 15 not taken.
198 const auto key = std::string("F32Qai8Qsi8_cache:") + std::to_string(M) + "x" + std::to_string(N) + "x" +
76
8/16
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 66 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 66 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 33 times.
✗ Branch 15 not taken.
198 std::to_string(K) + ":" + std::to_string(static_cast<uint32_t>(lhs_format.data_type())) + ":" +
77
5/10
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 33 times.
✗ Branch 9 not taken.
231 std::to_string(static_cast<uint32_t>(rhs_format.data_type())) + ":" +
78
7/14
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 66 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 33 times.
✗ Branch 13 not taken.
165 std::to_string(static_cast<uint32_t>(bias_format.data_type())) + ":" + std::to_string(clamp_keep_ratio);
79
1/2
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
66 auto& feed = seed_stream(key);
80
81
5/10
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
264 Buffer lhs = fill_matrix_random(shape.m, shape.k, lhs_format, feed());
82
5/10
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
264 Buffer rhs = fill_matrix_random(shape.k, shape.n, rhs_format, feed());
83
4/8
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
198 Buffer bias = fill_matrix_random(1, shape.n, bias_format, feed());
84
85
1/2
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
66 QuantizationInfo lhs_qinfo{};
86 lhs_qinfo.quant_width = K;
87 lhs_qinfo.dst_type = DataType::QAI8;
88 lhs_qinfo.scale_type = DataType::FP32;
89 lhs_qinfo.zero_point_type = DataType::I32;
90 const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(lhs.data(), DataType::FP32, M, K, lhs_qinfo);
91
92 QuantizationInfo rhs_qinfo{};
93 rhs_qinfo.quant_width = K;
94 rhs_qinfo.dst_type = DataType::QSI8;
95 rhs_qinfo.scale_type = DataType::FP32;
96 auto [ref_rhs_quant_t, rhs_qoutputs] = quantize_dynamic(rhs.data(), DataType::FP32, N, K, rhs_qinfo);
97
98 F32Qai8Qsi8CacheData out;
99
100 // Transposed RHS path.
101 const size_t ref_rhs_qsi8_nxk_stride = K;
102 const size_t ref_rhs_qsi8_kxn_stride = N;
103 const size_t ref_rhs_qsi8_kxn_size_bytes = K * ref_rhs_qsi8_kxn_stride;
104
105 // Non-Transposed(kxn) RHS dimensions
106 auto ref_rhs_qsi8 = transpose_with_padding<int8_t>(
107 ref_rhs_quant_t.data(), N, K, ref_rhs_qsi8_nxk_stride, ref_rhs_qsi8_kxn_stride, ref_rhs_qsi8_kxn_size_bytes);
108
109 auto ref_dst_nt_nt = matmul_clamp_nt_nt<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>(
110 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K,
111 ref_rhs_qsi8.data(), rhs_qoutputs.scales.data(), nullptr, K, bias.data(), std::numeric_limits<float>::lowest(),
112 std::numeric_limits<float>::max());
113
114 // Non-transposed RHS path.
115 auto ref_dst_nt_t = matmul_clamp_nt_t<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>(
116 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K,
117 ref_rhs_quant_t.data(), rhs_qoutputs.scales.data(), nullptr, K, bias.data(),
118 std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
119
120 // Only need to calculate range once for both, apply clamping
121 const auto [clamp_min, clamp_max] = find_clamp_range(DataType::FP32, ref_dst_nt_t.data(), M * N, clamp_keep_ratio);
122 auto ref_clamped_nt_t = clamp(DataType::FP32, ref_dst_nt_t.data(), M * N, clamp_min, clamp_max);
123 auto ref_clamped_nt_nt = clamp(DataType::FP32, ref_dst_nt_nt.data(), M * N, clamp_min, clamp_max);
124
125 out.ref_rhs_qsi8_nt_nt = std::move(ref_rhs_qsi8);
126 out.ref_rhs_qsi8_nt_t = std::move(ref_rhs_quant_t);
127 out.ref_dst_nt_nt = std::move(ref_clamped_nt_nt);
128 out.ref_dst_nt_t = std::move(ref_clamped_nt_t);
129 out.ref_lhs = std::move(lhs);
130 out.ref_bias = std::move(bias);
131 out.ref_rhs_scales = std::move(rhs_qoutputs.scales);
132 out.clamp_range = {clamp_min, clamp_max};
133
134 return out;
135 }
136
137 3 static const std::array<UkernelVariant<kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel>, 8>
138
0/4
✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
11 variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp = {{
139
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod),
140
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod},
141
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod),
142
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", cpu_has_dotprod},
143
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod),
144
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", cpu_has_dotprod},
145
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm),
146
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", cpu_has_i8mm},
147
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot),
148
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot", cpu_has_sme},
149
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa),
150
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa", cpu_has_sme},
151
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot),
152
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot", cpu_has_sme2},
153
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa),
154
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", cpu_has_sme2},
155 }};
156
157 using MatMulClampTestPortionedParams = std::tuple<size_t, MatMulShape, MatrixPortion, float>;
158
159 class MatMulTest_f32_qai8dxp_qsi8cxp : public ::testing::TestWithParam<MatMulClampTestPortionedParams> {};
160
161
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
9246 TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, Offset_RHS) {
162 10164 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
163 7392 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index);
164
165
3/4
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✓ Branch 3 taken 924 times.
3696 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
166
3/6
✓ Branch 0 taken 924 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 924 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 924 times.
✗ Branch 5 not taken.
924 GTEST_SKIP() << "Unsupported CPU feature";
167 }
168
169 5544 const size_t K = matmul_shape.k;
170 2772 const auto nr = ukernel_variant.interface.get_nr();
171 2772 const auto kr = ukernel_variant.interface.get_kr();
172 2772 const auto sr = ukernel_variant.interface.get_sr();
173
174 2772 auto n_step = ukernel_variant.interface.get_n_step();
175
176 2772 auto rhs_packed_offset_kxn = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(n_step, K, nr, kr, sr);
177 2772 auto rhs_packed_offset_nxk = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(n_step, K, nr, kr, sr);
178
179
3/14
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
2772 ASSERT_EQ(rhs_packed_offset_kxn, rhs_packed_offset_nxk);
180
181 2772 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(n_step, K);
182
3/14
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
2772 ASSERT_EQ(rhs_packed_offset_kxn, rhs_matmul_offset);
183 3696 }
184
185
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
9246 TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, Offset_LHS) {
186 10164 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
187 7392 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index);
188
189
3/4
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✓ Branch 3 taken 924 times.
3696 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
190
3/6
✓ Branch 0 taken 924 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 924 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 924 times.
✗ Branch 5 not taken.
924 GTEST_SKIP() << "Unsupported CPU feature";
191 }
192
193 5544 const size_t K = matmul_shape.k;
194 2772 const auto mr = ukernel_variant.interface.get_mr();
195 2772 const auto kr = ukernel_variant.interface.get_kr();
196 2772 const auto sr = ukernel_variant.interface.get_sr();
197
198 2772 auto m_step = ukernel_variant.interface.get_m_step();
199
200 2772 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(m_step, K, mr, kr, sr);
201 2772 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(m_step, K);
202
203
3/14
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
2772 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
204 3696 }
205
206
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
9246 TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_nxk_qsi8cx) {
207 24024 auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
208 7392 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index);
209
210
3/4
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✓ Branch 3 taken 924 times.
3696 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
211
3/6
✓ Branch 0 taken 924 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 924 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 924 times.
✗ Branch 5 not taken.
924 GTEST_SKIP() << "Unsupported CPU feature";
212 }
213
214 5544 const size_t M = matmul_shape.m;
215 5544 const size_t N = matmul_shape.n;
216 5544 const size_t K = matmul_shape.k;
217
218 2772 const auto mr = ukernel_variant.interface.get_mr();
219 2772 const auto nr = ukernel_variant.interface.get_nr();
220 2772 const auto kr = ukernel_variant.interface.get_kr();
221 2772 const auto sr = ukernel_variant.interface.get_sr();
222
223 5544 const F32Qai8Qsi8CacheDataId testdata_id = {
224 matmul_shape, //
225 2772 DataFormat(DataType::FP32), //
226 2772 DataFormat(DataType::FP32), //
227 2772 DataFormat(DataType::FP32), clamp_keep_ratio};
228 2772 const F32Qai8Qsi8CacheData& testdata = getV<F32Qai8Qsi8CacheDataId, F32Qai8Qsi8CacheData>(testdata_id);
229
230 2772 const auto& ref_rhs_qsi8 = testdata.ref_rhs_qsi8_nt_t;
231 2772 const auto& ref_rhs_scales = testdata.ref_rhs_scales;
232 2772 const auto& ref_dst = testdata.ref_dst_nt_t;
233 2772 const auto& ref_bias = testdata.ref_bias;
234 2772 const auto& ref_lhs = testdata.ref_lhs;
235 2772 auto m_step = ukernel_variant.interface.get_m_step();
236
3/14
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
2772 ASSERT_TRUE(m_step % mr == 0);
237
238 2772 auto n_step = ukernel_variant.interface.get_n_step();
239
3/14
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
2772 ASSERT_TRUE(n_step % nr == 0);
240
241 5544 const auto rect = portion.compute_portion(M, N, m_step, n_step);
242
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2772 times.
2772 if (rect.height() == 0 || rect.width() == 0) {
243 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
244 }
245
246 // Runs the LHS packing micro-kernel.
247 2772 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
248 2772 Buffer imp_packed_lhs(imp_packed_lhs_size);
249
250
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 const auto lhs_start_row = rect.start_row();
251 2772 size_t lhs_stride = K * sizeof(float);
252
253
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
254
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
255
256
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 kai_run_lhs_quant_pack_qai8dxp_f32(
257
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 rect.height(), K, mr, kr, sr, 0, reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
258
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 imp_packed_lhs.data() + lhs_packed_offset);
259
260 // Runs the RHS packing micro-kernel.
261 // * Generates the 8-bit signed symmetric quantized input for the micro-kernel.
262 // * Packs the RHS matrix.
263
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr);
264
265
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 Buffer imp_packed_rhs(imp_packed_rhs_size);
266 2772 const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f};
267
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(
268
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()),
269
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 reinterpret_cast<const float*>(ref_bias.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()),
270
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 imp_packed_rhs.data(), 0, &params);
271
272
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 const auto packed_rhs_start_row = rect.start_col();
273 2772 auto rhs_packed_offset =
274
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(packed_rhs_start_row, K, nr, kr, sr);
275
276 2772 const auto dst_stride = N * sizeof(float);
277
3/6
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
2772 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
278
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
279
4/16
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
2772 ASSERT_EQ(dst_offset, ref_dst_offset);
280
281
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 const auto matmul_lhs_packed_offset = ukernel_variant.interface.get_lhs_packed_offset(rect.start_row(), K);
282
4/16
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
2772 ASSERT_EQ(lhs_packed_offset, matmul_lhs_packed_offset);
283
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 const auto matmul_rhs_packed_offset = ukernel_variant.interface.get_rhs_packed_offset(rect.start_col(), K);
284
4/16
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
2772 ASSERT_EQ(rhs_packed_offset, matmul_rhs_packed_offset);
285
286 // Runs the GEMM micro-kernel.
287
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
288
5/18
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2772 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2772 times.
2772 ASSERT_EQ(imp_dst_size, ref_dst.size());
289
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 Buffer imp_dst(imp_dst_size);
290
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 abi_check(
291
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K,
292
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 imp_packed_lhs.data() + matmul_lhs_packed_offset, imp_packed_rhs.data() + matmul_rhs_packed_offset,
293
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 reinterpret_cast<float*>(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float),
294 2772 testdata.clamp_range.min, testdata.clamp_range.max);
295
296 // Compares the output of the micro-kernels against the output of the reference implementation.
297
5/6
✓ Branch 0 taken 8880 times.
✓ Branch 1 taken 17760 times.
✓ Branch 2 taken 23868 times.
✓ Branch 3 taken 2772 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2772 times.
26640 for (size_t y = 0; y < rect.height(); ++y) {
298
5/6
✓ Branch 0 taken 294936 times.
✓ Branch 1 taken 759120 times.
✓ Branch 2 taken 1030188 times.
✓ Branch 3 taken 23868 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 23868 times.
1054056 for (size_t x = 0; x < rect.width(); ++x) {
299 2060376 const auto imp_value =
300
4/8
✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1030188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1030188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1030188 times.
✗ Branch 7 not taken.
1030188 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
301 2060376 const auto ref_value =
302
4/8
✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1030188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1030188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1030188 times.
✗ Branch 7 not taken.
1030188 read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
303
3/6
✓ Branch 0 taken 743208 times.
✓ Branch 1 taken 286980 times.
✓ Branch 2 taken 286980 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
1030188 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value);
304
305
1/2
✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
1030188 if (rel_error > 0.0001F) {
306 ASSERT_EQ(imp_value, ref_value);
307 }
308 1030188 }
309 23868 }
310 3696 }
311
312
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
9246 TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_kxn_qsi8cx) {
313 24024 auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
314 7392 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index);
315
316
3/4
✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✓ Branch 3 taken 924 times.
3696 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
317
3/6
✓ Branch 0 taken 924 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 924 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 924 times.
✗ Branch 5 not taken.
924 GTEST_SKIP() << "Unsupported CPU feature";
318 }
319
320 5544 const size_t M = matmul_shape.m;
321 5544 const size_t N = matmul_shape.n;
322 5544 const size_t K = matmul_shape.k;
323
324 2772 const auto mr = ukernel_variant.interface.get_mr();
325 2772 const auto nr = ukernel_variant.interface.get_nr();
326 2772 const auto kr = ukernel_variant.interface.get_kr();
327 2772 const auto sr = ukernel_variant.interface.get_sr();
328
329 5544 const F32Qai8Qsi8CacheDataId testdata_id = {
330 matmul_shape, //
331 2772 DataFormat(DataType::FP32), //
332 2772 DataFormat(DataType::FP32), //
333 2772 DataFormat(DataType::FP32), clamp_keep_ratio};
334 2772 const F32Qai8Qsi8CacheData& testdata = getV<F32Qai8Qsi8CacheDataId, F32Qai8Qsi8CacheData>(testdata_id);
335 2772 const auto& ref_rhs_qsi8 = testdata.ref_rhs_qsi8_nt_nt;
336 2772 const auto& ref_rhs_scales = testdata.ref_rhs_scales;
337 2772 const auto& ref_dst = testdata.ref_dst_nt_nt;
338 2772 const auto& ref_bias = testdata.ref_bias;
339 2772 const auto& ref_lhs = testdata.ref_lhs;
340
341 2772 auto m_step = ukernel_variant.interface.get_m_step();
342
3/14
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
2772 ASSERT_TRUE(m_step % mr == 0);
343
344 2772 auto n_step = ukernel_variant.interface.get_n_step();
345
3/14
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
2772 ASSERT_TRUE(n_step % nr == 0);
346
347 5544 const auto rect = portion.compute_portion(M, N, m_step, n_step);
348
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2772 times.
2772 if (rect.height() == 0 || rect.width() == 0) {
349 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
350 }
351
352 2772 const auto lhs_start_row = rect.start_row();
353 2772 size_t const lhs_stride = K * sizeof(float);
354
355 2772 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
356 2772 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
357
358 // Runs the LHS packing micro-kernel.
359 2772 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
360 2772 Buffer imp_packed_lhs(imp_packed_lhs_size);
361
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 kai_run_lhs_quant_pack_qai8dxp_f32(
362
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 rect.height(), K, mr, kr, sr, 0, reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), K * sizeof(float),
363
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 imp_packed_lhs.data() + lhs_packed_offset);
364
365 // Runs the RHS packing micro-kernel.
366 // * Generates the 8-bit signed symmetric quantized input for the micro-kernel.
367 // * Packs the RHS matrix.
368
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr);
369
370
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 Buffer imp_packed_rhs(imp_packed_rhs_size);
371 2772 const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f};
372
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(
373
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()),
374
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 reinterpret_cast<const float*>(ref_bias.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()),
375
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 imp_packed_rhs.data(), 0, &params);
376
377
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 const auto packed_rhs_start_row = rect.start_col();
378 2772 auto rhs_packed_offset =
379
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(packed_rhs_start_row, K, nr, kr, sr);
380
381 2772 const auto dst_stride = N * sizeof(float);
382
3/6
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
2772 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
383
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
384
4/16
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
2772 ASSERT_EQ(dst_offset, ref_dst_offset);
385
386
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 const auto matmul_lhs_packed_offset = ukernel_variant.interface.get_lhs_packed_offset(rect.start_row(), K);
387
4/16
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
2772 ASSERT_EQ(lhs_packed_offset, matmul_lhs_packed_offset);
388
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 const auto matmul_rhs_packed_offset = ukernel_variant.interface.get_rhs_packed_offset(rect.start_col(), K);
389
4/16
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
2772 ASSERT_EQ(rhs_packed_offset, matmul_rhs_packed_offset);
390
391 // Runs the GEMM micro-kernel.
392
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
393
5/18
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2772 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2772 times.
2772 ASSERT_EQ(imp_dst_size, ref_dst.size());
394
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 Buffer imp_dst(imp_dst_size);
395
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 abi_check(
396
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K,
397
2/4
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
2772 imp_packed_lhs.data() + matmul_lhs_packed_offset, imp_packed_rhs.data() + matmul_rhs_packed_offset,
398
1/2
✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
2772 reinterpret_cast<float*>(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float),
399 2772 testdata.clamp_range.min, testdata.clamp_range.max);
400
401 // Compares the output of the micro-kernels against the output of the reference implementation.
402
5/6
✓ Branch 0 taken 8880 times.
✓ Branch 1 taken 17760 times.
✓ Branch 2 taken 23868 times.
✓ Branch 3 taken 2772 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2772 times.
26640 for (size_t y = 0; y < rect.height(); ++y) {
403
5/6
✓ Branch 0 taken 294936 times.
✓ Branch 1 taken 759120 times.
✓ Branch 2 taken 1030188 times.
✓ Branch 3 taken 23868 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 23868 times.
1054056 for (size_t x = 0; x < rect.width(); ++x) {
404 2060376 const auto imp_value =
405
4/8
✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1030188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1030188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1030188 times.
✗ Branch 7 not taken.
1030188 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
406 2060376 const auto ref_value =
407
4/8
✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1030188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1030188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1030188 times.
✗ Branch 7 not taken.
1030188 read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
408
3/6
✓ Branch 0 taken 743208 times.
✓ Branch 1 taken 286980 times.
✓ Branch 2 taken 286980 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
1030188 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value);
409
410
1/2
✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
1030188 if (rel_error > 0.0001F) {
411 ASSERT_EQ(imp_value, ref_value);
412 }
413 1030188 }
414 23868 }
415 3696 }
416
417
35/118
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✓ Branch 18 taken 8 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✓ Branch 22 taken 8 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 4 times.
✓ Branch 24 taken 8 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✓ Branch 26 taken 8 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✓ Branch 28 taken 8 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 4 times.
✓ Branch 30 taken 8 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 4 times.
✓ Branch 32 taken 8 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✓ Branch 34 taken 8 times.
✓ Branch 34 taken 7392 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✓ Branch 36 taken 14784 times.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 59 not taken.
✗ Branch 60 not taken.
✓ Branch 60 taken 7392 times.
✗ Branch 61 not taken.
✗ Branch 61 not taken.
✓ Branch 62 taken 14784 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 14784 times.
✗ Branch 65 not taken.
44367 INSTANTIATE_TEST_SUITE_P(
418 MatMul, MatMulTest_f32_qai8dxp_qsi8cxp,
419 testing::Combine(
420 testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.size()),
421 testing::Values(
422 MatMulShape{17, 33, 67}, //
423 MatMulShape{19, 35, 63}, //
424 MatMulShape{1, 27, 31}, //
425 MatMulShape{1, 65, 35}, //
426 MatMulShape{1, 64, 65}, //
427 MatMulShape{1, 63, 15}, //
428 MatMulShape{1, 130, 15}, //
429 MatMulShape{15, 65, 35}, //
430 MatMulShape{16, 64, 65}, //
431 MatMulShape{17, 63, 15}, //
432 MatMulShape{20, 130, 15}),
433 testing::Values(
434 MatrixPortion(0, 0, 1, 1), // Full matrix.
435 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
436 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
437 MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle
438 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
439 MatrixPortion(0.75, 0, 1, 1), // Partial rows
440 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
441 ),
442 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio
443 [](const auto& info) {
444 const auto variant_idx = std::get<0>(info.param);
445 const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_idx).name};
446 const auto shape = std::get<MatMulShape>(info.param);
447 const auto portion = std::get<MatrixPortion>(info.param);
448 const auto clamp_keep_ratio = std::get<float>(info.param);
449
450 return test_description(name, shape, portion, true, clamp_keep_ratio);
451 });
452
453 } // namespace kai::test
454