Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <cstdlib> | ||
13 | #include <functional> | ||
14 | #include <sstream> | ||
15 | #include <string> | ||
16 | #include <string_view> | ||
17 | #include <tuple> | ||
18 | #include <utility> | ||
19 | |||
20 | #include "kai/kai_common.h" | ||
21 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h" | ||
22 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h" | ||
23 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h" | ||
24 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h" | ||
25 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod.h" | ||
26 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" | ||
27 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod.h" | ||
28 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" | ||
29 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h" | ||
30 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod.h" | ||
31 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" | ||
32 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" | ||
33 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm.h" | ||
34 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" | ||
35 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" | ||
36 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h" | ||
37 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" | ||
38 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" | ||
39 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" | ||
40 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h" | ||
41 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h" | ||
42 | #include "test/common/bfloat16.hpp" | ||
43 | #include "test/common/buffer.hpp" | ||
44 | #include "test/common/compare.hpp" | ||
45 | #include "test/common/cpu_info.hpp" | ||
46 | #include "test/common/int4.hpp" | ||
47 | #include "test/common/matmul_test_common.hpp" | ||
48 | #include "test/common/matrix_portion.hpp" | ||
49 | #include "test/common/memory.hpp" | ||
50 | #include "test/common/round.hpp" | ||
51 | #include "test/common/test_suite.hpp" | ||
52 | #include "test/reference/cast.hpp" | ||
53 | #include "test/reference/clamp.hpp" | ||
54 | #include "test/reference/fill.hpp" | ||
55 | #include "test/reference/matmul.hpp" | ||
56 | #include "test/reference/pad.hpp" | ||
57 | #include "test/reference/quantize.hpp" | ||
58 | #include "test/reference/transpose.hpp" | ||
59 | |||
60 | // Using BFloat truncate implementation (BFloat16<false>) to match existing packing/inference | ||
61 | |||
62 | namespace kai::test { | ||
63 | |||
64 | enum class RhsPackType { NxK, KxN }; | ||
65 | |||
66 | static const std::array<UkernelVariant<kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel>, 11> | ||
67 | variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p = { | ||
68 | {{UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod), | ||
69 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod", cpu_has_dotprod}, | ||
70 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod), | ||
71 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod", cpu_has_dotprod}, | ||
72 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod), | ||
73 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", cpu_has_dotprod}, | ||
74 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod), | ||
75 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod", cpu_has_dotprod}, | ||
76 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod), | ||
77 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", cpu_has_dotprod}, | ||
78 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod), | ||
79 | "kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod", cpu_has_dotprod}, | ||
80 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod), | ||
81 | "kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod", cpu_has_dotprod}, | ||
82 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm), | ||
83 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", cpu_has_i8mm}, | ||
84 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm), | ||
85 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", cpu_has_i8mm}, | ||
86 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm), | ||
87 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", cpu_has_i8mm}, | ||
88 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm), | ||
89 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm", cpu_has_i8mm}}}; | ||
90 | |||
91 | static const std::array<UkernelVariant<kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_ukernel>, 2> | ||
92 | variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p = {{ | ||
93 | {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod), | ||
94 | "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod_and_bf16}, | ||
95 | {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm), | ||
96 | "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm", cpu_has_i8mm_and_bf16}, | ||
97 | }}; | ||
98 | |||
99 | static const auto test_matmul_shapes = testing::Values( | ||
100 | MatMulShape{1, 1, 64}, // | ||
101 | MatMulShape{16, 32, 64}, // | ||
102 | MatMulShape{8, 32, 128}, // | ||
103 | MatMulShape{17, 25, 64}, // | ||
104 | MatMulShape{15, 31, 128}, // | ||
105 | MatMulShape{1, 25, 64}, // | ||
106 | MatMulShape{101, 253, 256} // | ||
107 | ); | ||
108 | |||
109 | static const auto test_portions = testing::Values( | ||
110 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
111 | MatrixPortion(0, 0, 1, 0.25f), // Leftmost portion. | ||
112 | MatrixPortion(0, 0.75f, 1, 1), // Rightmost portion. | ||
113 | MatrixPortion(0, 0.5f, 1, 0.8f) // | ||
114 | ); | ||
115 | |||
116 | static const auto test_block_lengths = testing::Values(32, 64); | ||
117 | |||
118 | // Executes the scalar RHS packing micro-kernel. | ||
119 | 1404 | static inline std::tuple<Buffer, size_t> pack_rhs_qsi4c32pscalebf16( | |
120 | size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr, const Buffer& rhs_values_qsi4, const Buffer& biases, | ||
121 | size_t bias_offset, const Buffer& rhs_scales, RhsPackType pack_type, size_t rect_start_row, size_t rect_width) { | ||
122 |
2/2✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
|
1404 | const size_t width = pack_type == RhsPackType::KxN ? N : K; |
123 |
2/2✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
|
1404 | const size_t height = pack_type == RhsPackType::KxN ? K : N; |
124 | 1404 | kai_datatype scale_dt = kai_datatype::kai_dt_bf16; | |
125 | |||
126 | 1404 | const size_t rhs_stride = round_up_multiple(width, 2); | |
127 | 1404 | const size_t rhs_stride_bytes = round_up_division(width, 2); | |
128 | 1404 | const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt); | |
129 | |||
130 | − | KAI_ASSUME(rhs_values_qsi4.size() == round_up_division(height * rhs_stride, 2)); | |
131 | |||
132 | 1404 | const auto rhs_values_qsu4 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2); | |
133 | 1404 | auto rhs_qsu4 = | |
134 |
2/4✓ Branch 0 taken 1404 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1404 times.
✗ Branch 3 not taken.
|
1404 | pad_row<UInt4>(rhs_values_qsu4.data(), height, width, width, rhs_stride_bytes * 2, rhs_values_qsi4.size()); |
135 | |||
136 | 1404 | const size_t scale_offset = rect_start_row * scales_stride_bytes; | |
137 | 1404 | size_t rhs_offset, rhs_packed_offset, imp_packed_rhs_size; | |
138 |
2/2✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
|
1404 | if (pack_type == RhsPackType::KxN) { |
139 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rect_start_row, rhs_stride_bytes); |
140 | 702 | rhs_packed_offset = | |
141 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rect_start_row, K, nr, kr, sr, bl, scale_dt); |
142 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); |
143 | 702 | } else { | |
144 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rect_start_row, rhs_stride_bytes); |
145 | 702 | rhs_packed_offset = | |
146 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rect_start_row, K, nr, kr, sr, bl, scale_dt); |
147 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); |
148 | } | ||
149 | |||
150 |
1/2✓ Branch 0 taken 1404 times.
✗ Branch 1 not taken.
|
1404 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
151 |
2/2✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
|
1404 | if (pack_type == RhsPackType::KxN) { |
152 | 702 | kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{}; | |
153 | 702 | params.lhs_zero_point = 1; | |
154 | 702 | params.rhs_zero_point = 8; | |
155 | 702 | params.scale_dt = scale_dt; | |
156 | |||
157 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( |
158 | 702 | 1, rect_width, K, nr, kr, sr, bl, reinterpret_cast<uint8_t*>(rhs_qsu4.data() + rhs_offset), | |
159 | 702 | rhs_stride_bytes, reinterpret_cast<float*>(biases.data() + bias_offset), rhs_scales.data() + scale_offset, | |
160 | 702 | scales_stride_bytes, imp_packed_rhs.data() + rhs_packed_offset, 0, ¶ms); | |
161 | 702 | } else { | |
162 | 702 | kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{}; | |
163 | 702 | params.lhs_zero_point = 1; | |
164 | 702 | params.rhs_zero_point = 8; | |
165 | 702 | params.scale_dt = scale_dt; | |
166 | |||
167 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( |
168 | 702 | 1, rect_width, K, nr, kr, sr, bl, reinterpret_cast<uint8_t*>(rhs_qsu4.data() + rhs_offset), | |
169 | 702 | rhs_stride_bytes, reinterpret_cast<const float*>(biases.data() + bias_offset), | |
170 | 702 | rhs_scales.data() + scale_offset, scales_stride_bytes, imp_packed_rhs.data() + rhs_packed_offset, 0, | |
171 | ¶ms); | ||
172 | 702 | } | |
173 | 1404 | return {std::move(imp_packed_rhs), rhs_packed_offset}; | |
174 | 1404 | } | |
175 | |||
176 | // Executes the vectorized RHS packing micro-kernels for block length of 4 bytes or 8 bytes | ||
177 | 702 | static inline std::tuple<Buffer, size_t> pack_rhs_qsi4c32pscalebf16_neon( | |
178 | size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr, const Buffer& rhs_values_qsi4, const Buffer& biases, | ||
179 | size_t bias_offset, const Buffer& rhs_scales, RhsPackType pack_type, size_t rect_start_row, size_t rect_width) { | ||
180 | − | KAI_ASSUME(kr / sr == 8 || kr / sr == 4); | |
181 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 702 times.
|
702 | const size_t width = pack_type == RhsPackType::KxN ? N : K; |
182 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 702 times.
|
702 | const size_t height = pack_type == RhsPackType::KxN ? K : N; |
183 | 702 | kai_datatype scale_dt = kai_datatype::kai_dt_bf16; | |
184 | |||
185 | 702 | const size_t rhs_stride = round_up_multiple(width, 2); | |
186 | 702 | const size_t rhs_stride_bytes = round_up_division(width, 2); | |
187 | 702 | const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt); | |
188 | |||
189 | − | KAI_ASSUME(rhs_values_qsi4.size() == round_up_division(height * rhs_stride, 2)); | |
190 | |||
191 | 702 | const auto rhs_values_qsu4 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2); | |
192 | 702 | auto rhs_qsu4 = | |
193 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | pad_row<UInt4>(rhs_values_qsu4.data(), height, width, width, rhs_stride_bytes * 2, rhs_values_qsi4.size()); |
194 | |||
195 | 702 | size_t scale_offset = rect_start_row * scales_stride_bytes; | |
196 | |||
197 | 702 | size_t imp_packed_rhs_size_neon, rhs_packed_offset_neon, rhs_offset_neon; | |
198 |
2/2✓ Branch 0 taken 216 times.
✓ Branch 1 taken 486 times.
|
702 | if (kr / sr == 8) { |
199 | 486 | imp_packed_rhs_size_neon = | |
200 |
1/2✓ Branch 0 taken 486 times.
✗ Branch 1 not taken.
|
486 | kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); |
201 |
1/2✓ Branch 0 taken 486 times.
✗ Branch 1 not taken.
|
486 | rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( |
202 | 486 | rect_start_row, K, nr, kr, sr, bl, scale_dt); | |
203 | 486 | rhs_offset_neon = | |
204 |
1/2✓ Branch 0 taken 486 times.
✗ Branch 1 not taken.
|
486 | kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes); |
205 | 486 | } else { | |
206 | 216 | imp_packed_rhs_size_neon = | |
207 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); |
208 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( |
209 | 216 | rect_start_row, K, nr, kr, sr, bl, scale_dt); | |
210 | 216 | rhs_offset_neon = | |
211 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes); |
212 | } | ||
213 | |||
214 | 702 | kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{}; | |
215 | 702 | params.lhs_zero_point = 1; | |
216 | 702 | params.rhs_zero_point = 8; | |
217 | 702 | params.scale_dt = scale_dt; | |
218 | |||
219 |
1/2✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
|
702 | Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon); |
220 |
2/2✓ Branch 0 taken 216 times.
✓ Branch 1 taken 486 times.
|
702 | if (kr / sr == 8) { |
221 |
1/2✓ Branch 0 taken 486 times.
✗ Branch 1 not taken.
|
486 | kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( |
222 | 486 | 1, rect_width /* n */, K, nr, kr, sr, bl, | |
223 | 486 | reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset_neon), rhs_stride_bytes, | |
224 | 486 | reinterpret_cast<const float*>(biases.data() + bias_offset), | |
225 | 486 | reinterpret_cast<const float*>(rhs_scales.data() + scale_offset), scales_stride_bytes, | |
226 | 486 | imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); | |
227 | 486 | } else { | |
228 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( |
229 | 216 | 1, rect_width /* n */, K, nr, kr, sr, bl, | |
230 | 216 | reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset_neon), rhs_stride_bytes, | |
231 | 216 | reinterpret_cast<const float*>(biases.data() + bias_offset), | |
232 | 216 | reinterpret_cast<const float*>(rhs_scales.data() + scale_offset), scales_stride_bytes, | |
233 | 216 | imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); | |
234 | } | ||
235 | 702 | return {std::move(imp_packed_rhs_neon), rhs_packed_offset_neon}; | |
236 | 702 | } | |
237 | |||
238 | using MatMulTestParams_withBL_withRHSPackType = std::tuple<size_t, MatMulShape, size_t, MatrixPortion, RhsPackType>; | ||
239 | |||
240 | class MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p | ||
241 | : public ::testing::TestWithParam<MatMulTestParams_withBL_withRHSPackType> {}; | ||
242 | class MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p | ||
243 | : public ::testing::TestWithParam<MatMulTestParams_withBL_withRHSPackType> {}; | ||
244 | |||
245 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
3698 | TEST_P(MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd) { |
246 | 19910 | auto& [variant_index, matmul_shape, bl, portion, rhs_pack_type] = GetParam(); | |
247 | 2464 | auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); | |
248 | |||
249 |
2/4✓ Branch 0 taken 1232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1232 times.
✗ Branch 3 not taken.
|
1232 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
250 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
251 | } | ||
252 | |||
253 | 1232 | const uint32_t seed = 0; | |
254 | |||
255 | 2464 | size_t M = matmul_shape.m; | |
256 | 2464 | size_t N = matmul_shape.n; | |
257 | 2464 | size_t K = matmul_shape.k; | |
258 | |||
259 | − | KAI_ASSUME((K % bl) == 0); | |
260 | − | KAI_ASSUME((bl % 32) == 0); | |
261 | |||
262 | 1232 | auto mr = ukernel_variant.interface.get_mr(); | |
263 | 1232 | auto nr = ukernel_variant.interface.get_nr(); | |
264 | 1232 | auto kr = ukernel_variant.interface.get_kr(); | |
265 | 1232 | auto sr = ukernel_variant.interface.get_sr(); | |
266 | |||
267 | 1232 | auto m_step = ukernel_variant.interface.get_m_step(); | |
268 |
3/14✓ Branch 0 taken 1232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1232 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1232 times.
|
1232 | ASSERT_TRUE(m_step % mr == 0); |
269 | |||
270 | 1232 | auto n_step = ukernel_variant.interface.get_n_step(); | |
271 |
3/14✓ Branch 0 taken 1232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1232 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1232 times.
|
1232 | ASSERT_TRUE(n_step % nr == 0); |
272 | |||
273 | 2464 | auto rect = portion.compute_portion(M, N, m_step, n_step); | |
274 |
3/4✓ Branch 0 taken 1232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 44 times.
✓ Branch 3 taken 1188 times.
|
1232 | if (rect.height() == 0 || rect.width() == 0) { |
275 |
9/18✓ Branch 0 taken 44 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 44 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 44 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 44 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 44 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 44 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 44 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 44 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 44 times.
✗ Branch 17 not taken.
|
44 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
276 | } | ||
277 | |||
278 | // Generates input data. | ||
279 | 1188 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
280 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | const auto ref_rhs = fill_random<float>(N * K, seed + 1); |
281 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | const auto ref_biases = fill_random<float>(N, seed + 2); |
282 | |||
283 | // Runs the reference implementation. | ||
284 | // * Quantizes the LHS matrix using 8-bit symmetric quantization. | ||
285 | // * Quantizes the RHS matrix using 8-bit asymmetric quantization. | ||
286 | // * Performs GEMM. | ||
287 | 3564 | auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
288 |
2/4✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
|
1188 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K); |
289 | 5940 | auto [ref_rhs_values_qsi4, ref_rhs_scales] = | |
290 |
3/6✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
|
3564 | quantize_rhs_qsi4c32p<float, BFloat16<false>>(N, K, bl, ref_rhs, rhs_pack_type == RhsPackType::NxK); |
291 | |||
292 | 1188 | Buffer ref_dst_noclamp; | |
293 |
2/2✓ Branch 0 taken 594 times.
✓ Branch 1 taken 594 times.
|
1188 | if (rhs_pack_type == RhsPackType::NxK) { |
294 | 594 | ref_dst_noclamp = | |
295 |
1/2✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
|
1188 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( |
296 |
4/8✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 594 times.
✗ Branch 7 not taken.
|
1188 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, |
297 |
3/6✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
|
594 | ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, |
298 | 1); | ||
299 | 594 | } else { | |
300 |
1/2✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
|
1188 | ref_dst_noclamp = matmul_nt_nt_quantized< |
301 | int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
302 |
4/8✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 594 times.
✗ Branch 7 not taken.
|
1188 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, |
303 |
3/6✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
|
594 | ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, 1); |
304 | } | ||
305 | |||
306 | // Clamps the reference output. | ||
307 | 1188 | const auto clamp_ratio = 0.8F; | |
308 |
2/4✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
|
4752 | const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_noclamp.data(), M * N, clamp_ratio); |
309 |
4/8✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1188 times.
✗ Branch 7 not taken.
|
1188 | auto ref_dst = clamp<float>(ref_dst_noclamp.data(), M * N, clamp_min, clamp_max); |
310 | |||
311 | // Runs the LHS packing micro-kernel. | ||
312 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | const auto lhs_start_row = rect.start_row(); |
313 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); |
314 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
315 | |||
316 | 1188 | const auto lhs_stride = K * sizeof(float); | |
317 | |||
318 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); |
319 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); |
320 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
321 |
4/16✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1188 times.
✗ Branch 15 not taken.
|
1188 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
322 | |||
323 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | kai_run_lhs_quant_pack_qai8dxp_f32( |
324 |
2/4✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
|
1188 | rect.height() /* m */, K, mr, kr, sr, 0, reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), |
325 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | lhs_stride, reinterpret_cast<uint8_t*>(imp_packed_lhs.data()) + lhs_packed_offset); |
326 | |||
327 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | const auto rhs_start_row = rect.start_col(); |
328 | 1188 | size_t bias_offset = rhs_start_row * sizeof(float); | |
329 | |||
330 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
2970 | auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qsi4c32pscalebf16( |
331 | 4752 | N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type, | |
332 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | rhs_start_row, rect.width()); |
333 | |||
334 |
2/4✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
|
2376 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); |
335 |
5/18✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1188 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 1188 times.
|
2376 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
336 | |||
337 | 1188 | const auto dst_stride_row = N * sizeof(float); | |
338 | 1188 | const auto dst_stride_col = sizeof(float); | |
339 | 2376 | const auto dst_offset = | |
340 |
3/6✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
|
1188 | ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); |
341 |
2/4✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
|
1188 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; |
342 |
4/16✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1188 times.
|
1188 | ASSERT_EQ(dst_offset, ref_dst_offset); |
343 | |||
344 | // Runs the GEMM micro-kernel. | ||
345 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
346 |
5/18✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1188 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 1188 times.
|
1188 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
347 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | Buffer imp_dst(imp_dst_size); |
348 | |||
349 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
2376 | ukernel_variant.interface.run_matmul( |
350 |
4/8✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1188 times.
✗ Branch 7 not taken.
|
1188 | rect.height(), rect.width(), K, bl, reinterpret_cast<const uint8_t*>(imp_packed_lhs.data()) + lhs_matmul_offset, |
351 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | reinterpret_cast<const uint8_t*>(imp_packed_rhs.data()) + rhs_matmul_offset, |
352 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, dst_stride_col, clamp_min, clamp_max); |
353 | |||
354 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
355 | // tested. | ||
356 |
1/2✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
|
1188 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); |
357 | 1188 | DataFormat dst_format = DataFormat(DataType::FP32); | |
358 | 2376 | const auto success = | |
359 |
3/6✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
|
1188 | compare(reinterpret_cast<const uint8_t*>(imp_dst.data()), ref_dst.data(), dst_format, M, N, rect, handler); |
360 |
4/16✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1188 times.
|
1188 | ASSERT_TRUE(success); |
361 | |||
362 | // Test vectorized packing micro-kernels, if packing parameters allow | ||
363 |
5/6✓ Branch 0 taken 594 times.
✓ Branch 1 taken 594 times.
✓ Branch 2 taken 216 times.
✓ Branch 3 taken 378 times.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
|
1188 | if (rhs_pack_type == RhsPackType::NxK && (kr / sr == 8 || kr / sr == 4)) { |
364 |
1/2✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
|
1188 | const auto [imp_packed_rhs_neon, rhs_packed_offset_neon] = pack_rhs_qsi4c32pscalebf16_neon( |
365 | 2376 | N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type, | |
366 |
1/2✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
|
594 | rhs_start_row, rect.width()); |
367 |
5/18✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 594 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 594 times.
|
1188 | ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); |
368 | |||
369 |
1/2✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
|
1188 | ukernel_variant.interface.run_matmul( |
370 |
4/8✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 594 times.
✗ Branch 7 not taken.
|
594 | rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, |
371 |
2/4✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
|
594 | imp_packed_rhs_neon.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
372 | 1782 | dst_stride_row, dst_stride_col, clamp_min, clamp_max); | |
373 | |||
374 |
3/6✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
|
594 | const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); |
375 |
4/16✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 594 times.
|
594 | ASSERT_TRUE(success); |
376 | 594 | } | |
377 | 1232 | } | |
378 | |||
379 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
674 | TEST_P(MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p, EndToEnd) { |
380 | 3396 | auto& [variant_index, matmul_shape, bl, portion, rhs_pack_type] = GetParam(); | |
381 | 448 | auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_index); | |
382 | |||
383 |
2/4✓ Branch 0 taken 224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 224 times.
✗ Branch 3 not taken.
|
224 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
384 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
385 | } | ||
386 | |||
387 | 224 | const uint32_t seed = 0; | |
388 | |||
389 | 448 | size_t M = matmul_shape.m; | |
390 | 448 | size_t N = matmul_shape.n; | |
391 | 448 | size_t K = matmul_shape.k; | |
392 | |||
393 | 224 | auto mr = ukernel_variant.interface.get_mr(); | |
394 | 224 | auto nr = ukernel_variant.interface.get_nr(); | |
395 | 224 | auto kr = ukernel_variant.interface.get_kr(); | |
396 | 224 | auto sr = ukernel_variant.interface.get_sr(); | |
397 | |||
398 | 224 | auto m_step = ukernel_variant.interface.get_m_step(); | |
399 |
3/14✓ Branch 0 taken 224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 224 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 224 times.
|
224 | ASSERT_TRUE(m_step % mr == 0); |
400 | |||
401 | 224 | auto n_step = ukernel_variant.interface.get_n_step(); | |
402 |
3/14✓ Branch 0 taken 224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 224 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 224 times.
|
224 | ASSERT_TRUE(n_step % nr == 0); |
403 | |||
404 | 448 | auto rect = portion.compute_portion(M, N, m_step, n_step); | |
405 |
3/4✓ Branch 0 taken 224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 216 times.
|
224 | if (rect.height() == 0 || rect.width() == 0) { |
406 |
9/18✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
|
8 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
407 | } | ||
408 | |||
409 | // Generates input data. | ||
410 | 216 | const auto ref_lhs_bf16 = fill_random<BFloat16<false>>(M * K, seed + 0); | |
411 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | const auto ref_rhs = fill_random<float>(N * K, seed + 1); |
412 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | const auto ref_biases = fill_random<float>(N, seed + 2); |
413 | |||
414 | // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BF16 because the matmul | ||
415 | // implementation works with FP32 accumulation and casts the result to BF16 | ||
416 | 216 | const auto ref_lhs = | |
417 |
3/6✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
|
216 | cast<float, BFloat16<false>>(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits<BFloat16<false>>); |
418 | |||
419 | // Runs the reference implementation. | ||
420 | // * Quantizes the LHS matrix using 8-bit symmetric quantization. | ||
421 | // * Quantizes the RHS matrix using 8-bit asymmetric quantization. | ||
422 | // * Performs GEMM. | ||
423 | 648 | auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
424 |
2/4✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
|
216 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K); |
425 | 1080 | auto [ref_rhs_values_qsi4, ref_rhs_scales] = | |
426 |
3/6✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
|
648 | quantize_rhs_qsi4c32p<float, BFloat16<false>>(N, K, bl, ref_rhs, rhs_pack_type == RhsPackType::NxK); |
427 | |||
428 | 216 | Buffer ref_dst_noclamp; | |
429 |
2/2✓ Branch 0 taken 108 times.
✓ Branch 1 taken 108 times.
|
216 | if (rhs_pack_type == RhsPackType::NxK) { |
430 | 108 | ref_dst_noclamp = | |
431 |
1/2✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
|
216 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( |
432 |
4/8✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
|
216 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, |
433 |
3/6✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
|
108 | ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, |
434 | 1); | ||
435 | 108 | } else { | |
436 |
1/2✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
|
216 | ref_dst_noclamp = matmul_nt_nt_quantized< |
437 | int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
438 |
4/8✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
|
216 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, |
439 |
3/6✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
|
108 | ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, 1); |
440 | } | ||
441 | |||
442 | // Clamps the reference output. | ||
443 | 216 | const auto clamp_ratio = 0.8F; | |
444 |
2/4✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
|
864 | const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_noclamp.data(), M * N, clamp_ratio); |
445 |
4/8✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
|
216 | auto ref_dst_float = clamp<float>(ref_dst_noclamp.data(), M * N, clamp_min, clamp_max); |
446 | |||
447 | // Cast the reference output to BF16 | ||
448 |
3/6✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
|
216 | auto ref_dst = cast<BFloat16<false>, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>); |
449 | |||
450 | // Runs the LHS packing micro-kernel. | ||
451 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | const auto lhs_start_row = rect.start_row(); |
452 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); |
453 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
454 | |||
455 | 216 | const auto lhs_stride = K * sizeof(uint16_t); | |
456 | |||
457 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); |
458 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); |
459 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
460 |
4/16✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 216 times.
✗ Branch 15 not taken.
|
216 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
461 | |||
462 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | kai_run_lhs_quant_pack_qai8dxp_bf16_neon( |
463 |
2/4✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
|
216 | rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, |
464 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | reinterpret_cast<uint8_t*>(imp_packed_lhs.data()) + lhs_packed_offset); |
465 | |||
466 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | const auto rhs_start_row = rect.start_col(); |
467 | 216 | size_t bias_offset = rhs_start_row * sizeof(float); | |
468 | |||
469 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
540 | auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qsi4c32pscalebf16( |
470 | 864 | N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type, | |
471 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | rhs_start_row, rect.width()); |
472 | |||
473 |
2/4✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
|
432 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); |
474 |
5/18✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 216 times.
|
432 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
475 | |||
476 | 216 | const auto dst_stride_row = N * sizeof(uint16_t); | |
477 | 216 | const auto dst_stride_col = sizeof(uint16_t); | |
478 | 432 | const auto dst_offset = | |
479 |
3/6✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
|
216 | ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); |
480 |
2/4✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
|
216 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; |
481 |
4/16✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 216 times.
|
216 | ASSERT_EQ(dst_offset, ref_dst_offset); |
482 | |||
483 | // Runs the GEMM micro-kernel. | ||
484 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
485 |
5/18✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 216 times.
|
216 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
486 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | Buffer imp_dst(imp_dst_size); |
487 | |||
488 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
432 | ukernel_variant.interface.run_matmul( |
489 |
4/8✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
|
216 | rect.height(), rect.width(), K, bl, reinterpret_cast<const uint8_t*>(imp_packed_lhs.data()) + lhs_matmul_offset, |
490 |
2/4✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
|
216 | reinterpret_cast<const uint8_t*>(imp_packed_rhs.data()) + rhs_matmul_offset, imp_dst.data() + dst_offset, |
491 | 648 | dst_stride_row, dst_stride_col, clamp_min, clamp_max); | |
492 | |||
493 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
494 | // tested. | ||
495 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
496 | 216 | DataFormat dst_format = DataFormat(DataType::BF16); | |
497 | 432 | const auto success = | |
498 |
3/6✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
|
216 | compare(reinterpret_cast<const uint8_t*>(imp_dst.data()), ref_dst.data(), dst_format, M, N, rect, handler); |
499 |
4/16✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 216 times.
|
216 | ASSERT_TRUE(success); |
500 | |||
501 | // Test vectorized packing micro-kernels, if packing parameters allow | ||
502 |
3/6✓ Branch 0 taken 108 times.
✓ Branch 1 taken 108 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 108 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
216 | if (rhs_pack_type == RhsPackType::NxK && (kr / sr == 8 || kr / sr == 4)) { |
503 |
1/2✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
|
216 | const auto [imp_packed_rhs_neon, rhs_packed_offset_neon] = pack_rhs_qsi4c32pscalebf16_neon( |
504 | 432 | N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type, | |
505 |
1/2✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
|
108 | rhs_start_row, rect.width()); |
506 |
5/18✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 108 times.
|
216 | ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); |
507 | |||
508 |
1/2✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
|
216 | ukernel_variant.interface.run_matmul( |
509 |
4/8✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
|
108 | rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset, |
510 |
2/4✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
|
108 | imp_packed_rhs_neon.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
511 | 324 | dst_stride_row, dst_stride_col, clamp_min, clamp_max); | |
512 | |||
513 |
3/6✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
|
108 | const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); |
514 |
4/16✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 108 times.
|
108 | ASSERT_TRUE(success); |
515 | 108 | } | |
516 | 224 | } | |
517 | |||
518 |
21/46✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1232 times.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1232 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1232 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 616 times.
✓ Branch 29 taken 616 times.
✓ Branch 30 taken 1232 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1232 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1232 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1232 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1232 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1232 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 1232 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 1232 times.
✗ Branch 45 not taken.
|
2466 | INSTANTIATE_TEST_SUITE_P( |
519 | MatMul, MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p, | ||
520 | testing::Combine( | ||
521 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.size()), test_matmul_shapes, | ||
522 | test_block_lengths, test_portions, testing::Values(RhsPackType::NxK, RhsPackType::KxN)), | ||
523 | [](const auto& info) { | ||
524 | const auto variant_idx = std::get<0>(info.param); | ||
525 | const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_idx).name}; | ||
526 | const auto shape = std::get<MatMulShape>(info.param); | ||
527 | const auto bl = std::get<2>(info.param); | ||
528 | const auto portion = std::get<3>(info.param); | ||
529 | const RhsPackType rhs_pack_type = std::get<4>(info.param); | ||
530 | |||
531 | std::ostringstream sstream; | ||
532 | sstream << name << ((rhs_pack_type == RhsPackType::NxK) ? "__NxK" : "__KxN") << "__"; | ||
533 | PrintTo(shape, &sstream); | ||
534 | sstream << "__BL_" << bl << "__"; | ||
535 | PrintTo(portion, &sstream); | ||
536 | |||
537 | return sstream.str(); | ||
538 | }); | ||
539 | |||
540 |
21/46✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 224 times.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 224 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 224 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 112 times.
✓ Branch 29 taken 112 times.
✓ Branch 30 taken 224 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 224 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 224 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 224 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 224 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 224 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 224 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 224 times.
✗ Branch 45 not taken.
|
450 | INSTANTIATE_TEST_SUITE_P( |
541 | MatMul, MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p, | ||
542 | testing::Combine( | ||
543 | testing::Range<size_t>(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.size()), test_matmul_shapes, | ||
544 | test_block_lengths, test_portions, testing::Values(RhsPackType::NxK, RhsPackType::KxN)), | ||
545 | [](const auto& info) { | ||
546 | const auto variant_idx = std::get<0>(info.param); | ||
547 | const std::string name{variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_idx).name}; | ||
548 | const auto shape = std::get<MatMulShape>(info.param); | ||
549 | const auto bl = std::get<2>(info.param); | ||
550 | const auto portion = std::get<3>(info.param); | ||
551 | const RhsPackType rhs_pack_type = std::get<4>(info.param); | ||
552 | |||
553 | std::ostringstream sstream; | ||
554 | sstream << name << ((rhs_pack_type == RhsPackType::NxK) ? "__NxK" : "__KxN") << "__"; | ||
555 | PrintTo(shape, &sstream); | ||
556 | sstream << "__BL_" << bl << "__"; | ||
557 | PrintTo(portion, &sstream); | ||
558 | |||
559 | return sstream.str(); | ||
560 | }); | ||
561 | |||
562 | } // namespace kai::test | ||
563 |