Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <cstdlib> | ||
13 | #include <functional> | ||
14 | #include <optional> | ||
15 | #include <sstream> | ||
16 | #include <string> | ||
17 | #include <string_view> | ||
18 | #include <tuple> | ||
19 | #include <unordered_map> | ||
20 | |||
21 | #include "kai/kai_common.h" | ||
22 | #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h" | ||
23 | #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" | ||
24 | #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h" | ||
25 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" | ||
26 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h" | ||
27 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h" | ||
28 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" | ||
29 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_interface.h" | ||
30 | #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h" | ||
31 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h" | ||
32 | #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" | ||
33 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" | ||
34 | #include "test/common/buffer.hpp" | ||
35 | #include "test/common/cpu_info.hpp" | ||
36 | #include "test/common/matmul_test_common.hpp" | ||
37 | #include "test/common/matrix_portion.hpp" | ||
38 | #include "test/common/memory.hpp" | ||
39 | #include "test/common/rect.hpp" | ||
40 | #include "test/common/sme.hpp" | ||
41 | #include "test/reference/binary_elementwise.hpp" | ||
42 | #include "test/reference/clamp.hpp" | ||
43 | #include "test/reference/fill.hpp" | ||
44 | #include "test/reference/matmul.hpp" | ||
45 | #include "test/reference/matmul_pack.hpp" | ||
46 | #include "test/reference/quantize.hpp" | ||
47 | #include "test/reference/reduce.hpp" | ||
48 | #include "test/reference/reorder.hpp" | ||
49 | #include "test/reference/transpose.hpp" | ||
50 | |||
51 | namespace kai::test { | ||
52 | |||
53 | // Ensure static linkage for all functionality local to this test file | ||
54 | namespace { | ||
55 | |||
56 | struct KChunk { | ||
57 | size_t count; | ||
58 | size_t length; | ||
59 | }; | ||
60 | |||
61 |
3/6✓ Branch 0 taken 3336 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3336 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3336 times.
✗ Branch 5 not taken.
|
3336 | struct LhsPackKernel { |
62 | std::function<size_t(size_t mr)> get_m_step; | ||
63 | std::function<size_t(size_t m_idx, size_t lhs_stride)> get_lhs_offset; | ||
64 | std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)> get_packed_lhs_offset; | ||
65 | std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)> get_packed_lhs_size; | ||
66 | std::function<void( | ||
67 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
68 | void* lhs_packed)> | ||
69 | pack; | ||
70 | }; | ||
71 | |||
72 |
2/4✓ Branch 0 taken 46624 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 46624 times.
✗ Branch 3 not taken.
|
46624 | struct LhsPackIndirectKernel { |
73 | std::function<size_t()> get_m_step; | ||
74 | std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_packed_lhs_offset; | ||
75 | std::function<size_t(size_t m, size_t k_chunk_count, size_t k_chunk_length)> get_packed_lhs_size; | ||
76 | std::function<void( | ||
77 | size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, | ||
78 | const void* zero, void* packed_lhs)> | ||
79 | pack; | ||
80 | }; | ||
81 | |||
82 |
5/10✓ Branch 0 taken 4176 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4176 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4176 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4176 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4176 times.
✗ Branch 9 not taken.
|
4176 | struct RhsPackKernel { |
83 | std::function<size_t()> get_n_step; | ||
84 | std::function<size_t(size_t n_idx)> get_rhs_offset; | ||
85 | std::function<size_t(size_t n_idx)> get_bias_offset; | ||
86 | std::function<size_t(size_t n_idx)> get_scale_offset; | ||
87 | std::function<size_t(size_t n_idx, size_t k)> get_packed_rhs_offset; | ||
88 | std::function<size_t(size_t n, size_t k)> get_packed_rhs_size; | ||
89 | std::function<void( | ||
90 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||
91 | const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, | ||
92 | const struct kai_rhs_pack_qsi8cx_params* params)> | ||
93 | pack; | ||
94 | }; | ||
95 | |||
96 |
5/10✓ Branch 0 taken 46624 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 46624 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 46624 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 46624 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 46624 times.
✗ Branch 9 not taken.
|
46624 | struct RhsPackIndirectKernel { |
97 | std::function<size_t()> get_n_step; | ||
98 | std::function<size_t(size_t n_idx)> get_rhs_offset; | ||
99 | std::function<size_t(size_t n_idx)> get_bias_offset; | ||
100 | std::function<size_t(size_t n_idx)> get_scale_offset; | ||
101 | std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_packed_rhs_offset; | ||
102 | std::function<size_t(size_t n, size_t k_chunk_count, size_t k_chunk_length)> get_packed_rhs_size; | ||
103 | std::function<void( | ||
104 | size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride, const void* rhs, const void* bias, | ||
105 | const void* scale, void* rhs_packed, const kai_rhs_pack_qsi8cx_params* params)> | ||
106 | pack; | ||
107 | }; | ||
108 | |||
109 |
9/18✓ Branch 0 taken 4176 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4176 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4176 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4176 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4176 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4176 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4176 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4176 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4176 times.
✗ Branch 17 not taken.
|
4176 | struct MatMulKernel { |
110 | std::function<size_t(void)> get_m_step; | ||
111 | std::function<size_t(void)> get_n_step; | ||
112 | std::function<size_t(void)> get_mr; | ||
113 | std::function<size_t(void)> get_nr; | ||
114 | std::function<size_t(void)> get_kr; | ||
115 | std::function<size_t(void)> get_sr; | ||
116 | std::function<size_t(size_t m_idx, size_t k)> get_packed_lhs_offset; | ||
117 | std::function<size_t(size_t n_idx, size_t k)> get_packed_rhs_offset; | ||
118 | std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride)> get_dst_offset; | ||
119 | std::function<size_t(size_t m, size_t n)> get_dst_size; | ||
120 | std::function<void( | ||
121 | size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, | ||
122 | size_t dst_stride_col, const kai_matmul_requantize32_params* params)> | ||
123 | matmul; | ||
124 | }; | ||
125 | |||
126 |
5/10✓ Branch 0 taken 46624 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 46624 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 46624 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 46624 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 46624 times.
✗ Branch 9 not taken.
|
46624 | struct MatMulIndirectKernel { |
127 | std::function<size_t(void)> get_m_step; | ||
128 | std::function<size_t(void)> get_n_step; | ||
129 | std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset; | ||
130 | std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset; | ||
131 | std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride_row)> get_dst_offset; | ||
132 | std::function<size_t(size_t m, size_t n)> get_dst_size; | ||
133 | std::function<void( | ||
134 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_lenght, const void* lhs_packed, const void* rhs_packed, | ||
135 | void* dst, size_t dst_stride_row, const kai_matmul_requantize32_params* params)> | ||
136 | imatmul; | ||
137 | }; | ||
138 | |||
139 | /// Make sure that interface matches for qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa | ||
140 | const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& | ||
141 | 1 | get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface() { | |
142 | static kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel ukernel; | ||
143 | |||
144 | 1 | ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
145 | 1 | ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
146 | 1 | ukernel.get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
147 | 1 | ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
148 | 1 | ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
149 | 1 | ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
150 | 1 | ukernel.get_lhs_packed_offset = | |
151 | kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | ||
152 | 1 | ukernel.get_rhs_packed_offset = | |
153 | kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | ||
154 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
155 | 1 | ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
156 | 1 | ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
157 | |||
158 | 1 | return ukernel; | |
159 | } | ||
160 | |||
161 | /// Make sure that interface matches for qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme_mopa | ||
162 | const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& | ||
163 | 1 | get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface() { | |
164 | static kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel ukernel; | ||
165 | |||
166 | 1 | ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
167 | 1 | ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
168 | 1 | ukernel.get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
169 | 1 | ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
170 | 1 | ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
171 | 1 | ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
172 | 1 | ukernel.get_lhs_packed_offset = | |
173 | kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | ||
174 | 1 | ukernel.get_rhs_packed_offset = | |
175 | kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | ||
176 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
177 | 1 | ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
178 | 1 | ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
179 | |||
180 | 1 | return ukernel; | |
181 | } | ||
182 | |||
183 | /// Make sure that interface matches for qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot | ||
184 | const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& | ||
185 | 1 | get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() { | |
186 | static kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; | ||
187 | |||
188 | 1 | ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
189 | 1 | ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
190 | 1 | ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
191 | 1 | ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
192 | 1 | ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
193 | 1 | ukernel.get_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
194 | 1 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
195 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
196 | 1 | ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
197 | 1 | ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
198 | |||
199 | 1 | return ukernel; | |
200 | }; | ||
201 | |||
202 | /// Make sure that interface matches qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa | ||
203 | const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& | ||
204 | 1 | get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface() { | |
205 | static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; | ||
206 | |||
207 | 1 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
208 | 1 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
209 | 1 | ukernel.get_lhs_packed_offset = | |
210 | kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | ||
211 | 1 | ukernel.get_rhs_packed_offset = | |
212 | kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | ||
213 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
214 | 1 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
215 | 1 | ukernel.run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
216 | |||
217 | 1 | return ukernel; | |
218 | }; | ||
219 | |||
220 | /// Make sure that interface matches qai8_qai8p2vlx4_qsi8cxps2vlx4b_2vlx2vl_sme_mopa | ||
221 | const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& | ||
222 | 1 | get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface() { | |
223 | static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; | ||
224 | |||
225 | 1 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
226 | 1 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
227 | 1 | ukernel.get_lhs_packed_offset = | |
228 | kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | ||
229 | 1 | ukernel.get_rhs_packed_offset = | |
230 | kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | ||
231 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
232 | 1 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
233 | 1 | ukernel.run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
234 | |||
235 | 1 | return ukernel; | |
236 | }; | ||
237 | |||
238 | 3 | const RhsPackKernel& get_rhs_pack() { | |
239 |
3/4✓ Branch 0 taken 1 time.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
3 | static RhsPackKernel ukernel; |
240 | |||
241 | 3 | ukernel.get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
242 | 3 | ukernel.get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
243 | 3 | ukernel.get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
244 | 3 | ukernel.get_scale_offset = kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
245 | 3 | ukernel.get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
246 | 3 | ukernel.get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
247 | 3 | ukernel.pack = kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
248 | |||
249 | 3 | return ukernel; | |
250 | } | ||
251 | |||
252 | 2 | const LhsPackKernel& get_lhs_pack() { | |
253 |
3/4✓ Branch 0 taken 1 time.
✓ Branch 1 taken 1 time.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
2 | static LhsPackKernel ukernel; |
254 | |||
255 | 2 | ukernel.get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme; | |
256 | 2 | ukernel.get_lhs_offset = kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme; | |
257 | 2 | ukernel.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme; | |
258 | 2 | ukernel.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme; | |
259 | 2 | ukernel.pack = kai_run_lhs_pack_x8p2vlx4_x8_sme; | |
260 | |||
261 | 2 | return ukernel; | |
262 | } | ||
263 | |||
264 |
2/4✓ Branch 0 taken 4176 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4176 times.
✗ Branch 3 not taken.
|
4176 | struct MatMulVariant { |
265 | std::string_view name; ///< Test identification | ||
266 | MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr) | ||
267 | MatMulShape acc_step; ///< Accumulator shape for matmul (stepping) | ||
268 | |||
269 | std::function<bool(void)> is_supported; ///< HW support check | ||
270 | |||
271 | std::optional<LhsPackKernel> lhs_pack; ///< LHS packing micro-kernel interface | ||
272 | RhsPackKernel rhs_pack; ///< RHS packing micro-kernel interface | ||
273 | MatMulKernel matmul; ///< Matmul kernel interface | ||
274 | }; | ||
275 | |||
276 |
2/4✓ Branch 0 taken 46624 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 46624 times.
✗ Branch 3 not taken.
|
46624 | struct IndirectMatMulVariant { |
277 | std::string_view name; ///< Test identification | ||
278 | MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr) | ||
279 | MatMulShape acc_step; ///< Accumulator shape for matmul (stepping) | ||
280 | |||
281 | std::function<bool(void)> is_supported; ///< HW support check | ||
282 | |||
283 | LhsPackIndirectKernel lhs_pack; ///< LHS packing micro-kernel interface | ||
284 | RhsPackIndirectKernel rhs_pack; ///< RHS packing micro-kernel interface | ||
285 | MatMulIndirectKernel matmul; ///< Matmul kernel interface | ||
286 | }; | ||
287 | |||
288 | 1 | const std::array<MatMulVariant, 2>& get_gemm_variants() { | |
289 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
1 | static std::array<MatMulVariant, 2> variants; |
290 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
2 | static const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& ukernel_sme2 = |
291 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface(); |
292 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
2 | static const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& ukernel_sme = |
293 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 time.
|
1 | get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface(); |
294 | |||
295 | 1 | variants[0].name = "matmul_qai8_qai8p_qsi8cxp_sme"; | |
296 | 1 | variants[0].acc_pack.m = 2 * get_sme_vector_length<int32_t>(); | |
297 | 1 | variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
298 | 1 | variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
299 | 1 | variants[0].acc_step.m = 2 * get_sme_vector_length<int32_t>(); | |
300 | 1 | variants[0].acc_step.n = 2 * get_sme_vector_length<int32_t>(); | |
301 | 1 | variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
302 | 1 | variants[0].is_supported = cpu_has_sme; | |
303 | 1 | variants[0].lhs_pack = get_lhs_pack(); | |
304 | 1 | variants[0].rhs_pack = get_rhs_pack(); | |
305 | 1 | variants[0].matmul.get_m_step = ukernel_sme.get_m_step; | |
306 | 1 | variants[0].matmul.get_n_step = ukernel_sme.get_n_step; | |
307 | 1 | variants[0].matmul.get_mr = ukernel_sme.get_mr; | |
308 | 1 | variants[0].matmul.get_nr = ukernel_sme.get_nr; | |
309 | 1 | variants[0].matmul.get_kr = ukernel_sme.get_kr; | |
310 | 1 | variants[0].matmul.get_sr = ukernel_sme.get_sr; | |
311 | 1 | variants[0].matmul.get_packed_lhs_offset = ukernel_sme.get_lhs_packed_offset; | |
312 | 1 | variants[0].matmul.get_packed_rhs_offset = ukernel_sme.get_rhs_packed_offset; | |
313 | 1 | variants[0].matmul.get_dst_offset = ukernel_sme.get_dst_offset; | |
314 | 1 | variants[0].matmul.get_dst_size = ukernel_sme.get_dst_size; | |
315 | 1 | variants[0].matmul.matmul = ukernel_sme.run_matmul; | |
316 | |||
317 | 1 | variants[1].name = "matmul_qai8_qai8p_qsi8cxp_sme2"; | |
318 | 1 | variants[1].acc_pack.m = 2 * get_sme_vector_length<int32_t>(); | |
319 | 1 | variants[1].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
320 | 1 | variants[1].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
321 | 1 | variants[1].acc_step.m = 2 * get_sme_vector_length<int32_t>(); | |
322 | 1 | variants[1].acc_step.n = 2 * get_sme_vector_length<int32_t>(); | |
323 | 1 | variants[1].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
324 | 1 | variants[1].is_supported = cpu_has_sme2; | |
325 | 1 | variants[1].lhs_pack = get_lhs_pack(); | |
326 | 1 | variants[1].rhs_pack = get_rhs_pack(); | |
327 | 1 | variants[1].matmul.get_m_step = ukernel_sme2.get_m_step; | |
328 | 1 | variants[1].matmul.get_n_step = ukernel_sme2.get_n_step; | |
329 | 1 | variants[1].matmul.get_mr = ukernel_sme2.get_mr; | |
330 | 1 | variants[1].matmul.get_nr = ukernel_sme2.get_nr; | |
331 | 1 | variants[1].matmul.get_kr = ukernel_sme2.get_kr; | |
332 | 1 | variants[1].matmul.get_sr = ukernel_sme2.get_sr; | |
333 | 1 | variants[1].matmul.get_packed_lhs_offset = ukernel_sme2.get_lhs_packed_offset; | |
334 | 1 | variants[1].matmul.get_packed_rhs_offset = ukernel_sme2.get_rhs_packed_offset; | |
335 | 1 | variants[1].matmul.get_dst_offset = ukernel_sme2.get_dst_offset; | |
336 | 1 | variants[1].matmul.get_dst_size = ukernel_sme2.get_dst_size; | |
337 | 1 | variants[1].matmul.matmul = ukernel_sme2.run_matmul; | |
338 | |||
339 | 1 | return variants; | |
340 | ✗ | } | |
341 | |||
342 | 1 | const std::array<IndirectMatMulVariant, 2>& get_indirect_gemm_variants() { | |
343 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
1 | static std::array<IndirectMatMulVariant, 2> variants; |
344 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
2 | static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel_sme = |
345 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface(); |
346 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
2 | static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel_sme2 = |
347 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 time.
|
1 | get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface(); |
348 | |||
349 | 1 | variants[0].name = "imatmul_qai8_qai8p_qsi8cxp_sme"; | |
350 | 1 | variants[0].acc_pack.m = 2 * get_sme_vector_length<int32_t>(); | |
351 | 1 | variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
352 | 1 | variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
353 | 1 | variants[0].acc_step.m = 2 * get_sme_vector_length<int32_t>(); | |
354 | 1 | variants[0].acc_step.n = 2 * get_sme_vector_length<int32_t>(); | |
355 | 1 | variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
356 | 1 | variants[0].is_supported = cpu_has_sme; | |
357 | 1 | variants[0].lhs_pack.get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
358 | 1 | variants[0].lhs_pack.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
359 | 1 | variants[0].lhs_pack.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
360 | 1 | variants[0].lhs_pack.pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
361 | 1 | variants[0].rhs_pack.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
362 | 1 | variants[0].rhs_pack.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
363 | 1 | variants[0].rhs_pack.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
364 | 1 | variants[0].rhs_pack.get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
365 | 1 | variants[0].rhs_pack.get_packed_rhs_offset = | |
366 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | ||
367 | 1 | variants[0].rhs_pack.get_packed_rhs_size = | |
368 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | ||
369 | 1 | variants[0].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
370 | 1 | variants[0].matmul.get_m_step = ukernel_sme.get_m_step; | |
371 | 1 | variants[0].matmul.get_n_step = ukernel_sme.get_n_step; | |
372 | 1 | variants[0].matmul.get_lhs_packed_offset = ukernel_sme.get_lhs_packed_offset; | |
373 | 1 | variants[0].matmul.get_rhs_packed_offset = ukernel_sme.get_rhs_packed_offset; | |
374 | 1 | variants[0].matmul.get_dst_offset = ukernel_sme.get_dst_offset; | |
375 | 1 | variants[0].matmul.get_dst_size = ukernel_sme.get_dst_size; | |
376 | 1 | variants[0].matmul.imatmul = ukernel_sme.run_imatmul; | |
377 | |||
378 | 1 | variants[1].name = "imatmul_qai8_qai8p_qsi8cxp_sme2"; | |
379 | 1 | variants[1].acc_pack.m = 2 * get_sme_vector_length<int32_t>(); | |
380 | 1 | variants[1].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
381 | 1 | variants[1].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
382 | 1 | variants[1].acc_step.m = 2 * get_sme_vector_length<int32_t>(); | |
383 | 1 | variants[1].acc_step.n = 2 * get_sme_vector_length<int32_t>(); | |
384 | 1 | variants[1].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
385 | 1 | variants[1].is_supported = cpu_has_sme2; | |
386 | 1 | variants[1].lhs_pack.get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
387 | 1 | variants[1].lhs_pack.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
388 | 1 | variants[1].lhs_pack.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
389 | 1 | variants[1].lhs_pack.pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
390 | 1 | variants[1].rhs_pack.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
391 | 1 | variants[1].rhs_pack.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
392 | 1 | variants[1].rhs_pack.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
393 | 1 | variants[1].rhs_pack.get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
394 | 1 | variants[1].rhs_pack.get_packed_rhs_offset = | |
395 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | ||
396 | 1 | variants[1].rhs_pack.get_packed_rhs_size = | |
397 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | ||
398 | 1 | variants[1].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
399 | 1 | variants[1].matmul.get_m_step = ukernel_sme2.get_m_step; | |
400 | 1 | variants[1].matmul.get_n_step = ukernel_sme2.get_n_step; | |
401 | 1 | variants[1].matmul.get_lhs_packed_offset = ukernel_sme2.get_lhs_packed_offset; | |
402 | 1 | variants[1].matmul.get_rhs_packed_offset = ukernel_sme2.get_rhs_packed_offset; | |
403 | 1 | variants[1].matmul.get_dst_offset = ukernel_sme2.get_dst_offset; | |
404 | 1 | variants[1].matmul.get_dst_size = ukernel_sme2.get_dst_size; | |
405 | 1 | variants[1].matmul.imatmul = ukernel_sme2.run_imatmul; | |
406 | |||
407 | 1 | return variants; | |
408 | ✗ | } | |
409 | |||
410 | 1 | const std::array<MatMulVariant, 1>& get_gemv_variants() { | |
411 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
1 | static std::array<MatMulVariant, 1> variants; |
412 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
2 | static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel = |
413 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
1 | get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface(); |
414 | |||
415 | 1 | variants[0].name = "matmul_qai8_qai8_qsi8cxp"; | |
416 | 1 | variants[0].acc_pack.m = 1; | |
417 | 1 | variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
418 | 1 | variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
419 | 1 | variants[0].acc_step.m = 1; | |
420 | 1 | variants[0].acc_step.n = 16 * get_sme_vector_length<int32_t>(); | |
421 | 1 | variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
422 | 1 | variants[0].is_supported = cpu_has_sme2; | |
423 | 1 | variants[0].lhs_pack = std::nullopt; | |
424 | 1 | variants[0].rhs_pack = get_rhs_pack(); | |
425 | 1 | variants[0].matmul.get_m_step = ukernel.get_m_step; | |
426 | 1 | variants[0].matmul.get_n_step = ukernel.get_n_step; | |
427 | 169 | variants[0].matmul.get_mr = []() -> size_t { return 1; }; | |
428 | 1 | variants[0].matmul.get_nr = ukernel.get_nr; | |
429 | 1 | variants[0].matmul.get_kr = ukernel.get_kr; | |
430 | 1 | variants[0].matmul.get_sr = ukernel.get_sr; | |
431 | 1 | variants[0].matmul.get_packed_lhs_offset = nullptr; | |
432 | 1 | variants[0].matmul.get_packed_rhs_offset = ukernel.get_rhs_packed_offset; | |
433 | 1 | variants[0].matmul.get_dst_offset = ukernel.get_dst_offset; | |
434 | 1 | variants[0].matmul.get_dst_size = ukernel.get_dst_size; | |
435 | 1 | variants[0].matmul.matmul = ukernel.run_matmul; | |
436 | |||
437 | 1 | return variants; | |
438 | ✗ | } | |
439 | |||
440 | constexpr uint32_t seed = 0; ///< Random seed used for tests | ||
441 | |||
442 | /// Quantization parameters | ||
443 | struct Quant { | ||
444 | float scale; | ||
445 | int32_t zero_point; | ||
446 | }; | ||
447 | |||
448 | /// Reference test data | ||
449 | struct TestReference { | ||
450 | Range<int8_t> clamp; | ||
451 | |||
452 | Quant qa_lhs; | ||
453 | Quant qa_dst; | ||
454 | |||
455 | Buffer lhs_qai8; | ||
456 | Buffer lhs_qai8_scales; | ||
457 | Buffer lhs_qai8_zero_points; | ||
458 | Buffer lhs_qai8_indirect; | ||
459 | Buffer lhs_qai8_indirect_packed; | ||
460 | Buffer lhs_qai8_indirect_padding; | ||
461 | size_t lhs_qai8_indirect_offset; | ||
462 | |||
463 | Buffer rhs_qsi8; | ||
464 | Buffer rhs_scales; | ||
465 | |||
466 | Buffer bias_qsi32; | ||
467 | |||
468 | Buffer dst_qsi8_clamped; | ||
469 | |||
470 | Buffer packed_lhs; | ||
471 | Buffer packed_rhs; | ||
472 | }; | ||
473 | |||
474 | constexpr int8_t padding_value = 0; | ||
475 | |||
476 | // Functionality for hashing generated test data. | ||
477 | // This is particularly useful for portion testing | ||
478 | // which reuses the exact same data for all portions | ||
479 | struct TestDataId { | ||
480 | MatMulShape shape; | ||
481 | MatMulShape shape_pack; | ||
482 | size_t chunk_len; | ||
483 | bool pad_testing; | ||
484 | float clamp_ratio; | ||
485 | |||
486 | struct Hash { | ||
487 | 11085 | size_t operator()(const TestDataId& id) const { | |
488 | 11085 | return // | |
489 | 22170 | (MatMulShape::Hash{}(id.shape) << 0) ^ // | |
490 | 22170 | (MatMulShape::Hash{}(id.shape_pack) << 1) ^ // | |
491 | 22170 | (std::hash<size_t>{}(id.chunk_len) << 2) ^ // | |
492 | 22170 | (std::hash<bool>{}(id.pad_testing) << 3) ^ // | |
493 | 11085 | (std::hash<float>{}(id.clamp_ratio) << 4); | |
494 | } | ||
495 | }; | ||
496 | |||
497 | private: | ||
498 | 13342 | friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { | |
499 | 13342 | return // | |
500 |
2/2✓ Branch 0 taken 9267 times.
✓ Branch 1 taken 4075 times.
|
13342 | lhs.shape == rhs.shape && // |
501 |
1/2✓ Branch 0 taken 9267 times.
✗ Branch 1 not taken.
|
9267 | lhs.shape_pack == rhs.shape_pack && // |
502 |
2/2✓ Branch 0 taken 9231 times.
✓ Branch 1 taken 36 times.
|
9267 | lhs.chunk_len == rhs.chunk_len && // |
503 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 9231 times.
|
9231 | lhs.pad_testing == rhs.pad_testing && // |
504 | 9231 | lhs.clamp_ratio == rhs.clamp_ratio; | |
505 | } | ||
506 | }; | ||
507 | |||
508 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
509 | 1 | std::unordered_map<TestDataId, TestReference, TestDataId::Hash> g_data; | |
510 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
511 | |||
512 | /// Generate test reference data | ||
513 | 10158 | const TestReference& get_test_reference(const TestDataId& test_data_id) { | |
514 | // ============================================================ | ||
515 | // Generates input and reference output data | ||
516 | // ============================================================ | ||
517 | |||
518 | // Attempt to find test data in cache | ||
519 | 10158 | const auto data_it = g_data.find(test_data_id); | |
520 |
2/2✓ Branch 0 taken 9231 times.
✓ Branch 1 taken 927 times.
|
10158 | if (data_it != g_data.end()) { |
521 | 9231 | return data_it->second; | |
522 | } | ||
523 | |||
524 | 839526 | const auto& [shape, pack_shape, k_chunk_len, pad_testing, clamp_ratio] = test_data_id; | |
525 | |||
526 | // Generates the input data in floating-point. | ||
527 | 2781 | Buffer lhs_f32 = fill_random<float>(shape.m * shape.k, seed); | |
528 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
2781 | const Buffer rhs_f32 = fill_random<float>(shape.k * shape.n, seed); |
529 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
1854 | const Buffer bias_f32 = fill_random<float>(shape.n, seed); |
530 | |||
531 | // Quantizes the input data. | ||
532 | // * LHS: 8-bit asymmetric per-matrix quantization. | ||
533 | // * RHS: 8-bit symmetric per-channel quantization. | ||
534 | // * Bias: 32-bit symmetric per-channel quantization. | ||
535 | // | ||
536 | // Treat entire LHS as one row vector, to calculate one single pair of <scale, zero_point> | ||
537 | 8343 | auto [lhs_qai8, lhs_qai8_scales, lhs_qai8_zero_points] = | |
538 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>( |
539 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | lhs_f32.data(), 1, shape.m * shape.k, shape.m * shape.k); |
540 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
1854 | const auto lhs_scale = read_array<float>(lhs_qai8_scales.data(), 0); |
541 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
1854 | const auto lhs_zero_point = read_array<int32_t>(lhs_qai8_zero_points.data(), 0); |
542 | |||
543 | 2781 | const size_t k_chunk_count = shape.k / k_chunk_len; | |
544 | assert(k_chunk_count * k_chunk_len == shape.k); | ||
545 | |||
546 | // Setup an indirection buffer, where each "row" contains `k_chunk_count` | ||
547 | // pointers to chunks of length `k_chunk_len` in the input_buffer | ||
548 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
1854 | Buffer lhs_qai8_indirect(shape.m * k_chunk_count * sizeof(void*)); |
549 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
1854 | Buffer lhs_padding(k_chunk_len, padding_value); |
550 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | auto* lhs_qai8_indirect_ptr = reinterpret_cast<uint8_t**>(lhs_qai8_indirect.data()); |
551 |
4/4✓ Branch 0 taken 21087 times.
✓ Branch 1 taken 927 times.
✓ Branch 2 taken 21087 times.
✓ Branch 3 taken 927 times.
|
22014 | for (size_t m_i = 0; m_i < shape.m; ++m_i) { |
552 |
2/2✓ Branch 0 taken 271911 times.
✓ Branch 1 taken 21087 times.
|
292998 | for (size_t k_chunk_idx = 0; k_chunk_idx < k_chunk_count; ++k_chunk_idx) { |
553 | 271911 | const size_t idx = m_i * k_chunk_count + k_chunk_idx; | |
554 |
4/4✓ Branch 0 taken 262143 times.
✓ Branch 1 taken 9768 times.
✓ Branch 2 taken 250089 times.
✓ Branch 3 taken 12054 times.
|
271911 | if (pad_testing and m_i == 0) { |
555 | // Push padding pointers for first row | ||
556 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12054 times.
|
12054 | lhs_qai8_indirect_ptr[idx] = reinterpret_cast<uint8_t*>(lhs_padding.data()); |
557 | 12054 | } else { | |
558 | 779571 | uintptr_t offset = m_i * shape.k + k_chunk_idx * k_chunk_len; | |
559 | 259857 | lhs_qai8_indirect_ptr[idx] = reinterpret_cast<uint8_t*>(offset); | |
560 | 259857 | } | |
561 | 271911 | } | |
562 | 21087 | } | |
563 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
1854 | const auto indirection_base = reinterpret_cast<uintptr_t>(lhs_qai8.data()); |
564 | |||
565 | // Reorder indirection pointers to layout the packing micro-kernel expects | ||
566 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | Buffer lhs_qai8_indirect_packed = reorder_block<const void*>( |
567 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | reinterpret_cast<const void*>(lhs_qai8_indirect.data()), shape.m, k_chunk_count, pack_shape.m, 1); |
568 | |||
569 | // Transpose, then quantize symmetrically, then transpose back. This will give one | ||
570 | // quantization value for each column | ||
571 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
927 | const auto rhs_f32_t = transpose<float>(rhs_f32.data(), shape.k, shape.n); |
572 | 6489 | auto [rhs_qsi8_t, rhs_scales] = | |
573 |
4/8✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 927 times.
✗ Branch 7 not taken.
|
927 | quantize_symmetric_per_block_dynamic<float, int8_t, float>(rhs_f32_t.data(), shape.n, shape.k, shape.k); |
574 |
4/8✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 927 times.
✗ Branch 7 not taken.
|
1854 | auto rhs_qsi8 = transpose<int8_t>(rhs_qsi8_t.data(), shape.n, shape.k); |
575 | |||
576 | // Multiply all bias values with the LHS scale | ||
577 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
1854 | const auto bias_scales = mul<float>(&lhs_scale, 1, 1, rhs_scales.data(), 1, shape.n); |
578 | // Calculate quantized bias values, by treating bias as column, and | ||
579 | // scale using RHS scales. This will scale each bias value indiviually | ||
580 | 927 | auto bias_qsi32 = | |
581 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
927 | quantize_symmetric_per_block<float, int32_t, float>(bias_f32.data(), bias_scales.data(), shape.n, 1, 1); |
582 | |||
583 | // Runs the reference implementation of matmul to produce floating-point result. | ||
584 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | const void* const* lhs_iptr = reinterpret_cast<const void* const*>(lhs_qai8_indirect.data()); |
585 | 927 | const auto ref_dst_f32 = | |
586 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | indirect_matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>( |
587 | 2781 | shape.m, shape.n, k_chunk_count, k_chunk_len, // matmul shape | |
588 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | lhs_iptr, indirection_base, lhs_padding.data(), // LHS indirection, offset and padding |
589 | &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point | ||
590 | 1854 | shape.m, shape.k, // LHS quantization window shape | |
591 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
927 | rhs_qsi8_t.data(), rhs_scales.data(), nullptr, // RHS scaling factors |
592 | 927 | 1, shape.k, // RHS quantization window shape | |
593 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
927 | bias_qsi32.data(), bias_scales.data(), nullptr, // Bias, scaling and zero points |
594 | 1 // Bias quantization window shape | ||
595 | ); | ||
596 | |||
597 | // Computes the output quantization information and clamping limits. | ||
598 | // | ||
599 | // To get a realistic value for the output quantization information and clamping limits | ||
600 | // and avoid uncontrolled saturation problem, these information will be calculated | ||
601 | // based on the reference floating-point output. | ||
602 | // | ||
603 | // The clamping limits will be slightly narrower than the actual range of the output | ||
604 | // so that a portion of the output will be clampped. | ||
605 | 3708 | const auto [dst_scales, dst_zero_points] = | |
606 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | compute_asymmetric_per_block_quantization_info<float, int8_t, float, int32_t>( |
607 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | ref_dst_f32.data(), 1, shape.m * shape.n, shape.m * shape.n); |
608 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
1854 | const auto dst_scale = read_array<float>(dst_scales.data(), 0); |
609 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
1854 | const auto dst_zero_point = read_array<int32_t>(dst_zero_points.data(), 0); |
610 | |||
611 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
927 | const auto ref_dst_f32_min = reduce_min<float>(ref_dst_f32.data(), shape.m * shape.n); |
612 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
927 | const auto ref_dst_f32_max = reduce_max<float>(ref_dst_f32.data(), shape.m * shape.n); |
613 | 927 | const auto ref_dst_f32_range = ref_dst_f32_max - ref_dst_f32_min; | |
614 | |||
615 | 1854 | const auto ref_dst_f32_clamp_min = ref_dst_f32_min + ref_dst_f32_range * clamp_ratio / 2; | |
616 | 1854 | const auto ref_dst_f32_clamp_max = ref_dst_f32_max - ref_dst_f32_range * clamp_ratio / 2; | |
617 | 927 | const auto dst_qai8_clamp_min = | |
618 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | quantize_asymmetric<float, int8_t, int32_t>(ref_dst_f32_clamp_min, dst_scale, dst_zero_point); |
619 | 927 | const auto dst_qai8_clamp_max = | |
620 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | quantize_asymmetric<float, int8_t, int32_t>(ref_dst_f32_clamp_max, dst_scale, dst_zero_point); |
621 | |||
622 | // Clamps and quantizes the reference output matrix. | ||
623 | 927 | const auto ref_dst_f32_clamped = | |
624 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
927 | clamp<float>(ref_dst_f32.data(), shape.m * shape.n, ref_dst_f32_clamp_min, ref_dst_f32_clamp_max); |
625 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | auto ref_dst_qsi8_clamped = quantize_asymmetric_per_block<float, int8_t, float, int32_t>( |
626 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, // values, scales, zero point |
627 | 1854 | 1, shape.m * shape.n, // data shape | |
628 | 1854 | shape.m * shape.n // quantization window width | |
629 | ); | ||
630 | |||
631 | // Runs the reference implementation of the packing micro-kernels. | ||
632 | // | ||
633 | // The reference packing micro-kernels cannot be executed earlier | ||
634 | // because we need the reference floating-point output first to have | ||
635 | // the quantization information. | ||
636 |
6/12✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 927 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 927 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 927 times.
✗ Branch 11 not taken.
|
1854 | auto packed_lhs = reorder_block<int8_t>(lhs_qai8.data(), shape.m, shape.k, pack_shape.m, pack_shape.k); |
637 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
1854 | auto packed_rhs = matmul_pack_rhs_nxk_static_quantized<int8_t, float, int32_t>( |
638 |
3/6✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
|
927 | rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k, |
639 | 1854 | pack_shape.n, pack_shape.k); | |
640 | |||
641 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | TestReference& reference = g_data[test_data_id]; |
642 | 927 | reference.clamp.min = dst_qai8_clamp_min; | |
643 | 927 | reference.clamp.max = dst_qai8_clamp_max; | |
644 | 927 | reference.qa_lhs.scale = lhs_scale; | |
645 | 927 | reference.qa_lhs.zero_point = lhs_zero_point; | |
646 | 927 | reference.qa_dst.scale = dst_scale; | |
647 | 927 | reference.qa_dst.zero_point = dst_zero_point; | |
648 | 927 | reference.lhs_qai8 = std::move(lhs_qai8); | |
649 | 927 | reference.lhs_qai8_scales = std::move(lhs_qai8_scales); | |
650 | 927 | reference.lhs_qai8_zero_points = std::move(lhs_qai8_zero_points); | |
651 | 927 | reference.lhs_qai8_indirect = std::move(lhs_qai8_indirect); | |
652 | 927 | reference.lhs_qai8_indirect_packed = std::move(lhs_qai8_indirect_packed); | |
653 | 927 | reference.lhs_qai8_indirect_padding = std::move(lhs_padding); | |
654 | 927 | reference.lhs_qai8_indirect_offset = indirection_base; | |
655 | 927 | reference.rhs_qsi8 = std::move(rhs_qsi8); | |
656 | 927 | reference.rhs_scales = std::move(rhs_scales); | |
657 | 927 | reference.bias_qsi32 = std::move(bias_qsi32); | |
658 | 927 | reference.dst_qsi8_clamped = std::move(ref_dst_qsi8_clamped); | |
659 | 927 | reference.packed_lhs = std::move(packed_lhs); | |
660 | 927 | reference.packed_rhs = std::move(packed_rhs); | |
661 | |||
662 | 927 | return reference; | |
663 | 10158 | } | |
664 | |||
665 | /// Test LHS packing | ||
666 | 666 | void test_lhs_pack( | |
667 | const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { | ||
668 | − | KAI_ASSUME(variant.lhs_pack.has_value()); | |
669 | |||
670 | 1332 | const auto imp_packed_lhs_size = | |
671 | 666 | variant.lhs_pack->get_packed_lhs_size(shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k, 1); | |
672 |
2/12✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 666 times.
|
666 | ASSERT_EQ(imp_packed_lhs_size, reference.packed_lhs.size()); |
673 | |||
674 | 666 | Buffer imp_packed_lhs(imp_packed_lhs_size, 0); | |
675 |
2/4✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 666 times.
✗ Branch 3 not taken.
|
666 | const auto imp_lhs_offset = variant.lhs_pack->get_lhs_offset(output_area.start_row(), shape.k * sizeof(int8_t)); |
676 |
1/2✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
|
666 | const auto imp_packed_lhs_offset = variant.lhs_pack->get_packed_lhs_offset( |
677 |
1/2✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
|
666 | output_area.start_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1); |
678 | |||
679 |
1/2✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
|
666 | variant.lhs_pack->pack( |
680 |
1/2✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
|
666 | output_area.height(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1, 0, |
681 | 666 | reference.lhs_qai8.data() + imp_lhs_offset, shape.k * sizeof(int8_t), | |
682 | 666 | imp_packed_lhs.data() + imp_packed_lhs_offset); | |
683 | |||
684 |
3/4✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 132 times.
✓ Branch 3 taken 534 times.
|
798 | const auto imp_packed_lhs_end_offset = output_area.end_row() < shape.m |
685 |
1/2✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
|
132 | ? variant.lhs_pack->get_packed_lhs_offset( |
686 |
1/2✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
|
132 | output_area.end_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1) |
687 | 534 | : imp_packed_lhs_size; | |
688 | |||
689 | 666 | const auto* imp_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs.data()); | |
690 | 666 | const auto* ref_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(reference.packed_lhs.data()); | |
691 | |||
692 |
4/6✓ Branch 0 taken 680346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 679680 times.
✓ Branch 3 taken 666 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 666 times.
|
680346 | for (size_t i = 0; i < reference.packed_lhs.size(); ++i) { |
693 |
4/4✓ Branch 0 taken 651264 times.
✓ Branch 1 taken 28416 times.
✓ Branch 2 taken 120576 times.
✓ Branch 3 taken 530688 times.
|
679680 | if (i >= imp_packed_lhs_offset && i < imp_packed_lhs_end_offset) { |
694 |
3/14✗ Branch 0 not taken.
✓ Branch 1 taken 530688 times.
✓ Branch 2 taken 530688 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 530688 times.
|
530688 | ASSERT_EQ(imp_packed_lhs_ptr[i], ref_packed_lhs_ptr[i]); |
695 | 530688 | } else { | |
696 |
3/14✗ Branch 0 not taken.
✓ Branch 1 taken 148992 times.
✓ Branch 2 taken 148992 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 148992 times.
|
148992 | ASSERT_EQ(imp_packed_lhs_ptr[i], 0); |
697 | } | ||
698 | 679680 | } | |
699 | 666 | } | |
700 | |||
701 | /// Test RHS packing | ||
702 | 834 | void test_rhs_pack( | |
703 | const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { | ||
704 | 834 | const auto imp_packed_rhs_size = variant.rhs_pack.get_packed_rhs_size(shape.n, shape.k); | |
705 |
2/12✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 834 times.
|
834 | ASSERT_EQ(imp_packed_rhs_size, reference.packed_rhs.size()); |
706 | 834 | Buffer imp_packed_rhs(imp_packed_rhs_size, 0); | |
707 | |||
708 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const auto imp_rhs_offset = variant.rhs_pack.get_rhs_offset(output_area.start_col()); |
709 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const auto imp_bias_offset = variant.rhs_pack.get_bias_offset(output_area.start_col()); |
710 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const auto imp_scale_offset = variant.rhs_pack.get_scale_offset(output_area.start_col()); |
711 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const auto imp_packed_rhs_offset = variant.rhs_pack.get_packed_rhs_offset(output_area.start_col(), shape.k); |
712 | |||
713 | 834 | kai_rhs_pack_qsi8cx_params imp_pack_rhs_params{}; | |
714 | 834 | imp_pack_rhs_params.lhs_zero_point = reference.qa_lhs.zero_point; | |
715 | 834 | imp_pack_rhs_params.scale_multiplier = reference.qa_lhs.scale / reference.qa_dst.scale; | |
716 | |||
717 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
834 | variant.rhs_pack.pack( |
718 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
834 | 1, output_area.width(), shape.k, variant.acc_pack.n, variant.acc_pack.k, 1, shape.n * sizeof(int8_t), |
719 | 834 | reference.rhs_qsi8.data() + imp_rhs_offset, reference.bias_qsi32.data() + imp_bias_offset, | |
720 | 834 | reference.rhs_scales.data() + imp_scale_offset, imp_packed_rhs.data() + imp_packed_rhs_offset, 0, | |
721 | &imp_pack_rhs_params); | ||
722 | |||
723 |
3/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 174 times.
✓ Branch 3 taken 660 times.
|
1008 | const auto imp_packed_rhs_end_offset = output_area.end_col() < shape.n |
724 |
2/4✓ Branch 0 taken 174 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 174 times.
✗ Branch 3 not taken.
|
174 | ? variant.rhs_pack.get_packed_rhs_offset(output_area.end_col(), shape.k) |
725 | 660 | : imp_packed_rhs_size; | |
726 | |||
727 | 834 | size_t mismatches = 0; | |
728 | 834 | const auto* imp_packed_rhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_rhs.data()); | |
729 | 834 | const auto* ref_packed_rhs_ptr = reinterpret_cast<const uint8_t*>(reference.packed_rhs.data()); | |
730 | |||
731 |
2/2✓ Branch 0 taken 3380736 times.
✓ Branch 1 taken 834 times.
|
3381570 | for (size_t i = 0; i < reference.packed_rhs.size(); ++i) { |
732 |
4/4✓ Branch 0 taken 2832384 times.
✓ Branch 1 taken 548352 times.
✓ Branch 2 taken 713088 times.
✓ Branch 3 taken 2119296 times.
|
3380736 | if (i >= imp_packed_rhs_offset && i < imp_packed_rhs_end_offset) { |
733 |
1/2✓ Branch 0 taken 2119296 times.
✗ Branch 1 not taken.
|
2119296 | if (imp_packed_rhs_ptr[i] != ref_packed_rhs_ptr[i]) { |
734 | ✗ | mismatches += 1; | |
735 | ✗ | } | |
736 | 2119296 | } else { | |
737 |
1/2✓ Branch 0 taken 1261440 times.
✗ Branch 1 not taken.
|
1261440 | if (imp_packed_rhs_ptr[i] != 0) { |
738 | ✗ | mismatches += 1; | |
739 | ✗ | } | |
740 | } | ||
741 | 3380736 | } | |
742 |
3/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
834 | ASSERT_EQ(mismatches, 0) << "There are an unexpected amount of mismatches in RHS packing"; |
743 | 834 | } | |
744 | |||
745 | 10158 | void compare_matmul_result( | |
746 | const MatMulShape& shape, const Rect& output_area, const Buffer& actual, const Buffer& reference) { | ||
747 | 10158 | size_t mismatches = 0; | |
748 | 10158 | bool printed_row = false; | |
749 | 10158 | std::ostringstream sstream; | |
750 |
2/2✓ Branch 0 taken 236958 times.
✓ Branch 1 taken 10158 times.
|
247116 | for (size_t m_i = 0; m_i < shape.m; ++m_i) { |
751 |
2/2✓ Branch 0 taken 19177308 times.
✓ Branch 1 taken 236958 times.
|
19414266 | for (size_t n_i = 0; n_i < shape.n; ++n_i) { |
752 | 19177308 | const auto i = m_i * shape.n + n_i; | |
753 |
6/8✓ Branch 0 taken 19177308 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 17235804 times.
✓ Branch 3 taken 1941504 times.
✓ Branch 4 taken 17235804 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 14725092 times.
✓ Branch 7 taken 2510712 times.
|
33902400 | const auto in_area = m_i >= output_area.start_row() && m_i < output_area.end_row() && |
754 |
4/6✓ Branch 0 taken 14725092 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13058724 times.
✓ Branch 3 taken 1666368 times.
✓ Branch 4 taken 13058724 times.
✗ Branch 5 not taken.
|
14725092 | n_i >= output_area.start_col() && n_i < output_area.end_col(); |
755 | |||
756 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 19177308 times.
|
19177308 | const auto imp_value = read_array<int8_t>(actual.data(), i); |
757 |
3/4✓ Branch 0 taken 11567574 times.
✓ Branch 1 taken 7609734 times.
✓ Branch 2 taken 11567574 times.
✗ Branch 3 not taken.
|
19177308 | const auto ref_value = in_area ? read_array<int8_t>(reference.data(), i) : 0; |
758 | 19177308 | const auto error = std::abs(imp_value - ref_value); | |
759 | 19177308 | const auto threshold = in_area ? 1 : 0; | |
760 | 19177308 | const bool mismatch = error > threshold; | |
761 |
1/2✓ Branch 0 taken 19177308 times.
✗ Branch 1 not taken.
|
19177308 | if (mismatch) { |
762 | ✗ | if (not printed_row) { | |
763 | ✗ | sstream << " row=" << m_i << ", columns: "; | |
764 | ✗ | printed_row = true; | |
765 | ✗ | } | |
766 | ✗ | sstream << n_i << ", "; | |
767 | ✗ | } | |
768 | 19177308 | mismatches += static_cast<size_t>(mismatch); | |
769 | 19177308 | } | |
770 |
1/2✓ Branch 0 taken 236958 times.
✗ Branch 1 not taken.
|
236958 | if (printed_row) { |
771 | ✗ | sstream << "\n"; | |
772 | ✗ | } | |
773 | 236958 | printed_row = false; | |
774 | 236958 | } | |
775 |
3/20✓ Branch 0 taken 10158 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 10158 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✓ Branch 19 taken 10158 times.
|
10158 | ASSERT_EQ(mismatches, 0) << "Mismatches between reference result and actual result:\n" << sstream.str(); |
776 | 10158 | } | |
777 | |||
778 | /// Test MatMul of GEMM/GEMV like kernel | ||
779 | 834 | void test_matmul( | |
780 | const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { | ||
781 | 834 | const auto imp_dst_size = variant.matmul.get_dst_size(shape.m, shape.n); | |
782 |
2/12✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 834 times.
|
834 | ASSERT_EQ(imp_dst_size, reference.dst_qsi8_clamped.size()); |
783 | |||
784 | 834 | Buffer imp_dst(imp_dst_size, 0); | |
785 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
3336 | const auto [imp_lhs_offset, lhs_data] = [&]() -> std::tuple<size_t, const Buffer&> { |
786 |
2/2✓ Branch 0 taken 666 times.
✓ Branch 1 taken 168 times.
|
834 | if (variant.lhs_pack.has_value()) { |
787 | 666 | return {variant.matmul.get_packed_lhs_offset(output_area.start_row(), shape.k), reference.packed_lhs}; | |
788 | } | ||
789 | 168 | return {output_area.start_row() * shape.k, reference.lhs_qai8}; | |
790 | 834 | }(); | |
791 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const size_t imp_packed_rhs_offset = variant.matmul.get_packed_rhs_offset(output_area.start_col(), shape.k); |
792 | 1668 | const size_t imp_dst_offset = | |
793 |
3/6✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
|
834 | variant.matmul.get_dst_offset(output_area.start_row(), output_area.start_col(), shape.n * sizeof(int8_t)); |
794 |
5/18✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 834 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 834 times.
|
834 | ASSERT_EQ(imp_dst_offset, output_area.start_row() * shape.n + output_area.start_col()); |
795 | |||
796 | 834 | kai_matmul_requantize32_params imp_main_params{}; | |
797 | 834 | imp_main_params.min_value = reference.clamp.min; | |
798 | 834 | imp_main_params.max_value = reference.clamp.max; | |
799 | 834 | imp_main_params.output_zero_point = reference.qa_dst.zero_point; | |
800 | |||
801 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
1668 | variant.matmul.matmul( |
802 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | output_area.height(), output_area.width(), shape.k, lhs_data.data() + imp_lhs_offset, |
803 | 834 | reference.packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t), | |
804 | sizeof(int8_t), &imp_main_params); | ||
805 | |||
806 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
834 | compare_matmul_result(shape, output_area, imp_dst, reference.dst_qsi8_clamped); |
807 | 834 | } | |
808 | |||
809 | } // namespace | ||
810 | |||
811 | using MatMulQuantizedTest = testing::TestWithParam<std::tuple<MatMulVariant, MatMulShape, MatrixPortion, float>>; | ||
812 | using IndirectMatMulQuantizedTest = | ||
813 | testing::TestWithParam<std::tuple<IndirectMatMulVariant, MatMulShape, MatrixPortion, size_t, float>>; | ||
814 | |||
815 | 834 | static std::string test_description( | |
816 | const MatMulVariant& variant, // | ||
817 | const MatMulShape& shape, // | ||
818 | const MatrixPortion& portion, float clamp_ratio) { | ||
819 | 834 | std::ostringstream sstream; | |
820 | |||
821 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
1668 | sstream << test_description(variant.name, shape, portion, true) // |
822 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | << "__clamp_ratio_" << static_cast<int>(clamp_ratio * 100); |
823 | |||
824 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
834 | return sstream.str(); |
825 | 834 | }; | |
826 | |||
827 | 9324 | static std::string test_description( | |
828 | const IndirectMatMulVariant& variant, // | ||
829 | const MatMulShape& shape, // | ||
830 | const MatrixPortion& portion, size_t k_chunk_len, float clamp_ratio) { | ||
831 | 9324 | std::ostringstream sstream; | |
832 | |||
833 |
8/16✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9324 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9324 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 9324 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9324 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 9324 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 9324 times.
✗ Branch 15 not taken.
|
9324 | sstream << variant.name << "__M_" << shape.m << "__N_" << shape.n << "__k_chunk_count_" << shape.k << "__"; |
834 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
9324 | PrintTo(portion, &sstream); |
835 |
4/8✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9324 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9324 times.
✗ Branch 7 not taken.
|
9324 | sstream << "__k_chunk_len_" << k_chunk_len << "__clamp_ratio_" << static_cast<int>(clamp_ratio * 100); |
836 | |||
837 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
9324 | return sstream.str(); |
838 | 9324 | }; | |
839 | |||
840 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
2504 | TEST_P(MatMulQuantizedTest, EndToEnd) { |
841 | 13344 | const auto& [variant, shape, output_portion, clamp_ratio] = GetParam(); | |
842 | |||
843 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
834 | if (!variant.is_supported()) { |
844 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
845 | } | ||
846 | |||
847 | 4170 | TestDataId test_data_id{shape, variant.acc_pack, shape.k, false, clamp_ratio}; | |
848 | 834 | const TestReference& reference = get_test_reference(test_data_id); | |
849 | |||
850 | // Check scheduling parameters | ||
851 | 1668 | const auto imp_mr = variant.matmul.get_mr(); | |
852 | 1668 | const auto imp_nr = variant.matmul.get_nr(); | |
853 | 1668 | const auto imp_kr = variant.matmul.get_kr(); | |
854 | 1668 | const auto imp_sr = variant.matmul.get_sr(); | |
855 | |||
856 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_mr, variant.acc_pack.m); |
857 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_nr, variant.acc_pack.n); |
858 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_kr, variant.acc_pack.k); |
859 |
3/14✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 834 times.
|
834 | ASSERT_EQ(imp_sr, 1); |
860 | |||
861 | // Check that stepping is a multiple of accumulation | ||
862 | 1668 | const auto imp_m_step = variant.matmul.get_m_step(); | |
863 | 1668 | const auto imp_n_step = variant.matmul.get_n_step(); | |
864 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_m_step, variant.acc_step.m); |
865 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_n_step, variant.acc_step.n); |
866 | |||
867 | // Test kernels. Note that packing and actual stepping might not be the same | ||
868 | 4170 | const auto pack_portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_pack.m, variant.acc_pack.n); | |
869 | 834 | const auto matmul_portion = | |
870 | 3336 | output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); | |
871 |
2/2✓ Branch 0 taken 168 times.
✓ Branch 1 taken 666 times.
|
834 | if (variant.lhs_pack.has_value()) { |
872 | 666 | test_lhs_pack(shape, variant, pack_portion, reference); | |
873 | 666 | } | |
874 | 834 | test_rhs_pack(shape, variant, pack_portion, reference); | |
875 | 834 | test_matmul(shape, variant, matmul_portion, reference); | |
876 | 834 | } | |
877 | |||
878 | namespace imatmul { | ||
879 | |||
880 | /// Perform LHS IMATMUL packing | ||
881 | 9324 | static Buffer lhs_pack( | |
882 | const LhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t m, | ||
883 | const KChunk& k_chunk) { | ||
884 | 18648 | const void* const* indirection_pointer = | |
885 | 9324 | reinterpret_cast<const void* const*>(reference.lhs_qai8_indirect_packed.data()); | |
886 | |||
887 | // Allocate buffer | ||
888 | 9324 | const size_t dst_size = variant.get_packed_lhs_size(m, k_chunk.count, k_chunk.length); | |
889 | 9324 | Buffer packed(dst_size); | |
890 | |||
891 | // Calculate offsets | ||
892 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
9324 | const size_t input_offset = portion.start_row() * k_chunk.count; |
893 |
2/4✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
|
9324 | const size_t dst_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); |
894 | |||
895 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
9324 | variant.pack( |
896 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
9324 | portion.height(), k_chunk.count, k_chunk.length, // Dimensions |
897 | 9324 | indirection_pointer + input_offset, // Indirection input | |
898 | 9324 | reference.lhs_qai8_indirect_offset, // chunk offset | |
899 | 9324 | reference.lhs_qai8_indirect_padding.data(), // padding pointer | |
900 | 9324 | packed.data() + dst_offset); | |
901 | |||
902 | 9324 | return packed; | |
903 | 9324 | } | |
904 | |||
905 | /// Perform RHS IMATMUL packing | ||
906 | 9324 | static Buffer rhs_pack( | |
907 | const RhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t n, | ||
908 | const KChunk& k_chunk) { | ||
909 | // Allocate output buffer | ||
910 | 9324 | const size_t dst_size = variant.get_packed_rhs_size(n, k_chunk.count, k_chunk.length); | |
911 | 9324 | Buffer packed(dst_size); | |
912 | |||
913 | // Caluclate effective quantization parameters | ||
914 | 27972 | const kai_rhs_pack_qsi8cx_params quantization{ | |
915 | 9324 | reference.qa_lhs.zero_point, | |
916 | 9324 | reference.qa_lhs.scale / reference.qa_dst.scale, | |
917 | }; | ||
918 | |||
919 | // Calculate offsets | ||
920 |
2/4✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
|
9324 | const size_t rhs_offset = variant.get_rhs_offset(portion.start_col()); |
921 |
2/4✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
|
9324 | const size_t bias_offset = variant.get_bias_offset(portion.start_col()); |
922 |
2/4✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
|
9324 | const size_t scale_offset = variant.get_scale_offset(portion.start_col()); |
923 |
2/4✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
|
9324 | const size_t dst_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); |
924 | |||
925 | // Pack | ||
926 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
9324 | variant.pack( |
927 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
9324 | portion.width(), k_chunk.count, k_chunk.length, |
928 | 9324 | n * sizeof(uint8_t), // Dimensions, row stride | |
929 | 9324 | reference.rhs_qsi8.data() + rhs_offset, // RHS matrix | |
930 | 9324 | reference.bias_qsi32.data() + bias_offset, // Bias | |
931 | 9324 | reference.rhs_scales.data() + scale_offset, // Scales | |
932 | 9324 | packed.data() + dst_offset, // Output | |
933 | &quantization); | ||
934 | |||
935 | 9324 | return packed; | |
936 | 9324 | } | |
937 | |||
938 | /// Calculate the matmul result from IMATMUL kernels | ||
939 | 9324 | static Buffer matmul( | |
940 | const MatMulIndirectKernel& variant, const Rect& portion, const TestReference& reference, const Buffer& packed_lhs, | ||
941 | const Buffer& packed_rhs, const MatMulShape& shape, const KChunk& k_chunk) { | ||
942 | // Calculate portion offsets. | ||
943 | 9324 | size_t dst_offset = variant.get_dst_offset(portion.start_row(), portion.start_col(), shape.n); | |
944 | 9324 | size_t lhs_offset = variant.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); | |
945 | 9324 | size_t rhs_offset = variant.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); | |
946 | |||
947 | // Allocate output buffer | ||
948 | 9324 | const size_t dst_size = variant.get_dst_size(shape.m, shape.n); | |
949 | 9324 | Buffer dst(dst_size, 0); | |
950 | |||
951 | // Calculate geffective uantization parameters | ||
952 | 9324 | kai_matmul_requantize32_params requantization{}; | |
953 | 9324 | requantization.min_value = reference.clamp.min; | |
954 | 9324 | requantization.max_value = reference.clamp.max; | |
955 | 9324 | requantization.output_zero_point = reference.qa_dst.zero_point; | |
956 | |||
957 | // Call matmul kernel | ||
958 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
18648 | variant.imatmul( |
959 |
2/4✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
|
9324 | portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions |
960 | 9324 | packed_lhs.data() + lhs_offset, // LHS | |
961 | 9324 | packed_rhs.data() + rhs_offset, // RHS | |
962 | 9324 | dst.data() + dst_offset, // DST | |
963 | 9324 | shape.n * sizeof(uint8_t), &requantization); | |
964 | |||
965 | 9324 | return dst; | |
966 | 9324 | } | |
967 | } // namespace imatmul | ||
968 | |||
969 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
27974 | TEST_P(IndirectMatMulQuantizedTest, EndToEnd) { |
970 | /* This is a bit special, as shape.k must be k_chunk_len * k_chunk_count | ||
971 | * so instead of inventing a new special kind of shape, simply multiply | ||
972 | * with `k_chunk_len` here */ | ||
973 | 55944 | const auto& [variant, shape_k_chunk, output_portion, k_chunk_len, clamp_ratio] = GetParam(); | |
974 | 27972 | const KChunk k_chunk{shape_k_chunk.k, k_chunk_len}; | |
975 | 27972 | MatMulShape shape{shape_k_chunk.m, shape_k_chunk.n, k_chunk.count * k_chunk.length}; | |
976 | |||
977 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
9324 | if (!variant.is_supported()) { |
978 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
979 | } | ||
980 | |||
981 | // Toggle padding testst when LHS has more than one row | ||
982 | 27972 | TestDataId test_data_id{shape, variant.acc_pack, k_chunk.length, shape.m > 1, clamp_ratio}; | |
983 | 9324 | const TestReference& reference = get_test_reference(test_data_id); | |
984 | 37296 | const Rect portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); | |
985 | |||
986 | 18648 | Buffer packed_lhs = imatmul::lhs_pack(variant.lhs_pack, portion, reference, shape.m, k_chunk); | |
987 |
2/4✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
|
18648 | Buffer packed_rhs = imatmul::rhs_pack(variant.rhs_pack, portion, reference, shape.n, k_chunk); |
988 |
2/4✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
|
18648 | Buffer impl_result = imatmul::matmul(variant.matmul, portion, reference, packed_lhs, packed_rhs, shape, k_chunk); |
989 |
1/2✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
|
9324 | compare_matmul_result(shape, portion, impl_result, reference.dst_qsi8_clamped); |
990 | 9324 | } | |
991 | |||
992 | static constexpr std::array shapes{ | ||
993 | // clang-format off | ||
994 | MatMulShape{ 1, 1, 1}, | ||
995 | MatMulShape{ 1, 16, 4}, | ||
996 | MatMulShape{ 1, 16, 16}, | ||
997 | MatMulShape{ 1, 17, 4}, | ||
998 | MatMulShape{ 1, 19, 24}, | ||
999 | MatMulShape{ 1, 32, 4}, | ||
1000 | MatMulShape{ 1, 32, 32}, | ||
1001 | MatMulShape{ 1, 33,200}, | ||
1002 | MatMulShape{ 1, 49, 21}, | ||
1003 | MatMulShape{ 1, 64, 4}, | ||
1004 | MatMulShape{ 1, 65, 4}, | ||
1005 | MatMulShape{ 1, 300, 10}, | ||
1006 | MatMulShape{ 1, 512, 4}, | ||
1007 | MatMulShape{ 1, 1523, 10}, | ||
1008 | MatMulShape{ 2, 195, 50}, | ||
1009 | MatMulShape{ 3, 6, 6}, | ||
1010 | MatMulShape{ 3, 28, 25}, | ||
1011 | MatMulShape{ 3, 184,177}, | ||
1012 | MatMulShape{ 4, 16, 27}, | ||
1013 | MatMulShape{ 5, 136, 23}, | ||
1014 | MatMulShape{ 6, 18, 31}, | ||
1015 | MatMulShape{ 6, 28, 1}, | ||
1016 | MatMulShape{ 6, 29, 24}, | ||
1017 | MatMulShape{ 16, 16, 4}, | ||
1018 | MatMulShape{ 20, 30, 40}, | ||
1019 | MatMulShape{ 23, 1, 43}, | ||
1020 | MatMulShape{ 32, 14, 1}, | ||
1021 | MatMulShape{ 32, 16, 27}, | ||
1022 | MatMulShape{ 32, 32, 3}, | ||
1023 | MatMulShape{ 32, 32, 4}, | ||
1024 | MatMulShape{ 33, 29, 24}, | ||
1025 | MatMulShape{ 64, 64, 3}, | ||
1026 | MatMulShape{ 64, 64, 4}, | ||
1027 | MatMulShape{ 96, 96, 3}, | ||
1028 | MatMulShape{123, 85, 45}, | ||
1029 | MatMulShape{128, 128, 3}, | ||
1030 | MatMulShape{130, 130, 6}, | ||
1031 | // clang-format on | ||
1032 | }; | ||
1033 | |||
1034 |
14/44✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 666 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
|
1334 | INSTANTIATE_TEST_SUITE_P( |
1035 | matmul_clamp_qai8_qai8p_qsi8cxp, MatMulQuantizedTest, | ||
1036 | testing::Combine( | ||
1037 | testing::ValuesIn(get_gemm_variants()), // | ||
1038 | testing::ValuesIn(shapes), // | ||
1039 | testing::ValuesIn({ | ||
1040 | // clang-format off | ||
1041 | MatrixPortion( 0, 0, 1, 1), // Full matrix. | ||
1042 | MatrixPortion( 0, 0, 0.25, 0.25), // Top-left corner. | ||
1043 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
1044 | // clang-format on | ||
1045 | }), | ||
1046 | testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})), | ||
1047 | [](const auto& info) -> std::string { | ||
1048 | return test_description( | ||
1049 | std::get<MatMulVariant>(info.param), // | ||
1050 | std::get<MatMulShape>(info.param), // | ||
1051 | std::get<MatrixPortion>(info.param), // | ||
1052 | std::get<float>(info.param)); | ||
1053 | }); | ||
1054 | |||
1055 |
15/48✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 168 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
|
338 | INSTANTIATE_TEST_SUITE_P( |
1056 | matmul_clamp_qai8_qai8_qsi8cxp, MatMulQuantizedTest, | ||
1057 | testing::Combine( | ||
1058 | testing::ValuesIn(get_gemv_variants()), | ||
1059 | testing::ValuesIn({ | ||
1060 | // clang-format off | ||
1061 | MatMulShape{ 1, 1, 1}, | ||
1062 | MatMulShape{ 1, 16, 4}, | ||
1063 | MatMulShape{ 1, 16, 16}, | ||
1064 | MatMulShape{ 1, 17, 4}, | ||
1065 | MatMulShape{ 1, 19, 24}, | ||
1066 | MatMulShape{ 1, 32, 4}, | ||
1067 | MatMulShape{ 1, 32, 32}, | ||
1068 | MatMulShape{ 1, 33,200}, | ||
1069 | MatMulShape{ 1, 49, 21}, | ||
1070 | MatMulShape{ 1, 64, 4}, | ||
1071 | MatMulShape{ 1, 65, 4}, | ||
1072 | MatMulShape{ 1, 300, 10}, | ||
1073 | MatMulShape{ 1, 512, 4}, | ||
1074 | MatMulShape{ 1, 1523, 10}, | ||
1075 | // clang-format on | ||
1076 | }), | ||
1077 | testing::ValuesIn({ | ||
1078 | // clang-format off | ||
1079 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
1080 | MatrixPortion(0, .5, 1, .5), // Right half | ||
1081 | MatrixPortion(0, 0, 1, .5), // Left half | ||
1082 | MatrixPortion(0, .25, 1, .5) // Middle half | ||
1083 | // clang-format on | ||
1084 | }), | ||
1085 | // Clamp range | ||
1086 | testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})), | ||
1087 | [](const auto& info) -> std::string { | ||
1088 | return test_description( | ||
1089 | std::get<MatMulVariant>(info.param), // | ||
1090 | std::get<MatMulShape>(info.param), // | ||
1091 | std::get<MatrixPortion>(info.param), // | ||
1092 | std::get<float>(info.param)); | ||
1093 | }); | ||
1094 | |||
1095 |
18/60✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 time.
✗ Branch 33 not taken.
✓ Branch 34 taken 9324 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
|
18650 | INSTANTIATE_TEST_SUITE_P( |
1096 | indirect_matmul_clamp_qai8_qai8p_qsi8cxp, IndirectMatMulQuantizedTest, | ||
1097 | testing::Combine( | ||
1098 | testing::ValuesIn(get_indirect_gemm_variants()), // | ||
1099 | testing::ValuesIn(shapes), // | ||
1100 | testing::ValuesIn({ | ||
1101 | // clang-format off | ||
1102 | // (Start row , start col , height , width) | ||
1103 | MatrixPortion( 0 , 0 , 1 , 1) , // Full matrix. | ||
1104 | MatrixPortion( 0 , 0 , 1 , 0.5) , // Left half | ||
1105 | MatrixPortion( 0 , 0 , 0.5 , 1) , // Upper half | ||
1106 | MatrixPortion( 0 , 0.5 , 1 , 0.5) , // Right half | ||
1107 | MatrixPortion( 0.5 , 0 , 0.5 , 1) , // Bottom half | ||
1108 | MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3) , // Center ninth | ||
1109 | // clang-format on | ||
1110 | }), | ||
1111 | // k_chunk_len | ||
1112 | testing::ValuesIn(std::initializer_list<size_t>{1, 2, 3, 4, 8, 11, 32}), | ||
1113 | // Clamp range | ||
1114 | testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})), | ||
1115 | [](const auto& info) -> std::string { | ||
1116 | return test_description( | ||
1117 | std::get<IndirectMatMulVariant>(info.param), // | ||
1118 | std::get<MatMulShape>(info.param), // | ||
1119 | std::get<MatrixPortion>(info.param), // | ||
1120 | std::get<size_t>(info.param), // | ||
1121 | std::get<float>(info.param)); | ||
1122 | }); | ||
1123 | } // namespace kai::test | ||
1124 |