KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.5% 219 0 227
Functions: 100.0% 25 0 25
Branches: 38.3% 307 0 802

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <limits>
14 #include <string>
15 #include <tuple>
16
17 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
18 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h"
19 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
20 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
23 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
24 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
25 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
26 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h"
27 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"
28 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
29 #include "test/common/buffer.hpp"
30 #include "test/common/cache.hpp"
31 #include "test/common/cpu_info.hpp"
32 #include "test/common/matmul_test_common.hpp"
33 #include "test/common/matrix_portion.hpp"
34 #include "test/common/memory.hpp"
35 #include "test/common/printer.hpp"
36 #include "test/common/test_suite.hpp"
37 #include "test/reference/fill.hpp"
38 #include "test/reference/matmul.hpp"
39 #include "test/reference/quantize.hpp"
40 #include "test/reference/transpose.hpp"
41
42 namespace kai::test {
43 using CacheDataId = std::tuple<MatMulShape, DataFormat, DataFormat, DataFormat>;
44
45 struct CacheData {
46 Buffer lhs;
47 Buffer rhs;
48 Buffer bias;
49 };
50
51 template <>
52 11 CacheData ReferenceGenerator<CacheDataId, CacheData>::generate_reference(const CacheDataId& k) {
53 11 MatMulShape shape = std::get<0>(k);
54 11 DataFormat lhs_format = std::get<1>(k);
55 11 DataFormat rhs_format = std::get<2>(k);
56 11 DataFormat bias_format = std::get<3>(k);
57
58 static size_t seed = 1;
59 11 Buffer lhs = fill_matrix_random(shape.m, shape.k, lhs_format, seed++);
60
1/2
✓ Branch 0 taken 11 times.
✗ Branch 1 not taken.
11 Buffer rhs = fill_matrix_random(shape.k, shape.n, rhs_format, seed++);
61
1/2
✓ Branch 0 taken 11 times.
✗ Branch 1 not taken.
11 Buffer bias = fill_matrix_random(1, shape.n, bias_format, seed++);
62
63 11 CacheData test_reference;
64 11 test_reference.lhs = std::move(lhs);
65 11 test_reference.rhs = std::move(rhs);
66 11 test_reference.bias = std::move(bias);
67
68 11 return test_reference;
69 11 }
70
71 1 static const std::array<UkernelVariant<kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel>, 8>
72
0/2
✗ Branch 0 not taken.
✗ Branch 1 not taken.
9 variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp = {{
73
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod),
74
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod},
75
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod),
76
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", cpu_has_dotprod},
77
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod),
78
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", cpu_has_dotprod},
79
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm),
80
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", cpu_has_i8mm},
81
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot),
82
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot", cpu_has_sme},
83
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa),
84
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa", cpu_has_sme},
85
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot),
86
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot", cpu_has_sme2},
87
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa),
88
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", cpu_has_sme2},
89 }};
90
91 class MatMulTest_f32_qai8dxp_qsi8cxp : public ::testing::TestWithParam<MatMulTestPortionedParams> {};
92
93
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1850 TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, Offset_RHS) {
94 1848 const auto& [variant_index, matmul_shape, portion] = GetParam();
95 1232 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index);
96
97
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
616 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
98 GTEST_SKIP() << "Unsupported CPU feature";
99 }
100
101 1232 const size_t K = matmul_shape.k;
102 616 const auto nr = ukernel_variant.interface.get_nr();
103 616 const auto kr = ukernel_variant.interface.get_kr();
104 616 const auto sr = ukernel_variant.interface.get_sr();
105
106 616 auto n_step = ukernel_variant.interface.get_n_step();
107
108 616 auto rhs_packed_offset_kxn = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(n_step, K, nr, kr, sr);
109 616 auto rhs_packed_offset_nxk = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(n_step, K, nr, kr, sr);
110
111
3/14
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
616 ASSERT_EQ(rhs_packed_offset_kxn, rhs_packed_offset_nxk);
112
113 616 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(n_step, K);
114
3/14
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
616 ASSERT_EQ(rhs_packed_offset_kxn, rhs_matmul_offset);
115 616 }
116
117
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1850 TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, Offset_LHS) {
118 1848 const auto& [variant_index, matmul_shape, portion] = GetParam();
119 1232 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index);
120
121
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
616 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
122 GTEST_SKIP() << "Unsupported CPU feature";
123 }
124
125 1232 const size_t K = matmul_shape.k;
126 616 const auto mr = ukernel_variant.interface.get_mr();
127 616 const auto kr = ukernel_variant.interface.get_kr();
128 616 const auto sr = ukernel_variant.interface.get_sr();
129
130 616 auto m_step = ukernel_variant.interface.get_m_step();
131
132 616 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(m_step, K, mr, kr, sr);
133 616 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(m_step, K);
134
135
3/14
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
616 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
136 616 }
137
138
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1850 TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_nxk_qsi8cx) {
139 4312 auto& [variant_index, matmul_shape, portion] = GetParam();
140 1232 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index);
141
142
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
616 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
143 GTEST_SKIP() << "Unsupported CPU feature";
144 }
145
146 1232 const size_t M = matmul_shape.m;
147 1232 const size_t N = matmul_shape.n;
148 1232 const size_t K = matmul_shape.k;
149
150 616 const auto mr = ukernel_variant.interface.get_mr();
151 616 const auto nr = ukernel_variant.interface.get_nr();
152 616 const auto kr = ukernel_variant.interface.get_kr();
153 616 const auto sr = ukernel_variant.interface.get_sr();
154
155 // Generates input data.
156 1232 const CacheDataId id = {
157 matmul_shape, //
158 616 DataFormat(DataType::FP32), //
159 616 DataFormat(DataType::FP32), //
160 616 DataFormat(DataType::FP32)};
161 616 const CacheData& test_data = getV<CacheDataId, CacheData>(id);
162
163 // Runs the reference implementation.
164 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
165 // * Quantizes the RHS matrix using 8-bit symmetric quantization.
166 // * Performs GEMM.
167 1848 const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
168 616 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(test_data.lhs.data(), M, K, K);
169 3000 const auto [ref_rhs_qsi8, ref_rhs_scales] =
170
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
616 quantize_symmetric_per_block_dynamic<float, int8_t, float>(test_data.rhs.data(), N, K, K);
171
172
1/2
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
1232 const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>(
173
6/12
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 616 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 616 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 616 times.
✗ Branch 11 not taken.
1232 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi8.data(),
174
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
616 ref_rhs_scales.data(), nullptr, K, test_data.bias.data(), std::numeric_limits<float>::lowest(),
175 616 std::numeric_limits<float>::max());
176
177
1/2
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
616 auto m_step = ukernel_variant.interface.get_m_step();
178
4/16
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 616 times.
616 ASSERT_TRUE(m_step % mr == 0);
179
180
1/2
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
616 auto n_step = ukernel_variant.interface.get_n_step();
181
4/16
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 616 times.
616 ASSERT_TRUE(n_step % nr == 0);
182
183
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
1232 const auto rect = portion.compute_portion(M, N, m_step, n_step);
184
5/8
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✓ Branch 3 taken 40 times.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 576 times.
616 if (rect.height() == 0 || rect.width() == 0) {
185
10/20
✓ Branch 0 taken 40 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 40 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 40 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 40 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 40 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 40 times.
✗ Branch 19 not taken.
40 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
186 }
187
188 // Runs the LHS packing micro-kernel.
189
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
190
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 Buffer imp_packed_lhs(imp_packed_lhs_size);
191
192
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto lhs_start_row = rect.start_row();
193 576 size_t lhs_stride = K * sizeof(float);
194
195
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
196
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
197
198
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 kai_run_lhs_quant_pack_qai8dxp_f32(
199
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 rect.height(), K, mr, kr, sr, 0, reinterpret_cast<const float*>(test_data.lhs.data() + lhs_offset), lhs_stride,
200
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 imp_packed_lhs.data() + lhs_packed_offset);
201
202 // Runs the RHS packing micro-kernel.
203 // * Generates the 8-bit signed symmetric quantized input for the micro-kernel.
204 // * Packs the RHS matrix.
205
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr);
206
207
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 Buffer imp_packed_rhs(imp_packed_rhs_size);
208 576 const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f};
209
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(
210
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
1152 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()),
211
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 reinterpret_cast<const float*>(test_data.bias.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()),
212
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 imp_packed_rhs.data(), 0, &params);
213
214
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto packed_rhs_start_row = rect.start_col();
215 576 auto rhs_packed_offset =
216
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(packed_rhs_start_row, K, nr, kr, sr);
217
218 576 const auto dst_stride = N * sizeof(float);
219
3/6
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
576 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
220
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
221
4/16
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
576 ASSERT_EQ(dst_offset, ref_dst_offset);
222
223
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 const auto matmul_lhs_packed_offset = ukernel_variant.interface.get_lhs_packed_offset(rect.start_row(), K);
224
4/16
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
576 ASSERT_EQ(lhs_packed_offset, matmul_lhs_packed_offset);
225
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 const auto matmul_rhs_packed_offset = ukernel_variant.interface.get_rhs_packed_offset(rect.start_col(), K);
226
4/16
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
576 ASSERT_EQ(rhs_packed_offset, matmul_rhs_packed_offset);
227
228 // Runs the GEMM micro-kernel.
229
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
230
5/18
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 576 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 576 times.
576 ASSERT_EQ(imp_dst_size, ref_dst.size());
231
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 Buffer imp_dst(imp_dst_size);
232
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
1152 ukernel_variant.interface.run_matmul(
233
3/6
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
576 rect.height(), rect.width(), K, imp_packed_lhs.data() + matmul_lhs_packed_offset,
234
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 imp_packed_rhs.data() + matmul_rhs_packed_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
235 576 N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
236
237 // Compares the output of the micro-kernels against the output of the reference implementation.
238
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 5812 times.
✓ Branch 2 taken 5236 times.
✓ Branch 3 taken 576 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 576 times.
5812 for (size_t y = 0; y < rect.height(); ++y) {
239
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 248408 times.
✓ Branch 2 taken 243172 times.
✓ Branch 3 taken 5236 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5236 times.
248408 for (size_t x = 0; x < rect.width(); ++x) {
240 486344 const auto imp_value =
241
4/8
✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 243172 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 243172 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 243172 times.
✗ Branch 7 not taken.
243172 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
242 486344 const auto ref_value =
243
4/8
✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 243172 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 243172 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 243172 times.
✗ Branch 7 not taken.
243172 read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
244
1/2
✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
243172 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value);
245
246
1/2
✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
243172 if (rel_error > 0.0001F) {
247 ASSERT_EQ(imp_value, ref_value);
248 }
249 243172 }
250 5236 }
251 616 }
252
253
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
1850 TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_kxn_qsi8cx) {
254 4312 auto& [variant_index, matmul_shape, portion] = GetParam();
255 1232 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index);
256
257
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
616 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
258 GTEST_SKIP() << "Unsupported CPU feature";
259 }
260
261 1232 const size_t M = matmul_shape.m;
262 1232 const size_t N = matmul_shape.n;
263 1232 const size_t K = matmul_shape.k;
264
265 616 const auto mr = ukernel_variant.interface.get_mr();
266 616 const auto nr = ukernel_variant.interface.get_nr();
267 616 const auto kr = ukernel_variant.interface.get_kr();
268 616 const auto sr = ukernel_variant.interface.get_sr();
269
270 // Generates input data.
271 1232 const CacheDataId id = {
272 matmul_shape, //
273 616 DataFormat(DataType::FP32), //
274 616 DataFormat(DataType::FP32), //
275 616 DataFormat(DataType::FP32)};
276 616 const CacheData& test_data = getV<CacheDataId, CacheData>(id);
277
278 // Transposed(nxk) RHS dimensions
279 616 const size_t ref_rhs_qsi8_nxk_stride = K;
280
281 // Non-Transposed(kxn) RHS dimensions
282 616 const size_t ref_rhs_qsi8_kxn_stride = N;
283 616 const size_t ref_rhs_qsi8_kxn_size_bytes = K * ref_rhs_qsi8_kxn_stride;
284
285 // Runs the reference implementation.
286 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
287 // * Quantizes the RHS matrix using 8-bit symmetric quantization.
288 // * Performs GEMM.
289 1848 const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
290 616 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(test_data.lhs.data(), M, K, K);
291 3040 const auto [ref_rhs_qsi8_transposed, ref_rhs_scales] =
292
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
616 quantize_symmetric_per_block_dynamic<float, int8_t, float>(test_data.rhs.data(), N, K, K);
293
294
1/2
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
616 const auto ref_rhs_qsi8 = transpose_with_padding<int8_t>(
295
1/2
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
616 ref_rhs_qsi8_transposed.data(), N, K, ref_rhs_qsi8_nxk_stride, ref_rhs_qsi8_kxn_stride,
296 616 ref_rhs_qsi8_kxn_size_bytes);
297
298
1/2
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
1232 const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>(
299
5/10
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 616 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 616 times.
✗ Branch 9 not taken.
1232 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi8.data(),
300
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
616 ref_rhs_scales.data(), nullptr, K, test_data.bias.data(), std::numeric_limits<float>::lowest(),
301 616 std::numeric_limits<float>::max());
302
303
1/2
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
616 auto m_step = ukernel_variant.interface.get_m_step();
304
4/16
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 616 times.
616 ASSERT_TRUE(m_step % mr == 0);
305
306
1/2
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
616 auto n_step = ukernel_variant.interface.get_n_step();
307
4/16
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 616 times.
616 ASSERT_TRUE(n_step % nr == 0);
308
309
2/4
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
1232 const auto rect = portion.compute_portion(M, N, m_step, n_step);
310
5/8
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✓ Branch 3 taken 40 times.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 576 times.
616 if (rect.height() == 0 || rect.width() == 0) {
311
10/20
✓ Branch 0 taken 40 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 40 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 40 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 40 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 40 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 40 times.
✗ Branch 19 not taken.
40 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
312 }
313
314
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto lhs_start_row = rect.start_row();
315 576 size_t const lhs_stride = K * sizeof(float);
316
317
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
318
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
319
320 // Runs the LHS packing micro-kernel.
321
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
322
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 Buffer imp_packed_lhs(imp_packed_lhs_size);
323
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 kai_run_lhs_quant_pack_qai8dxp_f32(
324
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 rect.height(), K, mr, kr, sr, 0, reinterpret_cast<const float*>(test_data.lhs.data() + lhs_offset),
325
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 K * sizeof(float), imp_packed_lhs.data() + lhs_packed_offset);
326
327 // Runs the RHS packing micro-kernel.
328 // * Generates the 8-bit signed symmetric quantized input for the micro-kernel.
329 // * Packs the RHS matrix.
330
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr);
331
332
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 Buffer imp_packed_rhs(imp_packed_rhs_size);
333 576 const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f};
334
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(
335
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()),
336
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 reinterpret_cast<const float*>(test_data.bias.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()),
337
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 imp_packed_rhs.data(), 0, &params);
338
339
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto packed_rhs_start_row = rect.start_col();
340 576 auto rhs_packed_offset =
341
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(packed_rhs_start_row, K, nr, kr, sr);
342
343 576 const auto dst_stride = N * sizeof(float);
344
3/6
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
576 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
345
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
346
4/16
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
576 ASSERT_EQ(dst_offset, ref_dst_offset);
347
348
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 const auto matmul_lhs_packed_offset = ukernel_variant.interface.get_lhs_packed_offset(rect.start_row(), K);
349
4/16
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
576 ASSERT_EQ(lhs_packed_offset, matmul_lhs_packed_offset);
350
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 const auto matmul_rhs_packed_offset = ukernel_variant.interface.get_rhs_packed_offset(rect.start_col(), K);
351
4/16
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
576 ASSERT_EQ(rhs_packed_offset, matmul_rhs_packed_offset);
352
353 // Runs the GEMM micro-kernel.
354
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
355
5/18
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 576 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 576 times.
576 ASSERT_EQ(imp_dst_size, ref_dst.size());
356
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
576 Buffer imp_dst(imp_dst_size);
357
1/2
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
1152 ukernel_variant.interface.run_matmul(
358
3/6
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
576 rect.height(), rect.width(), K, imp_packed_lhs.data() + matmul_lhs_packed_offset,
359
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 imp_packed_rhs.data() + matmul_rhs_packed_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
360 576 N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
361
362 // Compares the output of the micro-kernels against the output of the reference implementation.
363
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 5812 times.
✓ Branch 2 taken 5236 times.
✓ Branch 3 taken 576 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 576 times.
5812 for (size_t y = 0; y < rect.height(); ++y) {
364
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 248408 times.
✓ Branch 2 taken 243172 times.
✓ Branch 3 taken 5236 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5236 times.
248408 for (size_t x = 0; x < rect.width(); ++x) {
365 486344 const auto imp_value =
366
4/8
✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 243172 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 243172 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 243172 times.
✗ Branch 7 not taken.
243172 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
367 486344 const auto ref_value =
368
4/8
✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 243172 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 243172 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 243172 times.
✗ Branch 7 not taken.
243172 read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
369
1/2
✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
243172 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value);
370
371
1/2
✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
243172 if (rel_error > 0.0001F) {
372 ASSERT_EQ(imp_value, ref_value);
373 }
374 243172 }
375 5236 }
376 616 }
377
378
18/58
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 4 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 4 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 2464 times.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✓ Branch 56 taken 2464 times.
✗ Branch 57 not taken.
4933 INSTANTIATE_TEST_SUITE_P(
379 MatMul, MatMulTest_f32_qai8dxp_qsi8cxp,
380 testing::Combine(
381 testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.size()),
382 testing::Values(
383 MatMulShape{17, 33, 67}, //
384 MatMulShape{19, 35, 63}, //
385 MatMulShape{1, 27, 31}, //
386 MatMulShape{1, 65, 35}, //
387 MatMulShape{1, 64, 65}, //
388 MatMulShape{1, 63, 15}, //
389 MatMulShape{1, 130, 15}, //
390 MatMulShape{15, 65, 35}, //
391 MatMulShape{16, 64, 65}, //
392 MatMulShape{17, 63, 15}, //
393 MatMulShape{20, 130, 15}),
394 testing::Values(
395 MatrixPortion(0, 0, 1, 1), // Full matrix.
396 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
397 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
398 MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle
399 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
400 MatrixPortion(0.75, 0, 1, 1), // Partial rows
401 MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle
402 )),
403 [](const auto& info) {
404 const auto variant_idx = std::get<0>(info.param);
405 const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_idx).name};
406 const auto shape = std::get<MatMulShape>(info.param);
407 const auto portion = std::get<MatrixPortion>(info.param);
408
409 return test_description(name, shape, portion, true);
410 });
411
412 } // namespace kai::test
413