Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <cstdlib> | ||
13 | #include <functional> | ||
14 | #include <limits> | ||
15 | #include <random> | ||
16 | #include <sstream> | ||
17 | #include <string> | ||
18 | #include <string_view> | ||
19 | |||
20 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h" | ||
21 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h" | ||
22 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h" | ||
23 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" | ||
24 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" | ||
25 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod.h" | ||
26 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h" | ||
27 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" | ||
28 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" | ||
29 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" | ||
30 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" | ||
31 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" | ||
32 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" | ||
33 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" | ||
34 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" | ||
35 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h" | ||
36 | #include "test/common/buffer.hpp" | ||
37 | #include "test/common/compare.hpp" | ||
38 | #include "test/common/cpu_info.hpp" | ||
39 | #include "test/common/int4.hpp" | ||
40 | #include "test/common/matmul_test_common.hpp" | ||
41 | #include "test/common/matrix_portion.hpp" | ||
42 | #include "test/common/memory.hpp" | ||
43 | #include "test/common/round.hpp" | ||
44 | #include "test/common/test_suite.hpp" | ||
45 | #include "test/reference/cast.hpp" | ||
46 | #include "test/reference/fill.hpp" | ||
47 | #include "test/reference/matmul.hpp" | ||
48 | #include "test/reference/pad.hpp" | ||
49 | #include "test/reference/quantize.hpp" | ||
50 | #include "test/reference/transpose.hpp" | ||
51 | |||
52 | namespace kai::test { | ||
53 | /// Matrix multiplication test information. | ||
54 | |||
55 | enum class RhsPackType { NxK, KxN }; | ||
56 | |||
57 | using ukernel_rhs_pack_function = std::function<decltype(kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>; | ||
58 | using ukernel_get_rhs_packed_size = std::function<decltype(kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>; | ||
59 | using ukernel_get_rhs_packed_offset = std::function<decltype(kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>; | ||
60 | using ukernel_get_rhs_offset = std::function<decltype(kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon)>; | ||
61 | |||
62 | template <typename T> | ||
63 | struct UkernelVariantCustom : public UkernelVariant<T> { | ||
64 | ukernel_rhs_pack_function run_rhs_pack; | ||
65 | ukernel_get_rhs_packed_size get_rhs_packed_size; | ||
66 | ukernel_get_rhs_packed_offset get_rhs_packed_offset; | ||
67 | ukernel_get_rhs_offset get_rhs_offset; | ||
68 | RhsPackType rhs_pack_type; | ||
69 | |||
70 | UkernelVariantCustom() = delete; | ||
71 | |||
72 | 40 | UkernelVariantCustom( | |
73 | T interface, std::string_view name, const std::function<bool(void)>& fn_is_supported, | ||
74 | ukernel_rhs_pack_function run_rhs_pack, ukernel_get_rhs_packed_size get_rhs_packed_size, | ||
75 | ukernel_get_rhs_packed_offset get_rhs_packed_offset, ukernel_get_rhs_offset get_rhs_offset, | ||
76 | const RhsPackType pack_type) : | ||
77 | 20 | UkernelVariant<T>(interface, name, fn_is_supported), | |
78 | 20 | run_rhs_pack(std::move(run_rhs_pack)), | |
79 | 20 | get_rhs_packed_size(std::move(get_rhs_packed_size)), | |
80 | 20 | get_rhs_packed_offset(std::move(get_rhs_packed_offset)), | |
81 | 20 | get_rhs_offset(std::move(get_rhs_offset)), | |
82 | 40 | rhs_pack_type(pack_type) { | |
83 | 40 | } | |
84 | }; | ||
85 | |||
86 | 1 | static const std::array<UkernelVariantCustom<kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel>, 20> | |
87 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
1 | variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp = { |
88 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
21 | {{UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa), |
89 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa__RHS_NxK__", cpu_has_sme2, |
90 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
91 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
92 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
93 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, RhsPackType::NxK}, |
94 | |||
95 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot), |
96 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot__RHS_NxK__", cpu_has_sme2, |
97 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
98 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
99 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
100 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, RhsPackType::NxK}, |
101 | |||
102 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod), |
103 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
104 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
105 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
106 | RhsPackType::NxK}, | ||
107 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod), |
108 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
109 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
110 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
111 | RhsPackType::KxN}, | ||
112 | |||
113 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod), |
114 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
115 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
116 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
117 | RhsPackType::NxK}, | ||
118 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod), |
119 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
120 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
121 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
122 | RhsPackType::KxN}, | ||
123 | |||
124 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod), |
125 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
126 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
127 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
128 | RhsPackType::NxK}, | ||
129 | |||
130 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod), |
131 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
132 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
133 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
134 | RhsPackType::KxN}, | ||
135 | |||
136 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod), |
137 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
138 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
139 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
140 | RhsPackType::NxK}, | ||
141 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod), |
142 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
143 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
144 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
145 | RhsPackType::KxN}, | ||
146 | |||
147 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod), |
148 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
149 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
150 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
151 | RhsPackType::NxK}, | ||
152 | |||
153 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod), |
154 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
155 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
156 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
157 | RhsPackType::KxN}, | ||
158 | |||
159 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm), |
160 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm, |
161 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
162 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
163 | RhsPackType::NxK}, | ||
164 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm), |
165 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm, |
166 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
167 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
168 | RhsPackType::KxN}, | ||
169 | |||
170 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm), |
171 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm, |
172 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
173 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
174 | RhsPackType::NxK}, | ||
175 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm), |
176 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm, |
177 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
178 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
179 | RhsPackType::KxN}, | ||
180 | |||
181 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm), |
182 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm, |
183 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
184 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
185 | RhsPackType::NxK}, | ||
186 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm), |
187 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm, |
188 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
189 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
190 | RhsPackType::KxN}, | ||
191 | |||
192 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm), |
193 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm, |
194 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
195 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
196 | RhsPackType::NxK}, | ||
197 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm), |
198 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm, |
199 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
200 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
201 | RhsPackType::KxN}} | ||
202 | |||
203 | }; | ||
204 | |||
205 | class MatMulTest_f32_qai8dxp_qsi4cxp : public ::testing::TestWithParam<MatMulTestPortionedParams> {}; | ||
206 | |||
207 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
2402 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, Offset_RHS) { |
208 | 4800 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
209 | 1600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
210 | |||
211 |
2/4✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
|
800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
212 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
213 | } | ||
214 | |||
215 | 1600 | const size_t M = matmul_shape.m; | |
216 | 1600 | const size_t N = matmul_shape.n; | |
217 | 1600 | const size_t K = matmul_shape.k; | |
218 | |||
219 | 800 | auto m_step = ukernel_variant.interface.get_m_step(); | |
220 | 800 | auto n_step = ukernel_variant.interface.get_n_step(); | |
221 | |||
222 | 1600 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
223 |
2/4✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 800 times.
|
800 | if (rect.height() == 0 || rect.width() == 0) { |
224 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
225 | } | ||
226 | |||
227 | 800 | const auto nr = ukernel_variant.interface.get_nr(); | |
228 | 800 | const auto kr = ukernel_variant.interface.get_kr(); | |
229 | 800 | const auto sr = ukernel_variant.interface.get_sr(); | |
230 | |||
231 | 800 | const auto rhs_start_row = rect.start_col(); | |
232 | 800 | auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr); | |
233 | 800 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); | |
234 |
3/14✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 800 times.
|
800 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
235 | 800 | } | |
236 | |||
237 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
2402 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, Offset_LHS) { |
238 | 4800 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
239 | 1600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
240 | |||
241 |
2/4✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
|
800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
242 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
243 | } | ||
244 | |||
245 | 1600 | const size_t M = matmul_shape.m; | |
246 | 1600 | const size_t N = matmul_shape.n; | |
247 | 1600 | const size_t K = matmul_shape.k; | |
248 | |||
249 | 800 | auto m_step = ukernel_variant.interface.get_m_step(); | |
250 | 800 | auto n_step = ukernel_variant.interface.get_n_step(); | |
251 | |||
252 | 1600 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
253 |
2/4✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 800 times.
|
800 | if (rect.height() == 0 || rect.width() == 0) { |
254 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
255 | } | ||
256 | |||
257 | 800 | const auto mr = ukernel_variant.interface.get_mr(); | |
258 | 800 | const auto kr = ukernel_variant.interface.get_kr(); | |
259 | 800 | const auto sr = ukernel_variant.interface.get_sr(); | |
260 | |||
261 | 800 | const auto lhs_start_row = rect.start_row(); | |
262 | 800 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); | |
263 | 800 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); | |
264 | |||
265 |
3/14✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 800 times.
|
800 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
266 | 800 | } | |
267 | |||
268 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
2402 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) { |
269 | 3360 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
270 | 1600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
271 | |||
272 |
2/4✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
|
800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
273 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
274 | } | ||
275 |
2/2✓ Branch 0 taken 360 times.
✓ Branch 1 taken 440 times.
|
800 | if (ukernel_variant.rhs_pack_type == RhsPackType::KxN) { |
276 |
3/6✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
|
360 | GTEST_SKIP() << "Wrong type. This test for NxK"; |
277 | } | ||
278 | |||
279 | 440 | const uint32_t seed = 0; | |
280 | |||
281 | 880 | const size_t M = matmul_shape.m; | |
282 | 880 | const size_t N = matmul_shape.n; | |
283 | 880 | const size_t K = matmul_shape.k; | |
284 | |||
285 | 440 | const auto mr = ukernel_variant.interface.get_mr(); | |
286 | 440 | const auto nr = ukernel_variant.interface.get_nr(); | |
287 | 440 | const auto kr = ukernel_variant.interface.get_kr(); | |
288 | 440 | const auto sr = ukernel_variant.interface.get_sr(); | |
289 | |||
290 | // Generates input data. | ||
291 | 440 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
292 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto ref_biases = fill_random<float>(N, seed + 2); |
293 | |||
294 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | std::uniform_real_distribution<float> dist(-10.0, 1.0); |
295 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | std::mt19937 rnd(seed + 1); |
296 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
1890680 | const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); }); |
297 | |||
298 | // Runs the reference implementation. | ||
299 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
300 | // * Quantizes the RHS matrix using 4-bit symmetric quantization. | ||
301 | // * Performs GEMM. | ||
302 | 1320 | const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
303 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K); |
304 | 2200 | const auto [ref_rhs_qsi4, ref_rhs_scales] = | |
305 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K); |
306 | |||
307 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
880 | const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( |
308 |
6/12✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 440 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 440 times.
✗ Branch 11 not taken.
|
880 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), |
309 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits<float>::lowest(), |
310 | 440 | std::numeric_limits<float>::max()); | |
311 | |||
312 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto m_step = ukernel_variant.interface.get_m_step(); |
313 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
440 | ASSERT_TRUE(m_step % mr == 0); |
314 | |||
315 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto n_step = ukernel_variant.interface.get_n_step(); |
316 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
440 | ASSERT_TRUE(n_step % nr == 0); |
317 | |||
318 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
880 | const auto rect = portion.compute_portion(M, N, m_step, n_step); |
319 |
4/8✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 440 times.
|
440 | if (rect.height() == 0 || rect.width() == 0) { |
320 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
321 | } | ||
322 | |||
323 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto lhs_start_row = rect.start_row(); |
324 | 440 | size_t lhs_stride = K * sizeof(float); | |
325 | |||
326 | // Runs the LHS packing micro-kernel. | ||
327 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); |
328 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
329 | |||
330 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); |
331 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); |
332 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
333 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 440 times.
✗ Branch 15 not taken.
|
440 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
334 | |||
335 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | kai_run_lhs_quant_pack_qai8dxp_f32( |
336 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, |
337 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, |
338 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | imp_packed_lhs.data() + lhs_packed_offset); |
339 | |||
340 | // Runs the RHS packing micro-kernel. | ||
341 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
342 | // * Packs the RHS matrix. | ||
343 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
880 | const auto ref_rhs_qsi4_padded = pad_row<Int4>( |
344 |
4/8✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
|
440 | ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); |
345 | |||
346 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); |
347 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto rhs_start_row = rect.start_col(); |
348 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr); |
349 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); |
350 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
440 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
351 | |||
352 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | auto rhs_offset = ukernel_variant.get_rhs_offset(rhs_start_row, round_up_division(K, 2)); |
353 | 440 | size_t bias_offset = rhs_start_row * sizeof(float); | |
354 | 440 | size_t scale_offset = rhs_start_row * sizeof(float); | |
355 | |||
356 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
357 | 440 | kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{}; | |
358 | 440 | params.lhs_zero_point = 1; | |
359 | 440 | params.rhs_zero_point = 0; | |
360 | |||
361 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
880 | ukernel_variant.run_rhs_pack( |
362 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | 1, rect.width() /* n */, K, nr, kr, sr, |
363 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data() + rhs_offset), |
364 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | reinterpret_cast<const float*>(ref_biases.data() + bias_offset), |
365 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | reinterpret_cast<const float*>(ref_rhs_scales.data() + scale_offset), imp_packed_rhs.data() + rhs_packed_offset, |
366 | 0, ¶ms); | ||
367 | |||
368 | 440 | const auto dst_stride = N * sizeof(float); | |
369 |
3/6✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
|
440 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
370 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); |
371 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
440 | ASSERT_EQ(dst_offset, ref_dst_offset); |
372 | |||
373 | // Runs the GEMM micro-kernel. | ||
374 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
375 |
5/18✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 440 times.
|
440 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
376 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | Buffer imp_dst(imp_dst_size); |
377 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
880 | ukernel_variant.interface.run_matmul( |
378 |
3/6✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
|
440 | rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, |
379 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
380 | 440 | N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | |
381 | |||
382 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
383 | // tested. | ||
384 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 11968 times.
✓ Branch 2 taken 11528 times.
✓ Branch 3 taken 440 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 440 times.
|
11968 | for (size_t y = 0; y < rect.height(); ++y) { |
385 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 848380 times.
✓ Branch 2 taken 836852 times.
✓ Branch 3 taken 11528 times.
✓ Branch 4 taken 11528 times.
✗ Branch 5 not taken.
|
848380 | for (size_t x = 0; x < rect.width(); ++x) { |
386 | 1673704 | const auto imp_value = | |
387 |
4/8✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 836852 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 836852 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 836852 times.
✗ Branch 7 not taken.
|
836852 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
388 | 1673704 | const auto ref_value = | |
389 |
4/8✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 836852 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 836852 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 836852 times.
✗ Branch 7 not taken.
|
836852 | read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
390 |
1/2✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
|
836852 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; |
391 | |||
392 |
1/2✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
|
836852 | if (rel_error > 0.0001F) { |
393 | ✗ | ASSERT_EQ(imp_value, ref_value); | |
394 | ✗ | } | |
395 | 836852 | } | |
396 | 11528 | } | |
397 | 800 | } | |
398 | |||
399 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
2402 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { |
400 | 3360 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
401 | 1600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
402 | |||
403 |
2/4✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
|
800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
404 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
405 | } | ||
406 |
2/2✓ Branch 0 taken 360 times.
✓ Branch 1 taken 440 times.
|
800 | if (ukernel_variant.rhs_pack_type == RhsPackType::KxN) { |
407 |
3/6✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
|
360 | GTEST_SKIP() << "Wrong type. This test for NxK"; |
408 | } | ||
409 | |||
410 | 440 | const uint32_t seed = 0; | |
411 | |||
412 | 880 | const size_t M = matmul_shape.m; | |
413 | 880 | const size_t N = matmul_shape.n; | |
414 | 880 | const size_t K = matmul_shape.k; | |
415 | |||
416 | 440 | const auto mr = ukernel_variant.interface.get_mr(); | |
417 | 440 | const auto nr = ukernel_variant.interface.get_nr(); | |
418 | 440 | const auto kr = ukernel_variant.interface.get_kr(); | |
419 | 440 | const auto sr = ukernel_variant.interface.get_sr(); | |
420 | |||
421 | // Generates input data. | ||
422 | 440 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
423 | |||
424 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | std::uniform_real_distribution<float> dist(-10.0, 1.0); |
425 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | std::mt19937 rnd(seed + 1); |
426 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
1890680 | const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); }); |
427 | |||
428 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto ref_biases = fill_random<float>(N, seed + 2); |
429 | |||
430 | // Runs the reference implementation. | ||
431 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
432 | // * Quantizes the RHS matrix using 4-bit symmetric quantization. | ||
433 | // * Performs GEMM. | ||
434 | 1320 | const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
435 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K); |
436 | 2200 | const auto [ref_rhs_qsi4, ref_rhs_scales] = | |
437 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K); |
438 | |||
439 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
880 | const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( |
440 |
6/12✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 440 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 440 times.
✗ Branch 11 not taken.
|
880 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), |
441 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits<float>::lowest(), |
442 | 440 | std::numeric_limits<float>::max()); | |
443 | |||
444 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto m_step = ukernel_variant.interface.get_m_step(); |
445 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
440 | ASSERT_TRUE(m_step % mr == 0); |
446 | |||
447 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto n_step = ukernel_variant.interface.get_n_step(); |
448 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
440 | ASSERT_TRUE(n_step % nr == 0); |
449 | |||
450 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
880 | const auto rect = portion.compute_portion(M, N, m_step, n_step); |
451 |
4/8✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 440 times.
|
440 | if (rect.height() == 0 || rect.width() == 0) { |
452 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
453 | } | ||
454 | |||
455 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto lhs_start_row = rect.start_row(); |
456 | 440 | size_t lhs_stride = K * sizeof(float); | |
457 | |||
458 | // Runs the LHS packing micro-kernel. | ||
459 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); |
460 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
461 | |||
462 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); |
463 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); |
464 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
465 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 440 times.
✗ Branch 15 not taken.
|
440 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
466 | |||
467 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | kai_run_lhs_quant_pack_qai8dxp_f32( |
468 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, |
469 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, |
470 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | imp_packed_lhs.data() + lhs_packed_offset); |
471 | |||
472 |
3/6✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
|
880 | const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); |
473 | // Runs the RHS packing micro-kernel. | ||
474 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
475 | // * Packs the RHS matrix. | ||
476 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
880 | const auto ref_rhs_qsu4_padded = pad_row<UInt4>( |
477 |
4/8✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
|
440 | ref_rhs_qsu4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); |
478 | |||
479 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); |
480 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto rhs_start_row = rect.start_col(); |
481 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr); |
482 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); |
483 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
440 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
484 | |||
485 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | auto rhs_offset = ukernel_variant.get_rhs_offset(rhs_start_row, round_up_division(K, 2)); |
486 | 440 | size_t bias_offset = rhs_start_row * sizeof(float); | |
487 | 440 | size_t scale_offset = rhs_start_row * sizeof(float); | |
488 | |||
489 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
490 | 440 | kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{}; | |
491 | 440 | params.lhs_zero_point = 1; | |
492 | 440 | params.rhs_zero_point = 8; | |
493 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
880 | ukernel_variant.run_rhs_pack( |
494 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | 1, rect.width() /* n */, K, nr, kr, sr, |
495 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_padded.data() + rhs_offset), |
496 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | reinterpret_cast<const float*>(ref_biases.data() + bias_offset), |
497 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | reinterpret_cast<const float*>(ref_rhs_scales.data() + scale_offset), imp_packed_rhs.data() + rhs_packed_offset, |
498 | 0, ¶ms); | ||
499 | |||
500 | 440 | const auto dst_stride = N * sizeof(float); | |
501 |
3/6✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
|
440 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
502 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); |
503 |
4/16✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 440 times.
|
440 | ASSERT_EQ(dst_offset, ref_dst_offset); |
504 | // Runs the GEMM micro-kernel. | ||
505 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
506 |
5/18✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 440 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 440 times.
|
440 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
507 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
440 | Buffer imp_dst(imp_dst_size); |
508 |
1/2✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
|
880 | ukernel_variant.interface.run_matmul( |
509 |
3/6✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
|
440 | rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, |
510 |
2/4✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
|
440 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
511 | 440 | N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | |
512 | |||
513 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
514 | // tested. | ||
515 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 11968 times.
✓ Branch 2 taken 11528 times.
✓ Branch 3 taken 440 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 440 times.
|
11968 | for (size_t y = 0; y < rect.height(); ++y) { |
516 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 848380 times.
✓ Branch 2 taken 836852 times.
✓ Branch 3 taken 11528 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 11528 times.
|
848380 | for (size_t x = 0; x < rect.width(); ++x) { |
517 | 1673704 | const auto imp_value = | |
518 |
4/8✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 836852 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 836852 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 836852 times.
✗ Branch 7 not taken.
|
836852 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
519 | 1673704 | const auto ref_value = | |
520 |
4/8✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 836852 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 836852 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 836852 times.
✗ Branch 7 not taken.
|
836852 | read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
521 |
1/2✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
|
836852 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; |
522 | |||
523 |
1/2✓ Branch 0 taken 836852 times.
✗ Branch 1 not taken.
|
836852 | if (rel_error > 0.0001F) { |
524 | ✗ | ASSERT_EQ(imp_value, ref_value); | |
525 | ✗ | } | |
526 | 836852 | } | |
527 | 11528 | } | |
528 | 800 | } | |
529 | |||
530 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
2402 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { |
531 | 3040 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
532 | 1600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
533 | |||
534 |
2/4✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
|
800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
535 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
536 | } | ||
537 |
2/2✓ Branch 0 taken 360 times.
✓ Branch 1 taken 440 times.
|
800 | if (ukernel_variant.rhs_pack_type == RhsPackType::NxK) { |
538 |
3/6✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
|
440 | GTEST_SKIP() << "Wrong type. This test for KxN"; |
539 | } | ||
540 | |||
541 | 360 | const uint32_t seed = 0; | |
542 | |||
543 | 720 | const size_t M = matmul_shape.m; | |
544 | 720 | const size_t N = matmul_shape.n; | |
545 | 720 | const size_t K = matmul_shape.k; | |
546 | |||
547 | 360 | const auto mr = ukernel_variant.interface.get_mr(); | |
548 | 360 | const auto nr = ukernel_variant.interface.get_nr(); | |
549 | 360 | const auto kr = ukernel_variant.interface.get_kr(); | |
550 | 360 | const auto sr = ukernel_variant.interface.get_sr(); | |
551 | |||
552 | // Generates input data. | ||
553 | 360 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
554 | |||
555 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | std::uniform_real_distribution<float> dist(-10.0, 1.0); |
556 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | std::mt19937 rnd(seed + 1); |
557 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
1546920 | const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); }); |
558 | |||
559 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto ref_biases = fill_random<float>(N, seed + 2); |
560 | |||
561 | // Transposed(nxk) RHS dimensions | ||
562 | 360 | const size_t ref_rhs_qsi4_nxk_stride = K; | |
563 | |||
564 | // Non-Transposed(kxn) RHS dimensions | ||
565 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2); |
566 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(K * ref_rhs_qsi4_kxn_stride, 2); |
567 | |||
568 | // Runs the reference implementation. | ||
569 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
570 | // * Quantizes the RHS matrix using 4-bit symmetric quantization. | ||
571 | // * Performs GEMM. | ||
572 | 1080 | const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
573 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K); |
574 | 1800 | const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] = | |
575 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K); |
576 | |||
577 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto ref_rhs_qsi4 = transpose_with_padding<Int4>( |
578 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, |
579 | 360 | ref_rhs_qsi4_kxn_size_bytes); | |
580 | |||
581 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
720 | const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( |
582 |
5/10✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 360 times.
✗ Branch 9 not taken.
|
720 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), |
583 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits<float>::lowest(), |
584 | 360 | std::numeric_limits<float>::max()); | |
585 | |||
586 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto m_step = ukernel_variant.interface.get_m_step(); |
587 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
|
360 | ASSERT_TRUE(m_step % mr == 0); |
588 | |||
589 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto n_step = ukernel_variant.interface.get_n_step(); |
590 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
|
360 | ASSERT_TRUE(n_step % nr == 0); |
591 | |||
592 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
720 | const auto rect = portion.compute_portion(M, N, m_step, n_step); |
593 |
4/8✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 360 times.
|
360 | if (rect.height() == 0 || rect.width() == 0) { |
594 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
595 | } | ||
596 | |||
597 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto lhs_start_row = rect.start_row(); |
598 | 360 | size_t lhs_stride = K * sizeof(float); | |
599 | |||
600 | // Runs the LHS packing micro-kernel. | ||
601 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); |
602 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
603 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); |
604 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); |
605 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
606 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 360 times.
✗ Branch 15 not taken.
|
360 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
607 | |||
608 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | kai_run_lhs_quant_pack_qai8dxp_f32( |
609 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, |
610 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, |
611 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | imp_packed_lhs.data() + lhs_packed_offset); |
612 | |||
613 | // Runs the RHS packing micro-kernel. | ||
614 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
615 | // * Packs the RHS matrix. | ||
616 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
720 | const auto ref_rhs_qsi4_padded = pad_row<Int4>( |
617 |
4/8✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
|
360 | ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); |
618 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); |
619 | |||
620 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto rhs_start_row = rect.start_col(); |
621 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr); |
622 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); |
623 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
|
360 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
624 | |||
625 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
626 | 360 | kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{}; | |
627 | 360 | params.lhs_zero_point = 1; | |
628 | 360 | params.rhs_zero_point = 0; | |
629 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
720 | ukernel_variant.run_rhs_pack( |
630 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()), |
631 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | reinterpret_cast<const float*>(ref_biases.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()), |
632 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | imp_packed_rhs.data(), 0, ¶ms); |
633 | |||
634 | 360 | const auto dst_stride = N * sizeof(float); | |
635 |
3/6✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
|
360 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
636 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); |
637 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
|
360 | ASSERT_EQ(dst_offset, ref_dst_offset); |
638 | |||
639 | // Runs the GEMM micro-kernel. | ||
640 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
641 |
5/18✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 360 times.
|
360 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
642 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | Buffer imp_dst(imp_dst_size); |
643 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
720 | ukernel_variant.interface.run_matmul( |
644 |
3/6✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
|
360 | rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, |
645 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
646 | 360 | N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | |
647 | |||
648 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
649 | // tested. | ||
650 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 9792 times.
✓ Branch 2 taken 9432 times.
✓ Branch 3 taken 360 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 360 times.
|
9792 | for (size_t y = 0; y < rect.height(); ++y) { |
651 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 667150 times.
✓ Branch 2 taken 657718 times.
✓ Branch 3 taken 9432 times.
✓ Branch 4 taken 9432 times.
✗ Branch 5 not taken.
|
667150 | for (size_t x = 0; x < rect.width(); ++x) { |
652 | 1315436 | const auto imp_value = | |
653 |
4/8✓ Branch 0 taken 657718 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 657718 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 657718 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 657718 times.
✗ Branch 7 not taken.
|
657718 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
654 | 1315436 | const auto ref_value = | |
655 |
4/8✓ Branch 0 taken 657718 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 657718 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 657718 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 657718 times.
✗ Branch 7 not taken.
|
657718 | read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
656 |
1/2✓ Branch 0 taken 657718 times.
✗ Branch 1 not taken.
|
657718 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; |
657 | |||
658 |
1/2✓ Branch 0 taken 657718 times.
✗ Branch 1 not taken.
|
657718 | if (rel_error > 0.0001F) { |
659 | ✗ | ASSERT_EQ(imp_value, ref_value); | |
660 | ✗ | } | |
661 | 657718 | } | |
662 | 9432 | } | |
663 | 800 | } | |
664 | |||
665 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
2402 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { |
666 | 3040 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
667 | 1600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
668 | |||
669 |
2/4✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 800 times.
✗ Branch 3 not taken.
|
800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
670 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
671 | } | ||
672 |
2/2✓ Branch 0 taken 360 times.
✓ Branch 1 taken 440 times.
|
800 | if (ukernel_variant.rhs_pack_type == RhsPackType::NxK) { |
673 |
3/6✓ Branch 0 taken 440 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 5 not taken.
|
440 | GTEST_SKIP() << "Wrong type. This test for KxN"; |
674 | } | ||
675 | |||
676 | 360 | const uint32_t seed = 0; | |
677 | |||
678 | 720 | const size_t M = matmul_shape.m; | |
679 | 720 | const size_t N = matmul_shape.n; | |
680 | 720 | const size_t K = matmul_shape.k; | |
681 | |||
682 | 360 | const auto mr = ukernel_variant.interface.get_mr(); | |
683 | 360 | const auto nr = ukernel_variant.interface.get_nr(); | |
684 | 360 | const auto kr = ukernel_variant.interface.get_kr(); | |
685 | 360 | const auto sr = ukernel_variant.interface.get_sr(); | |
686 | |||
687 | // Generates input data. | ||
688 | 360 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
689 | |||
690 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | std::uniform_real_distribution<float> dist(-10.0, 1.0); |
691 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | std::mt19937 rnd(seed + 1); |
692 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
1546920 | const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); }); |
693 | |||
694 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto ref_biases = fill_random<float>(N, seed + 2); |
695 | |||
696 | // Transposed(nxk) RHS dimensions | ||
697 | 360 | const size_t ref_rhs_qsi4_nxk_stride = K; | |
698 | |||
699 | // Non-Transposed(kxn) RHS dimensions | ||
700 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2); |
701 | 360 | const size_t ref_rhs_qsi4_kxn_size = K * ref_rhs_qsi4_kxn_stride; | |
702 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(ref_rhs_qsi4_kxn_size, 2); |
703 | |||
704 | // Runs the reference implementation. | ||
705 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
706 | // * Quantizes the RHS matrix using 4-bit symmetric quantization. | ||
707 | // * Performs GEMM. | ||
708 | 1080 | const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
709 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K); |
710 | 1800 | const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] = | |
711 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K); |
712 | |||
713 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto ref_rhs_qsi4 = transpose_with_padding<Int4>( |
714 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, |
715 | 360 | ref_rhs_qsi4_kxn_size_bytes); | |
716 | |||
717 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
720 | const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( |
718 |
5/10✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 360 times.
✗ Branch 9 not taken.
|
720 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), |
719 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits<float>::lowest(), |
720 | 360 | std::numeric_limits<float>::max()); | |
721 | |||
722 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto m_step = ukernel_variant.interface.get_m_step(); |
723 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
|
360 | ASSERT_TRUE(m_step % mr == 0); |
724 | |||
725 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto n_step = ukernel_variant.interface.get_n_step(); |
726 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
|
360 | ASSERT_TRUE(n_step % nr == 0); |
727 | |||
728 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
720 | const auto rect = portion.compute_portion(M, N, m_step, n_step); |
729 |
4/8✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 360 times.
|
360 | if (rect.height() == 0 || rect.width() == 0) { |
730 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
731 | } | ||
732 | |||
733 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto lhs_start_row = rect.start_row(); |
734 | 360 | size_t lhs_stride = K * sizeof(float); | |
735 | |||
736 | // Runs the LHS packing micro-kernel. | ||
737 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); |
738 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
739 | |||
740 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); |
741 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); |
742 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
743 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 360 times.
✗ Branch 15 not taken.
|
360 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
744 | |||
745 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | kai_run_lhs_quant_pack_qai8dxp_f32( |
746 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, |
747 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, |
748 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | imp_packed_lhs.data() + lhs_packed_offset); |
749 | |||
750 | // Runs the RHS packing micro-kernel. | ||
751 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
752 | // * Packs the RHS matrix. | ||
753 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), ref_rhs_qsi4_kxn_size); |
754 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
720 | const auto ref_rhs_qsu4_padded = pad_row<UInt4>( |
755 |
4/8✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
|
360 | ref_rhs_qsu4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); |
756 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); |
757 | |||
758 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto rhs_start_row = rect.start_col(); |
759 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr); |
760 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); |
761 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
|
360 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
762 | |||
763 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
764 | 360 | kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{}; | |
765 | 360 | params.lhs_zero_point = 1; | |
766 | 360 | params.rhs_zero_point = 8; | |
767 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
720 | ukernel_variant.run_rhs_pack( |
768 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_padded.data()), |
769 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | reinterpret_cast<const float*>(ref_biases.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()), |
770 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | imp_packed_rhs.data(), 0, ¶ms); |
771 | |||
772 | 360 | const auto dst_stride = N * sizeof(float); | |
773 |
3/6✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
|
360 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
774 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); |
775 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
|
360 | ASSERT_EQ(dst_offset, ref_dst_offset); |
776 | |||
777 | // Runs the GEMM micro-kernel. | ||
778 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
779 |
5/18✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 360 times.
|
360 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
780 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | Buffer imp_dst(imp_dst_size); |
781 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
720 | ukernel_variant.interface.run_matmul( |
782 |
3/6✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
|
360 | rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, |
783 |
2/4✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
|
360 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
784 | 360 | N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | |
785 | |||
786 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
787 | // tested. | ||
788 |
1/2✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
|
360 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); |
789 | 360 | DataFormat dst_format = DataFormat(DataType::FP32); | |
790 |
3/6✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
|
360 | const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); |
791 |
4/16✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 360 times.
|
360 | ASSERT_TRUE(success); |
792 | 800 | } | |
793 | |||
794 |
15/46✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 6 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 6 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 4800 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✓ Branch 44 taken 4800 times.
✗ Branch 45 not taken.
|
9607 | INSTANTIATE_TEST_SUITE_P( |
795 | MatMul, MatMulTest_f32_qai8dxp_qsi4cxp, | ||
796 | testing::Combine( | ||
797 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.size()), | ||
798 | testing::Values( | ||
799 | MatMulShape{16, 32, 64}, // | ||
800 | MatMulShape{16, 32, 36}, // | ||
801 | MatMulShape{15, 35, 65}, // | ||
802 | MatMulShape{8, 32, 64}, // | ||
803 | MatMulShape{15, 31, 45}, // | ||
804 | MatMulShape{1, 35, 65}, // | ||
805 | MatMulShape{1, 128, 32}, // | ||
806 | MatMulShape{64, 128, 32}, // | ||
807 | MatMulShape{1, 225, 55}, // | ||
808 | MatMulShape{125, 200, 56}), | ||
809 | testing::Values( | ||
810 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
811 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
812 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
813 | MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle | ||
814 | )), | ||
815 | [](const auto& info) { | ||
816 | const auto variant_idx = std::get<0>(info.param); | ||
817 | const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_idx).name}; | ||
818 | const auto shape = std::get<MatMulShape>(info.param); | ||
819 | const auto portion = std::get<2>(info.param); | ||
820 | |||
821 | return test_description(name, shape, portion, true); | ||
822 | }); | ||
823 | |||
824 | } // namespace kai::test | ||
825 |