KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 99.3% 272 5 279
Functions: 100.0% 20 0 20
Branches: 41.8% 359 14 872

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <functional>
14 #include <sstream>
15 #include <string>
16 #include <string_view>
17 #include <tuple>
18 #include <utility>
19
20 #include "kai/kai_common.h"
21 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h"
22 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
23 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h"
24 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
25 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod.h"
26 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
27 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod.h"
28 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h"
29 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
30 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod.h"
31 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h"
32 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h"
33 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm.h"
34 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h"
35 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h"
36 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h"
37 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h"
38 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h"
39 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h"
40 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h"
41 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h"
42 #include "test/common/bfloat16.hpp"
43 #include "test/common/buffer.hpp"
44 #include "test/common/compare.hpp"
45 #include "test/common/cpu_info.hpp"
46 #include "test/common/int4.hpp"
47 #include "test/common/matmul_test_common.hpp"
48 #include "test/common/matrix_portion.hpp"
49 #include "test/common/memory.hpp"
50 #include "test/common/round.hpp"
51 #include "test/common/test_suite.hpp"
52 #include "test/reference/cast.hpp"
53 #include "test/reference/clamp.hpp"
54 #include "test/reference/fill.hpp"
55 #include "test/reference/matmul.hpp"
56 #include "test/reference/pad.hpp"
57 #include "test/reference/quantize.hpp"
58 #include "test/reference/transpose.hpp"
59
60 // Using BFloat truncate implementation (BFloat16<false>) to match existing packing/inference
61
62 namespace kai::test {
63
64 enum class RhsPackType { NxK, KxN };
65
66 static const std::array<UkernelVariant<kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel>, 11>
67 variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p = {
68 {{UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod),
69 "kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod", cpu_has_dotprod},
70 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod),
71 "kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod", cpu_has_dotprod},
72 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod),
73 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", cpu_has_dotprod},
74 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod),
75 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod", cpu_has_dotprod},
76 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod),
77 "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", cpu_has_dotprod},
78 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod),
79 "kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod", cpu_has_dotprod},
80 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod),
81 "kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod", cpu_has_dotprod},
82 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm),
83 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", cpu_has_i8mm},
84 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm),
85 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", cpu_has_i8mm},
86 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm),
87 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", cpu_has_i8mm},
88 {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm),
89 "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm", cpu_has_i8mm}}};
90
91 static const std::array<UkernelVariant<kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_ukernel>, 2>
92 variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p = {{
93 {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod),
94 "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod_and_bf16},
95 {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm),
96 "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm", cpu_has_i8mm_and_bf16},
97 }};
98
99 static const auto test_matmul_shapes = testing::Values(
100 MatMulShape{1, 1, 64}, //
101 MatMulShape{16, 32, 64}, //
102 MatMulShape{8, 32, 128}, //
103 MatMulShape{17, 25, 64}, //
104 MatMulShape{15, 31, 128}, //
105 MatMulShape{1, 25, 64}, //
106 MatMulShape{101, 253, 256} //
107 );
108
109 static const auto test_portions = testing::Values(
110 MatrixPortion(0, 0, 1, 1), // Full matrix.
111 MatrixPortion(0, 0, 1, 0.25f), // Leftmost portion.
112 MatrixPortion(0, 0.75f, 1, 1), // Rightmost portion.
113 MatrixPortion(0, 0.5f, 1, 0.8f) //
114 );
115
116 static const auto test_block_lengths = testing::Values(32, 64);
117
118 // Executes the scalar RHS packing micro-kernel.
119 1404 static inline std::tuple<Buffer, size_t> pack_rhs_qsi4c32pscalebf16(
120 size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr, const Buffer& rhs_values_qsi4, const Buffer& biases,
121 size_t bias_offset, const Buffer& rhs_scales, RhsPackType pack_type, size_t rect_start_row, size_t rect_width) {
122
2/2
✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
1404 const size_t width = pack_type == RhsPackType::KxN ? N : K;
123
2/2
✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
1404 const size_t height = pack_type == RhsPackType::KxN ? K : N;
124 1404 kai_datatype scale_dt = kai_datatype::kai_dt_bf16;
125
126 1404 const size_t rhs_stride = round_up_multiple(width, 2);
127 1404 const size_t rhs_stride_bytes = round_up_division(width, 2);
128 1404 const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt);
129
130 KAI_ASSUME(rhs_values_qsi4.size() == round_up_division(height * rhs_stride, 2));
131
132 1404 const auto rhs_values_qsu4 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2);
133 1404 auto rhs_qsu4 =
134
2/4
✓ Branch 0 taken 1404 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1404 times.
✗ Branch 3 not taken.
1404 pad_row<UInt4>(rhs_values_qsu4.data(), height, width, width, rhs_stride_bytes * 2, rhs_values_qsi4.size());
135
136 1404 const size_t scale_offset = rect_start_row * scales_stride_bytes;
137 1404 size_t rhs_offset, rhs_packed_offset, imp_packed_rhs_size;
138
2/2
✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
1404 if (pack_type == RhsPackType::KxN) {
139
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rect_start_row, rhs_stride_bytes);
140 702 rhs_packed_offset =
141
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rect_start_row, K, nr, kr, sr, bl, scale_dt);
142
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt);
143 702 } else {
144
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rect_start_row, rhs_stride_bytes);
145 702 rhs_packed_offset =
146
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rect_start_row, K, nr, kr, sr, bl, scale_dt);
147
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt);
148 }
149
150
1/2
✓ Branch 0 taken 1404 times.
✗ Branch 1 not taken.
1404 Buffer imp_packed_rhs(imp_packed_rhs_size);
151
2/2
✓ Branch 0 taken 702 times.
✓ Branch 1 taken 702 times.
1404 if (pack_type == RhsPackType::KxN) {
152 702 kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{};
153 702 params.lhs_zero_point = 1;
154 702 params.rhs_zero_point = 8;
155 702 params.scale_dt = scale_dt;
156
157
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(
158 702 1, rect_width, K, nr, kr, sr, bl, reinterpret_cast<uint8_t*>(rhs_qsu4.data() + rhs_offset),
159 702 rhs_stride_bytes, reinterpret_cast<float*>(biases.data() + bias_offset), rhs_scales.data() + scale_offset,
160 702 scales_stride_bytes, imp_packed_rhs.data() + rhs_packed_offset, 0, &params);
161 702 } else {
162 702 kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{};
163 702 params.lhs_zero_point = 1;
164 702 params.rhs_zero_point = 8;
165 702 params.scale_dt = scale_dt;
166
167
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
168 702 1, rect_width, K, nr, kr, sr, bl, reinterpret_cast<uint8_t*>(rhs_qsu4.data() + rhs_offset),
169 702 rhs_stride_bytes, reinterpret_cast<const float*>(biases.data() + bias_offset),
170 702 rhs_scales.data() + scale_offset, scales_stride_bytes, imp_packed_rhs.data() + rhs_packed_offset, 0,
171 &params);
172 702 }
173 1404 return {std::move(imp_packed_rhs), rhs_packed_offset};
174 1404 }
175
176 // Executes the vectorized RHS packing micro-kernels for block length of 4 bytes or 8 bytes
177 702 static inline std::tuple<Buffer, size_t> pack_rhs_qsi4c32pscalebf16_neon(
178 size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr, const Buffer& rhs_values_qsi4, const Buffer& biases,
179 size_t bias_offset, const Buffer& rhs_scales, RhsPackType pack_type, size_t rect_start_row, size_t rect_width) {
180 KAI_ASSUME(kr / sr == 8 || kr / sr == 4);
181
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 702 times.
702 const size_t width = pack_type == RhsPackType::KxN ? N : K;
182
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 702 times.
702 const size_t height = pack_type == RhsPackType::KxN ? K : N;
183 702 kai_datatype scale_dt = kai_datatype::kai_dt_bf16;
184
185 702 const size_t rhs_stride = round_up_multiple(width, 2);
186 702 const size_t rhs_stride_bytes = round_up_division(width, 2);
187 702 const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt);
188
189 KAI_ASSUME(rhs_values_qsi4.size() == round_up_division(height * rhs_stride, 2));
190
191 702 const auto rhs_values_qsu4 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2);
192 702 auto rhs_qsu4 =
193
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 pad_row<UInt4>(rhs_values_qsu4.data(), height, width, width, rhs_stride_bytes * 2, rhs_values_qsi4.size());
194
195 702 size_t scale_offset = rect_start_row * scales_stride_bytes;
196
197 702 size_t imp_packed_rhs_size_neon, rhs_packed_offset_neon, rhs_offset_neon;
198
2/2
✓ Branch 0 taken 216 times.
✓ Branch 1 taken 486 times.
702 if (kr / sr == 8) {
199 486 imp_packed_rhs_size_neon =
200
1/2
✓ Branch 0 taken 486 times.
✗ Branch 1 not taken.
486 kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt);
201
1/2
✓ Branch 0 taken 486 times.
✗ Branch 1 not taken.
486 rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
202 486 rect_start_row, K, nr, kr, sr, bl, scale_dt);
203 486 rhs_offset_neon =
204
1/2
✓ Branch 0 taken 486 times.
✗ Branch 1 not taken.
486 kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes);
205 486 } else {
206 216 imp_packed_rhs_size_neon =
207
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt);
208
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(
209 216 rect_start_row, K, nr, kr, sr, bl, scale_dt);
210 216 rhs_offset_neon =
211
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes);
212 }
213
214 702 kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{};
215 702 params.lhs_zero_point = 1;
216 702 params.rhs_zero_point = 8;
217 702 params.scale_dt = scale_dt;
218
219
1/2
✓ Branch 0 taken 702 times.
✗ Branch 1 not taken.
702 Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon);
220
2/2
✓ Branch 0 taken 216 times.
✓ Branch 1 taken 486 times.
702 if (kr / sr == 8) {
221
1/2
✓ Branch 0 taken 486 times.
✗ Branch 1 not taken.
486 kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
222 486 1, rect_width /* n */, K, nr, kr, sr, bl,
223 486 reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset_neon), rhs_stride_bytes,
224 486 reinterpret_cast<const float*>(biases.data() + bias_offset),
225 486 reinterpret_cast<const float*>(rhs_scales.data() + scale_offset), scales_stride_bytes,
226 486 imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, &params);
227 486 } else {
228
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(
229 216 1, rect_width /* n */, K, nr, kr, sr, bl,
230 216 reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset_neon), rhs_stride_bytes,
231 216 reinterpret_cast<const float*>(biases.data() + bias_offset),
232 216 reinterpret_cast<const float*>(rhs_scales.data() + scale_offset), scales_stride_bytes,
233 216 imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, &params);
234 }
235 702 return {std::move(imp_packed_rhs_neon), rhs_packed_offset_neon};
236 702 }
237
238 using MatMulTestParams_withBL_withRHSPackType = std::tuple<size_t, MatMulShape, size_t, MatrixPortion, RhsPackType>;
239
240 class MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p
241 : public ::testing::TestWithParam<MatMulTestParams_withBL_withRHSPackType> {};
242 class MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p
243 : public ::testing::TestWithParam<MatMulTestParams_withBL_withRHSPackType> {};
244
245
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
3698 TEST_P(MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd) {
246 19910 auto& [variant_index, matmul_shape, bl, portion, rhs_pack_type] = GetParam();
247 2464 auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index);
248
249
2/4
✓ Branch 0 taken 1232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1232 times.
✗ Branch 3 not taken.
1232 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
250 GTEST_SKIP() << "Unsupported CPU feature";
251 }
252
253 1232 const uint32_t seed = 0;
254
255 2464 size_t M = matmul_shape.m;
256 2464 size_t N = matmul_shape.n;
257 2464 size_t K = matmul_shape.k;
258
259 KAI_ASSUME((K % bl) == 0);
260 KAI_ASSUME((bl % 32) == 0);
261
262 1232 auto mr = ukernel_variant.interface.get_mr();
263 1232 auto nr = ukernel_variant.interface.get_nr();
264 1232 auto kr = ukernel_variant.interface.get_kr();
265 1232 auto sr = ukernel_variant.interface.get_sr();
266
267 1232 auto m_step = ukernel_variant.interface.get_m_step();
268
3/14
✓ Branch 0 taken 1232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1232 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1232 times.
1232 ASSERT_TRUE(m_step % mr == 0);
269
270 1232 auto n_step = ukernel_variant.interface.get_n_step();
271
3/14
✓ Branch 0 taken 1232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1232 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1232 times.
1232 ASSERT_TRUE(n_step % nr == 0);
272
273 2464 auto rect = portion.compute_portion(M, N, m_step, n_step);
274
3/4
✓ Branch 0 taken 1232 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 44 times.
✓ Branch 3 taken 1188 times.
1232 if (rect.height() == 0 || rect.width() == 0) {
275
9/18
✓ Branch 0 taken 44 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 44 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 44 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 44 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 44 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 44 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 44 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 44 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 44 times.
✗ Branch 17 not taken.
44 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
276 }
277
278 // Generates input data.
279 1188 const auto ref_lhs = fill_random<float>(M * K, seed + 0);
280
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 const auto ref_rhs = fill_random<float>(N * K, seed + 1);
281
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 const auto ref_biases = fill_random<float>(N, seed + 2);
282
283 // Runs the reference implementation.
284 // * Quantizes the LHS matrix using 8-bit symmetric quantization.
285 // * Quantizes the RHS matrix using 8-bit asymmetric quantization.
286 // * Performs GEMM.
287 3564 auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
288
2/4
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
1188 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K);
289 5940 auto [ref_rhs_values_qsi4, ref_rhs_scales] =
290
3/6
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
3564 quantize_rhs_qsi4c32p<float, BFloat16<false>>(N, K, bl, ref_rhs, rhs_pack_type == RhsPackType::NxK);
291
292 1188 Buffer ref_dst_noclamp;
293
2/2
✓ Branch 0 taken 594 times.
✓ Branch 1 taken 594 times.
1188 if (rhs_pack_type == RhsPackType::NxK) {
294 594 ref_dst_noclamp =
295
1/2
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
1188 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
296
4/8
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 594 times.
✗ Branch 7 not taken.
1188 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K,
297
3/6
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
594 ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr,
298 1);
299 594 } else {
300
1/2
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
1188 ref_dst_noclamp = matmul_nt_nt_quantized<
301 int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
302
4/8
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 594 times.
✗ Branch 7 not taken.
1188 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K,
303
3/6
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
594 ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, 1);
304 }
305
306 // Clamps the reference output.
307 1188 const auto clamp_ratio = 0.8F;
308
2/4
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
4752 const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_noclamp.data(), M * N, clamp_ratio);
309
4/8
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1188 times.
✗ Branch 7 not taken.
1188 auto ref_dst = clamp<float>(ref_dst_noclamp.data(), M * N, clamp_min, clamp_max);
310
311 // Runs the LHS packing micro-kernel.
312
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 const auto lhs_start_row = rect.start_row();
313
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr);
314
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 Buffer imp_packed_lhs(imp_packed_lhs_size);
315
316 1188 const auto lhs_stride = K * sizeof(float);
317
318
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride);
319
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr);
320
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
321
4/16
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1188 times.
✗ Branch 15 not taken.
1188 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
322
323
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 kai_run_lhs_quant_pack_qai8dxp_f32(
324
2/4
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
1188 rect.height() /* m */, K, mr, kr, sr, 0, reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset),
325
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 lhs_stride, reinterpret_cast<uint8_t*>(imp_packed_lhs.data()) + lhs_packed_offset);
326
327
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 const auto rhs_start_row = rect.start_col();
328 1188 size_t bias_offset = rhs_start_row * sizeof(float);
329
330
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
2970 auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qsi4c32pscalebf16(
331 4752 N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type,
332
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 rhs_start_row, rect.width());
333
334
2/4
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
2376 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl);
335
5/18
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1188 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 1188 times.
2376 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
336
337 1188 const auto dst_stride_row = N * sizeof(float);
338 1188 const auto dst_stride_col = sizeof(float);
339 2376 const auto dst_offset =
340
3/6
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
1188 ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
341
2/4
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
1188 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
342
4/16
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1188 times.
1188 ASSERT_EQ(dst_offset, ref_dst_offset);
343
344 // Runs the GEMM micro-kernel.
345
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
346
5/18
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1188 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 1188 times.
1188 ASSERT_EQ(imp_dst_size, ref_dst.size());
347
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 Buffer imp_dst(imp_dst_size);
348
349
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
2376 ukernel_variant.interface.run_matmul(
350
4/8
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1188 times.
✗ Branch 7 not taken.
1188 rect.height(), rect.width(), K, bl, reinterpret_cast<const uint8_t*>(imp_packed_lhs.data()) + lhs_matmul_offset,
351
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 reinterpret_cast<const uint8_t*>(imp_packed_rhs.data()) + rhs_matmul_offset,
352
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
353
354 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
355 // tested.
356
1/2
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
1188 DefaultMismatchHandler handler(0, 0.1, 0, 0.05);
357 1188 DataFormat dst_format = DataFormat(DataType::FP32);
358 2376 const auto success =
359
3/6
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
1188 compare(reinterpret_cast<const uint8_t*>(imp_dst.data()), ref_dst.data(), dst_format, M, N, rect, handler);
360
4/16
✓ Branch 0 taken 1188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1188 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1188 times.
1188 ASSERT_TRUE(success);
361
362 // Test vectorized packing micro-kernels, if packing parameters allow
363
5/6
✓ Branch 0 taken 594 times.
✓ Branch 1 taken 594 times.
✓ Branch 2 taken 216 times.
✓ Branch 3 taken 378 times.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
1188 if (rhs_pack_type == RhsPackType::NxK && (kr / sr == 8 || kr / sr == 4)) {
364
1/2
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
1188 const auto [imp_packed_rhs_neon, rhs_packed_offset_neon] = pack_rhs_qsi4c32pscalebf16_neon(
365 2376 N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type,
366
1/2
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
594 rhs_start_row, rect.width());
367
5/18
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 594 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 594 times.
1188 ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset);
368
369
1/2
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
1188 ukernel_variant.interface.run_matmul(
370
4/8
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 594 times.
✗ Branch 7 not taken.
594 rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset,
371
2/4
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
594 imp_packed_rhs_neon.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
372 1782 dst_stride_row, dst_stride_col, clamp_min, clamp_max);
373
374
3/6
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
594 const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler);
375
4/16
✓ Branch 0 taken 594 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 594 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 594 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 594 times.
594 ASSERT_TRUE(success);
376 594 }
377 1232 }
378
379
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
674 TEST_P(MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p, EndToEnd) {
380 3396 auto& [variant_index, matmul_shape, bl, portion, rhs_pack_type] = GetParam();
381 448 auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_index);
382
383
2/4
✓ Branch 0 taken 224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 224 times.
✗ Branch 3 not taken.
224 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
384 GTEST_SKIP() << "Unsupported CPU feature";
385 }
386
387 224 const uint32_t seed = 0;
388
389 448 size_t M = matmul_shape.m;
390 448 size_t N = matmul_shape.n;
391 448 size_t K = matmul_shape.k;
392
393 224 auto mr = ukernel_variant.interface.get_mr();
394 224 auto nr = ukernel_variant.interface.get_nr();
395 224 auto kr = ukernel_variant.interface.get_kr();
396 224 auto sr = ukernel_variant.interface.get_sr();
397
398 224 auto m_step = ukernel_variant.interface.get_m_step();
399
3/14
✓ Branch 0 taken 224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 224 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 224 times.
224 ASSERT_TRUE(m_step % mr == 0);
400
401 224 auto n_step = ukernel_variant.interface.get_n_step();
402
3/14
✓ Branch 0 taken 224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 224 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 224 times.
224 ASSERT_TRUE(n_step % nr == 0);
403
404 448 auto rect = portion.compute_portion(M, N, m_step, n_step);
405
3/4
✓ Branch 0 taken 224 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 216 times.
224 if (rect.height() == 0 || rect.width() == 0) {
406
9/18
✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
8 GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")";
407 }
408
409 // Generates input data.
410 216 const auto ref_lhs_bf16 = fill_random<BFloat16<false>>(M * K, seed + 0);
411
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 const auto ref_rhs = fill_random<float>(N * K, seed + 1);
412
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 const auto ref_biases = fill_random<float>(N, seed + 2);
413
414 // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BF16 because the matmul
415 // implementation works with FP32 accumulation and casts the result to BF16
416 216 const auto ref_lhs =
417
3/6
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
216 cast<float, BFloat16<false>>(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits<BFloat16<false>>);
418
419 // Runs the reference implementation.
420 // * Quantizes the LHS matrix using 8-bit symmetric quantization.
421 // * Quantizes the RHS matrix using 8-bit asymmetric quantization.
422 // * Performs GEMM.
423 648 auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] =
424
2/4
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
216 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K);
425 1080 auto [ref_rhs_values_qsi4, ref_rhs_scales] =
426
3/6
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
648 quantize_rhs_qsi4c32p<float, BFloat16<false>>(N, K, bl, ref_rhs, rhs_pack_type == RhsPackType::NxK);
427
428 216 Buffer ref_dst_noclamp;
429
2/2
✓ Branch 0 taken 108 times.
✓ Branch 1 taken 108 times.
216 if (rhs_pack_type == RhsPackType::NxK) {
430 108 ref_dst_noclamp =
431
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
216 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
432
4/8
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
216 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K,
433
3/6
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
108 ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr,
434 1);
435 108 } else {
436
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
216 ref_dst_noclamp = matmul_nt_nt_quantized<
437 int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
438
4/8
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
216 M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K,
439
3/6
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
108 ref_rhs_values_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, bl, ref_biases.data(), nullptr, nullptr, 1);
440 }
441
442 // Clamps the reference output.
443 216 const auto clamp_ratio = 0.8F;
444
2/4
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
864 const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_noclamp.data(), M * N, clamp_ratio);
445
4/8
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
216 auto ref_dst_float = clamp<float>(ref_dst_noclamp.data(), M * N, clamp_min, clamp_max);
446
447 // Cast the reference output to BF16
448
3/6
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
216 auto ref_dst = cast<BFloat16<false>, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>);
449
450 // Runs the LHS packing micro-kernel.
451
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 const auto lhs_start_row = rect.start_row();
452
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr);
453
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 Buffer imp_packed_lhs(imp_packed_lhs_size);
454
455 216 const auto lhs_stride = K * sizeof(uint16_t);
456
457
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride);
458
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr);
459
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
460
4/16
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 216 times.
✗ Branch 15 not taken.
216 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
461
462
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 kai_run_lhs_quant_pack_qai8dxp_bf16_neon(
463
2/4
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
216 rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride,
464
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 reinterpret_cast<uint8_t*>(imp_packed_lhs.data()) + lhs_packed_offset);
465
466
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 const auto rhs_start_row = rect.start_col();
467 216 size_t bias_offset = rhs_start_row * sizeof(float);
468
469
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
540 auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qsi4c32pscalebf16(
470 864 N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type,
471
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 rhs_start_row, rect.width());
472
473
2/4
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
432 auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl);
474
5/18
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 216 times.
432 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
475
476 216 const auto dst_stride_row = N * sizeof(uint16_t);
477 216 const auto dst_stride_col = sizeof(uint16_t);
478 432 const auto dst_offset =
479
3/6
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
216 ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
480
2/4
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
216 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
481
4/16
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 216 times.
216 ASSERT_EQ(dst_offset, ref_dst_offset);
482
483 // Runs the GEMM micro-kernel.
484
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
485
5/18
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 216 times.
216 ASSERT_EQ(imp_dst_size, ref_dst.size());
486
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 Buffer imp_dst(imp_dst_size);
487
488
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
432 ukernel_variant.interface.run_matmul(
489
4/8
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
216 rect.height(), rect.width(), K, bl, reinterpret_cast<const uint8_t*>(imp_packed_lhs.data()) + lhs_matmul_offset,
490
2/4
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
216 reinterpret_cast<const uint8_t*>(imp_packed_rhs.data()) + rhs_matmul_offset, imp_dst.data() + dst_offset,
491 648 dst_stride_row, dst_stride_col, clamp_min, clamp_max);
492
493 // Compares the output of the micro-kernels against the output of the reference implementation for the portion
494 // tested.
495
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
496 216 DataFormat dst_format = DataFormat(DataType::BF16);
497 432 const auto success =
498
3/6
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
216 compare(reinterpret_cast<const uint8_t*>(imp_dst.data()), ref_dst.data(), dst_format, M, N, rect, handler);
499
4/16
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 216 times.
216 ASSERT_TRUE(success);
500
501 // Test vectorized packing micro-kernels, if packing parameters allow
502
3/6
✓ Branch 0 taken 108 times.
✓ Branch 1 taken 108 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 108 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
216 if (rhs_pack_type == RhsPackType::NxK && (kr / sr == 8 || kr / sr == 4)) {
503
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
216 const auto [imp_packed_rhs_neon, rhs_packed_offset_neon] = pack_rhs_qsi4c32pscalebf16_neon(
504 432 N, K, bl, nr, kr, sr, ref_rhs_values_qsi4, ref_biases, bias_offset, ref_rhs_scales, rhs_pack_type,
505
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
108 rhs_start_row, rect.width());
506
5/18
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 108 times.
216 ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset);
507
508
1/2
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
216 ukernel_variant.interface.run_matmul(
509
4/8
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 108 times.
✗ Branch 7 not taken.
108 rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_matmul_offset,
510
2/4
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
108 imp_packed_rhs_neon.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset),
511 324 dst_stride_row, dst_stride_col, clamp_min, clamp_max);
512
513
3/6
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
108 const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler);
514
4/16
✓ Branch 0 taken 108 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 108 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 108 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 108 times.
108 ASSERT_TRUE(success);
515 108 }
516 224 }
517
518
21/46
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1232 times.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1232 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1232 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 616 times.
✓ Branch 29 taken 616 times.
✓ Branch 30 taken 1232 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1232 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1232 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1232 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1232 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1232 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 1232 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 1232 times.
✗ Branch 45 not taken.
2466 INSTANTIATE_TEST_SUITE_P(
519 MatMul, MatMulTest_qmatmul_clamp_f32_qai8dxp_qsi4c32p,
520 testing::Combine(
521 testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.size()), test_matmul_shapes,
522 test_block_lengths, test_portions, testing::Values(RhsPackType::NxK, RhsPackType::KxN)),
523 [](const auto& info) {
524 const auto variant_idx = std::get<0>(info.param);
525 const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_idx).name};
526 const auto shape = std::get<MatMulShape>(info.param);
527 const auto bl = std::get<2>(info.param);
528 const auto portion = std::get<3>(info.param);
529 const RhsPackType rhs_pack_type = std::get<4>(info.param);
530
531 std::ostringstream sstream;
532 sstream << name << ((rhs_pack_type == RhsPackType::NxK) ? "__NxK" : "__KxN") << "__";
533 PrintTo(shape, &sstream);
534 sstream << "__BL_" << bl << "__";
535 PrintTo(portion, &sstream);
536
537 return sstream.str();
538 });
539
540
21/46
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 224 times.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 224 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 224 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 112 times.
✓ Branch 29 taken 112 times.
✓ Branch 30 taken 224 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 224 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 224 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 224 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 224 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 224 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 224 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 224 times.
✗ Branch 45 not taken.
450 INSTANTIATE_TEST_SUITE_P(
541 MatMul, MatMulTest_qmatmul_clamp_bf16_qai8dxp_qsi4c32p,
542 testing::Combine(
543 testing::Range<size_t>(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.size()), test_matmul_shapes,
544 test_block_lengths, test_portions, testing::Values(RhsPackType::NxK, RhsPackType::KxN)),
545 [](const auto& info) {
546 const auto variant_idx = std::get<0>(info.param);
547 const std::string name{variants_kai_matmul_clamp_bf16_qai8dxp_qsi4c32p.at(variant_idx).name};
548 const auto shape = std::get<MatMulShape>(info.param);
549 const auto bl = std::get<2>(info.param);
550 const auto portion = std::get<3>(info.param);
551 const RhsPackType rhs_pack_type = std::get<4>(info.param);
552
553 std::ostringstream sstream;
554 sstream << name << ((rhs_pack_type == RhsPackType::NxK) ? "__NxK" : "__KxN") << "__";
555 PrintTo(shape, &sstream);
556 sstream << "__BL_" << bl << "__";
557 PrintTo(portion, &sstream);
558
559 return sstream.str();
560 });
561
562 } // namespace kai::test
563