KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.4% 478 0 496
Functions: 97.1% 34 0 35
Branches: 38.3% 674 0 1762

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <functional>
14 #include <limits>
15 #include <random>
16 #include <sstream>
17 #include <string>
18 #include <string_view>
19
20 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h"
23 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h"
24 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h"
25 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod.h"
26 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h"
27 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h"
28 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h"
29 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h"
30 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h"
31 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h"
32 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h"
33 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h"
34 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h"
35 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h"
36 #include "test/common/buffer.hpp"
37 #include "test/common/compare.hpp"
38 #include "test/common/cpu_info.hpp"
39 #include "test/common/int4.hpp"
40 #include "test/common/matmul_test_common.hpp"
41 #include "test/common/matrix_portion.hpp"
42 #include "test/common/memory.hpp"
43 #include "test/common/round.hpp"
44 #include "test/common/test_suite.hpp"
45 #include "test/reference/cast.hpp"
46 #include "test/reference/fill.hpp"
47 #include "test/reference/matmul.hpp"
48 #include "test/reference/pad.hpp"
49 #include "test/reference/quantize.hpp"
50 #include "test/reference/transpose.hpp"
51
52 namespace kai::test {
53 /// Matrix multiplication test information.
54
55 enum class RhsPackType { NxK, KxN };
56
57 using ukernel_rhs_pack_function = std::function<decltype(kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>;
58 using ukernel_get_rhs_packed_size = std::function<decltype(kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>;
59 using ukernel_get_rhs_packed_offset = std::function<decltype(kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>;
60 using ukernel_get_rhs_offset = std::function<decltype(kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon)>;
61
62 template <typename T>
63 struct UkernelVariantCustom : public UkernelVariant<T> {
64 ukernel_rhs_pack_function run_rhs_pack;
65 ukernel_get_rhs_packed_size get_rhs_packed_size;
66 ukernel_get_rhs_packed_offset get_rhs_packed_offset;
67 ukernel_get_rhs_offset get_rhs_offset;
68 RhsPackType rhs_pack_type;
69
70 UkernelVariantCustom() = delete;
71
72 40 UkernelVariantCustom(
73 T interface, std::string_view name, const std::function<bool(void)>& fn_is_supported,
74 ukernel_rhs_pack_function run_rhs_pack, ukernel_get_rhs_packed_size get_rhs_packed_size,
75 ukernel_get_rhs_packed_offset get_rhs_packed_offset, ukernel_get_rhs_offset get_rhs_offset,
76 const RhsPackType pack_type) :
77 20 UkernelVariant<T>(interface, name, fn_is_supported),
78 20 run_rhs_pack(std::move(run_rhs_pack)),
79 20 get_rhs_packed_size(std::move(get_rhs_packed_size)),
80 20 get_rhs_packed_offset(std::move(get_rhs_packed_offset)),
81 20 get_rhs_offset(std::move(get_rhs_offset)),
82 40 rhs_pack_type(pack_type) {
83 40 }
84 };
85
86 1 static const std::array<UkernelVariantCustom<kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel>, 20>
87
0/2
✗ Branch 0 not taken.
✗ Branch 1 not taken.
1 variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp = {
88
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
21 {{UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa),
89
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa__RHS_NxK__", cpu_has_sme2,
90
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
91
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
92
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
93
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, RhsPackType::NxK},
94
95
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot),
96
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot__RHS_NxK__", cpu_has_sme2,
97
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
98
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
99
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
100
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, RhsPackType::NxK},
101
102
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod),
103
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
104
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
105
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
106 RhsPackType::NxK},
107
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod),
108
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
109
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
110
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
111 RhsPackType::KxN},
112
113
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod),
114
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
115
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
116
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
117 RhsPackType::NxK},
118
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod),
119
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
120
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
121
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
122 RhsPackType::KxN},
123
124
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod),
125
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
126
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
127
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
128 RhsPackType::NxK},
129
130
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod),
131
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
132
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
133
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
134 RhsPackType::KxN},
135
136
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod),
137
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
138
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
139
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
140 RhsPackType::NxK},
141
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod),
142
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
143
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
144
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
145 RhsPackType::KxN},
146
147
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod),
148
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
149
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
150
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
151 RhsPackType::NxK},
152
153
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod),
154
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
155
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
156
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
157 RhsPackType::KxN},
158
159
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm),
160
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm,
161
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
162
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
163 RhsPackType::NxK},
164
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm),
165
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm,
166
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
167
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
168 RhsPackType::KxN},
169
170
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm),
171
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm,
172
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
173
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
174 RhsPackType::NxK},
175
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm),
176
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm,
177
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
178
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
179 RhsPackType::KxN},
180
181
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm),
182
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm,
183
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
184
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
185 RhsPackType::NxK},
186
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm),
187
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm,
188
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
189
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
190 RhsPackType::KxN},
191
192
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm),
193
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm,
194
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
195
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
196 RhsPackType::NxK},
197
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
2 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm),
198
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm,
199
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
200
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
1 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
201 RhsPackType::KxN}}
202
203 };
204
205 class MatMulTest_f32_qai8dxp_qsi4cxp : public ::testing::TestWithParam<MatMulTestPortionedParams> {};
206
207
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
2402 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, Offset_RHS) {
208 4800 const auto& [variant_index, matmul_shape, portion] = GetParam();
209 1600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
210
211
2/4
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
212 GTEST_SKIP() << "Unsupported CPU feature";
213 }
214
215 1600 const size_t M = matmul_shape.m;
216 1600 const size_t N = matmul_shape.n;
217 1600 const size_t K = matmul_shape.k;
218
219 800 auto m_step = ukernel_variant.interface.get_m_step();
220 800 auto n_step = ukernel_variant.interface.get_n_step();
221
222 1600 const auto rect = portion.compute_portion(M, N, m_step, n_step);
223
2/4
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 800 times.
800 if (rect.height() == 0 || rect.width() == 0) {
224 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
225 }
226
227 800 const auto nr = ukernel_variant.interface.get_nr();
228 800 const auto kr = ukernel_variant.interface.get_kr();
229 800 const auto sr = ukernel_variant.interface.get_sr();
230
231 800 const auto rhs_start_row = rect.start_col();
232 800 auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr);
233 800 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
234
3/14
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 800 times.
800 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
235 800 }
236
237
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
2402 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, Offset_LHS) {
238 4800 const auto& [variant_index, matmul_shape, portion] = GetParam();
239 1600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
240
241
2/4
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
242 GTEST_SKIP() << "Unsupported CPU feature";
243 }
244
245 1600 const size_t M = matmul_shape.m;
246 1600 const size_t N = matmul_shape.n;
247 1600 const size_t K = matmul_shape.k;
248
249 800 auto m_step = ukernel_variant.interface.get_m_step();
250 800 auto n_step = ukernel_variant.interface.get_n_step();
251
252 1600 const auto rect = portion.compute_portion(M, N, m_step, n_step);
253
2/4
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 800 times.
800 if (rect.height() == 0 || rect.width() == 0) {
254 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
255 }
256
257 800 const auto mr = ukernel_variant.interface.get_mr();
258 800 const auto kr = ukernel_variant.interface.get_kr();
259 800 const auto sr = ukernel_variant.interface.get_sr();
260
261 800 const auto lhs_start_row = rect.start_row();
262 800 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
263 800 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
264
265
3/14
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 800 times.
800 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
266 800 }
267
268
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
2402 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) {
269 3360 const auto& [variant_index, matmul_shape, portion] = GetParam();
270 1600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
271
272
2/4
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
273 GTEST_SKIP() << "Unsupported CPU feature";
274 }
275
2/2
✓ Branch 0 taken 360 times.
✓ Branch 1 taken 440 times.
800 if (ukernel_variant.rhs_pack_type == RhsPackType::KxN) {
276
3/6
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
360 GTEST_SKIP() << "Wrong type. This test for NxK";
277 }
278
279 440 const uint32_t seed = 0;
280
281 880 const size_t M = matmul_shape.m;
282 880 const size_t N = matmul_shape.n;
283 880 const size_t K = matmul_shape.k;
284
285 440 const auto mr = ukernel_variant.interface.get_mr();
286 440 const auto nr = ukernel_variant.interface.get_nr();
287 440 const auto kr = ukernel_variant.interface.get_kr();
288 440 const auto sr = ukernel_variant.interface.get_sr();
289
290 // Generates input data.
291 440 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
292
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto ref_biases = fill_random<float>(N, seed + 2);
293
294
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 std::uniform_real_distribution<float> dist(-10.0, 1.0);
295
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 std::mt19937 rnd(seed + 1);
296
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
1890680 const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); });
297
298 // Runs the reference implementation.
299 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
300 // * Quantizes the RHS matrix using 4-bit symmetric quantization.
301 // * Performs GEMM.
302 1320 const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
303
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K);
304 2200 const auto [ref_rhs_qsi4, ref_rhs_scales] =
305
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K);
306
307
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
880 const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
308
6/12
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 440 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 440 times.
✗ Branch 11 not taken.
880 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(),
309
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits<float>::lowest(),
310 440 std::numeric_limits<float>::max());
311
312
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto m_step = ukernel_variant.interface.get_m_step();
313
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
440 ASSERT_TRUE(m_step % mr == 0);
314
315
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto n_step = ukernel_variant.interface.get_n_step();
316
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
440 ASSERT_TRUE(n_step % nr == 0);
317
318
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
880 const auto rect = portion.compute_portion(M, N, m_step, n_step);
319
4/8
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 440 times.
440 if (rect.height() == 0 || rect.width() == 0) {
320 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
321 }
322
323
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto lhs_start_row = rect.start_row();
324 440 size_t lhs_stride = K * sizeof(float);
325
326 // Runs the LHS packing micro-kernel.
327
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
328
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 Buffer imp_packed_lhs(imp_packed_lhs_size);
329
330
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
331
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
332
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
333
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 440 times.
✗ Branch 15 not taken.
440 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
334
335
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 kai_run_lhs_quant_pack_qai8dxp_f32(
336
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/,
337
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
338
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 imp_packed_lhs.data() + lhs_packed_offset);
339
340 // Runs the RHS packing micro-kernel.
341 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
342 // * Packs the RHS matrix.
343
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
880 const auto ref_rhs_qsi4_padded = pad_row<Int4>(
344
4/8
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
440 ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2));
345
346
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr);
347
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto rhs_start_row = rect.start_col();
348
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr);
349
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
350
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
440 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
351
352
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 auto rhs_offset = ukernel_variant.get_rhs_offset(rhs_start_row, round_up_division(K, 2));
353 440 size_t bias_offset = rhs_start_row * sizeof(float);
354 440 size_t scale_offset = rhs_start_row * sizeof(float);
355
356
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 Buffer imp_packed_rhs(imp_packed_rhs_size);
357 440 kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{};
358 440 params.lhs_zero_point = 1;
359 440 params.rhs_zero_point = 0;
360
361
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
880 ukernel_variant.run_rhs_pack(
362
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 1, rect.width() /* n */, K, nr, kr, sr,
363
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data() + rhs_offset),
364
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 reinterpret_cast<const float*>(ref_biases.data() + bias_offset),
365
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 reinterpret_cast<const float*>(ref_rhs_scales.data() + scale_offset), imp_packed_rhs.data() + rhs_packed_offset,
366 0, &params);
367
368 440 const auto dst_stride = N * sizeof(float);
369
3/6
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
440 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
370
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
371
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
440 ASSERT_EQ(dst_offset, ref_dst_offset);
372
373 // Runs the GEMM micro-kernel.
374
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
375
5/18
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 440 times.
440 ASSERT_EQ(imp_dst_size, ref_dst.size());
376
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 Buffer imp_dst(imp_dst_size);
377
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
880 ukernel_variant.interface.run_matmul(
378
3/6
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
440 rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
379
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
380 440 N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
381
382 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
383 // tested.
384
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 11968 times.
✓ Branch 2 taken 11528 times.
✓ Branch 3 taken 440 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 440 times.
11968 for (size_t y = 0; y < rect.height(); ++y) {
385
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 848380 times.
✓ Branch 2 taken 836852 times.
✓ Branch 3 taken 11528 times.
✓ Branch 4 taken 11528 times.
✗ Branch 5 not taken.
848380 for (size_t x = 0; x < rect.width(); ++x) {
386 1673704 const auto imp_value =
387
4/8
✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 836852 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 836852 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 836852 times.
✗ Branch 7 not taken.
836852 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
388 1673704 const auto ref_value =
389
4/8
✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 836852 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 836852 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 836852 times.
✗ Branch 7 not taken.
836852 read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
390
1/2
✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
836852 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value;
391
392
1/2
✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
836852 if (rel_error > 0.0001F) {
393 ASSERT_EQ(imp_value, ref_value);
394 }
395 836852 }
396 11528 }
397 800 }
398
399
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
2402 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) {
400 3360 const auto& [variant_index, matmul_shape, portion] = GetParam();
401 1600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
402
403
2/4
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
404 GTEST_SKIP() << "Unsupported CPU feature";
405 }
406
2/2
✓ Branch 0 taken 360 times.
✓ Branch 1 taken 440 times.
800 if (ukernel_variant.rhs_pack_type == RhsPackType::KxN) {
407
3/6
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
360 GTEST_SKIP() << "Wrong type. This test for NxK";
408 }
409
410 440 const uint32_t seed = 0;
411
412 880 const size_t M = matmul_shape.m;
413 880 const size_t N = matmul_shape.n;
414 880 const size_t K = matmul_shape.k;
415
416 440 const auto mr = ukernel_variant.interface.get_mr();
417 440 const auto nr = ukernel_variant.interface.get_nr();
418 440 const auto kr = ukernel_variant.interface.get_kr();
419 440 const auto sr = ukernel_variant.interface.get_sr();
420
421 // Generates input data.
422 440 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
423
424
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 std::uniform_real_distribution<float> dist(-10.0, 1.0);
425
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 std::mt19937 rnd(seed + 1);
426
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
1890680 const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); });
427
428
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto ref_biases = fill_random<float>(N, seed + 2);
429
430 // Runs the reference implementation.
431 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
432 // * Quantizes the RHS matrix using 4-bit symmetric quantization.
433 // * Performs GEMM.
434 1320 const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
435
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K);
436 2200 const auto [ref_rhs_qsi4, ref_rhs_scales] =
437
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K);
438
439
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
880 const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
440
6/12
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 440 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 440 times.
✗ Branch 11 not taken.
880 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(),
441
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits<float>::lowest(),
442 440 std::numeric_limits<float>::max());
443
444
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto m_step = ukernel_variant.interface.get_m_step();
445
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
440 ASSERT_TRUE(m_step % mr == 0);
446
447
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto n_step = ukernel_variant.interface.get_n_step();
448
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
440 ASSERT_TRUE(n_step % nr == 0);
449
450
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
880 const auto rect = portion.compute_portion(M, N, m_step, n_step);
451
4/8
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 440 times.
440 if (rect.height() == 0 || rect.width() == 0) {
452 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
453 }
454
455
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto lhs_start_row = rect.start_row();
456 440 size_t lhs_stride = K * sizeof(float);
457
458 // Runs the LHS packing micro-kernel.
459
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
460
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 Buffer imp_packed_lhs(imp_packed_lhs_size);
461
462
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
463
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
464
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
465
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 440 times.
✗ Branch 15 not taken.
440 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
466
467
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 kai_run_lhs_quant_pack_qai8dxp_f32(
468
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/,
469
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
470
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 imp_packed_lhs.data() + lhs_packed_offset);
471
472
3/6
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
880 const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K);
473 // Runs the RHS packing micro-kernel.
474 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
475 // * Packs the RHS matrix.
476
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
880 const auto ref_rhs_qsu4_padded = pad_row<UInt4>(
477
4/8
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
440 ref_rhs_qsu4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2));
478
479
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr);
480
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto rhs_start_row = rect.start_col();
481
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr);
482
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
483
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
440 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
484
485
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 auto rhs_offset = ukernel_variant.get_rhs_offset(rhs_start_row, round_up_division(K, 2));
486 440 size_t bias_offset = rhs_start_row * sizeof(float);
487 440 size_t scale_offset = rhs_start_row * sizeof(float);
488
489
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 Buffer imp_packed_rhs(imp_packed_rhs_size);
490 440 kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{};
491 440 params.lhs_zero_point = 1;
492 440 params.rhs_zero_point = 8;
493
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
880 ukernel_variant.run_rhs_pack(
494
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 1, rect.width() /* n */, K, nr, kr, sr,
495
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_padded.data() + rhs_offset),
496
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 reinterpret_cast<const float*>(ref_biases.data() + bias_offset),
497
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 reinterpret_cast<const float*>(ref_rhs_scales.data() + scale_offset), imp_packed_rhs.data() + rhs_packed_offset,
498 0, &params);
499
500 440 const auto dst_stride = N * sizeof(float);
501
3/6
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
440 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
502
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
503
4/16
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
440 ASSERT_EQ(dst_offset, ref_dst_offset);
504 // Runs the GEMM micro-kernel.
505
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
506
5/18
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 440 times.
440 ASSERT_EQ(imp_dst_size, ref_dst.size());
507
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
440 Buffer imp_dst(imp_dst_size);
508
1/2
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
880 ukernel_variant.interface.run_matmul(
509
3/6
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
440 rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
510
2/4
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
440 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
511 440 N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
512
513 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
514 // tested.
515
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 11968 times.
✓ Branch 2 taken 11528 times.
✓ Branch 3 taken 440 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 440 times.
11968 for (size_t y = 0; y < rect.height(); ++y) {
516
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 848380 times.
✓ Branch 2 taken 836852 times.
✓ Branch 3 taken 11528 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 11528 times.
848380 for (size_t x = 0; x < rect.width(); ++x) {
517 1673704 const auto imp_value =
518
4/8
✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 836852 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 836852 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 836852 times.
✗ Branch 7 not taken.
836852 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
519 1673704 const auto ref_value =
520
4/8
✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 836852 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 836852 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 836852 times.
✗ Branch 7 not taken.
836852 read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
521
1/2
✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
836852 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value;
522
523
1/2
✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
836852 if (rel_error > 0.0001F) {
524 ASSERT_EQ(imp_value, ref_value);
525 }
526 836852 }
527 11528 }
528 800 }
529
530
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
2402 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) {
531 3040 const auto& [variant_index, matmul_shape, portion] = GetParam();
532 1600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
533
534
2/4
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
535 GTEST_SKIP() << "Unsupported CPU feature";
536 }
537
2/2
✓ Branch 0 taken 360 times.
✓ Branch 1 taken 440 times.
800 if (ukernel_variant.rhs_pack_type == RhsPackType::NxK) {
538
3/6
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
440 GTEST_SKIP() << "Wrong type. This test for KxN";
539 }
540
541 360 const uint32_t seed = 0;
542
543 720 const size_t M = matmul_shape.m;
544 720 const size_t N = matmul_shape.n;
545 720 const size_t K = matmul_shape.k;
546
547 360 const auto mr = ukernel_variant.interface.get_mr();
548 360 const auto nr = ukernel_variant.interface.get_nr();
549 360 const auto kr = ukernel_variant.interface.get_kr();
550 360 const auto sr = ukernel_variant.interface.get_sr();
551
552 // Generates input data.
553 360 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
554
555
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 std::uniform_real_distribution<float> dist(-10.0, 1.0);
556
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 std::mt19937 rnd(seed + 1);
557
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
1546920 const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); });
558
559
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto ref_biases = fill_random<float>(N, seed + 2);
560
561 // Transposed(nxk) RHS dimensions
562 360 const size_t ref_rhs_qsi4_nxk_stride = K;
563
564 // Non-Transposed(kxn) RHS dimensions
565
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2);
566
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(K * ref_rhs_qsi4_kxn_stride, 2);
567
568 // Runs the reference implementation.
569 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
570 // * Quantizes the RHS matrix using 4-bit symmetric quantization.
571 // * Performs GEMM.
572 1080 const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
573
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K);
574 1800 const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] =
575
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K);
576
577
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto ref_rhs_qsi4 = transpose_with_padding<Int4>(
578
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride,
579 360 ref_rhs_qsi4_kxn_size_bytes);
580
581
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
720 const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
582
5/10
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 360 times.
✗ Branch 9 not taken.
720 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(),
583
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits<float>::lowest(),
584 360 std::numeric_limits<float>::max());
585
586
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto m_step = ukernel_variant.interface.get_m_step();
587
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
360 ASSERT_TRUE(m_step % mr == 0);
588
589
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto n_step = ukernel_variant.interface.get_n_step();
590
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
360 ASSERT_TRUE(n_step % nr == 0);
591
592
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
720 const auto rect = portion.compute_portion(M, N, m_step, n_step);
593
4/8
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 360 times.
360 if (rect.height() == 0 || rect.width() == 0) {
594 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
595 }
596
597
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto lhs_start_row = rect.start_row();
598 360 size_t lhs_stride = K * sizeof(float);
599
600 // Runs the LHS packing micro-kernel.
601
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
602
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 Buffer imp_packed_lhs(imp_packed_lhs_size);
603
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
604
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
605
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
606
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 360 times.
✗ Branch 15 not taken.
360 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
607
608
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 kai_run_lhs_quant_pack_qai8dxp_f32(
609
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/,
610
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
611
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 imp_packed_lhs.data() + lhs_packed_offset);
612
613 // Runs the RHS packing micro-kernel.
614 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
615 // * Packs the RHS matrix.
616
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
720 const auto ref_rhs_qsi4_padded = pad_row<Int4>(
617
4/8
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
360 ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2));
618
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr);
619
620
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto rhs_start_row = rect.start_col();
621
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr);
622
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
623
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
360 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
624
625
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 Buffer imp_packed_rhs(imp_packed_rhs_size);
626 360 kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{};
627 360 params.lhs_zero_point = 1;
628 360 params.rhs_zero_point = 0;
629
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
720 ukernel_variant.run_rhs_pack(
630
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()),
631
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 reinterpret_cast<const float*>(ref_biases.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()),
632
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 imp_packed_rhs.data(), 0, &params);
633
634 360 const auto dst_stride = N * sizeof(float);
635
3/6
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
360 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
636
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
637
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
360 ASSERT_EQ(dst_offset, ref_dst_offset);
638
639 // Runs the GEMM micro-kernel.
640
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
641
5/18
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 360 times.
360 ASSERT_EQ(imp_dst_size, ref_dst.size());
642
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 Buffer imp_dst(imp_dst_size);
643
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
720 ukernel_variant.interface.run_matmul(
644
3/6
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
360 rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
645
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
646 360 N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
647
648 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
649 // tested.
650
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 9792 times.
✓ Branch 2 taken 9432 times.
✓ Branch 3 taken 360 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 360 times.
9792 for (size_t y = 0; y < rect.height(); ++y) {
651
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 667150 times.
✓ Branch 2 taken 657718 times.
✓ Branch 3 taken 9432 times.
✓ Branch 4 taken 9432 times.
✗ Branch 5 not taken.
667150 for (size_t x = 0; x < rect.width(); ++x) {
652 1315436 const auto imp_value =
653
4/8
✓ Branch 0 taken 657718 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 657718 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 657718 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 657718 times.
✗ Branch 7 not taken.
657718 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
654 1315436 const auto ref_value =
655
4/8
✓ Branch 0 taken 657718 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 657718 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 657718 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 657718 times.
✗ Branch 7 not taken.
657718 read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
656
1/2
✓ Branch 0 taken 657718 times.
✗ Branch 1 not taken.
657718 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value;
657
658
1/2
✓ Branch 0 taken 657718 times.
✗ Branch 1 not taken.
657718 if (rel_error > 0.0001F) {
659 ASSERT_EQ(imp_value, ref_value);
660 }
661 657718 }
662 9432 }
663 800 }
664
665
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
2402 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) {
666 3040 const auto& [variant_index, matmul_shape, portion] = GetParam();
667 1600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
668
669
2/4
✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
670 GTEST_SKIP() << "Unsupported CPU feature";
671 }
672
2/2
✓ Branch 0 taken 360 times.
✓ Branch 1 taken 440 times.
800 if (ukernel_variant.rhs_pack_type == RhsPackType::NxK) {
673
3/6
✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
440 GTEST_SKIP() << "Wrong type. This test for KxN";
674 }
675
676 360 const uint32_t seed = 0;
677
678 720 const size_t M = matmul_shape.m;
679 720 const size_t N = matmul_shape.n;
680 720 const size_t K = matmul_shape.k;
681
682 360 const auto mr = ukernel_variant.interface.get_mr();
683 360 const auto nr = ukernel_variant.interface.get_nr();
684 360 const auto kr = ukernel_variant.interface.get_kr();
685 360 const auto sr = ukernel_variant.interface.get_sr();
686
687 // Generates input data.
688 360 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
689
690
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 std::uniform_real_distribution<float> dist(-10.0, 1.0);
691
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 std::mt19937 rnd(seed + 1);
692
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
1546920 const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); });
693
694
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto ref_biases = fill_random<float>(N, seed + 2);
695
696 // Transposed(nxk) RHS dimensions
697 360 const size_t ref_rhs_qsi4_nxk_stride = K;
698
699 // Non-Transposed(kxn) RHS dimensions
700
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2);
701 360 const size_t ref_rhs_qsi4_kxn_size = K * ref_rhs_qsi4_kxn_stride;
702
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(ref_rhs_qsi4_kxn_size, 2);
703
704 // Runs the reference implementation.
705 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
706 // * Quantizes the RHS matrix using 4-bit symmetric quantization.
707 // * Performs GEMM.
708 1080 const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
709
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K);
710 1800 const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] =
711
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K);
712
713
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto ref_rhs_qsi4 = transpose_with_padding<Int4>(
714
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride,
715 360 ref_rhs_qsi4_kxn_size_bytes);
716
717
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
720 const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
718
5/10
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 360 times.
✗ Branch 9 not taken.
720 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(),
719
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits<float>::lowest(),
720 360 std::numeric_limits<float>::max());
721
722
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto m_step = ukernel_variant.interface.get_m_step();
723
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
360 ASSERT_TRUE(m_step % mr == 0);
724
725
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto n_step = ukernel_variant.interface.get_n_step();
726
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
360 ASSERT_TRUE(n_step % nr == 0);
727
728
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
720 const auto rect = portion.compute_portion(M, N, m_step, n_step);
729
4/8
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 360 times.
360 if (rect.height() == 0 || rect.width() == 0) {
730 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
731 }
732
733
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto lhs_start_row = rect.start_row();
734 360 size_t lhs_stride = K * sizeof(float);
735
736 // Runs the LHS packing micro-kernel.
737
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
738
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 Buffer imp_packed_lhs(imp_packed_lhs_size);
739
740
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
741
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
742
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
743
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 360 times.
✗ Branch 15 not taken.
360 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
744
745
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 kai_run_lhs_quant_pack_qai8dxp_f32(
746
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/,
747
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
748
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 imp_packed_lhs.data() + lhs_packed_offset);
749
750 // Runs the RHS packing micro-kernel.
751 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
752 // * Packs the RHS matrix.
753
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), ref_rhs_qsi4_kxn_size);
754
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
720 const auto ref_rhs_qsu4_padded = pad_row<UInt4>(
755
4/8
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
360 ref_rhs_qsu4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2));
756
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr);
757
758
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto rhs_start_row = rect.start_col();
759
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr);
760
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
761
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
360 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
762
763
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 Buffer imp_packed_rhs(imp_packed_rhs_size);
764 360 kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{};
765 360 params.lhs_zero_point = 1;
766 360 params.rhs_zero_point = 8;
767
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
720 ukernel_variant.run_rhs_pack(
768
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_padded.data()),
769
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 reinterpret_cast<const float*>(ref_biases.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()),
770
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 imp_packed_rhs.data(), 0, &params);
771
772 360 const auto dst_stride = N * sizeof(float);
773
3/6
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
360 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
774
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
775
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
360 ASSERT_EQ(dst_offset, ref_dst_offset);
776
777 // Runs the GEMM micro-kernel.
778
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
779
5/18
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 360 times.
360 ASSERT_EQ(imp_dst_size, ref_dst.size());
780
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 Buffer imp_dst(imp_dst_size);
781
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
720 ukernel_variant.interface.run_matmul(
782
3/6
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
360 rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
783
2/4
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
360 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
784 360 N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
785
786 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
787 // tested.
788
1/2
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
360 DefaultMismatchHandler handler(0, 0.1, 0, 0.05);
789 360 DataFormat dst_format = DataFormat(DataType::FP32);
790
3/6
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
360 const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler);
791
4/16
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
360 ASSERT_TRUE(success);
792 800 }
793
794
15/46
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 6 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 6 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 4800 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✓ Branch 44 taken 4800 times.
✗ Branch 45 not taken.
9607 INSTANTIATE_TEST_SUITE_P(
795 MatMul, MatMulTest_f32_qai8dxp_qsi4cxp,
796 testing::Combine(
797 testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.size()),
798 testing::Values(
799 MatMulShape{16, 32, 64}, //
800 MatMulShape{16, 32, 36}, //
801 MatMulShape{15, 35, 65}, //
802 MatMulShape{8, 32, 64}, //
803 MatMulShape{15, 31, 45}, //
804 MatMulShape{1, 35, 65}, //
805 MatMulShape{1, 128, 32}, //
806 MatMulShape{64, 128, 32}, //
807 MatMulShape{1, 225, 55}, //
808 MatMulShape{125, 200, 56}),
809 testing::Values(
810 MatrixPortion(0, 0, 1, 1), // Full matrix.
811 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
812 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
813 MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle
814 )),
815 [](const auto& info) {
816 const auto variant_idx = std::get<0>(info.param);
817 const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_idx).name};
818 const auto shape = std::get<MatMulShape>(info.param);
819 const auto portion = std::get<2>(info.param);
820
821 return test_description(name, shape, portion, true);
822 });
823
824 } // namespace kai::test
825