KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.4% 226 / 0 / 232
Functions: 100.0% 38 / 0 / 38
Branches: 45.0% 327 / 0 / 726

test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <functional>
14 #include <limits>
15 #include <random>
16 #include <sstream>
17 #include <string>
18 #include <string_view>
19
20 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h"
21 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h"
23 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h"
24 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h"
25 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod.h"
26 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h"
27 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h"
28 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h"
29 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h"
30 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h"
31 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h"
32 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h"
33 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h"
34 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h"
35 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h"
36 #include "test/common/abi_checker.hpp"
37 #include "test/common/buffer.hpp"
38 #include "test/common/compare.hpp"
39 #include "test/common/cpu_info.hpp"
40 #include "test/common/int4.hpp"
41 #include "test/common/matmul_test_common.hpp"
42 #include "test/common/matrix_portion.hpp"
43 #include "test/common/memory.hpp"
44 #include "test/common/round.hpp"
45 #include "test/common/test_suite.hpp"
46 #include "test/reference/cast.hpp"
47 #include "test/reference/clamp.hpp"
48 #include "test/reference/fill.hpp"
49 #include "test/reference/matmul.hpp"
50 #include "test/reference/pad.hpp"
51 #include "test/reference/quantize.hpp"
52 #include "test/reference/transpose.hpp"
53
54 namespace kai::test {
55 /// Matrix multiplication test information.
56
57 enum class RhsPackType { NxK, KxN };
58
59 using ukernel_rhs_pack_function = std::function<decltype(kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>;
60 using ukernel_get_rhs_packed_size = std::function<decltype(kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>;
61 using ukernel_get_rhs_packed_offset = std::function<decltype(kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>;
62 using ukernel_get_rhs_offset = std::function<decltype(kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon)>;
63
64 template <typename T>
65 struct UkernelVariantCustom : public UkernelVariant<T> {
66 ukernel_rhs_pack_function run_rhs_pack;
67 ukernel_get_rhs_packed_size get_rhs_packed_size;
68 ukernel_get_rhs_packed_offset get_rhs_packed_offset;
69 ukernel_get_rhs_offset get_rhs_offset;
70 RhsPackType rhs_pack_type;
71
72 UkernelVariantCustom() = delete;
73
74 80 UkernelVariantCustom(
75 T interface, std::string_view name, const std::function<bool(void)>& fn_is_supported,
76 ukernel_rhs_pack_function run_rhs_pack, ukernel_get_rhs_packed_size get_rhs_packed_size,
77 ukernel_get_rhs_packed_offset get_rhs_packed_offset, ukernel_get_rhs_offset get_rhs_offset,
78 const RhsPackType pack_type) :
79 60 UkernelVariant<T>(interface, name, fn_is_supported),
80 60 run_rhs_pack(std::move(run_rhs_pack)),
81 60 get_rhs_packed_size(std::move(get_rhs_packed_size)),
82 60 get_rhs_packed_offset(std::move(get_rhs_packed_offset)),
83 60 get_rhs_offset(std::move(get_rhs_offset)),
84 80 rhs_pack_type(pack_type) {
85 80 }
86 };
87
88 3 static const std::array<UkernelVariantCustom<kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel>, 20>
89
0/4
✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
3 variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp = {
90
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
25 {{UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa),
91
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa__RHS_NxK__", cpu_has_sme2,
92
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
93
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
3 kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
94
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
95
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
3 kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, RhsPackType::NxK},
96
97
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot),
98
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot__RHS_NxK__", cpu_has_sme2,
99
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
100
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
3 kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
101
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon,
102
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
3 kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, RhsPackType::NxK},
103
104
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod),
105
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
106
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
107
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
108 RhsPackType::NxK},
109
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod),
110
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
111
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
112
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
113 RhsPackType::KxN},
114
115
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod),
116
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
117
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
118
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
119 RhsPackType::NxK},
120
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod),
121
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
122
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
123
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
124 RhsPackType::KxN},
125
126
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod),
127
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
128
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
129
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
130 RhsPackType::NxK},
131
132
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod),
133
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
134
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
135
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
136 RhsPackType::KxN},
137
138
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod),
139
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
140
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
141
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
142 RhsPackType::NxK},
143
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod),
144
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
145
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
146
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
147 RhsPackType::KxN},
148
149
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod),
150
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod,
151
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
152
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
153 RhsPackType::NxK},
154
155
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod),
156
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod,
157
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
158
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
159 RhsPackType::KxN},
160
161
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm),
162
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm,
163
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
164
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
165 RhsPackType::NxK},
166
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm),
167
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm,
168
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
169
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
170 RhsPackType::KxN},
171
172
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm),
173
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm,
174
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
175
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
176 RhsPackType::NxK},
177
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm),
178
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm,
179
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
180
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
181 RhsPackType::KxN},
182
183
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm),
184
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm,
185
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
186
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
187 RhsPackType::NxK},
188
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm),
189
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm,
190
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
191
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
192 RhsPackType::KxN},
193
194
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm),
195
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm,
196
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
197
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0,
198 RhsPackType::NxK},
199
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm),
200
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm,
201
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
202
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0,
203 RhsPackType::KxN}}
204
205 };
206
207 using MatMulClampTestPortionedParams = std::tuple<size_t, MatMulShape, MatrixPortion, float>;
208 class MatMulTest_f32_qai8dxp_qsi4cxp : public ::testing::TestWithParam<MatMulClampTestPortionedParams> {};
209
210
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
12006 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, Offset_RHS) {
211 27840 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
212 9600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
213
214
3/4
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
4800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
215
3/6
✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
240 GTEST_SKIP() << "Unsupported CPU feature";
216 }
217
218 9120 const size_t M = matmul_shape.m;
219 9120 const size_t N = matmul_shape.n;
220 9120 const size_t K = matmul_shape.k;
221
222 4560 auto m_step = ukernel_variant.interface.get_m_step();
223 4560 auto n_step = ukernel_variant.interface.get_n_step();
224
225 9120 const auto rect = portion.compute_portion(M, N, m_step, n_step);
226
2/4
✓ Branch 0 taken 4560 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4560 times.
4560 if (rect.height() == 0 || rect.width() == 0) {
227 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
228 }
229
230 4560 const auto nr = ukernel_variant.interface.get_nr();
231 4560 const auto kr = ukernel_variant.interface.get_kr();
232 4560 const auto sr = ukernel_variant.interface.get_sr();
233
234 4560 const auto rhs_start_row = rect.start_col();
235 4560 auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr);
236 4560 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
237
3/14
✓ Branch 0 taken 4560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 4560 times.
4560 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
238 4800 }
239
240
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
12006 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, Offset_LHS) {
241 27840 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
242 9600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
243
244
3/4
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
4800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
245
3/6
✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
240 GTEST_SKIP() << "Unsupported CPU feature";
246 }
247
248 9120 const size_t M = matmul_shape.m;
249 9120 const size_t N = matmul_shape.n;
250 9120 const size_t K = matmul_shape.k;
251
252 4560 auto m_step = ukernel_variant.interface.get_m_step();
253 4560 auto n_step = ukernel_variant.interface.get_n_step();
254
255 9120 const auto rect = portion.compute_portion(M, N, m_step, n_step);
256
2/4
✓ Branch 0 taken 4560 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4560 times.
4560 if (rect.height() == 0 || rect.width() == 0) {
257 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
258 }
259
260 4560 const auto mr = ukernel_variant.interface.get_mr();
261 4560 const auto kr = ukernel_variant.interface.get_kr();
262 4560 const auto sr = ukernel_variant.interface.get_sr();
263
264 4560 const auto lhs_start_row = rect.start_row();
265 4560 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
266 4560 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
267
268
3/14
✓ Branch 0 taken 4560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 4560 times.
4560 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
269 4800 }
270
271
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
12006 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) {
272 16800 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
273 9600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
274
275
3/4
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
4800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
276
3/6
✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
240 GTEST_SKIP() << "Unsupported CPU feature";
277 }
278
2/2
✓ Branch 0 taken 2160 times.
✓ Branch 1 taken 2400 times.
4560 if (ukernel_variant.rhs_pack_type == RhsPackType::KxN) {
279
3/6
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2160 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2160 times.
✗ Branch 5 not taken.
2160 GTEST_SKIP() << "Wrong type. This test for NxK";
280 }
281
282 2400 const uint32_t seed = 0;
283
284 4800 const size_t M = matmul_shape.m;
285 4800 const size_t N = matmul_shape.n;
286 4800 const size_t K = matmul_shape.k;
287
288 2400 const auto mr = ukernel_variant.interface.get_mr();
289 2400 const auto nr = ukernel_variant.interface.get_nr();
290 2400 const auto kr = ukernel_variant.interface.get_kr();
291 2400 const auto sr = ukernel_variant.interface.get_sr();
292
293 // Generates input data.
294 2400 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
295
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
2400 const auto ref_biases = fill_random<float>(N, seed + 2);
296
297
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
2400 std::uniform_real_distribution<float> dist(-10.0, 1.0);
298
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
2400 std::mt19937 rnd(seed + 1);
299
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
10312800 const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); });
300
301 // Runs the reference implementation.
302 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
303 // * Quantizes the RHS matrix using 4-bit symmetric quantization.
304 // * Performs GEMM.
305
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
2400 QuantizationInfo lhs_qinfo{};
306 lhs_qinfo.quant_width = K;
307 lhs_qinfo.dst_type = DataType::QAI8;
308 lhs_qinfo.scale_type = DataType::FP32;
309 lhs_qinfo.zero_point_type = DataType::I32;
310 const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo);
311
312 QuantizationInfo rhs_qinfo{};
313 rhs_qinfo.quant_width = K;
314 rhs_qinfo.dst_type = DataType::QSI4;
315 rhs_qinfo.scale_type = DataType::FP32;
316 const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo);
317
318 const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
319 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K,
320 ref_rhs_quant.data(), rhs_qoutputs.scales.data(), nullptr, K, ref_biases.data(),
321 std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
322
323 const auto [clamp_min, clamp_max] =
324 find_clamp_range(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_keep_ratio);
325
326 auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_min, clamp_max);
327
328 auto m_step = ukernel_variant.interface.get_m_step();
329 ASSERT_TRUE(m_step % mr == 0);
330
331 auto n_step = ukernel_variant.interface.get_n_step();
332 ASSERT_TRUE(n_step % nr == 0);
333
334 const auto rect = portion.compute_portion(M, N, m_step, n_step);
335 if (rect.height() == 0 || rect.width() == 0) {
336 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
337 }
338
339 const auto lhs_start_row = rect.start_row();
340 size_t lhs_stride = K * sizeof(float);
341
342 // Runs the LHS packing micro-kernel.
343 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
344 Buffer imp_packed_lhs(imp_packed_lhs_size);
345
346 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
347 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
348 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
349 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
350
351 kai_run_lhs_quant_pack_qai8dxp_f32(
352 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/,
353 reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
354 imp_packed_lhs.data() + lhs_packed_offset);
355
356 // Runs the RHS packing micro-kernel.
357 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
358 // * Packs the RHS matrix.
359 const auto ref_rhs_qsi4_padded = pad_row<Int4>(
360 ref_rhs_quant.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2));
361
362 const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr);
363 const auto rhs_start_row = rect.start_col();
364 auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr);
365 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
366 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
367
368 auto rhs_offset = ukernel_variant.get_rhs_offset(rhs_start_row, round_up_division(K, 2));
369 size_t bias_offset = rhs_start_row * sizeof(float);
370 size_t scale_offset = rhs_start_row * sizeof(float);
371
372 Buffer imp_packed_rhs(imp_packed_rhs_size);
373 kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{};
374 params.lhs_zero_point = 1;
375 params.rhs_zero_point = 0;
376
377 abi_check(
378 ukernel_variant.run_rhs_pack, 1, rect.width() /* n */, K, nr, kr, sr,
379 reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data() + rhs_offset),
380 reinterpret_cast<const float*>(ref_biases.data() + bias_offset),
381 reinterpret_cast<const float*>(rhs_qoutputs.scales.data() + scale_offset),
382 imp_packed_rhs.data() + rhs_packed_offset, 0, &params);
383
384 const auto dst_stride = N * sizeof(float);
385 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
386 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
387 ASSERT_EQ(dst_offset, ref_dst_offset);
388
389 // Runs the GEMM micro-kernel.
390 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
391 ASSERT_EQ(imp_dst_size, ref_dst.size());
392 Buffer imp_dst(imp_dst_size);
393 abi_check(
394 ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
395 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
396 N * sizeof(float), sizeof(float), clamp_min, clamp_max);
397
398 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
399 // tested.
400 for (size_t y = 0; y < rect.height(); ++y) {
401 for (size_t x = 0; x < rect.width(); ++x) {
402 const auto imp_value =
403 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
404 const auto ref_value =
405 read_array<float>(ref_clamped.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
406 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value;
407
408 if (rel_error > 0.0001F) {
409 ASSERT_EQ(imp_value, ref_value);
410 }
411 }
412 }
413 }
414
415
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
12006 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) {
416 16800 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
417 9600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
418
419
3/4
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
4800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
420
3/6
✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
240 GTEST_SKIP() << "Unsupported CPU feature";
421 }
422
2/2
✓ Branch 0 taken 2160 times.
✓ Branch 1 taken 2400 times.
4560 if (ukernel_variant.rhs_pack_type == RhsPackType::KxN) {
423
3/6
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2160 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2160 times.
✗ Branch 5 not taken.
2160 GTEST_SKIP() << "Wrong type. This test for NxK";
424 }
425
426 2400 const uint32_t seed = 0;
427
428 4800 const size_t M = matmul_shape.m;
429 4800 const size_t N = matmul_shape.n;
430 4800 const size_t K = matmul_shape.k;
431
432 2400 const auto mr = ukernel_variant.interface.get_mr();
433 2400 const auto nr = ukernel_variant.interface.get_nr();
434 2400 const auto kr = ukernel_variant.interface.get_kr();
435 2400 const auto sr = ukernel_variant.interface.get_sr();
436
437 // Generates input data.
438 2400 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
439
440
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
2400 std::uniform_real_distribution<float> dist(-10.0, 1.0);
441
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
2400 std::mt19937 rnd(seed + 1);
442
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
10312800 const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); });
443
444
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
2400 const auto ref_biases = fill_random<float>(N, seed + 2);
445
446 // Runs the reference implementation.
447 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
448 // * Quantizes the RHS matrix using 4-bit symmetric quantization.
449 // * Performs GEMM.
450
1/2
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
2400 QuantizationInfo lhs_qinfo{};
451 lhs_qinfo.quant_width = K;
452 lhs_qinfo.dst_type = DataType::QAI8;
453 lhs_qinfo.scale_type = DataType::FP32;
454 lhs_qinfo.zero_point_type = DataType::I32;
455 const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo);
456
457 QuantizationInfo rhs_qinfo{};
458 rhs_qinfo.quant_width = K;
459 rhs_qinfo.dst_type = DataType::QSI4;
460 rhs_qinfo.scale_type = DataType::FP32;
461 const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo);
462
463 const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
464 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K,
465 ref_rhs_quant.data(), rhs_qoutputs.scales.data(), nullptr, K, ref_biases.data(),
466 std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
467
468 const auto [clamp_min, clamp_max] =
469 find_clamp_range(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_keep_ratio);
470
471 auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_min, clamp_max);
472
473 auto m_step = ukernel_variant.interface.get_m_step();
474 ASSERT_TRUE(m_step % mr == 0);
475
476 auto n_step = ukernel_variant.interface.get_n_step();
477 ASSERT_TRUE(n_step % nr == 0);
478
479 const auto rect = portion.compute_portion(M, N, m_step, n_step);
480 if (rect.height() == 0 || rect.width() == 0) {
481 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
482 }
483
484 const auto lhs_start_row = rect.start_row();
485 size_t lhs_stride = K * sizeof(float);
486
487 // Runs the LHS packing micro-kernel.
488 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
489 Buffer imp_packed_lhs(imp_packed_lhs_size);
490
491 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
492 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
493 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
494 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
495
496 kai_run_lhs_quant_pack_qai8dxp_f32(
497 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/,
498 reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
499 imp_packed_lhs.data() + lhs_packed_offset);
500
501 const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_quant.data(), N * K);
502 // Runs the RHS packing micro-kernel.
503 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
504 // * Packs the RHS matrix.
505 const auto ref_rhs_qsu4_padded = pad_row<UInt4>(
506 ref_rhs_qsu4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2));
507
508 const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr);
509 const auto rhs_start_row = rect.start_col();
510 auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr);
511 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
512 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
513
514 auto rhs_offset = ukernel_variant.get_rhs_offset(rhs_start_row, round_up_division(K, 2));
515 size_t bias_offset = rhs_start_row * sizeof(float);
516 size_t scale_offset = rhs_start_row * sizeof(float);
517
518 Buffer imp_packed_rhs(imp_packed_rhs_size);
519 kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{};
520 params.lhs_zero_point = 1;
521 params.rhs_zero_point = 8;
522 abi_check(
523 ukernel_variant.run_rhs_pack, 1, rect.width() /* n */, K, nr, kr, sr,
524 reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_padded.data() + rhs_offset),
525 reinterpret_cast<const float*>(ref_biases.data() + bias_offset),
526 reinterpret_cast<const float*>(rhs_qoutputs.scales.data() + scale_offset),
527 imp_packed_rhs.data() + rhs_packed_offset, 0, &params);
528
529 const auto dst_stride = N * sizeof(float);
530 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
531 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
532 ASSERT_EQ(dst_offset, ref_dst_offset);
533 // Runs the GEMM micro-kernel.
534 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
535 ASSERT_EQ(imp_dst_size, ref_dst.size());
536 Buffer imp_dst(imp_dst_size);
537 abi_check(
538 ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
539 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
540 N * sizeof(float), sizeof(float), clamp_min, clamp_max);
541
542 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
543 // tested.
544 for (size_t y = 0; y < rect.height(); ++y) {
545 for (size_t x = 0; x < rect.width(); ++x) {
546 const auto imp_value =
547 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
548 const auto ref_value =
549 read_array<float>(ref_clamped.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
550 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value;
551
552 if (rel_error > 0.0001F) {
553 ASSERT_EQ(imp_value, ref_value);
554 }
555 }
556 }
557 }
558
559
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
12006 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) {
560 16080 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
561 9600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
562
563
3/4
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
4800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
564
3/6
✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
240 GTEST_SKIP() << "Unsupported CPU feature";
565 }
566
2/2
✓ Branch 0 taken 2160 times.
✓ Branch 1 taken 2400 times.
4560 if (ukernel_variant.rhs_pack_type == RhsPackType::NxK) {
567
3/6
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
2400 GTEST_SKIP() << "Wrong type. This test for KxN";
568 }
569
570 2160 const uint32_t seed = 0;
571
572 4320 const size_t M = matmul_shape.m;
573 4320 const size_t N = matmul_shape.n;
574 4320 const size_t K = matmul_shape.k;
575
576 2160 const auto mr = ukernel_variant.interface.get_mr();
577 2160 const auto nr = ukernel_variant.interface.get_nr();
578 2160 const auto kr = ukernel_variant.interface.get_kr();
579 2160 const auto sr = ukernel_variant.interface.get_sr();
580
581 // Generates input data.
582 2160 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
583
584
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 std::uniform_real_distribution<float> dist(-10.0, 1.0);
585
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 std::mt19937 rnd(seed + 1);
586
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
9281520 const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); });
587
588
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 const auto ref_biases = fill_random<float>(N, seed + 2);
589
590 // Transposed(nxk) RHS dimensions
591 2160 const size_t ref_rhs_qsi4_nxk_stride = K;
592
593 // Non-Transposed(kxn) RHS dimensions
594
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2);
595
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(K * ref_rhs_qsi4_kxn_stride, 2);
596
597 // Runs the reference implementation.
598 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
599 // * Quantizes the RHS matrix using 4-bit symmetric quantization.
600 // * Performs GEMM.
601
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 QuantizationInfo lhs_qinfo{};
602 lhs_qinfo.quant_width = K;
603 lhs_qinfo.dst_type = DataType::QAI8;
604 lhs_qinfo.scale_type = DataType::FP32;
605 lhs_qinfo.zero_point_type = DataType::I32;
606 const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo);
607
608 QuantizationInfo rhs_qinfo{};
609 rhs_qinfo.quant_width = K;
610 rhs_qinfo.dst_type = DataType::QSI4;
611 rhs_qinfo.scale_type = DataType::FP32;
612 const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo);
613
614 const auto ref_rhs_qsi4 = transpose_with_padding<Int4>(
615 ref_rhs_quant.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes);
616
617 const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
618 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K,
619 ref_rhs_qsi4.data(), rhs_qoutputs.scales.data(), nullptr, K, ref_biases.data(),
620 std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
621
622 const auto [clamp_min, clamp_max] =
623 find_clamp_range(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_keep_ratio);
624
625 auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_min, clamp_max);
626
627 auto m_step = ukernel_variant.interface.get_m_step();
628 ASSERT_TRUE(m_step % mr == 0);
629
630 auto n_step = ukernel_variant.interface.get_n_step();
631 ASSERT_TRUE(n_step % nr == 0);
632
633 const auto rect = portion.compute_portion(M, N, m_step, n_step);
634 if (rect.height() == 0 || rect.width() == 0) {
635 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
636 }
637
638 const auto lhs_start_row = rect.start_row();
639 size_t lhs_stride = K * sizeof(float);
640
641 // Runs the LHS packing micro-kernel.
642 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
643 Buffer imp_packed_lhs(imp_packed_lhs_size);
644 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
645 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
646 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
647 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
648
649 kai_run_lhs_quant_pack_qai8dxp_f32(
650 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/,
651 reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
652 imp_packed_lhs.data() + lhs_packed_offset);
653
654 // Runs the RHS packing micro-kernel.
655 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
656 // * Packs the RHS matrix.
657 const auto ref_rhs_qsi4_padded = pad_row<Int4>(
658 ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2));
659 const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr);
660
661 const auto rhs_start_row = rect.start_col();
662 auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr);
663 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
664 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
665
666 Buffer imp_packed_rhs(imp_packed_rhs_size);
667 kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{};
668 params.lhs_zero_point = 1;
669 params.rhs_zero_point = 0;
670 abi_check(
671 ukernel_variant.run_rhs_pack, 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()),
672 reinterpret_cast<const float*>(ref_biases.data()), reinterpret_cast<const float*>(rhs_qoutputs.scales.data()),
673 imp_packed_rhs.data(), 0, &params);
674
675 const auto dst_stride = N * sizeof(float);
676 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
677 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
678 ASSERT_EQ(dst_offset, ref_dst_offset);
679
680 // Runs the GEMM micro-kernel.
681 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
682 ASSERT_EQ(imp_dst_size, ref_dst.size());
683 Buffer imp_dst(imp_dst_size);
684 abi_check(
685 ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
686 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
687 N * sizeof(float), sizeof(float), clamp_min, clamp_max);
688
689 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
690 // tested.
691 for (size_t y = 0; y < rect.height(); ++y) {
692 for (size_t x = 0; x < rect.width(); ++x) {
693 const auto imp_value =
694 read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
695 const auto ref_value =
696 read_array<float>(ref_clamped.data(), (rect.start_row() + y) * N + (x + rect.start_col()));
697 const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value;
698
699 if (rel_error > 0.0001F) {
700 ASSERT_EQ(imp_value, ref_value);
701 }
702 }
703 }
704 }
705
706
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
12006 TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) {
707 16080 const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam();
708 9600 const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index);
709
710
3/4
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
4800 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
711
3/6
✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
240 GTEST_SKIP() << "Unsupported CPU feature";
712 }
713
2/2
✓ Branch 0 taken 2160 times.
✓ Branch 1 taken 2400 times.
4560 if (ukernel_variant.rhs_pack_type == RhsPackType::NxK) {
714
3/6
✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
2400 GTEST_SKIP() << "Wrong type. This test for KxN";
715 }
716
717 2160 const uint32_t seed = 0;
718
719 4320 const size_t M = matmul_shape.m;
720 4320 const size_t N = matmul_shape.n;
721 4320 const size_t K = matmul_shape.k;
722
723 2160 const auto mr = ukernel_variant.interface.get_mr();
724 2160 const auto nr = ukernel_variant.interface.get_nr();
725 2160 const auto kr = ukernel_variant.interface.get_kr();
726 2160 const auto sr = ukernel_variant.interface.get_sr();
727
728 // Generates input data.
729 2160 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
730
731
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 std::uniform_real_distribution<float> dist(-10.0, 1.0);
732
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 std::mt19937 rnd(seed + 1);
733
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
9281520 const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); });
734
735
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 const auto ref_biases = fill_random<float>(N, seed + 2);
736
737 // Transposed(nxk) RHS dimensions
738 2160 const size_t ref_rhs_qsi4_nxk_stride = K;
739
740 // Non-Transposed(kxn) RHS dimensions
741
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2);
742 2160 const size_t ref_rhs_qsi4_kxn_size = K * ref_rhs_qsi4_kxn_stride;
743
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(ref_rhs_qsi4_kxn_size, 2);
744
745 // Runs the reference implementation.
746 // * Quantizes the LHS matrix using 8-bit asymmetric quantization.
747 // * Quantizes the RHS matrix using 4-bit symmetric quantization.
748 // * Performs GEMM.
749
1/2
✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
2160 QuantizationInfo lhs_qinfo{};
750 lhs_qinfo.quant_width = K;
751 lhs_qinfo.dst_type = DataType::QAI8;
752 lhs_qinfo.scale_type = DataType::FP32;
753 lhs_qinfo.zero_point_type = DataType::I32;
754 const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo);
755
756 QuantizationInfo rhs_qinfo{};
757 rhs_qinfo.quant_width = K;
758 rhs_qinfo.dst_type = DataType::QSI4;
759 rhs_qinfo.scale_type = DataType::FP32;
760 const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo);
761
762 const auto ref_rhs_qsi4 = transpose_with_padding<Int4>(
763 ref_rhs_quant.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes);
764
765 const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
766 M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K,
767 ref_rhs_qsi4.data(), rhs_qoutputs.scales.data(), nullptr, K, ref_biases.data(),
768 std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max());
769
770 const auto [clamp_min, clamp_max] =
771 find_clamp_range(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_keep_ratio);
772
773 auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_min, clamp_max);
774
775 auto m_step = ukernel_variant.interface.get_m_step();
776 ASSERT_TRUE(m_step % mr == 0);
777
778 auto n_step = ukernel_variant.interface.get_n_step();
779 ASSERT_TRUE(n_step % nr == 0);
780
781 const auto rect = portion.compute_portion(M, N, m_step, n_step);
782 if (rect.height() == 0 || rect.width() == 0) {
783 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
784 }
785
786 const auto lhs_start_row = rect.start_row();
787 size_t lhs_stride = K * sizeof(float);
788
789 // Runs the LHS packing micro-kernel.
790 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
791 Buffer imp_packed_lhs(imp_packed_lhs_size);
792
793 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
794 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
795 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
796 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
797
798 kai_run_lhs_quant_pack_qai8dxp_f32(
799 rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/,
800 reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride,
801 imp_packed_lhs.data() + lhs_packed_offset);
802
803 // Runs the RHS packing micro-kernel.
804 // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel.
805 // * Packs the RHS matrix.
806 const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), ref_rhs_qsi4_kxn_size);
807 const auto ref_rhs_qsu4_padded = pad_row<UInt4>(
808 ref_rhs_qsu4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2));
809 const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr);
810
811 const auto rhs_start_row = rect.start_col();
812 auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr);
813 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K);
814 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
815
816 Buffer imp_packed_rhs(imp_packed_rhs_size);
817 kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{};
818 params.lhs_zero_point = 1;
819 params.rhs_zero_point = 8;
820 abi_check(
821 ukernel_variant.run_rhs_pack, 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_padded.data()),
822 reinterpret_cast<const float*>(ref_biases.data()), reinterpret_cast<const float*>(rhs_qoutputs.scales.data()),
823 imp_packed_rhs.data(), 0, &params);
824
825 const auto dst_stride = N * sizeof(float);
826 const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride);
827 const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float);
828 ASSERT_EQ(dst_offset, ref_dst_offset);
829
830 // Runs the GEMM micro-kernel.
831 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
832 ASSERT_EQ(imp_dst_size, ref_dst.size());
833 Buffer imp_dst(imp_dst_size);
834 abi_check(
835 ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset,
836 imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
837 N * sizeof(float), sizeof(float), clamp_min, clamp_max);
838
839 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
840 // tested.
841 DefaultMismatchHandler handler(0, 0.1, 0, 0.05);
842 DataFormat dst_format = DataFormat(DataType::FP32);
843 const auto success = compare(imp_dst.data(), ref_clamped.data(), dst_format, M, N, rect, handler);
844 ASSERT_TRUE(success);
845 }
846
847
29/94
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 6 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 6 times.
✓ Branch 12 taken 12 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 6 times.
✓ Branch 14 taken 12 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✓ Branch 16 taken 12 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 6 times.
✓ Branch 18 taken 12 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 6 times.
✓ Branch 20 taken 12 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✓ Branch 22 taken 12 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 6 times.
✓ Branch 24 taken 12 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 6 times.
✓ Branch 26 taken 12 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 12 times.
✓ Branch 28 taken 14400 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✓ Branch 30 taken 28800 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 48 taken 14400 times.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✓ Branch 50 taken 28800 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 28800 times.
✗ Branch 53 not taken.
86421 INSTANTIATE_TEST_SUITE_P(
848 MatMul, MatMulTest_f32_qai8dxp_qsi4cxp,
849 testing::Combine(
850 testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.size()),
851 testing::Values(
852 MatMulShape{16, 32, 64}, //
853 MatMulShape{16, 32, 36}, //
854 MatMulShape{15, 35, 65}, //
855 MatMulShape{8, 32, 64}, //
856 MatMulShape{15, 31, 45}, //
857 MatMulShape{1, 35, 65}, //
858 MatMulShape{1, 128, 32}, //
859 MatMulShape{64, 128, 32}, //
860 MatMulShape{1, 225, 55}, //
861 MatMulShape{125, 200, 56}),
862 testing::Values(
863 MatrixPortion(0, 0, 1, 1), // Full matrix.
864 MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.
865 MatrixPortion(0, 0.75, 1, 1), // Rightmost portion.
866 MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle
867 ),
868 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio
869 [](const auto& info) {
870 const auto variant_idx = std::get<0>(info.param);
871 const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_idx).name};
872 const auto shape = std::get<MatMulShape>(info.param);
873 const auto portion = std::get<2>(info.param);
874 const auto clamp_keep_ratio = std::get<3>(info.param);
875
876 return test_description(name, shape, portion, true, clamp_keep_ratio);
877 });
878
879 } // namespace kai::test
880