KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.0% 657 / 4 / 681
Functions: 99.1% 105 / 0 / 106
Branches: 35.7% 884 / 10 / 2486

test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <ostream>
12 #include <sstream>
13 #include <string>
14 #include <tuple>
15 #include <utility>
16 #include <vector>
17
18 #include "kai/kai_common.h"
19 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h"
20 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
21 #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h"
22 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
23 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot.h"
24 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
25 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod.h"
26 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
27 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod.h"
28 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h"
29 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
30 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod.h"
31 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h"
32 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h"
33 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm.h"
34 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h"
35 #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h"
36 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h"
37 #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h"
38 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h"
39 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon.h"
40 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h"
41 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h"
42 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h"
43 #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon.h"
44 #include "test/common/abi_checker.hpp"
45 #include "test/common/bfloat16.hpp"
46 #include "test/common/buffer.hpp"
47 #include "test/common/cache.hpp"
48 #include "test/common/compare.hpp"
49 #include "test/common/cpu_info.hpp"
50 #include "test/common/data_type.hpp"
51 #include "test/common/int4.hpp"
52 #include "test/common/matmul_test_common.hpp"
53 #include "test/common/matrix_portion.hpp"
54 #include "test/common/memory.hpp"
55 #include "test/common/round.hpp"
56 #include "test/common/seed.hpp"
57 #include "test/common/test_suite.hpp"
58 #include "test/reference/cast.hpp"
59 #include "test/reference/clamp.hpp"
60 #include "test/reference/fill.hpp"
61 #include "test/reference/matmul.hpp"
62 #include "test/reference/pad.hpp"
63 #include "test/reference/quantize.hpp"
64 #include "test/reference/transpose.hpp"
65
66 namespace kai::test {
67
68 namespace {
69
70 // LHS QAI8DXP
71 using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32);
72 using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32);
73 using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32);
74 using kai_run_lhs_pack_func_t = decltype(&kai_run_lhs_quant_pack_qai8dxp_f32);
75
76 // LHS QAI8DXP pack interface
77 struct kai_qai8dxp_pack_functions {
78 kai_get_lhs_packed_size_func_t packed_size;
79 kai_get_lhs_packed_offset_func_t get_packed_offset;
80 kai_get_lhs_offset_func_t get_offset;
81 kai_run_lhs_pack_func_t run_pack;
82 };
83
84 // RHS QSI4C32P (nxk, BF16 block scales; sums float, bias float)
85 using kai_get_rhs_packed_size_func_t = decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0);
86 using kai_get_rhs_packed_offset_func_t = decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0);
87 using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0);
88 using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0);
89
90 // RHS QSI4C32P pack interface
91 struct kai_qsi4c32p_pack_functions {
92 kai_get_rhs_packed_size_func_t packed_size;
93 kai_get_rhs_packed_offset_func_t get_packed_offset;
94 kai_get_rhs_offset_func_t get_offset;
95 kai_run_rhs_pack_func_t run_pack;
96 };
97
98 80972 const auto& get_f32_gemm_variants() noexcept {
99 using Variant = UkernelMatmulPackVariant<
100 kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>;
101
102
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 80969 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
80973 static const std::array<Variant, 12> variants = {{
103
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
104 clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32,
105 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0,
106 /*rhs_s0s1_input=*/false),
107
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
108 clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32,
109 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
110
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
111 clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32,
112 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
113
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
114 clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32,
115 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
116
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
117 clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32,
118 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
119
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
120 clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32,
121 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
122
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
123 clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32,
124 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
125
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
126 clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qai8dxp_f32,
127 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
128
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
129 clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qai8dxp_f32,
130 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
131
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
132 clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qai8dxp_f32,
133 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
134
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
135 clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qai8dxp_f32,
136 rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false),
137 // SME2 MOPA
138
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
139 clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, lhs_quant_pack_qai8dxp_f32,
140 rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon, false),
141 }};
142
143 80972 return variants;
144 }
145
146 1322 const auto& get_f32_gemv_variants() noexcept {
147 using Variant = UkernelMatmulPackVariant<
148 kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>;
149
150
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1319 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
1323 static const std::array<Variant, 1> variants = {{
151
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
3 UKERNEL_MATMUL_PACK_VARIANT(
152 clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qai8dxp_f32,
153 rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon, false),
154 }};
155
156 1322 return variants;
157 }
158
159 771 const auto& get_bf16_gemm_variants() noexcept {
160 using Variant = UkernelVariant<kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_ukernel>;
161
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 768 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
774 static const std::array<Variant, 2> variants = {
162
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
7 Variant{
163 3 UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod),
164
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod_and_bf16},
165
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
5 Variant{
166 3 UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm),
167
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm", cpu_has_i8mm_and_bf16},
168 };
169 771 return variants;
170 }
171
172 // NEON/i8mm only (exclude SME2)
173 6 const auto& get_f32_neon_gemm_variants_only() {
174
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
6 static std::vector<UkernelMatmulPackVariant<
175 kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>>
176 3 filtered;
177
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
6 if (filtered.empty()) {
178 3 const auto& all = get_f32_gemm_variants();
179
2/2
✓ Branch 0 taken 36 times.
✓ Branch 1 taken 3 times.
39 for (const auto& v : all) {
180 36 const char* n = v.ukernel.name.data();
181
3/4
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 33 times.
36 if (n == nullptr || std::strstr(n, "sme2") == nullptr) {
182 33 filtered.push_back(v);
183 33 }
184 36 }
185 3 }
186 6 return filtered;
187 }
188
189 enum class RhsPackType : std::uint8_t { NxK = 0, KxN = 1 };
190
191 3430 std::tuple<Buffer, size_t> pack_lhs_qai8dxp(
192 const kai_qai8dxp_pack_functions& pack_interface, const size_t M, const size_t K, const size_t mr, const size_t kr,
193 const size_t sr, const Buffer& lhs_values_f32, const size_t lhs_stride_bytes, const size_t rect_start_row,
194 const size_t rect_height) {
195 3430 const auto lhs_packed_size = pack_interface.packed_size(M, K, mr, kr, sr);
196 3430 Buffer lhs_packed(lhs_packed_size, 0);
197
198
1/2
✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
3430 const auto lhs_offset = pack_interface.get_offset(rect_start_row, lhs_stride_bytes);
199
1/2
✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
3430 const auto lhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, mr, kr, sr);
200
201
1/2
✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
3430 abi_check(
202 3430 pack_interface.run_pack, rect_height, K, mr, kr, sr, 0,
203
1/2
✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
3430 reinterpret_cast<const float*>(lhs_values_f32.data() + lhs_offset), lhs_stride_bytes,
204
1/2
✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
3430 lhs_packed.data() + lhs_packed_offset);
205
206 3430 return {std::move(lhs_packed), lhs_packed_offset};
207 3430 }
208
209 // Executes the scalar RHS packing micro-kernel.
210 1250 std::tuple<Buffer, size_t> pack_rhs_qsi4c32pscalebf16(
211 // clang-format off
212 const size_t N,
213 const size_t K,
214 const size_t nr,
215 const size_t kr,
216 const size_t sr,
217 const size_t bl,
218 const Buffer& rhs_values_qsi4,
219 const Buffer& biases,
220 const size_t bias_offset,
221 const Buffer& rhs_scales,
222 const RhsPackType pack_type,
223 const size_t rect_start_row,
224 const size_t rect_width,
225 const bool use_ps1s0) {
226 // clang-format on
227
2/2
✓ Branch 0 taken 1154 times.
✓ Branch 1 taken 96 times.
1250 const size_t width = pack_type == RhsPackType::KxN ? N : K;
228
2/2
✓ Branch 0 taken 1154 times.
✓ Branch 1 taken 96 times.
1250 const size_t height = pack_type == RhsPackType::KxN ? K : N;
229 1250 constexpr kai_datatype scale_dt = kai_dt_bf16;
230
231 1250 const size_t rhs_stride = round_up_multiple(width, 2);
232 1250 const size_t rhs_stride_bytes = round_up_division(width, 2);
233 1250 const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt);
234
235 KAI_ASSUME_ALWAYS(rhs_values_qsi4.size() == round_up_division(height * rhs_stride, 2));
236
237 1250 const auto rhs_values_qsu4 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2);
238
1/2
✓ Branch 0 taken 1250 times.
✗ Branch 1 not taken.
1250 const size_t dst_bytes_total = round_up_division(height * rhs_stride, 2);
239 1250 const size_t dst_bytes_total_safe = dst_bytes_total + rhs_stride_bytes + 8;
240 1250 const auto rhs_qsu4 =
241
1/2
✓ Branch 0 taken 1250 times.
✗ Branch 1 not taken.
1250 pad_row<UInt4>(rhs_values_qsu4.data(), height, width, width, rhs_stride_bytes * 2, dst_bytes_total_safe);
242
243 1250 const size_t scale_offset = rect_start_row * scales_stride_bytes;
244 1250 size_t rhs_offset = 0;
245 1250 size_t rhs_packed_offset = 0;
246 1250 size_t imp_packed_rhs_size = 0;
247
248
2/2
✓ Branch 0 taken 1154 times.
✓ Branch 1 taken 96 times.
1250 if (pack_type == RhsPackType::KxN) {
249
2/2
✓ Branch 0 taken 46 times.
✓ Branch 1 taken 1108 times.
1154 if (use_ps1s0) {
250 46 rhs_offset =
251
1/2
✓ Branch 0 taken 46 times.
✗ Branch 1 not taken.
46 kai_get_rhs_offset_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes);
252
1/2
✓ Branch 0 taken 46 times.
✗ Branch 1 not taken.
46 rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
253 46 rect_start_row, K, nr, kr, sr, bl, scale_dt);
254 46 imp_packed_rhs_size =
255
1/2
✓ Branch 0 taken 46 times.
✗ Branch 1 not taken.
46 kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt);
256 46 } else {
257
1/2
✓ Branch 0 taken 1108 times.
✗ Branch 1 not taken.
1108 rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rect_start_row, rhs_stride_bytes);
258
1/2
✓ Branch 0 taken 1108 times.
✗ Branch 1 not taken.
1108 rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(
259 1108 rect_start_row, K, nr, kr, sr, bl, scale_dt);
260 1108 imp_packed_rhs_size =
261
1/2
✓ Branch 0 taken 1108 times.
✗ Branch 1 not taken.
1108 kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt);
262 }
263 1154 } else {
264
1/2
✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
96 rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rect_start_row, rhs_stride_bytes);
265 96 rhs_packed_offset =
266
1/2
✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
96 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rect_start_row, K, nr, kr, sr, bl, scale_dt);
267
1/2
✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
96 imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt);
268 }
269
270
1/2
✓ Branch 0 taken 1250 times.
✗ Branch 1 not taken.
1250 Buffer imp_packed_rhs(imp_packed_rhs_size);
271
2/2
✓ Branch 0 taken 1154 times.
✓ Branch 1 taken 96 times.
1250 if (pack_type == RhsPackType::KxN) {
272 1154 kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{};
273 1154 params.lhs_zero_point = 1;
274 1154 params.rhs_zero_point = 8;
275 1154 params.scale_dt = scale_dt;
276
277
2/2
✓ Branch 0 taken 46 times.
✓ Branch 1 taken 1108 times.
1154 if (use_ps1s0) {
278 // clang-format off
279
1/2
✓ Branch 0 taken 46 times.
✗ Branch 1 not taken.
46 abi_check(
280 kai_run_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon,
281 46 1, // num_groups
282 rect_width, // n
283 K, // k
284 nr, kr, sr, bl, // packing args
285 46 reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset),
286 rhs_stride_bytes,
287 46 reinterpret_cast<const float*>(biases.data() + bias_offset),
288 46 reinterpret_cast<const void*>(rhs_scales.data() + scale_offset),
289 scales_stride_bytes,
290 46 static_cast<void*>(imp_packed_rhs.data() + rhs_packed_offset),
291 46 0,
292 46 &params);
293 // clang-format on
294 46 } else {
295 1108 kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params_kxn{};
296 1108 params_kxn.lhs_zero_point = 1;
297 1108 params_kxn.rhs_zero_point = 8;
298 1108 params_kxn.scale_dt = scale_dt;
299
300 // clang-format off
301
1/2
✓ Branch 0 taken 1108 times.
✗ Branch 1 not taken.
1108 abi_check(
302 kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0,
303 1108 1,
304 rect_width,
305 K,
306 nr, kr, sr, bl,
307 1108 reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset),
308 rhs_stride_bytes,
309 1108 reinterpret_cast<const float*>(biases.data() + bias_offset),
310 1108 reinterpret_cast<const void*>(rhs_scales.data() + scale_offset),
311 scales_stride_bytes,
312 1108 static_cast<void*>(imp_packed_rhs.data() + rhs_packed_offset),
313 1108 0,
314 1108 &params_kxn);
315 // clang-format on
316 1108 }
317 1154 } else {
318 96 kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{};
319 96 params.lhs_zero_point = 1;
320 96 params.rhs_zero_point = 8;
321 96 params.scale_dt = scale_dt;
322
323
1/2
✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
96 abi_check(
324 // clang-format off
325 kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0,
326 96 1,
327 rect_width,
328 K,
329 nr, kr, sr, bl,
330 96 reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset),
331 rhs_stride_bytes,
332 96 reinterpret_cast<const float*>(biases.data() + bias_offset),
333 96 reinterpret_cast<const void*>(rhs_scales.data() + scale_offset),
334 scales_stride_bytes,
335 96 static_cast<void*>(imp_packed_rhs.data() + rhs_packed_offset),
336 96 0,
337 96 &params);
338 // clang-format on
339 96 }
340
341 1250 return {std::move(imp_packed_rhs), rhs_packed_offset};
342 1250 }
343
344 /// Executes RHS NxK packing helper
345 1078 std::tuple<Buffer, size_t> pack_rhs_qsi4c32p_nxk(
346 const kai_qsi4c32p_pack_functions& pack_iface, const size_t N, const size_t K, const size_t nr, const size_t kr,
347 const size_t sr, const size_t bl, const Buffer& rhs_values_qsi4, const float* bias, const Buffer& rhs_scales,
348 const size_t rect_start_row, const size_t rect_width, const bool rhs_s0s1_input) {
349 // Convert signed int4 -> unsigned int4, preserving any row padding in the source buffer.
350 1078 const auto rhs_qsu4s1s0 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2);
351
352
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 const auto rhs_packed_size = pack_iface.packed_size(N, K, nr, kr, sr, bl, kai_dt_bf16);
353
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 Buffer rhs_packed(rhs_packed_size);
354
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 const auto rhs_packed_offset = pack_iface.get_packed_offset(rect_start_row, K, nr, kr, sr, bl, kai_dt_bf16);
355
356
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 const size_t rhs_stride_bytes = round_up_division(K, 2); // bytes per row
357
2/4
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1078 times.
✗ Branch 3 not taken.
1078 const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(kai_dt_bf16);
358 1078 const size_t scale_offset = rect_start_row * scales_stride_bytes;
359
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 const size_t rhs_offset = pack_iface.get_offset(rect_start_row, rhs_stride_bytes);
360
361 1078 kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{};
362 1078 params.lhs_zero_point = 1;
363 1078 params.rhs_zero_point = 8;
364 1078 params.scale_dt = kai_dt_bf16;
365
366 // Apply optional s0s1 -> s1s0 nibble swap.
367 1078 const Buffer* rhs_qsu4_ptr = &rhs_qsu4s1s0;
368 1078 Buffer rhs_qsu4_converted;
369
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 if (rhs_s0s1_input) {
370 rhs_qsu4_converted = convert_s0s1_s1s0(rhs_qsu4s1s0);
371 rhs_qsu4_ptr = &rhs_qsu4_converted;
372 }
373
374
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 abi_check(
375 1078 pack_iface.run_pack, 1, rect_width, K, nr, kr, sr, bl,
376 1078 reinterpret_cast<const uint8_t*>(rhs_qsu4_ptr->data() + rhs_offset), rhs_stride_bytes, bias,
377 1078 rhs_scales.data() + scale_offset, scales_stride_bytes, rhs_packed.data() + rhs_packed_offset, 0, &params);
378
379 1078 return {std::move(rhs_packed), rhs_packed_offset};
380 1078 }
381
382 // Executes F32-only RHS KxN packing helper (wrapper around BF16-scaled helper for clarity)
383 1058 std::tuple<Buffer, size_t> pack_rhs_qsi4c32p_kxn(
384 const size_t N, const size_t K, const size_t nr, const size_t kr, const size_t sr, const size_t bl,
385 const Buffer& rhs_values_qsi4, const Buffer& biases, const size_t bias_offset, const Buffer& rhs_scales,
386 const size_t rect_start_row, const size_t rect_width, const bool use_ps1s0) {
387 1058 return pack_rhs_qsi4c32pscalebf16(
388 1058 N, K, nr, kr, sr, bl, rhs_values_qsi4, biases, bias_offset, rhs_scales, RhsPackType::KxN, rect_start_row,
389 1058 rect_width, use_ps1s0);
390 }
391
392 /// Executes the vectorized RHS packing micro-kernels for block length of 4 bytes or 8 bytes
393 448 std::tuple<Buffer, size_t> pack_rhs_qsi4c32pscalebf16_neon(
394 const size_t N, const size_t K, const size_t nr, const size_t kr, const size_t sr, const size_t bl,
395 const Buffer& rhs_values_qsi4, const Buffer& biases, const size_t bias_offset, const Buffer& rhs_scales,
396 const RhsPackType pack_type, const size_t rect_start_row, const size_t rect_width) {
397 KAI_ASSUME_ALWAYS(kr / sr == 8 || kr / sr == 4);
398
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 448 times.
448 const size_t width = pack_type == RhsPackType::KxN ? N : K;
399
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 448 times.
448 const size_t height = pack_type == RhsPackType::KxN ? K : N;
400 448 constexpr kai_datatype scale_dt = kai_dt_bf16;
401
402 448 const size_t rhs_stride = round_up_multiple(width, 2);
403 448 const size_t rhs_stride_bytes = round_up_division(width, 2);
404 448 const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt);
405
406 KAI_ASSUME_ALWAYS(rhs_values_qsi4.size() == round_up_division(height * rhs_stride, 2));
407
408 448 const auto rhs_values_qsu4 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2);
409
1/2
✓ Branch 0 taken 448 times.
✗ Branch 1 not taken.
448 const size_t dst_bytes_total = round_up_division(height * rhs_stride, 2);
410 448 const size_t dst_bytes_total_safe = dst_bytes_total + rhs_stride_bytes + 8;
411 448 const auto rhs_qsu4 =
412
1/2
✓ Branch 0 taken 448 times.
✗ Branch 1 not taken.
448 pad_row<UInt4>(rhs_values_qsu4.data(), height, width, width, rhs_stride_bytes * 2, dst_bytes_total_safe);
413
414 448 const size_t scale_offset = rect_start_row * scales_stride_bytes;
415
416 448 size_t imp_packed_rhs_size_neon = 0;
417 448 size_t rhs_packed_offset_neon = 0;
418 448 size_t rhs_offset_neon = 0;
419
420
2/2
✓ Branch 0 taken 128 times.
✓ Branch 1 taken 320 times.
448 if (kr / sr == 8) {
421 320 imp_packed_rhs_size_neon =
422
1/2
✓ Branch 0 taken 320 times.
✗ Branch 1 not taken.
320 kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt);
423
1/2
✓ Branch 0 taken 320 times.
✗ Branch 1 not taken.
320 rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
424 320 rect_start_row, K, nr, kr, sr, bl, scale_dt);
425 320 rhs_offset_neon =
426
1/2
✓ Branch 0 taken 320 times.
✗ Branch 1 not taken.
320 kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes);
427 320 } else {
428 128 imp_packed_rhs_size_neon =
429
1/2
✓ Branch 0 taken 128 times.
✗ Branch 1 not taken.
128 kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt);
430
1/2
✓ Branch 0 taken 128 times.
✗ Branch 1 not taken.
128 rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(
431 128 rect_start_row, K, nr, kr, sr, bl, scale_dt);
432 128 rhs_offset_neon =
433
1/2
✓ Branch 0 taken 128 times.
✗ Branch 1 not taken.
128 kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes);
434 }
435
436 448 kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{};
437 448 params.lhs_zero_point = 1;
438 448 params.rhs_zero_point = 8;
439 448 params.scale_dt = scale_dt;
440
441
1/2
✓ Branch 0 taken 448 times.
✗ Branch 1 not taken.
448 Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon);
442
2/2
✓ Branch 0 taken 128 times.
✓ Branch 1 taken 320 times.
448 if (kr / sr == 8) {
443
1/2
✓ Branch 0 taken 320 times.
✗ Branch 1 not taken.
320 kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
444 320 1, rect_width /* n */, K, nr, kr, sr, bl,
445 320 reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset_neon), rhs_stride_bytes,
446 320 reinterpret_cast<const float*>(biases.data() + bias_offset), rhs_scales.data() + scale_offset,
447 320 scales_stride_bytes, imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, &params);
448 320 } else {
449
1/2
✓ Branch 0 taken 128 times.
✗ Branch 1 not taken.
128 kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(
450 128 1, rect_width /* n */, K, nr, kr, sr, bl,
451 128 reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset_neon), rhs_stride_bytes,
452 128 reinterpret_cast<const float*>(biases.data() + bias_offset), rhs_scales.data() + scale_offset,
453 128 scales_stride_bytes, imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, &params);
454 }
455 448 return {std::move(imp_packed_rhs_neon), rhs_packed_offset_neon};
456 448 }
457
458 48840 std::string test_description(
459 const std::string& name, const RhsPackType rhs_pack_type, const MatMulShape& shape, const size_t bl,
460 const MatrixPortion& portion, const float clamp_keep_ratio) {
461 // Remove redundant prefix to make output easier to read
462 48840 std::string clean_name = name;
463
1/2
✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
48840 const std::string prefix = "kai_matmul_clamp_";
464
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 48840 times.
48840 if (clean_name.rfind(prefix, 0) == 0) { // starts with prefix
465
1/2
✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
48840 clean_name.erase(0, prefix.length());
466 48840 }
467
468
1/2
✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
48840 std::ostringstream sstream;
469
5/10
✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 48840 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 48840 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 48840 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 48840 times.
✗ Branch 9 not taken.
97680 sstream << test_description(clean_name, shape, portion, /*bias=*/false, clamp_keep_ratio) << "__BL_" << bl << "__"
470
3/4
✓ Branch 0 taken 23472 times.
✓ Branch 1 taken 25368 times.
✓ Branch 2 taken 48840 times.
✗ Branch 3 not taken.
48840 << ((rhs_pack_type == RhsPackType::NxK) ? "NxK" : "KxN");
471
472
1/2
✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
48840 return sstream.str();
473 48840 }
474
475 1884 struct TestData {
476 942 size_t M{}, N{}, K{}, bl{};
477
478 1884 Rect rect{0, 0, 0, 0};
479
480 Buffer lhs;
481 Buffer rhs;
482 Buffer bias;
483
484 Buffer rhs_quant;
485 Buffer rhs_scales;
486
487 Buffer lhs_packed;
488 942 size_t lhs_packed_offset{};
489
490 Buffer ref_dst_clamped;
491 Range<float> clamp;
492 };
493
494 using BF16QMatMulRefKey = std::tuple<
495 MatMulShape, // shape
496 size_t, // bl
497 size_t, // mr
498 size_t, // nr
499 size_t, // kr
500 size_t, // sr
501 size_t, size_t, size_t, size_t, // rect.start_row, rect.start_col, rect.height, rect.width
502 RhsPackType, // rhs_pack_type
503 float // clamp_keep_ratio
504 >;
505
506 384 struct BF16TestData {
507 192 size_t M{}, N{}, K{}, bl{};
508 384 Rect rect{0, 0, 0, 0};
509
510 Buffer lhs_bf16; // Original BF16 LHS (kept for completeness)
511 Buffer bias; // Biases (FP32)
512 Buffer rhs_quant; // QSI4 quantized RHS (possibly transposed to match pack type)
513 Buffer rhs_scales; // BF16 per-block scales
514
515 Buffer lhs_packed; // Packed LHS buffer (BF16 dynamic quant + pack)
516 192 size_t lhs_packed_offset{}; // Offset for rect.start_row
517
518 384 Range<float> clamp{}; // Clamp range used for matmul
519 Buffer ref_dst_bf16; // Reference DST in BF16 (clamped)
520 };
521
522 } // anonymous namespace
523
524 using QMatmulClampF32ParamT = std::tuple<size_t, bool, MatMulShape, size_t, MatrixPortion, RhsPackType, float>;
525
526 16088 class QMatMulClampF32Test : public ::testing::TestWithParam<QMatmulClampF32ParamT> {
527 struct TestParams {
528 const UkernelMatmulPackVariant<
529 kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>*
530 variant;
531 size_t variant_index;
532 MatMulShape matmul_shape;
533 size_t bl;
534 MatrixPortion portion;
535 RhsPackType rhs_pack_type;
536 Rect rect;
537 float clamp_keep_ratio;
538 bool is_sme2;
539
540 24132 TestParams() :
541 16088 variant(nullptr),
542 16088 variant_index(0),
543 16088 matmul_shape{0, 0, 0},
544 16088 bl(32),
545 16088 portion(0, 0, 1, 1),
546 16088 rhs_pack_type(RhsPackType::NxK),
547 16088 rect(0, 0, 0, 0),
548 16088 clamp_keep_ratio(0.8F),
549 24132 is_sme2(false) {
550 24132 }
551
552 TestParams(
553 const UkernelMatmulPackVariant<
554 kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>*
555 variant,
556 const size_t v_idx, const MatMulShape& shape, const size_t bl, const MatrixPortion p, const RhsPackType r,
557 const Rect& rect, const float clamp_keep_ratio) :
558 variant(variant),
559 variant_index(v_idx),
560 matmul_shape(shape),
561 bl(bl),
562 portion(p),
563 rhs_pack_type(r),
564 rect(rect),
565 clamp_keep_ratio(clamp_keep_ratio),
566 is_sme2(false) {
567 }
568 };
569
570 TestParams params;
571
572 protected:
573 static const TestData& test_data();
574 15304 void SetupCommonForParam() {
575 15304 TestWithParam::SetUp();
576
2/2
✓ Branch 0 taken 15164 times.
✓ Branch 1 taken 140 times.
15304 if (std::get<1>(GetParam())) { // is_gemm
577 15164 SetupCommon(get_f32_gemm_variants());
578 15164 } else {
579 140 SetupCommon(get_f32_gemv_variants());
580 }
581 15304 }
582
583 [[nodiscard]] const TestParams& GetParams() const {
584 return params;
585 }
586 30608 TestParams& GetParams() {
587 30608 return params;
588 }
589
590 16088 void SetUp() override {
591 // Gate CPU features before computing kernel interface params (which may touch unsupported instructions).
592 16088 const auto& param = GetParam();
593 16088 const size_t variant_index = std::get<0>(param);
594 16088 const bool is_gemm = std::get<1>(param);
595 32176 const auto& variant =
596
2/2
✓ Branch 0 taken 15808 times.
✓ Branch 1 taken 280 times.
16088 is_gemm ? get_f32_gemm_variants().at(variant_index) : get_f32_gemv_variants().at(variant_index);
597
598
3/4
✓ Branch 0 taken 16088 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15304 times.
✓ Branch 3 taken 784 times.
16088 if (variant.ukernel.fn_is_supported && !variant.ukernel.fn_is_supported()) {
599
3/6
✓ Branch 0 taken 784 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 784 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 784 times.
✗ Branch 5 not taken.
784 GTEST_SKIP() << "Unsupported CPU feature";
600 return;
601 }
602
603 // Safe to compute aligned params/rect now.
604 15304 SetupCommonForParam();
605 15304 const auto& p = GetParams();
606
607 // GEMV vs GEMM constraints (after params are set)
608
2/2
✓ Branch 0 taken 15164 times.
✓ Branch 1 taken 140 times.
15304 if (!is_gemm) {
609
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 140 times.
140 if (p.matmul_shape.m != 1) {
610 GTEST_SKIP() << "GEMV requires M=1";
611 return;
612 }
613
2/4
✓ Branch 0 taken 140 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 140 times.
140 if (p.rect.height() != 1 || p.rect.start_row() != 0) {
614 GTEST_SKIP() << "GEMV portion invalid, rect height != 1 or start_row != 0";
615 return;
616 }
617 140 }
618 16088 }
619
620 template <size_t ArrN>
621 15304 void SetupCommon(
622 const std::array<
623 UkernelMatmulPackVariant<
624 kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>,
625 ArrN>& variants) {
626 153040 const auto& [variant_index, is_gemm, shape, bl, portion, rhs_dir, clamp_keep_ratio] = GetParam();
627 30608 const auto& variant = variants.at(variant_index);
628
629 15304 params.variant = &variant;
630 15304 params.variant_index = variant_index;
631
632 // Compute aligned portion rect once
633 15304 const size_t m_step = variant.ukernel.interface.get_m_step();
634 15304 const size_t n_step = variant.ukernel.interface.get_n_step();
635 45912 const Rect rect = portion.compute_portion(shape.m, shape.n, m_step, n_step);
636
637 15304 params.matmul_shape = shape;
638 15304 params.bl = bl;
639 15304 params.portion = portion;
640 15304 params.rhs_pack_type = rhs_dir;
641 15304 params.rect = rect;
642 15304 params.clamp_keep_ratio = clamp_keep_ratio;
643 15304 params.is_sme2 =
644
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 15164 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 140 times.
15304 (variant.ukernel.name.data() != nullptr && std::strstr(variant.ukernel.name.data(), "sme2") != nullptr);
645 15304 }
646 };
647
648 using F32QMatMulRefKey = std::tuple<
649 MatMulShape, // shape
650 size_t, // bl
651 size_t, // mr
652 size_t, // kr
653 size_t, // sr
654 size_t, // rect_start_row
655 size_t, // rect_start_col
656 size_t, // rect_height
657 size_t, // rect_width
658 RhsPackType, // rhs_pack_type
659 int, // clamp_pct
660 const void* // lhs_pack_key
661 >;
662
663 template <>
664 942 TestData ReferenceGenerator<F32QMatMulRefKey, TestData>::generate_reference(const F32QMatMulRefKey& test_id) {
665 1884 TestData ref{};
666
667 8478 const auto& [shape, bl, mr, kr, sr, rect_start_row, rect_start_col, rect_height, rect_width, rhs_pack_type, clamp_pct, lhs_pack_key] =
668 942 test_id;
669 KAI_UNUSED(lhs_pack_key);
670 1884 const float clamp_keep_ratio = static_cast<float>(clamp_pct) / 100.0F;
671
5/10
✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 942 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 942 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 942 times.
✗ Branch 9 not taken.
4710 const Rect rect(rect_start_row, rect_start_col, rect_height, rect_width);
672
673 942 ref.M = shape.m;
674 942 ref.N = shape.n;
675 942 ref.K = shape.k;
676 942 ref.bl = bl;
677 942 ref.rect = rect;
678
679 // Creates a unique seed for the test data.
680
8/16
✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 942 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 942 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 942 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 942 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 942 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 942 times.
✗ Branch 15 not taken.
2826 const auto key = std::string("F32QMatMulRefKey:") + std::to_string(ref.M) + "x" + std::to_string(ref.N) + "x" +
681
10/18
✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 942 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 942 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 942 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 942 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 462 times.
✓ Branch 13 taken 480 times.
✓ Branch 14 taken 942 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 942 times.
✗ Branch 17 not taken.
1884 std::to_string(ref.K) + "_" + std::to_string(bl) + "_" + ((rhs_pack_type == RhsPackType::NxK) ? "NxK" : "KxN") +
682
2/4
✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
942 "_" + std::to_string(clamp_keep_ratio);
683
1/2
✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
942 auto& feed = seed_stream(key);
684
685
2/4
✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
942 ref.lhs = fill_random<float>(ref.M * ref.K, feed());
686
2/4
✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
942 ref.rhs = fill_random<float>(ref.N * ref.K, feed());
687
2/4
✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
942 ref.bias = fill_random<float>(ref.N, feed());
688
689 // Dynamic LHS quantization (reference only).
690
1/2
✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
942 QuantizationInfo lhs_qinfo{};
691 lhs_qinfo.quant_width = ref.K;
692 lhs_qinfo.dst_type = DataType::QAI8;
693 lhs_qinfo.scale_type = DataType::FP32;
694 lhs_qinfo.zero_point_type = DataType::I32;
695 auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref.lhs.data(), DataType::FP32, ref.M, ref.K, lhs_qinfo);
696
697 // Dynamic RHS quantization to QSI4 with BF16 block scales.
698 QuantizationInfo rhs_qinfo{};
699 rhs_qinfo.quant_width = bl;
700 rhs_qinfo.dst_type = DataType::QSI4;
701 rhs_qinfo.scale_type = DataType::BF16;
702 auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref.rhs.data(), DataType::FP32, ref.N, ref.K, rhs_qinfo);
703
704 ref.rhs_quant = std::move(ref_rhs_quant);
705 ref.rhs_scales = std::move(rhs_qoutputs.scales);
706
707 const bool transposed = (rhs_pack_type == RhsPackType::NxK);
708 const size_t width = transposed ? ref.K : ref.N;
709 const size_t height = transposed ? ref.N : ref.K;
710
711 const size_t qsi4_stride = round_up_multiple(width, 2);
712 const size_t qsi4_size_bytes = round_up_division(height * qsi4_stride, 2);
713
714 if (!transposed) {
715 ref.rhs_quant =
716 transpose_with_padding<Int4>(ref.rhs_quant.data(), ref.N, ref.K, ref.K, qsi4_stride, qsi4_size_bytes);
717 }
718
719 Buffer ref_dst_noclamp;
720 if (transposed) {
721 ref_dst_noclamp =
722 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
723 ref.M, ref.N, ref.K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(),
724 1, ref.K, ref.rhs_quant.data(), ref.rhs_scales.data(), nullptr, 1, bl, ref.bias.data(), nullptr,
725 nullptr, 1);
726 } else {
727 ref_dst_noclamp = matmul_nt_nt_quantized<
728 int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
729 ref.M, ref.N, ref.K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), 1,
730 ref.K, ref.rhs_quant.data(), ref.rhs_scales.data(), nullptr, 1, bl, ref.bias.data(), nullptr, nullptr, 1);
731 }
732
733 const auto [cmin, cmax] = find_clamp_range<float>(ref_dst_noclamp.data(), ref.M * ref.N, clamp_keep_ratio);
734 ref.clamp = {cmin, cmax};
735 ref.ref_dst_clamped = clamp<float>(ref_dst_noclamp.data(), ref.M * ref.N, cmin, cmax);
736
737 // Pack LHS once for this key.
738 const size_t lhs_stride_bytes = ref.K * sizeof(float);
739 constexpr kai_qai8dxp_pack_functions lhs_iface{
740 kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32,
741 kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32,
742 kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
743 kai_run_lhs_quant_pack_qai8dxp_f32,
744 };
745
746 auto [lhs_packed, lhs_packed_offset] = pack_lhs_qai8dxp(
747 lhs_iface, ref.M, ref.K, mr, kr, sr, ref.lhs, lhs_stride_bytes, rect.start_row(), rect.height());
748
749 ref.lhs_packed = std::move(lhs_packed);
750 ref.lhs_packed_offset = lhs_packed_offset;
751
752 return ref;
753 }
754
755 48264 [[maybe_unused]] static void PrintTo(const QMatmulClampF32ParamT& param, std::ostream* os) {
756 193056 const auto& [variant_idx, is_gemm, shape, bl, portion, rhs_pack_type, clamp_keep_ratio] = param;
757
1/2
✓ Branch 0 taken 32176 times.
✗ Branch 1 not taken.
96528 const auto name = std::string(
758
2/2
✓ Branch 0 taken 47424 times.
✓ Branch 1 taken 840 times.
48264 (is_gemm ? get_f32_gemm_variants().at(variant_idx).ukernel.name
759 1680 : get_f32_gemv_variants().at(variant_idx).ukernel.name));
760
5/10
✓ Branch 0 taken 48264 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 48264 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 48264 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 48264 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 48264 times.
✗ Branch 9 not taken.
193056 *os << test_description(name, rhs_pack_type, shape, bl, portion, clamp_keep_ratio);
761 48264 }
762
763 2488 const TestData& QMatMulClampF32Test::test_data() {
764 2488 const auto& param = GetParam();
765 2488 const size_t variant_index = std::get<0>(param);
766 2488 const bool is_gemm = std::get<1>(param);
767 2488 const MatMulShape& shape = std::get<2>(param);
768 2488 const size_t bl = std::get<3>(param);
769 2488 const MatrixPortion& portion = std::get<4>(param);
770 2488 const RhsPackType rhs_pack_type = std::get<5>(param);
771 2488 const float clamp_keep_ratio = std::get<6>(param);
772
773 4976 const auto& variant =
774
2/2
✓ Branch 0 taken 2468 times.
✓ Branch 1 taken 20 times.
2488 is_gemm ? get_f32_gemm_variants().at(variant_index) : get_f32_gemv_variants().at(variant_index);
775 2488 const auto& iface = variant.ukernel.interface;
776
777 2488 const size_t mr = iface.get_mr();
778 2488 const size_t kr = iface.get_kr();
779 2488 const size_t sr = iface.get_sr();
780 2488 const size_t m_step = iface.get_m_step();
781 2488 const size_t n_step = iface.get_n_step();
782 2488 const Rect rect = portion.compute_portion(shape.m, shape.n, m_step, n_step);
783
784 2488 const int clamp_pct = static_cast<int>(clamp_keep_ratio * 100 + 0.5F);
785
786 4976 const F32QMatMulRefKey key{
787 2488 shape,
788 bl,
789 mr,
790 kr,
791 sr,
792 2488 rect.start_row(),
793 2488 rect.start_col(),
794 2488 rect.height(),
795 2488 rect.width(),
796 rhs_pack_type,
797 clamp_pct,
798 2488 reinterpret_cast<const void*>(variant.lhs_pack_interface.run_pack)};
799
800 4976 return getV<F32QMatMulRefKey, TestData>(key);
801 2488 }
802
803 using MatMulTestParams_withBL_withRHSPackType =
804 std::tuple<size_t, MatMulShape, size_t, MatrixPortion, RhsPackType, float>;
805
806 576 [[maybe_unused]] static void PrintTo(const MatMulTestParams_withBL_withRHSPackType& param, std::ostream* os) {
807 576 const size_t variant_idx = std::get<0>(param);
808 576 const MatMulShape shape = std::get<1>(param);
809 576 const size_t bl = std::get<2>(param);
810 576 const MatrixPortion portion = std::get<3>(param);
811 576 const RhsPackType rhs_pack_type = std::get<4>(param);
812
1/2
✓ Branch 0 taken 384 times.
✗ Branch 1 not taken.
576 const std::string name{get_bf16_gemm_variants().at(variant_idx).name};
813 576 const float clamp_keep_ratio = std::get<5>(param);
814
2/4
✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
576 *os << test_description(name, rhs_pack_type, shape, bl, portion, clamp_keep_ratio);
815 576 }
816
817 template <>
818 192 BF16TestData ReferenceGenerator<BF16QMatMulRefKey, BF16TestData>::generate_reference(const BF16QMatMulRefKey& test_id) {
819 576 BF16TestData ref{};
820
821 192 const MatMulShape shape = std::get<0>(test_id);
822 192 const size_t bl = std::get<1>(test_id);
823 192 const size_t mr = std::get<2>(test_id);
824 192 const size_t nr = std::get<3>(test_id);
825 KAI_UNUSED(nr);
826 192 const size_t kr = std::get<4>(test_id);
827 192 const size_t sr = std::get<5>(test_id);
828 192 const size_t rect_start_row = std::get<6>(test_id);
829 192 const size_t rect_start_col = std::get<7>(test_id);
830 192 const size_t rect_height = std::get<8>(test_id);
831 192 const size_t rect_width = std::get<9>(test_id);
832 192 const RhsPackType rhs_pack_type = std::get<10>(test_id);
833 192 const float clamp_keep_ratio = std::get<11>(test_id);
834
835 192 ref.M = shape.m;
836 192 ref.N = shape.n;
837 192 ref.K = shape.k;
838 192 ref.bl = bl;
839
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 ref.rect = Rect(rect_start_row, rect_start_col, rect_height, rect_width);
840
841 // Creates a unique seed for the test data.
842
8/16
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 192 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 192 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 192 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 192 times.
✗ Branch 15 not taken.
576 const auto key = std::string("BF16QMatMulRefKey:") + std::to_string(ref.M) + "x" + std::to_string(ref.N) + "x" +
843
9/16
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 192 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 96 times.
✓ Branch 11 taken 96 times.
✓ Branch 12 taken 192 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 192 times.
✗ Branch 15 not taken.
384 std::to_string(ref.K) + "_" + std::to_string(bl) + "_" + ((rhs_pack_type == RhsPackType::NxK) ? "NxK" : "KxN") +
844
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 "_" + std::to_string(clamp_keep_ratio);
845
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 auto& feed = seed_stream(key);
846
847 // Inputs
848
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 ref.lhs_bf16 = fill_random<BFloat16<false>>(ref.M * ref.K, feed());
849
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 Buffer const ref_rhs = fill_random<float>(ref.N * ref.K, feed());
850
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 ref.bias = fill_random<float>(ref.N, feed());
851
852 // Cast BF16 LHS to FP32 for reference quantization
853 192 const Buffer ref_lhs =
854
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 cast<float, BFloat16<false>>(ref.lhs_bf16.data(), ref.lhs_bf16.size() * 8 / size_in_bits<BFloat16<false>>);
855
856 // Reference quantizations for LHS and RHS
857
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 QuantizationInfo lhs_qinfo{};
858 lhs_qinfo.quant_width = ref.K;
859 lhs_qinfo.dst_type = DataType::QAI8;
860 lhs_qinfo.scale_type = DataType::FP32;
861 lhs_qinfo.zero_point_type = DataType::I32;
862 auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, ref.M, ref.K, lhs_qinfo);
863
864 QuantizationInfo rhs_qinfo{};
865 rhs_qinfo.quant_width = bl;
866 rhs_qinfo.dst_type = DataType::QSI4;
867 rhs_qinfo.scale_type = DataType::BF16;
868 auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, ref.N, ref.K, rhs_qinfo);
869
870 // Prepare RHS layout per pack type
871 const bool transposed = (rhs_pack_type == RhsPackType::NxK);
872 const size_t width = transposed ? ref.K : ref.N;
873 const size_t height = transposed ? ref.N : ref.K;
874
875 const size_t qsi4_stride = round_up_multiple(width, 2);
876 const size_t qsi4_size_bytes = round_up_division(height * qsi4_stride, 2);
877
878 ref.rhs_quant = std::move(ref_rhs_quant);
879 if (!transposed) {
880 ref.rhs_quant =
881 transpose_with_padding<Int4>(ref.rhs_quant.data(), ref.N, ref.K, ref.K, qsi4_stride, qsi4_size_bytes);
882 }
883 ref.rhs_scales = std::move(rhs_qoutputs.scales);
884
885 // Compute reference destination (float), clamp, and cast to BF16
886 Buffer ref_dst_noclamp;
887 if (transposed) {
888 ref_dst_noclamp =
889 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
890 ref.M, ref.N, ref.K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(),
891 1, ref.K, ref.rhs_quant.data(), ref.rhs_scales.data(), nullptr, 1, bl, ref.bias.data(), nullptr,
892 nullptr, 1);
893 } else {
894 ref_dst_noclamp = matmul_nt_nt_quantized<
895 int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
896 ref.M, ref.N, ref.K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), 1,
897 ref.K, ref.rhs_quant.data(), ref.rhs_scales.data(), nullptr, 1, bl, ref.bias.data(), nullptr, nullptr, 1);
898 }
899
900 const auto [clamp_min, clamp_max] =
901 find_clamp_range<float>(ref_dst_noclamp.data(), ref.M * ref.N, clamp_keep_ratio);
902 ref.clamp = {clamp_min, clamp_max};
903 const Buffer ref_dst_float = clamp<float>(ref_dst_noclamp.data(), ref.M * ref.N, clamp_min, clamp_max);
904 ref.ref_dst_bf16 =
905 cast<BFloat16<false>, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>);
906
907 // Pack LHS once (BF16 packer)
908 const size_t imp_packed_lhs_size =
909 kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(ref.M, ref.K, mr, kr, sr);
910 ref.lhs_packed = Buffer(imp_packed_lhs_size);
911
912 const size_t lhs_stride = ref.K * sizeof(uint16_t);
913 const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(rect_start_row, lhs_stride);
914 ref.lhs_packed_offset =
915 kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(rect_start_row, ref.K, mr, kr, sr);
916
917 kai_run_lhs_quant_pack_qai8dxp_bf16_neon(
918 rect_height, ref.K, mr, kr, sr, 0, ref.lhs_bf16.data() + lhs_offset, lhs_stride,
919 reinterpret_cast<uint8_t*>(ref.lhs_packed.data()) + ref.lhs_packed_offset);
920
921 return ref;
922 }
923
924 /// Verifies RHS packed offsets (KxN vs NxK) match each other and the matmul interface at n_step.
925
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5514 TEST_P(QMatMulClampF32Test, OffsetRHS) {
926 2136 const auto& p = GetParams();
927 2136 const auto fn_supported = p.variant->ukernel.fn_is_supported;
928
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
2136 if (fn_supported && !fn_supported()) {
929 GTEST_SKIP() << "Unsupported CPU feature";
930 return;
931 }
932
933 2136 const auto& ukernel = p.variant->ukernel;
934 2136 const size_t K = p.matmul_shape.k;
935 2136 const size_t bl = p.bl;
936
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto nr = ukernel.interface.get_nr();
937
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto kr = ukernel.interface.get_kr();
938
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto sr = ukernel.interface.get_sr();
939
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto n_step = ukernel.interface.get_n_step();
940
941 2136 const auto rhs_packed_offset_kxn =
942
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(n_step, K, nr, kr, sr, bl, kai_dt_bf16);
943
2/4
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
4272 const auto rhs_packed_offset_kxn_ps1s0 = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
944 2136 n_step, K, nr, kr, sr, bl, kai_dt_bf16);
945 2136 const auto rhs_packed_offset_nxk =
946
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n_step, K, nr, kr, sr, bl, kai_dt_bf16);
947 2136 const auto rhs_packed_offset_nxk_ps1s0_nrx4 =
948
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(
949 2136 n_step, K, nr, kr, sr, bl, kai_dt_bf16);
950
951
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(rhs_packed_offset_kxn, rhs_packed_offset_kxn_ps1s0);
952
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(rhs_packed_offset_kxn_ps1s0, rhs_packed_offset_nxk);
953
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(rhs_packed_offset_nxk, rhs_packed_offset_nxk_ps1s0_nrx4);
954
955
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto rhs_matmul_offset = ukernel.interface.get_rhs_packed_offset(n_step, K, bl);
956
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(rhs_packed_offset_kxn, rhs_matmul_offset);
957 2136 }
958
959 /// Verifies LHS packed offset matches the matmul interface at m_step.
960
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5514 TEST_P(QMatMulClampF32Test, OffsetLHS) {
961 2136 const auto& p = GetParams();
962 2136 const auto fn_supported = p.variant->ukernel.fn_is_supported;
963
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
2136 if (fn_supported && !fn_supported()) {
964 GTEST_SKIP() << "Unsupported CPU feature";
965 return;
966 }
967
968 2136 const auto& ukernel = p.variant->ukernel;
969 2136 const size_t K = p.matmul_shape.k;
970
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto mr = ukernel.interface.get_mr();
971
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto kr = ukernel.interface.get_kr();
972
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto sr = ukernel.interface.get_sr();
973
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto m_step = ukernel.interface.get_m_step();
974
975
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(m_step, K, mr, kr, sr);
976
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto lhs_matmul_offset = ukernel.interface.get_lhs_packed_offset(m_step, K);
977
978
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
979 2136 }
980
981 /// Verifies the kernel’s get_dst_offset computes row/col addressing correctly at tile-aligned starts:
982
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5514 TEST_P(QMatMulClampF32Test, OffsetDst) {
983 2136 const auto& p = GetParams();
984 2136 const auto fn_supported = p.variant->ukernel.fn_is_supported;
985
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
2136 if (fn_supported && !fn_supported()) {
986 GTEST_SKIP() << "Unsupported CPU feature";
987 return;
988 }
989
990 2136 const auto& ukernel = p.variant->ukernel;
991 2136 const size_t M = p.matmul_shape.m;
992 2136 const size_t N = p.matmul_shape.n;
993
994 2136 const auto dst_stride_row = N * sizeof(float);
995 2136 constexpr auto dst_stride_col = sizeof(float);
996
997
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto m_step = ukernel.interface.get_m_step();
998
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto n_step = ukernel.interface.get_n_step();
999
1000
5/18
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
2136 ASSERT_TRUE(m_step % ukernel.interface.get_mr() == 0);
1001
5/18
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
2136 ASSERT_TRUE(n_step % ukernel.interface.get_nr() == 0);
1002
1003
2/2
✓ Branch 0 taken 2116 times.
✓ Branch 1 taken 20 times.
2136 const size_t m_idx = (M > m_step) ? m_step : 0;
1004
2/2
✓ Branch 0 taken 2094 times.
✓ Branch 1 taken 42 times.
2136 const size_t n_idx = (N > n_step) ? n_step : 0;
1005
1006
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto off00 = ukernel.interface.get_dst_offset(0, 0, dst_stride_row);
1007
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(off00, 0U);
1008
1009
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto off10 = ukernel.interface.get_dst_offset(m_idx, 0, dst_stride_row);
1010
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(off10, m_idx * dst_stride_row);
1011
1012
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto off01 = ukernel.interface.get_dst_offset(0, n_idx, dst_stride_row);
1013
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(off01, n_idx * dst_stride_col);
1014
1015
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto off11 = ukernel.interface.get_dst_offset(m_idx, n_idx, dst_stride_row);
1016
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(off11, m_idx * dst_stride_row + n_idx * dst_stride_col);
1017 2136 }
1018
1019 /// Sanity-checks kernel interface parameters (mr/nr/kr/sr and step alignment).
1020
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5514 TEST_P(QMatMulClampF32Test, KernelInvariants) {
1021 2136 const auto& p = GetParams();
1022 2136 const auto fn_supported = p.variant->ukernel.fn_is_supported;
1023
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
2136 if (fn_supported && !fn_supported()) {
1024 GTEST_SKIP() << "Unsupported CPU feature";
1025 return;
1026 }
1027
1028 2136 const auto& ukernel = p.variant->ukernel;
1029
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto mr = ukernel.interface.get_mr();
1030
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto nr = ukernel.interface.get_nr();
1031
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto kr = ukernel.interface.get_kr();
1032
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto sr = ukernel.interface.get_sr();
1033
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto m_step = ukernel.interface.get_m_step();
1034
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto n_step = ukernel.interface.get_n_step();
1035
1036
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_GT(mr, 0U);
1037
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_GT(nr, 0U);
1038
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_GT(kr, 0U);
1039
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_GT(sr, 0U);
1040
1041
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(m_step % mr, 0U);
1042
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(n_step % nr, 0U);
1043
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(kr % sr, 0U);
1044 2136 }
1045
1046 /// Verifies RHS row stride using difference of offsets equals the layout formula.
1047
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5514 TEST_P(QMatMulClampF32Test, RhsStrideByDifference) {
1048 2136 const auto& p = GetParams();
1049 2136 const auto fn_supported = p.variant->ukernel.fn_is_supported;
1050
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
2136 if (fn_supported && !fn_supported()) {
1051 GTEST_SKIP() << "Unsupported CPU feature";
1052 return;
1053 }
1054
1055 2136 const auto& ukernel = p.variant->ukernel;
1056 2136 const size_t K = p.matmul_shape.k;
1057 2136 const size_t bl = p.bl;
1058
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto nr = ukernel.interface.get_nr();
1059
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto n_step = ukernel.interface.get_n_step();
1060
1061 // Stride by difference using kernel offsets at 0 and n_step.
1062
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const size_t off0 = ukernel.interface.get_rhs_packed_offset(0, K, bl);
1063
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const size_t off1 = ukernel.interface.get_rhs_packed_offset(n_step, K, bl);
1064 2136 const size_t stride_by_diff = off1 - off0;
1065
1066 // Expected stride formula for qsi4c32p with BF16 scales:
1067 // nr * ( num_blocks * (bl/2 + 2) + 4 /*rsum*/ + 4 /*bias*/ )
1068
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const size_t k_internal = round_up_multiple(K, 32);
1069
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const size_t num_blocks = round_up_division(k_internal, bl);
1070 2136 const size_t bytes_per_block = (bl / 2) + 2; // int4 values + BF16 scale
1071 2136 const size_t expected_stride = nr * (num_blocks * bytes_per_block) + nr * 4 + nr * 4;
1072
1073
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_EQ(stride_by_diff, expected_stride);
1074 2136 }
1075
1076 /// Validation of the packed group slice against a reconstructed reference.
1077
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5514 TEST_P(QMatMulClampF32Test, LhsPackBufferMatchesReference) {
1078 2136 const auto& p = GetParams();
1079 2136 const auto fn_supported = p.variant->ukernel.fn_is_supported;
1080
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
2136 if (fn_supported && !fn_supported()) {
1081 GTEST_SKIP() << "Unsupported CPU feature";
1082 return;
1083 }
1084 2136 const auto& uk = p.variant->ukernel;
1085
1086 2136 const size_t M = p.matmul_shape.m;
1087 2136 const size_t K = p.matmul_shape.k;
1088
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const size_t mr = uk.interface.get_mr();
1089
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const size_t kr = uk.interface.get_kr();
1090
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const size_t sr = uk.interface.get_sr();
1091
1092 2136 const size_t k_block_len = kr / sr;
1093 2136 const size_t k_internal = ((K + 31) / 32) * 32;
1094
1095 2136 const size_t i8_region_bytes = mr * k_internal;
1096 2136 const size_t neg_zero_point_region_bytes = mr * sizeof(int32_t);
1097 2136 const size_t recip_scale_region_bytes = mr * sizeof(float);
1098 2136 const size_t group_stride = i8_region_bytes + neg_zero_point_region_bytes + recip_scale_region_bytes;
1099
1100 2136 constexpr size_t rect_start_row = 0;
1101 2136 constexpr size_t rect_height = 1;
1102
1103
4/8
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
2136 const auto ref_lhs = fill_random<float>(M * K, seed_stream(current_test_key())());
1104
1105 2136 const size_t lhs_stride = K * sizeof(float);
1106
2/4
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
4272 std::tuple<Buffer, size_t> pack_pair = pack_lhs_qai8dxp(
1107 2136 p.variant->lhs_pack_interface, M, K, mr, kr, sr, ref_lhs, lhs_stride, rect_start_row, rect_height);
1108
1109 2136 Buffer const lhs_packed = std::move(std::get<0>(pack_pair));
1110 2136 const size_t lhs_packed_off = std::get<1>(pack_pair);
1111
1112
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 QuantizationInfo lhs_qinfo{};
1113 lhs_qinfo.quant_width = K;
1114 lhs_qinfo.dst_type = DataType::QAI8;
1115 lhs_qinfo.scale_type = DataType::FP32;
1116 lhs_qinfo.zero_point_type = DataType::I32;
1117 auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo);
1118
1119 Buffer const expected(group_stride, 0);
1120 std::byte* expected_bytes = expected.data();
1121
1122 // Build reference layout into `expected`
1123 constexpr size_t lane_row_idx = rect_start_row;
1124 const size_t lane = lane_row_idx % mr;
1125 const size_t ref_row_base = lane_row_idx * K;
1126 const auto pad_val = read_array<int8_t>(ref_lhs_quant.data(), ref_row_base + (K - 1));
1127
1128 size_t ref_idx = 0;
1129 const size_t num_blocks_internal = k_internal / k_block_len;
1130
1131 for (size_t b = 0; b < num_blocks_internal; ++b) {
1132 const size_t block_base = b * mr * k_block_len;
1133 const size_t lane_offset = block_base + lane * k_block_len;
1134
1135 for (size_t i = 0; i < k_block_len; ++i) {
1136 const size_t dst_index = lane_offset + i;
1137 const bool in_range = ref_idx < K;
1138
1139 const int8_t val = in_range ? read_array<int8_t>(ref_lhs_quant.data(), ref_row_base + ref_idx) : pad_val;
1140
1141 write_array<int8_t>(expected_bytes, dst_index, val);
1142
1143 if (in_range) {
1144 ++ref_idx;
1145 }
1146 }
1147 }
1148
1149 // Header (per-lane): neg_zero_point, recip_scale
1150 const size_t neg_zero_point_elem_base = i8_region_bytes / sizeof(int32_t);
1151 const size_t recip_scale_elem_base = (i8_region_bytes + neg_zero_point_region_bytes) / sizeof(float);
1152
1153 write_array<int32_t>(
1154 expected_bytes, neg_zero_point_elem_base + lane,
1155 -read_array<int32_t>(lhs_qoutputs.zero_points.data(), lane_row_idx));
1156
1157 write_array<float>(
1158 expected_bytes, recip_scale_elem_base + lane, read_array<float>(lhs_qoutputs.scales.data(), lane_row_idx));
1159
1160 // Validate packed buffer vs reference
1161 KAI_ASSUME_ALWAYS(lhs_packed_off + group_stride <= lhs_packed.size());
1162
1163 // Int8 region: allow ±1 LSB
1164 for (size_t i = 0; i < i8_region_bytes; ++i) {
1165 const auto g = read_array<int8_t>(lhs_packed.data(), lhs_packed_off + i);
1166 const auto e = read_array<int8_t>(expected.data(), i);
1167 const int dq = static_cast<int>(g) - static_cast<int>(e);
1168 EXPECT_LE(std::abs(dq), 1) << "int8 mismatch at byte " << i << " (got=" << static_cast<int>(g)
1169 << ", exp=" << static_cast<int>(e) << ", dq=" << dq << ")";
1170 }
1171
1172 // Region offsets (in bytes)
1173 const size_t neg_zero_point_offset = i8_region_bytes;
1174 const size_t recip_scale_offset = neg_zero_point_offset + neg_zero_point_region_bytes;
1175
1176 // neg_zero_point (exact)
1177 for (size_t hdr_lane = 0; hdr_lane < mr; ++hdr_lane) {
1178 const auto gzp = read_array<int32_t>(
1179 lhs_packed.data(), lhs_packed_off / sizeof(int32_t) + (neg_zero_point_offset / sizeof(int32_t)) + hdr_lane);
1180 const auto ezp = read_array<int32_t>(expected.data(), (neg_zero_point_offset / sizeof(int32_t)) + hdr_lane);
1181 EXPECT_EQ(gzp, ezp) << "neg_zp mismatch at lane " << hdr_lane;
1182 }
1183
1184 // recip_scale (near-equal)
1185 for (size_t hdr_lane = 0; hdr_lane < mr; ++hdr_lane) {
1186 const auto gsc = read_array<float>(
1187 lhs_packed.data(), lhs_packed_off / sizeof(float) + (recip_scale_offset / sizeof(float)) + hdr_lane);
1188 const auto esc = read_array<float>(expected.data(), (recip_scale_offset / sizeof(float)) + hdr_lane);
1189 EXPECT_NEAR(gsc, esc, 1e-5F) << "recip_scale mismatch at lane " << hdr_lane;
1190 }
1191 }
1192
1193
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
5514 TEST_P(QMatMulClampF32Test, EndToEnd) {
1194 2136 const auto& p = GetParams();
1195 2136 const auto fn_supported = p.variant->ukernel.fn_is_supported;
1196
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
2136 if (fn_supported && !fn_supported()) {
1197 GTEST_SKIP() << "Unsupported CPU feature";
1198 return;
1199 }
1200 2136 const auto& ukernel = p.variant->ukernel;
1201
1202 2136 const size_t bl = p.bl;
1203 2136 const RhsPackType rhs_pack_type = p.rhs_pack_type;
1204
1205 KAI_ASSUME_ALWAYS(bl % 32 == 0);
1206
1207
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto nr = ukernel.interface.get_nr();
1208
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto kr = ukernel.interface.get_kr();
1209
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto sr = ukernel.interface.get_sr();
1210
1211
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto n_step = ukernel.interface.get_n_step();
1212
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_TRUE(n_step % nr == 0);
1213
1214 2136 const auto rect = p.rect;
1215
5/18
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
2136 ASSERT_GT(rect.height(), 0U);
1216
5/18
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
2136 ASSERT_GT(rect.width(), 0U);
1217
1218
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto& data = test_data();
1219
1220
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto rhs_start_col = rect.start_col();
1221 2136 const size_t bias_offset_bytes = rhs_start_col * sizeof(float);
1222
1223 2136 Buffer imp_packed_rhs;
1224 2136 size_t rhs_packed_offset = 0;
1225
2/2
✓ Branch 0 taken 1058 times.
✓ Branch 1 taken 1078 times.
2136 if (rhs_pack_type == RhsPackType::NxK) {
1226
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 const float* bias_ptr = reinterpret_cast<const float*>(data.bias.data()) + rhs_start_col;
1227
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 std::tie(imp_packed_rhs, rhs_packed_offset) = pack_rhs_qsi4c32p_nxk(
1228 1078 p.variant->rhs_pack_interface, data.N, data.K, nr, kr, sr, bl, data.rhs_quant, bias_ptr, data.rhs_scales,
1229
1/2
✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
1078 rhs_start_col, rect.width(), p.variant->rhs_s0s1_input);
1230 1078 } else {
1231
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1058 times.
1058 if ((rhs_start_col % 2) != 0) {
1232 GTEST_SKIP() << "KxN RHS pack requires even N-start index";
1233 return;
1234 }
1235
1/2
✓ Branch 0 taken 1058 times.
✗ Branch 1 not taken.
1058 std::tie(imp_packed_rhs, rhs_packed_offset) = pack_rhs_qsi4c32p_kxn(
1236 1058 data.N, data.K, nr, kr, sr, bl, data.rhs_quant, data.bias, bias_offset_bytes, data.rhs_scales,
1237
1/2
✓ Branch 0 taken 1058 times.
✗ Branch 1 not taken.
1058 rhs_start_col, rect.width(), p.is_sme2);
1238 }
1239
1240
5/18
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
2136 ASSERT_EQ(rhs_packed_offset, ukernel.interface.get_rhs_packed_offset(rhs_start_col, data.K, bl));
1241
1242 // Destination buffer and offsets
1243 2136 const auto dst_stride_row = data.N * sizeof(float);
1244 2136 constexpr auto dst_stride_col = sizeof(float);
1245
2/4
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
2136 const auto dst_offset = ukernel.interface.get_dst_offset(rect.start_row(), rhs_start_col, dst_stride_row);
1246
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 const auto imp_dst_size = ukernel.interface.get_dst_size(data.M, data.N);
1247
5/18
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
2136 ASSERT_EQ(imp_dst_size, data.ref_dst_clamped.size());
1248
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 Buffer const imp_dst(imp_dst_size);
1249
1250 // Run matmul
1251
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 abi_check(
1252
2/4
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
2136 ukernel.interface.run_matmul, rect.height(), rect.width(), data.K, bl,
1253
2/4
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
2136 data.lhs_packed.data() + data.lhs_packed_offset, imp_packed_rhs.data() + rhs_packed_offset,
1254
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, dst_stride_col, data.clamp.min,
1255 2136 data.clamp.max);
1256
1257
1/2
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
2136 DefaultMismatchHandler handler(0, 0.1, 0, 0.05);
1258 2136 const auto dst_format = DataFormat(DataType::FP32);
1259 4272 const auto success =
1260
3/6
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
2136 compare(imp_dst.data(), data.ref_dst_clamped.data(), dst_format, data.M, data.N, rect, handler);
1261
4/16
✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
2136 ASSERT_TRUE(success);
1262 2136 }
1263
1264 /// RHS vectorised packer format is s16s0 this is not relevant for sme2 kernels
1265 class NeonRhsPackF32Test : public QMatMulClampF32Test {};
1266
1267
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
886 TEST_P(NeonRhsPackF32Test, EndToEndNeonRhsPack) {
1268 352 const auto& p = GetParams();
1269 352 const auto fn_supported = p.variant->ukernel.fn_is_supported;
1270
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 352 times.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
352 if (fn_supported && !fn_supported()) {
1271 GTEST_SKIP() << "Unsupported CPU feature";
1272 return;
1273 }
1274 352 const auto& ukernel = p.variant->ukernel;
1275
1276
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 const size_t mr = ukernel.interface.get_mr();
1277
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 const size_t nr = ukernel.interface.get_nr();
1278
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 const size_t kr = ukernel.interface.get_kr();
1279
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 const size_t sr = ukernel.interface.get_sr();
1280
5/18
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 352 times.
352 ASSERT_EQ(ukernel.interface.get_m_step() % mr, 0U);
1281
5/18
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 352 times.
352 ASSERT_EQ(ukernel.interface.get_n_step() % nr, 0U);
1282
1283
4/6
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 128 times.
✓ Branch 3 taken 224 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 128 times.
352 if (p.rhs_pack_type != RhsPackType::NxK || (kr / sr != 8 && kr / sr != 4)) {
1284 GTEST_SKIP() << "RHS packers not applicable";
1285 }
1286
5/18
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 352 times.
352 ASSERT_GT(p.rect.height(), 0U);
1287
5/18
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 352 times.
352 ASSERT_GT(p.rect.width(), 0U);
1288
1289
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 const auto& data = test_data();
1290
1291 // LHS pack
1292 352 const size_t lhs_stride_bytes = data.K * sizeof(float);
1293
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
1056 auto [imp_packed_lhs, lhs_packed_offset] = pack_lhs_qai8dxp(
1294
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 p.variant->lhs_pack_interface, data.M, data.K, mr, kr, sr, data.lhs, lhs_stride_bytes, p.rect.start_row(),
1295
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 p.rect.height());
1296
7/22
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 352 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 352 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 352 times.
704 ASSERT_EQ(lhs_packed_offset, ukernel.interface.get_lhs_packed_offset(p.rect.start_row(), data.K));
1297
1298 // RHS pack
1299
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 const size_t rhs_start_row = p.rect.start_col();
1300 352 const size_t bias_offset = rhs_start_row * sizeof(float);
1301
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
704 const auto [imp_packed_rhs_neon, rhs_packed_offset_neon] = pack_rhs_qsi4c32pscalebf16_neon(
1302 352 data.N, data.K, nr, kr, sr, p.bl, data.rhs_quant, data.bias, bias_offset, data.rhs_scales, p.rhs_pack_type,
1303
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 rhs_start_row, p.rect.width());
1304
1305
6/20
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 352 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✓ Branch 19 taken 352 times.
704 ASSERT_EQ(rhs_packed_offset_neon, ukernel.interface.get_rhs_packed_offset(rhs_start_row, data.K, p.bl));
1306
1307 352 const auto dst_stride_row = data.N * sizeof(float);
1308
2/4
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
352 Buffer const imp_dst(ukernel.interface.get_dst_size(data.M, data.N));
1309
2/4
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
352 const auto dst_offset = ukernel.interface.get_dst_offset(p.rect.start_row(), rhs_start_row, dst_stride_row);
1310
1311 // Run matmul
1312
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 abi_check(
1313
2/4
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
352 ukernel.interface.run_matmul, p.rect.height(), p.rect.width(), data.K, p.bl,
1314
4/8
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
704 imp_packed_lhs.data() + lhs_packed_offset, imp_packed_rhs_neon.data() + rhs_packed_offset_neon,
1315
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, sizeof(float), data.clamp.min,
1316 352 data.clamp.max);
1317
1318
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
352 DefaultMismatchHandler handler(0, 0.1, 0, 0.05);
1319 352 const DataFormat dst_format(DataType::FP32);
1320
7/22
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 352 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 352 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 352 times.
352 ASSERT_TRUE(compare(imp_dst.data(), data.ref_dst_clamped.data(), dst_format, data.M, data.N, p.rect, handler));
1321 352 }
1322
1323 class QMatMulClampBF16Test : public ::testing::TestWithParam<MatMulTestParams_withBL_withRHSPackType> {};
1324
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
486 TEST_P(QMatMulClampBF16Test, EndToEnd) {
1325 2592 const auto& [variant_index, matmul_shape, bl, portion, rhs_pack_type, clamp_keep_ratio] = GetParam();
1326 384 const auto& ukernel_variant = get_bf16_gemm_variants().at(variant_index);
1327
1328
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) {
1329 GTEST_SKIP() << "Unsupported CPU feature";
1330 }
1331
1332 384 const size_t M = matmul_shape.m;
1333 384 const size_t N = matmul_shape.n;
1334 384 const size_t K = matmul_shape.k;
1335
1336 192 const auto mr = ukernel_variant.interface.get_mr();
1337 192 const auto nr = ukernel_variant.interface.get_nr();
1338 192 const auto kr = ukernel_variant.interface.get_kr();
1339 192 const auto sr = ukernel_variant.interface.get_sr();
1340
1341 192 const auto m_step = ukernel_variant.interface.get_m_step();
1342
3/14
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
192 ASSERT_TRUE(m_step % mr == 0);
1343
1344 192 const auto n_step = ukernel_variant.interface.get_n_step();
1345
3/14
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
192 ASSERT_TRUE(n_step % nr == 0);
1346
1347 384 const auto rect = portion.compute_portion(M, N, m_step, n_step);
1348
3/14
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
192 ASSERT_GT(rect.height(), 0U);
1349
3/14
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
192 ASSERT_GT(rect.width(), 0U);
1350
1351 // Cached reference and inputs
1352 384 const BF16QMatMulRefKey key{matmul_shape,
1353 bl,
1354 mr,
1355 nr,
1356 kr,
1357 sr,
1358 192 rect.start_row(),
1359 192 rect.start_col(),
1360 192 rect.height(),
1361 192 rect.width(),
1362 rhs_pack_type,
1363 clamp_keep_ratio};
1364 192 const BF16TestData& data = getV<BF16QMatMulRefKey, BF16TestData>(key);
1365
1366 // Verify LHS offsets match interface
1367 192 const auto lhs_start_row = rect.start_row();
1368 384 const auto lhs_packed_offset =
1369 192 kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr);
1370 192 const auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K);
1371
3/14
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
192 ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset);
1372
1373 // RHS: pack using cached quant/scales/bias
1374 192 const size_t rhs_start_row = rect.start_col();
1375 192 const size_t bias_offset = rhs_start_row * sizeof(float);
1376
3/4
✓ Branch 0 taken 96 times.
✓ Branch 1 taken 96 times.
✓ Branch 2 taken 96 times.
✗ Branch 3 not taken.
192 if (rhs_pack_type == RhsPackType::KxN && (rhs_start_row % 2) != 0) {
1377 GTEST_SKIP() << "KxN RHS pack requires even N-start index";
1378 return;
1379 }
1380
1381 672 auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qsi4c32pscalebf16(
1382 576 N, K, nr, kr, sr, bl, data.rhs_quant, data.bias, bias_offset, data.rhs_scales, rhs_pack_type, rhs_start_row,
1383 192 rect.width(), false);
1384
1385
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
384 const auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl);
1386
5/18
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 192 times.
384 ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset);
1387
1388 // Destination
1389 192 const auto dst_stride_row = N * sizeof(uint16_t);
1390 192 constexpr auto dst_stride_col = sizeof(uint16_t);
1391 384 const auto dst_offset =
1392
3/6
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
192 ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row);
1393
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col;
1394
4/16
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 192 times.
192 ASSERT_EQ(dst_offset, ref_dst_offset);
1395
1396
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N);
1397
5/18
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 192 times.
192 ASSERT_EQ(imp_dst_size, data.ref_dst_bf16.size());
1398
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 Buffer imp_dst(imp_dst_size);
1399
1400 // Run matmul
1401
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 abi_check(
1402
2/4
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
192 ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, bl,
1403
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 reinterpret_cast<const uint8_t*>(data.lhs_packed.data()) + lhs_matmul_offset,
1404
3/6
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
384 reinterpret_cast<const uint8_t*>(imp_packed_rhs.data()) + rhs_matmul_offset, imp_dst.data() + dst_offset,
1405 192 dst_stride_row, dst_stride_col, data.clamp.min, data.clamp.max);
1406
1407
1/2
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
192 DefaultMismatchHandler handler(0, 0.02, 0, 0.05);
1408 192 auto dst_format = DataFormat(DataType::BF16);
1409
3/6
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
192 const auto success = compare(imp_dst.data(), data.ref_dst_bf16.data(), dst_format, M, N, rect, handler);
1410
4/16
✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 192 times.
192 ASSERT_TRUE(success);
1411
1412 // Test vectorized packing micro-kernels, if packing parameters allow
1413
3/6
✓ Branch 0 taken 96 times.
✓ Branch 1 taken 96 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 96 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
192 if (rhs_pack_type == RhsPackType::NxK && (kr / sr == 8 || kr / sr == 4)) {
1414
1/2
✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
96 const auto [imp_packed_rhs_neon, rhs_packed_offset_neon] = pack_rhs_qsi4c32pscalebf16_neon(
1415 288 N, K, nr, kr, sr, bl, data.rhs_quant, data.bias, bias_offset, data.rhs_scales, rhs_pack_type, rhs_start_row,
1416
1/2
✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
96 rect.width());
1417
5/18
✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 96 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 96 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 96 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 96 times.
192 ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset);
1418 96 }
1419 192 }
1420
1421 // clang-format off
1422
1423 /// Portion categories (GEMM/GEMV)
1424 static constexpr std::array gemm_portions{
1425 MatrixPortion(0, 0, 1, 1), // Full matrix
1426 MatrixPortion(0.4, 0.5, 0.6, 0.8), // Middle block
1427 };
1428 static constexpr std::array gemv_portions{
1429 MatrixPortion(0, 0, 1, 1), // Full width
1430 MatrixPortion(0, 0.5, 1, 0.5), // Right half
1431 };
1432
1433 /// Shape categories (GEMM/GEMV)
1434
1435 /// Small/Odd edge coverage (odd m/n, varied K)
1436 static constexpr std::array gemm_shapes_small_odd{
1437 MatMulShape{ 17, 25, 64},
1438 MatMulShape{ 31, 31, 64},
1439 MatMulShape{ 21, 53, 256},
1440 MatMulShape{ 35, 27, 320},
1441 };
1442
1443 /// Aligned squares (cache-friendly, power-of-two-ish)
1444 static constexpr std::array gemm_shapes_aligned{
1445 MatMulShape{ 32, 32, 128},
1446 MatMulShape{ 64, 64, 128},
1447 MatMulShape{128, 128, 256},
1448 MatMulShape{192, 192, 384},
1449 };
1450
1451 /// Rectangular (skinny/wide), varied K
1452 static constexpr std::array gemm_shapes_rect{
1453 MatMulShape{ 64, 128, 256}, // wide N
1454 MatMulShape{128, 64, 256}, // tall M
1455 MatMulShape{ 96, 192, 384},
1456 MatMulShape{160, 96, 320},
1457 };
1458
1459 /// Larger/stress (within reason for CI)
1460 static constexpr std::array gemm_shapes_large{
1461 MatMulShape{128, 160, 320},
1462 MatMulShape{160, 128, 320},
1463 MatMulShape{224, 160, 320},
1464 MatMulShape{160, 224, 320},
1465 };
1466
1467 /// GEMV shape categories (F32)
1468 /// M = 1, RHS NxK only in instantiation
1469
1470 /// Small/medium N, diverse K (aligned/odd N)
1471 static constexpr std::array gemv_shapes_small{
1472 MatMulShape{ 1, 16, 64},
1473 MatMulShape{ 1, 31, 64},
1474 MatMulShape{ 1, 128, 256},
1475 MatMulShape{ 1, 256, 256},
1476 MatMulShape{ 1, 320, 320},
1477 };
1478
1479 /// Larger N bands (bandwidth/cache stress)
1480 static constexpr std::array gemv_shapes_large{
1481 MatMulShape{ 1, 512, 256},
1482 MatMulShape{ 1, 640, 320},
1483 MatMulShape{ 1, 768, 384},
1484 MatMulShape{ 1, 1024, 256},
1485 MatMulShape{ 1, 896, 384},
1486 };
1487
1488 static constexpr std::array bf16_shapes {
1489 MatMulShape{ 32, 32, 64}, // small aligned
1490 MatMulShape{ 48, 64, 64}, // rectangular (tall K-block reuse)
1491 MatMulShape{ 64, 64, 128}, // aligned square
1492 MatMulShape{ 96, 96, 192}, // larger aligned
1493 MatMulShape{128, 64, 256}, // rectangular (tall M)
1494 MatMulShape{ 17, 25, 64}, // odd sizes (edge behavior)
1495 MatMulShape{ 33, 29, 192}, // odd sizes with larger K
1496 MatMulShape{128, 160, 320}, // larger rectangular
1497 };
1498
1499 /// Dedicated clamp sweep ratios
1500 static constexpr std::array<float, 3> clamp_keep_ratios_sweep{
1501 1.0F, // no clamp
1502 0.5F, // clamp away 50%
1503 0.1F, // clamp away 90%
1504 };
1505
1506
24/80
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 1344 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 2688 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
4056 INSTANTIATE_TEST_SUITE_P(
1507 MatMulGemm_SmallOdd, QMatMulClampF32Test,
1508 testing::Combine(
1509 testing::Range<size_t>(0, get_f32_gemm_variants().size()),
1510 testing::Values(true),
1511 testing::ValuesIn(gemm_shapes_small_odd),
1512 testing::Values(32),
1513 testing::ValuesIn(gemm_portions),
1514 testing::Values(RhsPackType::NxK, RhsPackType::KxN),
1515 testing::Values(0.5F)),
1516 testing::PrintToStringParamName());
1517
1518
24/80
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 1344 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 2688 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
4056 INSTANTIATE_TEST_SUITE_P(
1519 MatMulGemm_Aligned, QMatMulClampF32Test,
1520 testing::Combine(
1521 testing::Range<size_t>(0, get_f32_gemm_variants().size()),
1522 testing::Values(true),
1523 testing::ValuesIn(gemm_shapes_aligned),
1524 testing::Values(32),
1525 testing::ValuesIn(gemm_portions),
1526 testing::Values(RhsPackType::NxK, RhsPackType::KxN),
1527 testing::Values(0.5F)),
1528 testing::PrintToStringParamName());
1529
1530
24/80
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 2688 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 5376 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
8088 INSTANTIATE_TEST_SUITE_P(
1531 MatMulGemm_Rect, QMatMulClampF32Test,
1532 testing::Combine(
1533 testing::Range<size_t>(0, get_f32_gemm_variants().size()),
1534 testing::Values(true),
1535 testing::ValuesIn(gemm_shapes_rect),
1536 testing::Values(32, 64),
1537 testing::ValuesIn(gemm_portions),
1538 testing::Values(RhsPackType::NxK, RhsPackType::KxN),
1539 testing::Values(0.5F)),
1540 testing::PrintToStringParamName());
1541
1542
24/80
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 1344 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 2688 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
4056 INSTANTIATE_TEST_SUITE_P(
1543 MatMulGemm_Large, QMatMulClampF32Test,
1544 testing::Combine(
1545 testing::Range<size_t>(0, get_f32_gemm_variants().size()),
1546 testing::Values(true),
1547 testing::ValuesIn(gemm_shapes_large),
1548 testing::Values(32),
1549 testing::ValuesIn(gemm_portions),
1550 testing::Values(RhsPackType::NxK, RhsPackType::KxN),
1551 testing::Values(0.5F)),
1552 testing::PrintToStringParamName());
1553
1554
24/80
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 70 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 140 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
234 INSTANTIATE_TEST_SUITE_P(
1555 MatMulGemv_Small, QMatMulClampF32Test,
1556 testing::Combine(
1557 testing::Range<size_t>(0, get_f32_gemv_variants().size()),
1558 testing::Values(false),
1559 testing::ValuesIn(gemv_shapes_small),
1560 testing::Values(32),
1561 testing::ValuesIn(gemv_portions),
1562 testing::Values(RhsPackType::NxK),
1563 testing::Values(0.5F)),
1564 testing::PrintToStringParamName());
1565
1566
24/80
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 70 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 140 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
234 INSTANTIATE_TEST_SUITE_P(
1567 MatMulGemv_Large, QMatMulClampF32Test,
1568 testing::Combine(
1569 testing::Range<size_t>(0, get_f32_gemv_variants().size()),
1570 testing::Values(false),
1571 testing::ValuesIn(gemv_shapes_large),
1572 testing::Values(32),
1573 testing::ValuesIn(gemv_portions),
1574 testing::Values(RhsPackType::NxK),
1575 testing::Values(0.5F)),
1576 testing::PrintToStringParamName());
1577
1578
24/80
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 88 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 176 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
270 INSTANTIATE_TEST_SUITE_P(
1579 MatMulNeonRhsPackGemm_SmallOdd, NeonRhsPackF32Test,
1580 testing::Combine(
1581 testing::Range<size_t>(0, get_f32_neon_gemm_variants_only().size()),
1582 testing::Values(true),
1583 testing::ValuesIn(gemm_shapes_small_odd),
1584 testing::Values(32),
1585 testing::ValuesIn(gemm_portions),
1586 testing::Values(RhsPackType::NxK),
1587 testing::Values(0.5F)),
1588 testing::PrintToStringParamName());
1589
1590
24/80
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 88 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 176 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
270 INSTANTIATE_TEST_SUITE_P(
1591 MatMulNeonRhsPackGemm_Aligned, NeonRhsPackF32Test,
1592 testing::Combine(
1593 testing::Range<size_t>(0, get_f32_neon_gemm_variants_only().size()),
1594 testing::Values(true),
1595 testing::ValuesIn(gemm_shapes_aligned),
1596 testing::Values(32),
1597 testing::ValuesIn(gemm_portions),
1598 testing::Values(RhsPackType::NxK),
1599 testing::Values(0.5F)),
1600 testing::PrintToStringParamName());
1601
1602 static constexpr std::array clamp_sweep_shapes{
1603 MatMulShape{ 64, 64, 128 },
1604 MatMulShape{ 64, 128, 256 },
1605 };
1606
1607
26/88
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 7 times.
✓ Branch 26 taken 14 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 14 times.
✓ Branch 28 taken 1008 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✓ Branch 30 taken 2016 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
3048 INSTANTIATE_TEST_SUITE_P(
1608 MatMulGemm_ClampSweep, QMatMulClampF32Test,
1609 testing::Combine(
1610 testing::Range<size_t>(0, get_f32_gemm_variants().size()),
1611 testing::Values(true),
1612 testing::ValuesIn(clamp_sweep_shapes),
1613 testing::Values(32),
1614 testing::Values(MatrixPortion(0, 0, 1, 1)),
1615 testing::Values(RhsPackType::NxK, RhsPackType::KxN),
1616 testing::ValuesIn(clamp_keep_ratios_sweep)),
1617 testing::PrintToStringParamName());
1618
1619
24/80
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 96 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 192 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
294 INSTANTIATE_TEST_SUITE_P(
1620 MatMulBF16_SingleSet, QMatMulClampBF16Test,
1621 testing::Combine(
1622 testing::Range<size_t>(0, get_bf16_gemm_variants().size()),
1623 testing::ValuesIn(bf16_shapes),
1624 testing::Values(32),
1625 testing::Values(MatrixPortion(0, 0, 1, 1)),
1626 testing::Values(RhsPackType::NxK, RhsPackType::KxN),
1627 testing::ValuesIn(clamp_keep_ratios_sweep)
1628 ),
1629 testing::PrintToStringParamName());
1630
1631 } // namespace kai::test
1632