Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <cstdlib> | ||
13 | #include <sstream> | ||
14 | #include <string> | ||
15 | #include <tuple> | ||
16 | |||
17 | #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h" | ||
18 | #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h" | ||
19 | #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h" | ||
20 | #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.h" | ||
21 | #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h" | ||
22 | #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h" | ||
23 | #include "kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p_qai4c32p_interface.h" | ||
24 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f16_neon.h" | ||
25 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.h" | ||
26 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.h" | ||
27 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon.h" | ||
28 | #include "test/common/buffer.hpp" | ||
29 | #include "test/common/compare.hpp" | ||
30 | #include "test/common/cpu_info.hpp" | ||
31 | #include "test/common/data_format.hpp" | ||
32 | #include "test/common/float16.hpp" | ||
33 | #include "test/common/int4.hpp" | ||
34 | #include "test/common/matmul_test_common.hpp" | ||
35 | #include "test/common/matrix_portion.hpp" | ||
36 | #include "test/common/memory.hpp" | ||
37 | #include "test/common/round.hpp" | ||
38 | #include "test/common/test_suite.hpp" | ||
39 | #include "test/reference/cast.hpp" | ||
40 | #include "test/reference/clamp.hpp" | ||
41 | #include "test/reference/fill.hpp" | ||
42 | #include "test/reference/matmul.hpp" | ||
43 | #include "test/reference/pack.hpp" | ||
44 | #include "test/reference/quantize.hpp" | ||
45 | |||
46 | namespace kai::test { | ||
47 | |||
48 | // Interface for the LHS and RHS packed size and packing micro-kernels | ||
49 | using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f16_neon); | ||
50 | using kai_get_rhs_packed_size_func_t = | ||
51 | decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); | ||
52 | using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f16_neon); | ||
53 | using kai_get_rhs_packed_offset_func_t = | ||
54 | decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); | ||
55 | using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f16_neon); | ||
56 | using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); | ||
57 | using kai_run_lhs_pack_func_f16_t = decltype(&kai_run_lhs_quant_pack_qsi8d32pscalef32_f16_neon); | ||
58 | using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); | ||
59 | |||
60 | // Micro-kernel interface | ||
61 | struct kai_qai4c32p_pack_functions { | ||
62 | kai_get_rhs_packed_size_func_t packed_size; | ||
63 | kai_get_rhs_packed_offset_func_t get_packed_offset; | ||
64 | kai_get_rhs_offset_func_t get_offset; | ||
65 | kai_run_rhs_pack_func_t run_pack; | ||
66 | }; | ||
67 | |||
68 | struct kai_qsi8d32p_f16_pack_functions { | ||
69 | kai_get_lhs_packed_size_func_t packed_size; | ||
70 | kai_get_lhs_packed_offset_func_t get_packed_offset; | ||
71 | kai_get_lhs_offset_func_t get_offset; | ||
72 | kai_run_lhs_pack_func_f16_t run_pack; | ||
73 | }; | ||
74 | |||
75 | static const std::array< | ||
76 | UkernelMatmulPackVariant< | ||
77 | kai_matmul_clamp_f16_qsi8d32p_qai4c32p_ukernel, kai_qsi8d32p_f16_pack_functions, kai_qai4c32p_pack_functions>, | ||
78 | 8> | ||
79 | variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p = { | ||
80 | {UKERNEL_MATMUL_PACK_VARIANT( | ||
81 | clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod, cpu_has_dotprod, | ||
82 | lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), | ||
83 | UKERNEL_MATMUL_PACK_VARIANT( | ||
84 | clamp_f16_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32pscalef32_f16_neon, | ||
85 | rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), | ||
86 | UKERNEL_MATMUL_PACK_VARIANT( | ||
87 | clamp_f16_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod, cpu_has_dotprod, | ||
88 | lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), | ||
89 | UKERNEL_MATMUL_PACK_VARIANT( | ||
90 | clamp_f16_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod, | ||
91 | lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), | ||
92 | UKERNEL_MATMUL_PACK_VARIANT( | ||
93 | clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f16_neon, | ||
94 | rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon, false), | ||
95 | UKERNEL_MATMUL_PACK_VARIANT( | ||
96 | clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, | ||
97 | lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon, | ||
98 | false), | ||
99 | UKERNEL_MATMUL_PACK_VARIANT( | ||
100 | clamp_f16_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f16_neon, | ||
101 | rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon, true), | ||
102 | UKERNEL_MATMUL_PACK_VARIANT( | ||
103 | clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, | ||
104 | lhs_quant_pack_qsi8d32pscalef32_f16_neon, rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon, | ||
105 | true)}}; | ||
106 | |||
107 | static const auto test_matmul_shapes = testing::Values( | ||
108 | MatMulShape{1, 64, 32}, // | ||
109 | MatMulShape{1, 63, 32}, // | ||
110 | MatMulShape{1, 65, 32}, // | ||
111 | MatMulShape{1, 64, 64}, // | ||
112 | MatMulShape{1, 64, 128}, // | ||
113 | MatMulShape{1, 128, 32}, // | ||
114 | MatMulShape{1, 128, 128}, // | ||
115 | MatMulShape{1, 2, 32}, // | ||
116 | MatMulShape{1, 3, 32}, // | ||
117 | MatMulShape{1, 4, 32}, // | ||
118 | MatMulShape{1, 5, 32}, // | ||
119 | MatMulShape{3, 3, 32}, // | ||
120 | MatMulShape{4, 4, 32}, // | ||
121 | MatMulShape{5, 5, 32}, // | ||
122 | MatMulShape{32, 128, 32}, // | ||
123 | MatMulShape{15, 64, 64}, // | ||
124 | MatMulShape{17, 64, 64}, // | ||
125 | MatMulShape{16, 63, 64}, // | ||
126 | MatMulShape{16, 64, 64}, // | ||
127 | MatMulShape{16, 65, 64}, // | ||
128 | MatMulShape{32, 64, 64}, // | ||
129 | MatMulShape{16, 32, 64}, // | ||
130 | MatMulShape{8, 32, 64}, // | ||
131 | MatMulShape{15, 32, 32}, // | ||
132 | MatMulShape{77, 99, 64} // | ||
133 | ); | ||
134 | |||
135 | static const auto test_portions = testing::Values( | ||
136 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
137 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
138 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
139 | MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle | ||
140 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
141 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
142 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
143 | ); | ||
144 | |||
145 | static const auto test_block_lengths = testing::Values(32, 64); | ||
146 | |||
147 | // Executes the LHS packing micro-kernel. | ||
148 | 10904 | static inline Buffer pack_lhs_qsi8d32p_f16( | |
149 | const kai_qsi8d32p_f16_pack_functions& pack_interface, size_t M, size_t K, size_t bl, size_t mr, size_t kr, | ||
150 | size_t sr, const Buffer& lhs_f16, size_t stride, size_t rect_start_row, size_t rect_height) { | ||
151 | 10904 | const auto imp_packed_lhs_size = pack_interface.packed_size(M, K, bl, mr, kr, sr); | |
152 | 10904 | Buffer imp_packed_lhs(imp_packed_lhs_size, 0); | |
153 | |||
154 |
1/2✓ Branch 0 taken 10904 times.
✗ Branch 1 not taken.
|
10904 | auto lhs_offset = pack_interface.get_offset(rect_start_row, stride); |
155 |
1/2✓ Branch 0 taken 10904 times.
✗ Branch 1 not taken.
|
10904 | auto lhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, bl, mr, kr, sr); |
156 | |||
157 |
2/4✓ Branch 0 taken 10904 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 10904 times.
✗ Branch 3 not taken.
|
21808 | pack_interface.run_pack( |
158 | 10904 | rect_height, K, bl, mr, kr, sr, 0, reinterpret_cast<const uint8_t*>(lhs_f16.data() + lhs_offset), stride, | |
159 | 10904 | imp_packed_lhs.data() + lhs_packed_offset); | |
160 | |||
161 | 10904 | return (imp_packed_lhs); | |
162 | 10904 | } | |
163 | |||
164 | // Executes the RHS packing micro-kernel. | ||
165 | 2616 | static inline Buffer pack_rhs_qai4c32p( | |
166 | const kai_qai4c32p_pack_functions& pack_interface, size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr, | ||
167 | const Buffer& rhs_values_qai4, const bool has_bias, const Buffer& biases, const Buffer& rhs_scales, | ||
168 | const Buffer& rhs_zp, bool s0s1_input) { | ||
169 | // Cast to unsigned int | ||
170 | 2616 | auto rhs_qau4s1s0 = cast_qsu4_qsi4(rhs_values_qai4.data(), N * K); | |
171 | |||
172 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | const auto imp_packed_rhs_size = pack_interface.packed_size(N, K, nr, kr, bl); |
173 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
174 | |||
175 | // Runs the RHS packing micro-kernel. | ||
176 | 2616 | kai_rhs_pack_nxk_qai4c32p_params params{}; | |
177 | 2616 | params.lhs_zero_point = 1; | |
178 | 2616 | params.rhs_zero_point = 8; | |
179 | |||
180 |
5/10✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 654 times.
✓ Branch 3 taken 1962 times.
✓ Branch 4 taken 654 times.
✓ Branch 5 taken 1962 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
5232 | pack_interface.run_pack( |
181 | 2616 | 1, N, K, nr, kr, sr, bl, | |
182 |
3/4✓ Branch 0 taken 1962 times.
✓ Branch 1 taken 654 times.
✓ Branch 2 taken 1962 times.
✗ Branch 3 not taken.
|
2616 | reinterpret_cast<const uint8_t*>(s0s1_input ? convert_s0s1_s1s0(rhs_qau4s1s0).data() : rhs_qau4s1s0.data()), |
183 |
2/2✓ Branch 0 taken 1308 times.
✓ Branch 1 taken 1308 times.
|
2616 | rhs_zp.data(), has_bias ? biases.data() : nullptr, rhs_scales.data(), imp_packed_rhs.data(), 0, ¶ms); |
184 | |||
185 | 2616 | return (imp_packed_rhs); | |
186 | 2616 | } | |
187 | |||
188 | class MatMulTest_f16_qsi8d32p_qai4c32p : public ::testing::TestWithParam<MatMulTestPortionedParamsWithBias_WithBL> {}; | ||
189 | |||
190 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
16802 | TEST_P(MatMulTest_f16_qsi8d32p_qai4c32p, LhsPackedWithSameBlockdepth) { |
191 | // Verify LHS quant and pack int8 kernel behaves same for int4 and int8 matmul kernels, | ||
192 | // when the block-depth is same for different values of kr, sr. | ||
193 | |||
194 | 4801776 | const auto& [variant_index, matmul_shape, bl, portion, has_bias] = GetParam(); | |
195 | 11200 | const auto& ukernel_variant = variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.at(variant_index); | |
196 | |||
197 |
2/4✓ Branch 0 taken 5600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5600 times.
✗ Branch 3 not taken.
|
5600 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
198 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
199 | } | ||
200 | |||
201 | 5600 | const std::uint32_t seed = 0; | |
202 | |||
203 | 11200 | const size_t M = matmul_shape.m; | |
204 | 11200 | const size_t N = matmul_shape.n; | |
205 | 11200 | const size_t K = matmul_shape.k; | |
206 | |||
207 |
4/4✓ Branch 0 taken 1456 times.
✓ Branch 1 taken 4144 times.
✓ Branch 2 taken 1456 times.
✓ Branch 3 taken 4144 times.
|
11200 | if (K % bl != 0) { |
208 |
3/6✓ Branch 0 taken 1456 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1456 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1456 times.
✗ Branch 5 not taken.
|
1456 | GTEST_SKIP() << "K must be a multiple of bl"; |
209 | } | ||
210 | |||
211 | 4144 | const auto mr = ukernel_variant.ukernel.interface.get_mr(); | |
212 | 4144 | const auto nr = ukernel_variant.ukernel.interface.get_nr(); | |
213 | 4144 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
214 | 4144 | const auto sr = ukernel_variant.ukernel.interface.get_sr(); | |
215 | |||
216 | 4144 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
217 |
3/14✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 4144 times.
|
4144 | ASSERT_TRUE(m_step % mr == 0); |
218 | |||
219 | 4144 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
220 |
3/14✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 4144 times.
|
4144 | ASSERT_TRUE(n_step % nr == 0); |
221 | |||
222 | 8288 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
223 | |||
224 | // Generates input data. | ||
225 | 4144 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
226 | |||
227 | // Runs the reference implementation. | ||
228 | // * Quantizes the LHS matrix using 8-bit symmetric quantization. | ||
229 | 8288 | const auto [ref_lhs_qvalues, ref_lhs_scales] = | |
230 |
3/6✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4144 times.
✗ Branch 5 not taken.
|
4144 | quantize_symmetric_per_block_dynamic<float, int8_t, float>(ref_lhs.data(), M, K, bl); |
231 | |||
232 | // Runs the LHS packing micro-kernel. | ||
233 |
1/2✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
|
4144 | const auto lhs_start_row = rect.start_row(); |
234 | 4144 | auto lhs_stride = K * sizeof(uint16_t); | |
235 | |||
236 |
1/2✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
|
4144 | auto imp_packed_lhs = pack_lhs_qsi8d32p_f16( |
237 |
2/4✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
|
8288 | ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs, lhs_stride, lhs_start_row, rect.height()); |
238 |
2/4✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
|
8288 | auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr); |
239 | |||
240 | 4144 | const size_t kr_qsi8 = kr / sr; | |
241 | 4144 | const size_t sr_qsi8 = 1; | |
242 | |||
243 |
1/2✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
|
4144 | auto imp_packed_lhs_qsi8 = pack_lhs_qsi8d32p_f16( |
244 | 8288 | ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr_qsi8, sr_qsi8, ref_lhs, lhs_stride, lhs_start_row, | |
245 |
1/2✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
|
4144 | rect.height()); |
246 | 4144 | auto lhs_qsi8_packed_offset = | |
247 |
2/4✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
|
8288 | ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr_qsi8, sr_qsi8); |
248 | |||
249 |
4/16✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4144 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4144 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 4144 times.
|
4144 | ASSERT_EQ(lhs_qsi8_packed_offset, lhs_packed_offset); |
250 | |||
251 |
1/2✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
|
4144 | auto* imp_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs.data()); |
252 |
1/2✓ Branch 0 taken 4144 times.
✗ Branch 1 not taken.
|
4144 | auto* imp_packed_lhs_qsi8_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs_qsi8.data()); |
253 |
5/8✗ Branch 0 not taken.
✓ Branch 1 taken 4751600 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4751600 times.
✓ Branch 4 taken 4144 times.
✓ Branch 5 taken 4747456 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4144 times.
|
4751600 | for (size_t i = 0; i < ukernel_variant.lhs_pack_interface.packed_size(M, K, bl, mr, kr, sr); i++) { |
254 |
4/16✓ Branch 0 taken 4747456 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4747456 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4747456 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 4747456 times.
|
4747456 | ASSERT_EQ(imp_packed_lhs_ptr[i], imp_packed_lhs_qsi8_ptr[i]); |
255 | 4747456 | } | |
256 | 5600 | } | |
257 | |||
258 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
16802 | TEST_P(MatMulTest_f16_qsi8d32p_qai4c32p, EndToEnd) { |
259 | 67848 | const auto& [variant_index, matmul_shape, bl, portion, has_bias] = GetParam(); | |
260 | 11200 | const auto& ukernel_variant = variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.at(variant_index); | |
261 | |||
262 |
2/4✓ Branch 0 taken 5600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5600 times.
✗ Branch 3 not taken.
|
5600 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
263 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
264 | } | ||
265 | |||
266 | 5600 | const std::uint32_t seed = 0; | |
267 | |||
268 | 11200 | const size_t M = matmul_shape.m; | |
269 | 11200 | const size_t N = matmul_shape.n; | |
270 | 11200 | const size_t K = matmul_shape.k; | |
271 | |||
272 |
4/4✓ Branch 0 taken 1456 times.
✓ Branch 1 taken 4144 times.
✓ Branch 2 taken 1456 times.
✓ Branch 3 taken 4144 times.
|
11200 | if (K % bl != 0) { |
273 |
3/6✓ Branch 0 taken 1456 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1456 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1456 times.
✗ Branch 5 not taken.
|
1456 | GTEST_SKIP() << "K must be a multiple of bl"; |
274 | } | ||
275 | |||
276 | 4144 | const auto mr = ukernel_variant.ukernel.interface.get_mr(); | |
277 | 4144 | const auto nr = ukernel_variant.ukernel.interface.get_nr(); | |
278 | 4144 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
279 | 4144 | const auto sr = ukernel_variant.ukernel.interface.get_sr(); | |
280 | |||
281 |
4/4✓ Branch 0 taken 2072 times.
✓ Branch 1 taken 2072 times.
✓ Branch 2 taken 784 times.
✓ Branch 3 taken 1288 times.
|
4144 | if (mr == 1 && M > 1) { |
282 |
3/6✓ Branch 0 taken 1288 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1288 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1288 times.
✗ Branch 5 not taken.
|
1288 | GTEST_SKIP() << "Kernel does not support M != 1"; |
283 | } | ||
284 | |||
285 | 2856 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
286 |
3/14✓ Branch 0 taken 2856 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2856 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2856 times.
|
2856 | ASSERT_TRUE(m_step % mr == 0); |
287 | |||
288 | 2856 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
289 |
3/14✓ Branch 0 taken 2856 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2856 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2856 times.
|
2856 | ASSERT_TRUE(n_step % nr == 0); |
290 | |||
291 | 5712 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
292 |
4/4✓ Branch 0 taken 2632 times.
✓ Branch 1 taken 224 times.
✓ Branch 2 taken 16 times.
✓ Branch 3 taken 2616 times.
|
2856 | if (rect.height() == 0 || rect.width() == 0) { |
293 |
9/18✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 240 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 240 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 240 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 240 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 240 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 240 times.
✗ Branch 17 not taken.
|
240 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
294 | } | ||
295 | |||
296 | // Generates input data. | ||
297 | 2616 | const auto ref_lhs_f16 = fill_random<Float16>(M * K, seed + 0); | |
298 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | const auto ref_rhs = fill_random<float>(N * K, seed + 1); |
299 | 2616 | Buffer ref_biases; | |
300 | |||
301 |
2/2✓ Branch 0 taken 1308 times.
✓ Branch 1 taken 1308 times.
|
2616 | if (has_bias) { |
302 |
1/2✓ Branch 0 taken 1308 times.
✗ Branch 1 not taken.
|
1308 | ref_biases = fill_random<float>(N, seed + 2); |
303 | 1308 | } | |
304 | // For reference implementation, Casting FP16 input to FP32 type and FP32 output back to FP16 because the matmul | ||
305 | // implementation works with FP32 accumulation and casts the result to FP16 | ||
306 |
3/6✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
|
2616 | const auto ref_lhs = cast<float, Float16>(ref_lhs_f16.data(), ref_lhs_f16.size() * 8 / size_in_bits<Float16>); |
307 | |||
308 | // Runs the reference implementation. | ||
309 | // * Quantizes the LHS matrix using 8-bit symmetric quantization. | ||
310 | // * Quantizes the RHS matrix using 8-bit asymmetric quantization. | ||
311 | // * Performs GEMM. | ||
312 | 7848 | const auto [ref_lhs_qvalues, ref_lhs_scales] = | |
313 |
3/6✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
|
2616 | quantize_symmetric_per_block_dynamic<float, int8_t, float>(ref_lhs.data(), M, K, bl); |
314 | 275664 | const auto [ref_rhs_qai4, ref_rhs_scales, ref_rhs_zero_points] = | |
315 |
3/6✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
|
2616 | quantize_asymmetric_per_block_dynamic<float, Int4, float, int32_t>(ref_rhs.data(), N, K, bl); |
316 | |||
317 | 2616 | const auto ref_dst_no_clamp = | |
318 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>( |
319 |
5/10✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2616 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2616 times.
✗ Branch 9 not taken.
|
5232 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, 1, bl, ref_rhs_qai4.data(), |
320 |
7/10✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1308 times.
✓ Branch 5 taken 1308 times.
✓ Branch 6 taken 1308 times.
✓ Branch 7 taken 1308 times.
✓ Branch 8 taken 1308 times.
✗ Branch 9 not taken.
|
2616 | ref_rhs_scales.data(), ref_rhs_zero_points.data(), 1, bl, has_bias ? ref_biases.data() : nullptr, nullptr, |
321 | nullptr, 1); | ||
322 | |||
323 | // Clamps the reference output. | ||
324 | 2616 | const auto clamp_ratio = 0.8F; | |
325 |
2/4✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
|
7848 | const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_no_clamp.data(), M * N, clamp_ratio); |
326 |
4/8✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2616 times.
✗ Branch 7 not taken.
|
2616 | const auto ref_dst_float = clamp<float>(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); |
327 | |||
328 | // Cast the reference output to F16 | ||
329 |
3/6✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
|
2616 | auto ref_dst = cast<Float16, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>); |
330 | |||
331 | // Runs the LHS packing micro-kernel. | ||
332 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | const auto lhs_start_row = rect.start_row(); |
333 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | auto imp_packed_lhs = pack_lhs_qsi8d32p_f16( |
334 | 5232 | ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs_f16, K * sizeof(uint16_t), lhs_start_row, | |
335 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | rect.height()); |
336 |
2/4✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
|
5232 | auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr); |
337 |
2/4✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
|
5232 | auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl); |
338 | |||
339 |
4/16✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2616 times.
|
2616 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
340 | |||
341 | // Prepare the offsets as the RHS packing micro-kernel expects the scaled zero-points in float. | ||
342 |
2/4✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
|
5232 | const size_t num_blocks_per_row = round_up_division(K, bl); |
343 | 2616 | const size_t ref_zp_size = N * num_blocks_per_row; | |
344 | 2616 | const size_t ref_zp_size_in_bytes = ref_zp_size * sizeof(float); | |
345 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | Buffer ref_rhs_zp_f32(ref_zp_size_in_bytes); |
346 |
2/2✓ Branch 0 taken 262584 times.
✓ Branch 1 taken 2616 times.
|
265200 | for (size_t i = 0; i < ref_zp_size; ++i) { |
347 |
1/2✓ Branch 0 taken 262584 times.
✗ Branch 1 not taken.
|
262584 | reinterpret_cast<float*>(ref_rhs_zp_f32.data())[i] = |
348 |
1/2✓ Branch 0 taken 262584 times.
✗ Branch 1 not taken.
|
262584 | -reinterpret_cast<const int32_t*>(ref_rhs_zero_points.data())[i] * |
349 |
1/2✓ Branch 0 taken 262584 times.
✗ Branch 1 not taken.
|
262584 | reinterpret_cast<const float*>(ref_rhs_scales.data())[i]; |
350 | 262584 | } | |
351 | |||
352 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | const auto rhs_start_row = rect.start_col(); |
353 |
2/4✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
|
5232 | auto imp_packed_rhs = pack_rhs_qai4c32p( |
354 | 7848 | ukernel_variant.rhs_pack_interface, N, K, bl, nr, kr, sr, ref_rhs_qai4, has_bias, ref_biases, ref_rhs_scales, | |
355 | 2616 | ref_rhs_zp_f32, ukernel_variant.rhs_s0s1_input); | |
356 |
2/4✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
|
5232 | auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl); |
357 |
2/4✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
|
5232 | auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl); |
358 |
4/16✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2616 times.
|
2616 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
359 | |||
360 | 2616 | const auto dst_stride_row = N * sizeof(uint16_t); | |
361 | 2616 | const auto dst_stride_col = sizeof(uint16_t); | |
362 | 5232 | const auto dst_offset = | |
363 |
3/6✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
|
2616 | ukernel_variant.ukernel.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); |
364 |
2/4✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
|
2616 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; |
365 |
4/16✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2616 times.
|
2616 | ASSERT_EQ(dst_offset, ref_dst_offset); |
366 | |||
367 | // Runs the GEMM micro-kernel. | ||
368 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | const auto imp_dst_size = ukernel_variant.ukernel.interface.get_dst_size(M, N); |
369 |
5/18✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2616 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2616 times.
|
2616 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
370 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | Buffer imp_dst(imp_dst_size); |
371 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
5232 | ukernel_variant.ukernel.interface.run_matmul( |
372 |
4/8✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2616 times.
✗ Branch 7 not taken.
|
2616 | rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, |
373 |
2/4✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
|
2616 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
374 | 7848 | dst_stride_row, dst_stride_col, clamp_min, clamp_max); | |
375 | |||
376 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
377 | // tested. | ||
378 |
1/2✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
|
2616 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); |
379 | 2616 | DataFormat dst_format = DataFormat(DataType::FP16); | |
380 |
3/6✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
|
2616 | const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); |
381 |
4/16✓ Branch 0 taken 2616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2616 times.
|
2616 | ASSERT_TRUE(success); |
382 | 5600 | } | |
383 | |||
384 |
27/56✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11200 times.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 11200 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 11200 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 11200 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 11200 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 11200 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 11200 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 11200 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 5600 times.
✓ Branch 39 taken 5600 times.
✓ Branch 40 taken 5600 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 5600 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 11200 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 8400 times.
✓ Branch 47 taken 2800 times.
✓ Branch 48 taken 8400 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 2800 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 11200 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 11200 times.
✗ Branch 55 not taken.
|
30803 | INSTANTIATE_TEST_SUITE_P( |
385 | MatMul, MatMulTest_f16_qsi8d32p_qai4c32p, | ||
386 | testing::Combine( | ||
387 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.size()), test_matmul_shapes, | ||
388 | test_block_lengths, test_portions, testing::Bool()), | ||
389 | [](const auto& info) { | ||
390 | const auto variant_idx = std::get<0>(info.param); | ||
391 | const std::string name{variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.at(variant_idx).ukernel.name}; | ||
392 | const auto shape = std::get<MatMulShape>(info.param); | ||
393 | const auto bl = std::get<2>(info.param); | ||
394 | const auto portion = std::get<3>(info.param); | ||
395 | const auto has_bias = std::get<4>(info.param); | ||
396 | |||
397 | std::ostringstream sstream; | ||
398 | sstream << name << "__"; | ||
399 | PrintTo(shape, &sstream); | ||
400 | sstream << "__BL_" << bl << "_"; | ||
401 | if (has_bias) { | ||
402 | sstream << "_withBias_"; | ||
403 | } else { | ||
404 | sstream << "_noBias_"; | ||
405 | } | ||
406 | if (variants_kai_matmul_clamp_f16_qsi8d32p_qai4c32p.at(variant_idx).rhs_s0s1_input) { | ||
407 | sstream << "_RHS_s0s1__"; | ||
408 | } else { | ||
409 | sstream << "_RHS_s1s0__"; | ||
410 | } | ||
411 | PrintTo(portion, &sstream); | ||
412 | |||
413 | return sstream.str(); | ||
414 | }); | ||
415 | |||
416 | } // namespace kai::test | ||
417 |