Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <cstdlib> | ||
13 | #include <limits> | ||
14 | #include <sstream> | ||
15 | #include <string> | ||
16 | #include <tuple> | ||
17 | |||
18 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" | ||
19 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" | ||
20 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" | ||
21 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" | ||
22 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" | ||
23 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" | ||
24 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" | ||
25 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" | ||
26 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" | ||
27 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.h" | ||
28 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.h" | ||
29 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" | ||
30 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" | ||
31 | #include "test/common/buffer.hpp" | ||
32 | #include "test/common/cpu_info.hpp" | ||
33 | #include "test/common/float16.hpp" | ||
34 | #include "test/common/int4.hpp" | ||
35 | #include "test/common/matmul_test_common.hpp" | ||
36 | #include "test/common/matrix_portion.hpp" | ||
37 | #include "test/common/memory.hpp" | ||
38 | #include "test/common/round.hpp" | ||
39 | #include "test/common/test_suite.hpp" | ||
40 | #include "test/reference/cast.hpp" | ||
41 | #include "test/reference/fill.hpp" | ||
42 | #include "test/reference/matmul.hpp" | ||
43 | #include "test/reference/pack.hpp" | ||
44 | #include "test/reference/quantize.hpp" | ||
45 | |||
46 | namespace kai::test { | ||
47 | |||
48 | // Interface for the LHS and RHS packed size and packing micro-kernels | ||
49 | using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32); | ||
50 | using kai_get_rhs_packed_size_func_t = decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0); | ||
51 | using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32); | ||
52 | using kai_get_rhs_packed_offset_func_t = | ||
53 | decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0); | ||
54 | using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32); | ||
55 | using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0); | ||
56 | using kai_run_lhs_pack_func_t = decltype(&kai_run_lhs_quant_pack_qsi8d32p_f32); | ||
57 | using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0); | ||
58 | |||
59 | // Micro-kernel interface | ||
60 | struct kai_qsi8d32p_pack_functions { | ||
61 | kai_get_lhs_packed_size_func_t packed_size; | ||
62 | kai_get_lhs_packed_offset_func_t get_packed_offset; | ||
63 | kai_get_lhs_offset_func_t get_offset; | ||
64 | kai_run_lhs_pack_func_t run_pack; | ||
65 | }; | ||
66 | struct kai_qsi4c32p_pack_functions { | ||
67 | kai_get_rhs_packed_size_func_t packed_size; | ||
68 | kai_get_rhs_packed_offset_func_t get_packed_offset; | ||
69 | kai_get_rhs_offset_func_t get_offset; | ||
70 | kai_run_rhs_pack_func_t run_pack; | ||
71 | }; | ||
72 | |||
73 | static const std::array< | ||
74 | UkernelMatmulPackVariant< | ||
75 | kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_ukernel, kai_qsi8d32p_pack_functions, kai_qsi4c32p_pack_functions>, | ||
76 | 8> | ||
77 | variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p = { | ||
78 | {UKERNEL_MATMUL_PACK_VARIANT( | ||
79 | clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32, | ||
80 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), | ||
81 | UKERNEL_MATMUL_PACK_VARIANT( | ||
82 | clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32, | ||
83 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), | ||
84 | UKERNEL_MATMUL_PACK_VARIANT( | ||
85 | 4x8sb_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||
86 | cpu_has_i8mm, lhs_quant_pack_qsi8d32p4x8sb_f32_neon, rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), | ||
87 | UKERNEL_MATMUL_PACK_VARIANT( | ||
88 | clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32, | ||
89 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), | ||
90 | UKERNEL_MATMUL_PACK_VARIANT( | ||
91 | clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32, | ||
92 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), | ||
93 | UKERNEL_MATMUL_PACK_VARIANT( | ||
94 | clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32, | ||
95 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), | ||
96 | UKERNEL_MATMUL_PACK_VARIANT( | ||
97 | clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, lhs_quant_pack_qsi8d32p_f32_neon, | ||
98 | rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, false), | ||
99 | UKERNEL_MATMUL_PACK_VARIANT( | ||
100 | clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, cpu_has_sme2, lhs_quant_pack_qsi8d32p_f32_neon, | ||
101 | rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, false)}}; | ||
102 | |||
103 | class MatMulTest_f32_qsi8d32p_qsi4c32p : public ::testing::TestWithParam<MatMulTestPortionedParams> {}; | ||
104 | |||
105 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
578 | TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, Offset_RHS) { |
106 | 1152 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
107 | 384 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index); | |
108 | |||
109 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
110 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
111 | } | ||
112 | |||
113 | 192 | const size_t bl = 32; | |
114 | 384 | const size_t M = matmul_shape.m; | |
115 | 384 | const size_t N = matmul_shape.n; | |
116 | 384 | const size_t K = matmul_shape.k; | |
117 | |||
118 | 192 | const auto nr = ukernel_variant.ukernel.interface.get_nr(); | |
119 | 192 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
120 | |||
121 | 192 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
122 | 192 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
123 | |||
124 | 384 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
125 |
3/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 184 times.
|
192 | if (rect.height() == 0 || rect.width() == 0) { |
126 |
9/18✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
|
8 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
127 | } | ||
128 | |||
129 | 184 | const auto rhs_start_row = rect.start_col(); | |
130 | 184 | auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl); | |
131 | 184 | auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl); | |
132 |
3/14✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 184 times.
|
184 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
133 | 192 | } | |
134 | |||
135 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
578 | TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, Offset_LHS) { |
136 | 1152 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
137 | 384 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index); | |
138 | |||
139 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
140 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
141 | } | ||
142 | |||
143 | 192 | const size_t bl = 32; | |
144 | 384 | const size_t M = matmul_shape.m; | |
145 | 384 | const size_t N = matmul_shape.n; | |
146 | 384 | const size_t K = matmul_shape.k; | |
147 | |||
148 | 192 | const auto mr = ukernel_variant.ukernel.interface.get_mr(); | |
149 | 192 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
150 | 192 | const auto sr = ukernel_variant.ukernel.interface.get_sr(); | |
151 | |||
152 | 192 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
153 | 192 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
154 | |||
155 | 384 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
156 |
3/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 184 times.
|
192 | if (rect.height() == 0 || rect.width() == 0) { |
157 |
9/18✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
|
8 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
158 | } | ||
159 | |||
160 | 184 | const auto lhs_start_row = rect.start_row(); | |
161 | 184 | auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr); | |
162 | 184 | auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl); | |
163 | |||
164 |
3/14✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 184 times.
|
184 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
165 | 192 | } | |
166 | |||
167 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
578 | TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, EndToEnd) { |
168 | 1092 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
169 | 384 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index); | |
170 | |||
171 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
172 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
173 | } | ||
174 | |||
175 | 192 | const std::uint32_t seed = 0; | |
176 | |||
177 | 384 | const size_t M = matmul_shape.m; | |
178 | 384 | const size_t N = matmul_shape.n; | |
179 | 384 | const size_t K = matmul_shape.k; | |
180 | 192 | const size_t bl = 32; | |
181 | |||
182 | 192 | const auto mr = ukernel_variant.ukernel.interface.get_mr(); | |
183 | 192 | const auto nr = ukernel_variant.ukernel.interface.get_nr(); | |
184 | 192 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
185 | 192 | const auto sr = ukernel_variant.ukernel.interface.get_sr(); | |
186 | |||
187 |
4/4✓ Branch 0 taken 72 times.
✓ Branch 1 taken 120 times.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 60 times.
|
192 | if (mr == 1 && M > 1) { |
188 |
3/6✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
|
60 | GTEST_SKIP() << "Kernel does not support M != 1"; |
189 | } | ||
190 | |||
191 | 132 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
192 |
3/14✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 132 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 132 times.
|
132 | ASSERT_TRUE(m_step % mr == 0); |
193 | |||
194 | 132 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
195 |
3/14✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 132 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 132 times.
|
132 | ASSERT_TRUE(n_step % nr == 0); |
196 | |||
197 | 264 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
198 |
3/4✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 124 times.
|
132 | if (rect.height() == 0 || rect.width() == 0) { |
199 |
9/18✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
|
8 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
200 | } | ||
201 | // Generates input data. | ||
202 | 124 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
203 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | const auto ref_rhs = fill_random<float>(N * K, seed + 1); |
204 | |||
205 | // Runs the reference implementation. | ||
206 | 372 | const auto [ref_lhs_qvalues, ref_lhs_scales] = | |
207 |
2/4✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
|
124 | quantize_symmetric_per_block_dynamic<float, int8_t, Float16>(ref_lhs.data(), M, K, bl); |
208 | 620 | const auto [ref_rhs_qsi4, ref_rhs_scales] = | |
209 |
2/4✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
|
124 | quantize_symmetric_per_block_dynamic<float, Int4, Float16>(ref_rhs.data(), N, K, bl); |
210 | |||
211 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
248 | const auto ref_dst = matmul_clamp_nt_t<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>( |
212 |
5/10✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 124 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 124 times.
✗ Branch 9 not taken.
|
248 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, bl, ref_rhs_qsi4.data(), ref_rhs_scales.data(), |
213 | 124 | nullptr, bl, nullptr, std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | |
214 | |||
215 | // Runs the LHS packing micro-kernel. | ||
216 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | const auto lhs_start_row = rect.start_row(); |
217 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | const auto imp_packed_lhs_size = ukernel_variant.lhs_pack_interface.packed_size(M, K, bl, mr, kr, sr); |
218 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
219 | |||
220 | 124 | auto lhs_stride = K * sizeof(float); | |
221 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | auto lhs_offset = ukernel_variant.lhs_pack_interface.get_offset(lhs_start_row, lhs_stride); |
222 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr); |
223 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl); |
224 | |||
225 |
4/16✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 124 times.
✗ Branch 15 not taken.
|
124 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
226 | |||
227 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
248 | ukernel_variant.lhs_pack_interface.run_pack( |
228 |
2/4✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
|
124 | rect.height() /* m */, K, bl, mr, kr, sr, 0, reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), |
229 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); |
230 | |||
231 | // Runs the RHS packing micro-kernel. | ||
232 |
3/6✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
|
248 | const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); |
233 | 124 | const auto ref_rhs_qsu4_scale_f16 = | |
234 |
3/6✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
|
124 | pack_data_scales_interleave_block<UInt4, Float16>(ref_rhs_qsu4.data(), ref_rhs_scales.data(), N, K, bl); |
235 | |||
236 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | const auto imp_packed_rhs_size = ukernel_variant.rhs_pack_interface.packed_size(N, K, nr, kr, bl); |
237 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
238 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | const auto rhs_start_row = rect.start_col(); |
239 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl); |
240 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl); |
241 |
4/16✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 124 times.
|
124 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
242 | |||
243 | 124 | const kai_rhs_pack_qs4cxs1s0_param params{.lhs_zero_point = 1, .rhs_zero_point = 8}; | |
244 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
248 | ukernel_variant.rhs_pack_interface.run_pack( |
245 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | 1, N, K, nr, kr, sr, bl, reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_scale_f16.data()), nullptr, |
246 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | imp_packed_rhs.data(), 0, ¶ms); |
247 | |||
248 | 124 | const auto dst_stride_row = N * sizeof(float); | |
249 | 124 | const auto dst_stride_col = sizeof(float); | |
250 | 248 | const auto dst_offset = | |
251 |
3/6✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
|
124 | ukernel_variant.ukernel.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); |
252 |
2/4✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
|
124 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; |
253 |
4/16✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 124 times.
|
124 | ASSERT_EQ(dst_offset, ref_dst_offset); |
254 | |||
255 | // Runs the GEMM micro-kernel. | ||
256 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | const auto imp_dst_size = ukernel_variant.ukernel.interface.get_dst_size(M, N); |
257 |
5/18✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 124 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 124 times.
|
124 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
258 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
124 | Buffer imp_dst(imp_dst_size); |
259 |
1/2✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
|
248 | ukernel_variant.ukernel.interface.run_matmul( |
260 |
3/6✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 124 times.
✗ Branch 5 not taken.
|
124 | rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, |
261 |
2/4✓ Branch 0 taken 124 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
|
124 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
262 | 124 | dst_stride_row, dst_stride_col, std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | |
263 | |||
264 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
265 | // tested. | ||
266 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 3108 times.
✓ Branch 2 taken 2984 times.
✓ Branch 3 taken 124 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 124 times.
|
3108 | for (size_t y = 0; y < rect.height(); ++y) { |
267 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 128593 times.
✓ Branch 2 taken 125609 times.
✓ Branch 3 taken 2984 times.
✓ Branch 4 taken 2984 times.
✗ Branch 5 not taken.
|
128593 | for (size_t x = 0; x < rect.width(); ++x) { |
268 | 251218 | const auto imp_value = | |
269 |
4/8✓ Branch 0 taken 125609 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125609 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125609 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125609 times.
✗ Branch 7 not taken.
|
125609 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
270 | 251218 | const auto ref_value = | |
271 |
4/8✓ Branch 0 taken 125609 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 125609 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 125609 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125609 times.
✗ Branch 7 not taken.
|
125609 | read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
272 |
1/2✓ Branch 0 taken 125609 times.
✗ Branch 1 not taken.
|
125609 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; |
273 | |||
274 |
1/2✓ Branch 0 taken 125609 times.
✗ Branch 1 not taken.
|
125609 | if (rel_error > 0.0001F) { |
275 | ✗ | ASSERT_EQ(imp_value, ref_value); | |
276 | ✗ | } | |
277 | 125609 | } | |
278 | 2984 | } | |
279 | 192 | } | |
280 | |||
281 |
15/46✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 576 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✓ Branch 44 taken 576 times.
✗ Branch 45 not taken.
|
1156 | INSTANTIATE_TEST_SUITE_P( |
282 | MatMul, MatMulTest_f32_qsi8d32p_qsi4c32p, | ||
283 | testing::Combine( | ||
284 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.size()), | ||
285 | testing::Values( | ||
286 | MatMulShape{1, 2, 32}, // | ||
287 | MatMulShape{32, 64, 64}, // | ||
288 | MatMulShape{16, 32, 64}, // | ||
289 | MatMulShape{8, 32, 64}, // | ||
290 | MatMulShape{15, 32, 32}, // | ||
291 | MatMulShape{77, 99, 64}), | ||
292 | testing::Values( | ||
293 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
294 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
295 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
296 | MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle | ||
297 | )), | ||
298 | [](const auto& info) { | ||
299 | const auto variant_idx = std::get<0>(info.param); | ||
300 | const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_idx).ukernel.name}; | ||
301 | const auto shape = std::get<MatMulShape>(info.param); | ||
302 | const auto portion = std::get<2>(info.param); | ||
303 | |||
304 | return test_description(name, shape, portion, true); | ||
305 | }); | ||
306 | |||
307 | } // namespace kai::test | ||
308 |