KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 96.6% 458 / 1 / 475
Functions: 97.5% 79 / 0 / 81
Branches: 39.4% 371 / 2 / 944

test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <functional>
14 #include <optional>
15 #include <sstream>
16 #include <string>
17 #include <string_view>
18 #include <tuple>
19 #include <unordered_map>
20
21 #include "kai/kai_common.h"
22 #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h"
23 #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h"
24 #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h"
25 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h"
26 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h"
27 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h"
28 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h"
29 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_interface.h"
30 #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h"
31 #include "kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h"
32 #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h"
33 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h"
34 #include "test/common/abi_checker.hpp"
35 #include "test/common/buffer.hpp"
36 #include "test/common/cpu_info.hpp"
37 #include "test/common/matmul_test_common.hpp"
38 #include "test/common/matrix_portion.hpp"
39 #include "test/common/memory.hpp"
40 #include "test/common/rect.hpp"
41 #include "test/common/seed.hpp"
42 #include "test/common/sme.hpp"
43 #include "test/reference/binary_elementwise.hpp"
44 #include "test/reference/clamp.hpp"
45 #include "test/reference/fill.hpp"
46 #include "test/reference/matmul.hpp"
47 #include "test/reference/matmul_pack.hpp"
48 #include "test/reference/quantize.hpp"
49 #include "test/reference/reduce.hpp"
50 #include "test/reference/reorder.hpp"
51 #include "test/reference/transpose.hpp"
52
53 namespace kai::test {
54
55 // Ensure static linkage for all functionality local to this test file
56 namespace {
57
58 struct KChunk {
59 size_t count;
60 size_t length;
61 };
62
63
3/6
✓ Branch 0 taken 10008 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 10008 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10008 times.
✗ Branch 5 not taken.
10008 struct LhsPackKernel {
64 std::function<size_t(size_t mr)> get_m_step;
65 std::function<size_t(size_t m_idx, size_t lhs_stride)> get_lhs_offset;
66 std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)> get_packed_lhs_offset;
67 std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)> get_packed_lhs_size;
68 std::function<void(
69 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
70 void* lhs_packed)>
71 pack;
72 };
73
74
2/4
✓ Branch 0 taken 49986 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 49986 times.
✗ Branch 3 not taken.
49986 struct LhsPackIndirectKernel {
75 std::function<size_t()> get_m_step;
76 std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_packed_lhs_offset;
77 std::function<size_t(size_t m, size_t k_chunk_count, size_t k_chunk_length)> get_packed_lhs_size;
78 std::function<void(
79 size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset,
80 const void* zero, void* packed_lhs)>
81 pack;
82 };
83
84
5/10
✓ Branch 0 taken 12528 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12528 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12528 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 12528 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 12528 times.
✗ Branch 9 not taken.
12528 struct RhsPackKernel {
85 std::function<size_t()> get_n_step;
86 std::function<size_t(size_t n_idx)> get_rhs_offset;
87 std::function<size_t(size_t n_idx)> get_bias_offset;
88 std::function<size_t(size_t n_idx)> get_scale_offset;
89 std::function<size_t(size_t n_idx, size_t k)> get_packed_rhs_offset;
90 std::function<size_t(size_t n, size_t k)> get_packed_rhs_size;
91 std::function<void(
92 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
93 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes,
94 const struct kai_rhs_pack_qsi8cx_params* params)>
95 pack;
96 };
97
98
5/10
✓ Branch 0 taken 49986 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 49986 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 49986 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 49986 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 49986 times.
✗ Branch 9 not taken.
49986 struct RhsPackIndirectKernel {
99 std::function<size_t()> get_n_step;
100 std::function<size_t(size_t n_idx)> get_rhs_offset;
101 std::function<size_t(size_t n_idx)> get_bias_offset;
102 std::function<size_t(size_t n_idx)> get_scale_offset;
103 std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_packed_rhs_offset;
104 std::function<size_t(size_t n, size_t k_chunk_count, size_t k_chunk_length)> get_packed_rhs_size;
105 std::function<void(
106 size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride, const void* rhs, const void* bias,
107 const void* scale, void* rhs_packed, const kai_rhs_pack_qsi8cx_params* params)>
108 pack;
109 };
110
111
9/18
✓ Branch 0 taken 12528 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12528 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12528 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 12528 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 12528 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12528 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 12528 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 12528 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12528 times.
✗ Branch 17 not taken.
12528 struct MatMulKernel {
112 std::function<size_t(void)> get_m_step;
113 std::function<size_t(void)> get_n_step;
114 std::function<size_t(void)> get_mr;
115 std::function<size_t(void)> get_nr;
116 std::function<size_t(void)> get_kr;
117 std::function<size_t(void)> get_sr;
118 std::function<size_t(size_t m_idx, size_t k)> get_packed_lhs_offset;
119 std::function<size_t(size_t n_idx, size_t k)> get_packed_rhs_offset;
120 std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride)> get_dst_offset;
121 std::function<size_t(size_t m, size_t n)> get_dst_size;
122 std::function<void(
123 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
124 size_t dst_stride_col, const kai_matmul_requantize32_params* params)>
125 matmul;
126 };
127
128
5/10
✓ Branch 0 taken 49986 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 49986 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 49986 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 49986 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 49986 times.
✗ Branch 9 not taken.
49986 struct MatMulIndirectKernel {
129 std::function<size_t(void)> get_m_step;
130 std::function<size_t(void)> get_n_step;
131 std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset;
132 std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset;
133 std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride_row)> get_dst_offset;
134 std::function<size_t(size_t m, size_t n)> get_dst_size;
135 std::function<void(
136 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_lenght, const void* lhs_packed, const void* rhs_packed,
137 void* dst, size_t dst_stride_row, const kai_matmul_requantize32_params* params)>
138 imatmul;
139 };
140
141 /// Make sure that interface matches for qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa
142 const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel&
143 3 get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface() {
144 static kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel ukernel;
145
146 3 ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
147 3 ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
148 3 ukernel.get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
149 3 ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
150 3 ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
151 3 ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
152 3 ukernel.get_lhs_packed_offset =
153 kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
154 3 ukernel.get_rhs_packed_offset =
155 kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
156 3 ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
157 3 ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
158 3 ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
159
160 3 return ukernel;
161 }
162
163 /// Make sure that interface matches for qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme_mopa
164 const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel&
165 3 get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface() {
166 static kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel ukernel;
167
168 3 ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
169 3 ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
170 3 ukernel.get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
171 3 ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
172 3 ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
173 3 ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
174 3 ukernel.get_lhs_packed_offset =
175 kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
176 3 ukernel.get_rhs_packed_offset =
177 kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
178 3 ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
179 3 ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
180 3 ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
181
182 3 return ukernel;
183 }
184
185 /// Make sure that interface matches for qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot
186 const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel&
187 3 get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() {
188 static kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel;
189
190 3 ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
191 3 ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
192 3 ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
193 3 ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
194 3 ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
195 3 ukernel.get_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
196 3 ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
197 3 ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
198 3 ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
199 3 ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
200
201 3 return ukernel;
202 };
203
204 /// Make sure that interface matches qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa
205 const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel&
206 3 get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface() {
207 static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel;
208
209 3 ukernel.get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
210 3 ukernel.get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
211 3 ukernel.get_lhs_packed_offset =
212 kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
213 3 ukernel.get_rhs_packed_offset =
214 kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
215 3 ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
216 3 ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
217 3 ukernel.run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
218
219 3 return ukernel;
220 };
221
222 /// Make sure that interface matches qai8_qai8p2vlx4_qsi8cxps2vlx4b_2vlx2vl_sme_mopa
223 const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel&
224 3 get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface() {
225 static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel;
226
227 3 ukernel.get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
228 3 ukernel.get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
229 3 ukernel.get_lhs_packed_offset =
230 kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
231 3 ukernel.get_rhs_packed_offset =
232 kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
233 3 ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
234 3 ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
235 3 ukernel.run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
236
237 3 return ukernel;
238 };
239
240 9 const RhsPackKernel& get_rhs_pack() {
241
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
9 static RhsPackKernel ukernel;
242
243 9 ukernel.get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
244 9 ukernel.get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
245 9 ukernel.get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
246 9 ukernel.get_scale_offset = kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
247 9 ukernel.get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
248 9 ukernel.get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
249 9 ukernel.pack = kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
250
251 9 return ukernel;
252 }
253
254 6 const LhsPackKernel& get_lhs_pack() {
255
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
6 static LhsPackKernel ukernel;
256
257 6 ukernel.get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme;
258 6 ukernel.get_lhs_offset = kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme;
259 6 ukernel.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme;
260 6 ukernel.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme;
261 6 ukernel.pack = kai_run_lhs_pack_x8p2vlx4_x8_sme;
262
263 6 return ukernel;
264 }
265
266
2/4
✓ Branch 0 taken 12528 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12528 times.
✗ Branch 3 not taken.
12528 struct MatMulVariant {
267 std::string_view name; ///< Test identification
268 MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr)
269 MatMulShape acc_step; ///< Accumulator shape for matmul (stepping)
270
271 std::function<bool(void)> is_supported; ///< HW support check
272
273 std::optional<LhsPackKernel> lhs_pack; ///< LHS packing micro-kernel interface
274 RhsPackKernel rhs_pack; ///< RHS packing micro-kernel interface
275 MatMulKernel matmul; ///< Matmul kernel interface
276 };
277
278
2/4
✓ Branch 0 taken 49986 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 49986 times.
✗ Branch 3 not taken.
49986 struct IndirectMatMulVariant {
279 std::string_view name; ///< Test identification
280 MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr)
281 MatMulShape acc_step; ///< Accumulator shape for matmul (stepping)
282
283 std::function<bool(void)> is_supported; ///< HW support check
284
285 LhsPackIndirectKernel lhs_pack; ///< LHS packing micro-kernel interface
286 RhsPackIndirectKernel rhs_pack; ///< RHS packing micro-kernel interface
287 MatMulIndirectKernel matmul; ///< Matmul kernel interface
288 };
289
290 3 const auto& get_gemm_variants() {
291
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
3 static std::array<MatMulVariant, 2> variants;
292
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
6 static const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& ukernel_sme2 =
293
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface();
294
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
6 static const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& ukernel_sme =
295
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
3 get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface();
296
297 3 variants[0].name = "matmul_qai8_qai8p_qsi8cxp_sme";
298 3 variants[0].acc_pack.m = 2 * get_sme_vector_length<int32_t>();
299 3 variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
300 3 variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
301 3 variants[0].acc_step.m = 2 * get_sme_vector_length<int32_t>();
302 3 variants[0].acc_step.n = 2 * get_sme_vector_length<int32_t>();
303 3 variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
304 3 variants[0].is_supported = cpu_has_sme;
305 3 variants[0].lhs_pack = get_lhs_pack();
306 3 variants[0].rhs_pack = get_rhs_pack();
307 3 variants[0].matmul.get_m_step = ukernel_sme.get_m_step;
308 3 variants[0].matmul.get_n_step = ukernel_sme.get_n_step;
309 3 variants[0].matmul.get_mr = ukernel_sme.get_mr;
310 3 variants[0].matmul.get_nr = ukernel_sme.get_nr;
311 3 variants[0].matmul.get_kr = ukernel_sme.get_kr;
312 3 variants[0].matmul.get_sr = ukernel_sme.get_sr;
313 3 variants[0].matmul.get_packed_lhs_offset = ukernel_sme.get_lhs_packed_offset;
314 3 variants[0].matmul.get_packed_rhs_offset = ukernel_sme.get_rhs_packed_offset;
315 3 variants[0].matmul.get_dst_offset = ukernel_sme.get_dst_offset;
316 3 variants[0].matmul.get_dst_size = ukernel_sme.get_dst_size;
317 3 variants[0].matmul.matmul = ukernel_sme.run_matmul;
318
319 3 variants[1].name = "matmul_qai8_qai8p_qsi8cxp_sme2";
320 3 variants[1].acc_pack.m = 2 * get_sme_vector_length<int32_t>();
321 3 variants[1].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
322 3 variants[1].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
323 3 variants[1].acc_step.m = 2 * get_sme_vector_length<int32_t>();
324 3 variants[1].acc_step.n = 2 * get_sme_vector_length<int32_t>();
325 3 variants[1].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
326 3 variants[1].is_supported = cpu_has_sme2;
327 3 variants[1].lhs_pack = get_lhs_pack();
328 3 variants[1].rhs_pack = get_rhs_pack();
329 3 variants[1].matmul.get_m_step = ukernel_sme2.get_m_step;
330 3 variants[1].matmul.get_n_step = ukernel_sme2.get_n_step;
331 3 variants[1].matmul.get_mr = ukernel_sme2.get_mr;
332 3 variants[1].matmul.get_nr = ukernel_sme2.get_nr;
333 3 variants[1].matmul.get_kr = ukernel_sme2.get_kr;
334 3 variants[1].matmul.get_sr = ukernel_sme2.get_sr;
335 3 variants[1].matmul.get_packed_lhs_offset = ukernel_sme2.get_lhs_packed_offset;
336 3 variants[1].matmul.get_packed_rhs_offset = ukernel_sme2.get_rhs_packed_offset;
337 3 variants[1].matmul.get_dst_offset = ukernel_sme2.get_dst_offset;
338 3 variants[1].matmul.get_dst_size = ukernel_sme2.get_dst_size;
339 3 variants[1].matmul.matmul = ukernel_sme2.run_matmul;
340
341 3 return variants;
342 }
343
344 9 const auto& get_indirect_gemm_variants() {
345
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
9 static std::array<IndirectMatMulVariant, 2> variants;
346
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
12 static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel_sme =
347
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface();
348
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
12 static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel_sme2 =
349
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
3 get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface();
350
351 9 variants[0].name = "imatmul_qai8_qai8p_qsi8cxp_sme";
352 9 variants[0].acc_pack.m = 2 * get_sme_vector_length<int32_t>();
353 9 variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
354 9 variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
355 9 variants[0].acc_step.m = 2 * get_sme_vector_length<int32_t>();
356 9 variants[0].acc_step.n = 2 * get_sme_vector_length<int32_t>();
357 9 variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
358 9 variants[0].is_supported = cpu_has_sme;
359 9 variants[0].lhs_pack.get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
360 9 variants[0].lhs_pack.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
361 9 variants[0].lhs_pack.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
362 9 variants[0].lhs_pack.pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
363 9 variants[0].rhs_pack.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
364 9 variants[0].rhs_pack.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
365 9 variants[0].rhs_pack.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
366 9 variants[0].rhs_pack.get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
367 9 variants[0].rhs_pack.get_packed_rhs_offset =
368 kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
369 9 variants[0].rhs_pack.get_packed_rhs_size =
370 kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
371 9 variants[0].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
372 9 variants[0].matmul.get_m_step = ukernel_sme.get_m_step;
373 9 variants[0].matmul.get_n_step = ukernel_sme.get_n_step;
374 9 variants[0].matmul.get_lhs_packed_offset = ukernel_sme.get_lhs_packed_offset;
375 9 variants[0].matmul.get_rhs_packed_offset = ukernel_sme.get_rhs_packed_offset;
376 9 variants[0].matmul.get_dst_offset = ukernel_sme.get_dst_offset;
377 9 variants[0].matmul.get_dst_size = ukernel_sme.get_dst_size;
378 9 variants[0].matmul.imatmul = ukernel_sme.run_imatmul;
379
380 9 variants[1].name = "imatmul_qai8_qai8p_qsi8cxp_sme2";
381 9 variants[1].acc_pack.m = 2 * get_sme_vector_length<int32_t>();
382 9 variants[1].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
383 9 variants[1].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
384 9 variants[1].acc_step.m = 2 * get_sme_vector_length<int32_t>();
385 9 variants[1].acc_step.n = 2 * get_sme_vector_length<int32_t>();
386 9 variants[1].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
387 9 variants[1].is_supported = cpu_has_sme2;
388 9 variants[1].lhs_pack.get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
389 9 variants[1].lhs_pack.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
390 9 variants[1].lhs_pack.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
391 9 variants[1].lhs_pack.pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
392 9 variants[1].rhs_pack.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
393 9 variants[1].rhs_pack.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
394 9 variants[1].rhs_pack.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
395 9 variants[1].rhs_pack.get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
396 9 variants[1].rhs_pack.get_packed_rhs_offset =
397 kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
398 9 variants[1].rhs_pack.get_packed_rhs_size =
399 kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
400 9 variants[1].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
401 9 variants[1].matmul.get_m_step = ukernel_sme2.get_m_step;
402 9 variants[1].matmul.get_n_step = ukernel_sme2.get_n_step;
403 9 variants[1].matmul.get_lhs_packed_offset = ukernel_sme2.get_lhs_packed_offset;
404 9 variants[1].matmul.get_rhs_packed_offset = ukernel_sme2.get_rhs_packed_offset;
405 9 variants[1].matmul.get_dst_offset = ukernel_sme2.get_dst_offset;
406 9 variants[1].matmul.get_dst_size = ukernel_sme2.get_dst_size;
407 9 variants[1].matmul.imatmul = ukernel_sme2.run_imatmul;
408
409 9 return variants;
410 }
411
412 3 const auto& get_gemv_variants() {
413
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
3 static std::array<MatMulVariant, 1> variants;
414
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
6 static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel =
415
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface();
416
417 3 variants[0].name = "matmul_qai8_qai8_qsi8cxp";
418 3 variants[0].acc_pack.m = 1;
419 3 variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
420 3 variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
421 3 variants[0].acc_step.m = 1;
422 3 variants[0].acc_step.n = 16 * get_sme_vector_length<int32_t>();
423 3 variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
424 3 variants[0].is_supported = cpu_has_sme2;
425 3 variants[0].lhs_pack = std::nullopt;
426 3 variants[0].rhs_pack = get_rhs_pack();
427 3 variants[0].matmul.get_m_step = ukernel.get_m_step;
428 3 variants[0].matmul.get_n_step = ukernel.get_n_step;
429 171 variants[0].matmul.get_mr = []() -> size_t { return 1; };
430 3 variants[0].matmul.get_nr = ukernel.get_nr;
431 3 variants[0].matmul.get_kr = ukernel.get_kr;
432 3 variants[0].matmul.get_sr = ukernel.get_sr;
433 3 variants[0].matmul.get_packed_lhs_offset = nullptr;
434 3 variants[0].matmul.get_packed_rhs_offset = ukernel.get_rhs_packed_offset;
435 3 variants[0].matmul.get_dst_offset = ukernel.get_dst_offset;
436 3 variants[0].matmul.get_dst_size = ukernel.get_dst_size;
437 3 variants[0].matmul.matmul = ukernel.run_matmul;
438
439 3 return variants;
440 }
441
442 /// Quantization parameters
443 struct Quant {
444 float scale;
445 int32_t zero_point;
446 };
447
448 /// Reference test data
449 struct TestReference {
450 Range<int8_t> clamp;
451
452 Quant qa_lhs;
453 Quant qa_dst;
454
455 Buffer lhs_qai8;
456 Buffer lhs_qai8_scales;
457 Buffer lhs_qai8_zero_points;
458 Buffer lhs_qai8_indirect;
459 Buffer lhs_qai8_indirect_packed;
460 Buffer lhs_qai8_indirect_padding;
461 size_t lhs_qai8_indirect_offset;
462
463 Buffer rhs_qsi8;
464 Buffer rhs_scales;
465
466 Buffer bias_qsi32;
467
468 Buffer dst_qsi8_clamped;
469
470 Buffer packed_lhs;
471 Buffer packed_rhs;
472 };
473
474 constexpr int8_t padding_value = 0;
475
476 // Functionality for hashing generated test data.
477 // This is particularly useful for portion testing
478 // which reuses the exact same data for all portions
479 struct TestDataId {
480 MatMulShape shape;
481 MatMulShape shape_pack;
482 size_t chunk_len;
483 bool pad_testing;
484 float clamp_keep_ratio;
485
486 struct Hash {
487 4685 size_t operator()(const TestDataId& id) const {
488 4685 return //
489 9370 (MatMulShape::Hash{}(id.shape) << 0) ^ //
490 9370 (MatMulShape::Hash{}(id.shape_pack) << 1) ^ //
491 9370 (std::hash<size_t>{}(id.chunk_len) << 2) ^ //
492 9370 (std::hash<bool>{}(id.pad_testing) << 3) ^ //
493 4685 (std::hash<float>{}(id.clamp_keep_ratio) << 4);
494 }
495 };
496
497 private:
498 4489 friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) {
499 4489 return //
500
2/2
✓ Branch 0 taken 3659 times.
✓ Branch 1 taken 830 times.
4489 lhs.shape == rhs.shape && //
501
1/2
✓ Branch 0 taken 3659 times.
✗ Branch 1 not taken.
3659 lhs.shape_pack == rhs.shape_pack && //
502
2/2
✓ Branch 0 taken 3643 times.
✓ Branch 1 taken 16 times.
3659 lhs.chunk_len == rhs.chunk_len && //
503
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3643 times.
3643 lhs.pad_testing == rhs.pad_testing && //
504 3643 lhs.clamp_keep_ratio == rhs.clamp_keep_ratio;
505 }
506 };
507
508 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
509 3 std::unordered_map<TestDataId, TestReference, TestDataId::Hash> g_data;
510 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
511
512 /// Generate test reference data
513 4164 const TestReference& get_test_reference(const TestDataId& test_data_id) {
514 // ============================================================
515 // Generates input and reference output data
516 // ============================================================
517
518 // Attempt to find test data in cache
519 4164 const auto data_it = g_data.find(test_data_id);
520
2/2
✓ Branch 0 taken 3643 times.
✓ Branch 1 taken 521 times.
4164 if (data_it != g_data.end()) {
521 3643 return data_it->second;
522 }
523
524 5210 const auto& [shape, pack_shape, k_chunk_len, pad_testing, clamp_keep_ratio] = test_data_id;
525
526 // Seed the random generator.
527
8/18
✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 521 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 521 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 521 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 521 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 521 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
1563 const auto key = std::string("Qai8Qai8Qsi8_cache:") + std::to_string(shape.m) + "x" + std::to_string(shape.n) +
528
7/14
✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 521 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 521 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 521 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 521 times.
✗ Branch 13 not taken.
1042 "x" + std::to_string(shape.k) + ":" + std::to_string(clamp_keep_ratio);
529
1/2
✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
521 auto& feed = seed_stream(key);
530
531 // Generates the input data in floating-point.
532
4/8
✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 521 times.
✗ Branch 7 not taken.
1563 Buffer lhs_f32 = fill_random<float>(shape.m * shape.k, feed());
533
4/8
✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 521 times.
✗ Branch 7 not taken.
1563 const Buffer rhs_f32 = fill_random<float>(shape.k * shape.n, feed());
534
3/6
✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
1042 const Buffer bias_f32 = fill_random<float>(shape.n, feed());
535
536 // Quantizes the input data.
537 // * LHS: 8-bit asymmetric per-matrix quantization.
538 // * RHS: 8-bit symmetric per-channel quantization.
539 // * Bias: 32-bit symmetric per-channel quantization.
540
541
1/2
✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
521 QuantizationInfo lhs_qinfo{};
542 lhs_qinfo.quant_width = shape.m * shape.k;
543 lhs_qinfo.dst_type = DataType::QAI8;
544 lhs_qinfo.scale_type = DataType::FP32;
545 lhs_qinfo.zero_point_type = DataType::I32;
546 auto [lhs_ref_quant, lhs_qoutputs] =
547 quantize_dynamic(lhs_f32.data(), DataType::FP32, 1, shape.m * shape.k, lhs_qinfo);
548 const auto lhs_scale = read_array<float>(lhs_qoutputs.scales.data(), 0);
549 const auto lhs_zero_point = read_array<int32_t>(lhs_qoutputs.zero_points.data(), 0);
550
551 const size_t k_chunk_count = shape.k / k_chunk_len;
552 assert(k_chunk_count * k_chunk_len == shape.k);
553
554 // Setup an indirection buffer, where each "row" contains `k_chunk_count`
555 // pointers to chunks of length `k_chunk_len` in the input_buffer
556 Buffer lhs_qai8_indirect(shape.m * k_chunk_count * sizeof(void*));
557 Buffer lhs_padding(k_chunk_len, padding_value);
558 auto* lhs_qai8_indirect_ptr = reinterpret_cast<uint8_t**>(lhs_qai8_indirect.data());
559 for (size_t m_i = 0; m_i < shape.m; ++m_i) {
560 for (size_t k_chunk_idx = 0; k_chunk_idx < k_chunk_count; ++k_chunk_idx) {
561 const size_t idx = m_i * k_chunk_count + k_chunk_idx;
562 if (pad_testing and m_i == 0) {
563 // Push padding pointers for first row
564 lhs_qai8_indirect_ptr[idx] = reinterpret_cast<uint8_t*>(lhs_padding.data());
565 } else {
566 uintptr_t offset = m_i * shape.k + k_chunk_idx * k_chunk_len;
567 lhs_qai8_indirect_ptr[idx] = reinterpret_cast<uint8_t*>(offset);
568 }
569 }
570 }
571 const auto indirection_base = reinterpret_cast<uintptr_t>(lhs_ref_quant.data());
572
573 // Reorder indirection pointers to layout the packing micro-kernel expects
574 Buffer lhs_qai8_indirect_packed = reorder_block<const void*>(
575 reinterpret_cast<const void*>(lhs_qai8_indirect.data()), shape.m, k_chunk_count, pack_shape.m, 1);
576
577 // Transpose, then quantize symmetrically, then transpose back. This will give one
578 // quantization value for each column
579 const auto rhs_f32_t = transpose<float>(rhs_f32.data(), shape.k, shape.n);
580
581 QuantizationInfo rhs_qinfo{};
582 rhs_qinfo.quant_width = shape.k;
583 rhs_qinfo.dst_type = DataType::QSI8;
584 rhs_qinfo.scale_type = DataType::FP32;
585 auto [rhs_ref_quant_t, rhs_qoutputs] =
586 quantize_dynamic(rhs_f32_t.data(), DataType::FP32, shape.n, shape.k, rhs_qinfo);
587 auto rhs_qsi8 = transpose<int8_t>(rhs_ref_quant_t.data(), shape.n, shape.k);
588
589 // Multiply all bias values with the LHS scale
590 const auto bias_scales = mul<float>(&lhs_scale, 1, 1, rhs_qoutputs.scales.data(), 1, shape.n);
591 // Calculate quantized bias values, by treating bias as column, and
592 // scale using RHS scales. This will scale each bias value indiviually
593 auto bias_qsi32 =
594 quantize_symmetric_per_block<float, int32_t, float>(bias_f32.data(), bias_scales.data(), shape.n, 1, 1);
595
596 // Runs the reference implementation of matmul to produce floating-point result.
597 const void* const* lhs_iptr = reinterpret_cast<const void* const*>(lhs_qai8_indirect.data());
598 const auto ref_dst_f32 =
599 indirect_matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>(
600 shape.m, shape.n, k_chunk_count, k_chunk_len, // matmul shape
601 lhs_iptr, indirection_base, lhs_padding.data(), // LHS indirection, offset and padding
602 &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point
603 shape.m, shape.k, // LHS quantization window shape
604 rhs_ref_quant_t.data(), rhs_qoutputs.scales.data(), nullptr, // RHS scaling factors
605 1, shape.k, // RHS quantization window shape
606 bias_qsi32.data(), bias_scales.data(), nullptr, // Bias, scaling and zero points
607 1 // Bias quantization window shape
608 );
609
610 // Computes the output quantization information and clamping limits.
611 //
612 // To get a realistic value for the output quantization information and clamping limits
613 // and avoid uncontrolled saturation problem, these information will be calculated
614 // based on the reference floating-point output.
615 //
616 // The clamping limits will be slightly narrower than the actual range of the output
617 // so that a portion of the output will be clampped.
618 const auto [dst_scales, dst_zero_points] =
619 compute_asymmetric_per_block_quantization_info<float, int8_t, float, int32_t>(
620 ref_dst_f32.data(), 1, shape.m * shape.n, shape.m * shape.n);
621 const auto dst_scale = read_array<float>(dst_scales.data(), 0);
622 const auto dst_zero_point = read_array<int32_t>(dst_zero_points.data(), 0);
623
624 const auto ref_dst_f32_min = reduce_min<float>(ref_dst_f32.data(), shape.m * shape.n);
625 const auto ref_dst_f32_max = reduce_max<float>(ref_dst_f32.data(), shape.m * shape.n);
626 const auto ref_dst_f32_range = ref_dst_f32_max - ref_dst_f32_min;
627
628 const auto ref_dst_f32_clamp_min = ref_dst_f32_min + ref_dst_f32_range * (1.0F - clamp_keep_ratio) / 2;
629 const auto ref_dst_f32_clamp_max = ref_dst_f32_max - ref_dst_f32_range * (1.0F - clamp_keep_ratio) / 2;
630 const auto dst_qai8_clamp_min =
631 quantize_asymmetric<float, int8_t, int32_t>(ref_dst_f32_clamp_min, dst_scale, dst_zero_point);
632 const auto dst_qai8_clamp_max =
633 quantize_asymmetric<float, int8_t, int32_t>(ref_dst_f32_clamp_max, dst_scale, dst_zero_point);
634
635 // Clamps and quantizes the reference output matrix.
636 const auto ref_dst_f32_clamped =
637 clamp<float>(ref_dst_f32.data(), shape.m * shape.n, ref_dst_f32_clamp_min, ref_dst_f32_clamp_max);
638 auto ref_dst_qsi8_clamped = quantize_asymmetric_per_block<float, int8_t, float, int32_t>(
639 ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, // values, scales, zero point
640 1, shape.m * shape.n, // data shape
641 shape.m * shape.n // quantization window width
642 );
643
644 // Runs the reference implementation of the packing micro-kernels.
645 //
646 // The reference packing micro-kernels cannot be executed earlier
647 // because we need the reference floating-point output first to have
648 // the quantization information.
649 auto packed_lhs = reorder_block<int8_t>(lhs_ref_quant.data(), shape.m, shape.k, pack_shape.m, pack_shape.k);
650 auto packed_rhs = matmul_pack_rhs_nxk_static_quantized<int8_t, float, int32_t>(
651 rhs_ref_quant_t.data(), rhs_qoutputs.scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point,
652 shape.n, shape.k, pack_shape.n, pack_shape.k);
653
654 TestReference& reference = g_data[test_data_id];
655 reference.clamp.min = dst_qai8_clamp_min;
656 reference.clamp.max = dst_qai8_clamp_max;
657 reference.qa_lhs.scale = lhs_scale;
658 reference.qa_lhs.zero_point = lhs_zero_point;
659 reference.qa_dst.scale = dst_scale;
660 reference.qa_dst.zero_point = dst_zero_point;
661 reference.lhs_qai8 = std::move(lhs_ref_quant);
662 reference.lhs_qai8_scales = std::move(lhs_qoutputs.scales);
663 reference.lhs_qai8_zero_points = std::move(lhs_qoutputs.zero_points);
664 reference.lhs_qai8_indirect = std::move(lhs_qai8_indirect);
665 reference.lhs_qai8_indirect_packed = std::move(lhs_qai8_indirect_packed);
666 reference.lhs_qai8_indirect_padding = std::move(lhs_padding);
667 reference.lhs_qai8_indirect_offset = indirection_base;
668 reference.rhs_qsi8 = std::move(rhs_qsi8);
669 reference.rhs_scales = std::move(rhs_qoutputs.scales);
670 reference.bias_qsi32 = std::move(bias_qsi32);
671 reference.dst_qsi8_clamped = std::move(ref_dst_qsi8_clamped);
672 reference.packed_lhs = std::move(packed_lhs);
673 reference.packed_rhs = std::move(packed_rhs);
674
675 return reference;
676 }
677
678 /// Test LHS packing
679 666 void test_lhs_pack(
680 const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) {
681 KAI_ASSUME_ALWAYS(variant.lhs_pack.has_value());
682
683 1332 const auto imp_packed_lhs_size =
684 666 variant.lhs_pack->get_packed_lhs_size(shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k, 1);
685
2/12
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 666 times.
666 ASSERT_EQ(imp_packed_lhs_size, reference.packed_lhs.size());
686
687 666 Buffer imp_packed_lhs(imp_packed_lhs_size, 0);
688
2/4
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 666 times.
✗ Branch 3 not taken.
666 const auto imp_lhs_offset = variant.lhs_pack->get_lhs_offset(output_area.start_row(), shape.k * sizeof(int8_t));
689
1/2
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
666 const auto imp_packed_lhs_offset = variant.lhs_pack->get_packed_lhs_offset(
690
1/2
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
666 output_area.start_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1);
691
692
1/2
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
666 abi_check(
693
1/2
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
666 variant.lhs_pack->pack, output_area.height(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1, 0,
694 666 reference.lhs_qai8.data() + imp_lhs_offset, shape.k * sizeof(int8_t),
695 666 imp_packed_lhs.data() + imp_packed_lhs_offset);
696
697
3/4
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✓ Branch 3 taken 624 times.
708 const auto imp_packed_lhs_end_offset = output_area.end_row() < shape.m
698
1/2
✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
42 ? variant.lhs_pack->get_packed_lhs_offset(
699
1/2
✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
42 output_area.end_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1)
700 624 : imp_packed_lhs_size;
701
702 666 const auto* imp_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs.data());
703 666 const auto* ref_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(reference.packed_lhs.data());
704
705
4/6
✓ Branch 0 taken 680346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 679680 times.
✓ Branch 3 taken 666 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 666 times.
680346 for (size_t i = 0; i < reference.packed_lhs.size(); ++i) {
706
4/4
✓ Branch 0 taken 651264 times.
✓ Branch 1 taken 28416 times.
✓ Branch 2 taken 42240 times.
✓ Branch 3 taken 609024 times.
679680 if (i >= imp_packed_lhs_offset && i < imp_packed_lhs_end_offset) {
707
3/14
✗ Branch 0 not taken.
✓ Branch 1 taken 609024 times.
✓ Branch 2 taken 609024 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 609024 times.
609024 ASSERT_EQ(imp_packed_lhs_ptr[i], ref_packed_lhs_ptr[i]);
708 609024 } else {
709
3/14
✗ Branch 0 not taken.
✓ Branch 1 taken 70656 times.
✓ Branch 2 taken 70656 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 70656 times.
70656 ASSERT_EQ(imp_packed_lhs_ptr[i], 0);
710 }
711 679680 }
712 666 }
713
714 /// Test RHS packing
715 834 void test_rhs_pack(
716 const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) {
717 834 const auto imp_packed_rhs_size = variant.rhs_pack.get_packed_rhs_size(shape.n, shape.k);
718
2/12
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 834 times.
834 ASSERT_EQ(imp_packed_rhs_size, reference.packed_rhs.size());
719 834 Buffer imp_packed_rhs(imp_packed_rhs_size, 0);
720
721
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const auto imp_rhs_offset = variant.rhs_pack.get_rhs_offset(output_area.start_col());
722
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const auto imp_bias_offset = variant.rhs_pack.get_bias_offset(output_area.start_col());
723
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const auto imp_scale_offset = variant.rhs_pack.get_scale_offset(output_area.start_col());
724
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const auto imp_packed_rhs_offset = variant.rhs_pack.get_packed_rhs_offset(output_area.start_col(), shape.k);
725
726 834 kai_rhs_pack_qsi8cx_params imp_pack_rhs_params{};
727 834 imp_pack_rhs_params.lhs_zero_point = reference.qa_lhs.zero_point;
728 834 imp_pack_rhs_params.scale_multiplier = reference.qa_lhs.scale / reference.qa_dst.scale;
729
730
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
834 abi_check(
731
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
834 variant.rhs_pack.pack, 1, output_area.width(), shape.k, variant.acc_pack.n, variant.acc_pack.k, 1,
732 834 shape.n * sizeof(int8_t), reference.rhs_qsi8.data() + imp_rhs_offset,
733 834 reference.bias_qsi32.data() + imp_bias_offset, reference.rhs_scales.data() + imp_scale_offset,
734 834 imp_packed_rhs.data() + imp_packed_rhs_offset, 0, &imp_pack_rhs_params);
735
736
3/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 150 times.
✓ Branch 3 taken 684 times.
984 const auto imp_packed_rhs_end_offset = output_area.end_col() < shape.n
737
2/4
✓ Branch 0 taken 150 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 150 times.
✗ Branch 3 not taken.
150 ? variant.rhs_pack.get_packed_rhs_offset(output_area.end_col(), shape.k)
738 684 : imp_packed_rhs_size;
739
740 834 size_t mismatches = 0;
741 834 const auto* imp_packed_rhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_rhs.data());
742 834 const auto* ref_packed_rhs_ptr = reinterpret_cast<const uint8_t*>(reference.packed_rhs.data());
743
744
2/2
✓ Branch 0 taken 3380736 times.
✓ Branch 1 taken 834 times.
3381570 for (size_t i = 0; i < reference.packed_rhs.size(); ++i) {
745
4/4
✓ Branch 0 taken 2843136 times.
✓ Branch 1 taken 537600 times.
✓ Branch 2 taken 690816 times.
✓ Branch 3 taken 2152320 times.
3380736 if (i >= imp_packed_rhs_offset && i < imp_packed_rhs_end_offset) {
746
1/2
✓ Branch 0 taken 2152320 times.
✗ Branch 1 not taken.
2152320 if (imp_packed_rhs_ptr[i] != ref_packed_rhs_ptr[i]) {
747 mismatches += 1;
748 }
749 2152320 } else {
750
1/2
✓ Branch 0 taken 1228416 times.
✗ Branch 1 not taken.
1228416 if (imp_packed_rhs_ptr[i] != 0) {
751 mismatches += 1;
752 }
753 }
754 3380736 }
755
3/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
834 ASSERT_EQ(mismatches, 0) << "There are an unexpected amount of mismatches in RHS packing";
756 834 }
757
758 4164 void compare_matmul_result(
759 const MatMulShape& shape, const Rect& output_area, const Buffer& actual, const Buffer& reference) {
760 4164 size_t mismatches = 0;
761 4164 bool printed_row = false;
762 4164 std::ostringstream sstream;
763
2/2
✓ Branch 0 taken 94884 times.
✓ Branch 1 taken 4164 times.
99048 for (size_t m_i = 0; m_i < shape.m; ++m_i) {
764
2/2
✓ Branch 0 taken 7690212 times.
✓ Branch 1 taken 94884 times.
7785096 for (size_t n_i = 0; n_i < shape.n; ++n_i) {
765 7690212 const auto i = m_i * shape.n + n_i;
766
6/8
✓ Branch 0 taken 7690212 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6881252 times.
✓ Branch 3 taken 808960 times.
✓ Branch 4 taken 6881252 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6015836 times.
✓ Branch 7 taken 865416 times.
13706048 const auto in_area = m_i >= output_area.start_row() && m_i < output_area.end_row() &&
767
4/6
✓ Branch 0 taken 6015836 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5388508 times.
✓ Branch 3 taken 627328 times.
✓ Branch 4 taken 5388508 times.
✗ Branch 5 not taken.
6015836 n_i >= output_area.start_col() && n_i < output_area.end_col();
768
769
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7690212 times.
7690212 const auto imp_value = read_array<int8_t>(actual.data(), i);
770
3/4
✓ Branch 0 taken 4807837 times.
✓ Branch 1 taken 2882375 times.
✓ Branch 2 taken 4807837 times.
✗ Branch 3 not taken.
7690212 const auto ref_value = in_area ? read_array<int8_t>(reference.data(), i) : 0;
771 7690212 const auto error = std::abs(imp_value - ref_value);
772 7690212 const auto threshold = in_area ? 1 : 0;
773 7690212 const bool mismatch = error > threshold;
774
1/2
✓ Branch 0 taken 7690212 times.
✗ Branch 1 not taken.
7690212 if (mismatch) {
775 if (not printed_row) {
776 sstream << " row=" << m_i << ", columns: ";
777 printed_row = true;
778 }
779 sstream << n_i << ", ";
780 }
781 7690212 mismatches += static_cast<size_t>(mismatch);
782 7690212 }
783
1/2
✓ Branch 0 taken 94884 times.
✗ Branch 1 not taken.
94884 if (printed_row) {
784 sstream << "\n";
785 }
786 94884 printed_row = false;
787 94884 }
788
3/20
✓ Branch 0 taken 4164 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4164 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✓ Branch 19 taken 4164 times.
4164 ASSERT_EQ(mismatches, 0) << "Mismatches between reference result and actual result:\n" << sstream.str();
789 4164 }
790
791 /// Test MatMul of GEMM/GEMV like kernel
792 834 void test_matmul(
793 const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) {
794 834 const auto imp_dst_size = variant.matmul.get_dst_size(shape.m, shape.n);
795
2/12
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 834 times.
834 ASSERT_EQ(imp_dst_size, reference.dst_qsi8_clamped.size());
796
797 834 Buffer imp_dst(imp_dst_size, 0);
798
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
3336 const auto [imp_lhs_offset, lhs_data] = [&]() -> std::tuple<size_t, const Buffer&> {
799
2/2
✓ Branch 0 taken 666 times.
✓ Branch 1 taken 168 times.
834 if (variant.lhs_pack.has_value()) {
800 666 return {variant.matmul.get_packed_lhs_offset(output_area.start_row(), shape.k), reference.packed_lhs};
801 }
802 168 return {output_area.start_row() * shape.k, reference.lhs_qai8};
803 834 }();
804
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const size_t imp_packed_rhs_offset = variant.matmul.get_packed_rhs_offset(output_area.start_col(), shape.k);
805 1668 const size_t imp_dst_offset =
806
3/6
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
834 variant.matmul.get_dst_offset(output_area.start_row(), output_area.start_col(), shape.n * sizeof(int8_t));
807
5/18
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 834 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 834 times.
834 ASSERT_EQ(imp_dst_offset, output_area.start_row() * shape.n + output_area.start_col());
808
809 834 kai_matmul_requantize32_params imp_main_params{};
810 834 imp_main_params.min_value = reference.clamp.min;
811 834 imp_main_params.max_value = reference.clamp.max;
812 834 imp_main_params.output_zero_point = reference.qa_dst.zero_point;
813
814
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
834 abi_check(
815
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 variant.matmul.matmul, output_area.height(), output_area.width(), shape.k, lhs_data.data() + imp_lhs_offset,
816 834 reference.packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t),
817 834 sizeof(int8_t), &imp_main_params);
818
819
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
834 compare_matmul_result(shape, output_area, imp_dst, reference.dst_qsi8_clamped);
820 834 }
821
822 } // namespace
823
824 using MatMulQuantizedTest = testing::TestWithParam<std::tuple<MatMulVariant, MatMulShape, MatrixPortion, float>>;
825 using IndirectMatMulQuantizedTestParams = std::tuple<IndirectMatMulVariant, MatMulShape, size_t, MatrixPortion, float>;
826 using IndirectMatMulQuantizedTest = testing::TestWithParam<IndirectMatMulQuantizedTestParams>;
827
828 2502 static std::string test_description(
829 const MatMulVariant& variant, //
830 const MatMulShape& shape, //
831 const MatrixPortion& portion, float clamp_keep_ratio) {
832 2502 std::ostringstream sstream;
833
834
2/4
✓ Branch 0 taken 2502 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2502 times.
✗ Branch 3 not taken.
2502 sstream << test_description(variant.name, shape, portion, true, clamp_keep_ratio);
835
836
1/2
✓ Branch 0 taken 2502 times.
✗ Branch 1 not taken.
2502 return sstream.str();
837 2502 };
838
839 19980 [[maybe_unused]] static void PrintTo(const IndirectMatMulQuantizedTestParams& param, std::ostream* os) {
840 119880 const auto& [variant, shape, k_chunk_length, portion, clamp_keep_ratio] = param;
841
842 39960 *os << variant.name << "__";
843 19980 PrintTo(shape, os);
844 39960 *os << "__K_chunk_length_" << k_chunk_length;
845 39960 *os << "__clamp_keep_ratio_" << static_cast<int>(clamp_keep_ratio * 100) << "__";
846 19980 PrintTo(portion, os);
847 19980 };
848
849
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
4176 TEST_P(MatMulQuantizedTest, EndToEnd) {
850 14178 const auto& [variant, shape, output_portion, clamp_keep_ratio] = GetParam();
851
852
2/2
✓ Branch 0 taken 834 times.
✓ Branch 1 taken 834 times.
1668 if (!variant.is_supported()) {
853
3/6
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
834 GTEST_SKIP() << "Unsupported CPU feature";
854 }
855
856 4170 TestDataId test_data_id{shape, variant.acc_pack, shape.k, false, clamp_keep_ratio};
857 834 const TestReference& reference = get_test_reference(test_data_id);
858
859 // Check scheduling parameters
860 1668 const auto imp_mr = variant.matmul.get_mr();
861 1668 const auto imp_nr = variant.matmul.get_nr();
862 1668 const auto imp_kr = variant.matmul.get_kr();
863 1668 const auto imp_sr = variant.matmul.get_sr();
864
865
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_mr, variant.acc_pack.m);
866
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_nr, variant.acc_pack.n);
867
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_kr, variant.acc_pack.k);
868
3/14
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 834 times.
834 ASSERT_EQ(imp_sr, 1);
869
870 // Check that stepping is a multiple of accumulation
871 1668 const auto imp_m_step = variant.matmul.get_m_step();
872 1668 const auto imp_n_step = variant.matmul.get_n_step();
873
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_m_step, variant.acc_step.m);
874
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_n_step, variant.acc_step.n);
875
876 // Test kernels. Note that packing and actual stepping might not be the same
877 4170 const auto pack_portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_pack.m, variant.acc_pack.n);
878 834 const auto matmul_portion =
879 3336 output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n);
880
2/2
✓ Branch 0 taken 168 times.
✓ Branch 1 taken 666 times.
834 if (variant.lhs_pack.has_value()) {
881 666 test_lhs_pack(shape, variant, pack_portion, reference);
882 666 }
883 834 test_rhs_pack(shape, variant, pack_portion, reference);
884 834 test_matmul(shape, variant, matmul_portion, reference);
885 1668 }
886
887 namespace imatmul {
888
889 /// Perform LHS IMATMUL packing
890 3330 static Buffer lhs_pack(
891 const LhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t m,
892 const KChunk& k_chunk) {
893 6660 const void* const* indirection_pointer =
894 3330 reinterpret_cast<const void* const*>(reference.lhs_qai8_indirect_packed.data());
895
896 // Allocate buffer
897 3330 const size_t dst_size = variant.get_packed_lhs_size(m, k_chunk.count, k_chunk.length);
898 3330 Buffer packed(dst_size);
899
900 // Calculate offsets
901
1/2
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
3330 const size_t input_offset = portion.start_row() * k_chunk.count;
902
2/4
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
3330 const size_t dst_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length);
903
904
1/2
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
3330 abi_check(
905 3330 variant.pack, // Kernel
906
1/2
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
3330 portion.height(), k_chunk.count, k_chunk.length, // Dimensions
907 3330 indirection_pointer + input_offset, // Indirection input
908 3330 reference.lhs_qai8_indirect_offset, // chunk offset
909 3330 reference.lhs_qai8_indirect_padding.data(), // padding pointer
910 3330 packed.data() + dst_offset);
911
912 3330 return packed;
913 3330 }
914
915 /// Perform RHS IMATMUL packing
916 3330 static Buffer rhs_pack(
917 const RhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t n,
918 const KChunk& k_chunk) {
919 // Allocate output buffer
920 3330 const size_t dst_size = variant.get_packed_rhs_size(n, k_chunk.count, k_chunk.length);
921 3330 Buffer packed(dst_size);
922
923 // Caluclate effective quantization parameters
924 9990 const kai_rhs_pack_qsi8cx_params quantization{
925 3330 reference.qa_lhs.zero_point,
926 3330 reference.qa_lhs.scale / reference.qa_dst.scale,
927 };
928
929 // Calculate offsets
930
2/4
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
3330 const size_t rhs_offset = variant.get_rhs_offset(portion.start_col());
931
2/4
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
3330 const size_t bias_offset = variant.get_bias_offset(portion.start_col());
932
2/4
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
3330 const size_t scale_offset = variant.get_scale_offset(portion.start_col());
933
2/4
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
3330 const size_t dst_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length);
934
935 // Pack
936
1/2
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
3330 abi_check(
937 3330 variant.pack, // Kernel
938
1/2
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
3330 portion.width(), k_chunk.count, k_chunk.length, // Dimensions
939 3330 n * sizeof(uint8_t), // Row stride
940 3330 reference.rhs_qsi8.data() + rhs_offset, // RHS matrix
941 3330 reference.bias_qsi32.data() + bias_offset, // Bias
942 3330 reference.rhs_scales.data() + scale_offset, // Scales
943 3330 packed.data() + dst_offset, // Output
944 3330 &quantization);
945
946 3330 return packed;
947 3330 }
948
949 /// Calculate the matmul result from IMATMUL kernels
950 3330 static Buffer matmul(
951 const MatMulIndirectKernel& variant, const Rect& portion, const TestReference& reference, const Buffer& packed_lhs,
952 const Buffer& packed_rhs, const MatMulShape& shape, const KChunk& k_chunk) {
953 // Calculate portion offsets.
954 3330 size_t dst_offset = variant.get_dst_offset(portion.start_row(), portion.start_col(), shape.n);
955 3330 size_t lhs_offset = variant.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length);
956 3330 size_t rhs_offset = variant.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length);
957
958 // Allocate output buffer
959 3330 const size_t dst_size = variant.get_dst_size(shape.m, shape.n);
960 3330 Buffer dst(dst_size, 0);
961
962 // Calculate geffective uantization parameters
963 3330 kai_matmul_requantize32_params requantization{};
964 3330 requantization.min_value = reference.clamp.min;
965 3330 requantization.max_value = reference.clamp.max;
966 3330 requantization.output_zero_point = reference.qa_dst.zero_point;
967
968 // Call matmul kernel
969
1/2
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
3330 abi_check(
970 3330 variant.imatmul, // Kernel
971
2/4
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
3330 portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions
972 3330 packed_lhs.data() + lhs_offset, // LHS
973 3330 packed_rhs.data() + rhs_offset, // RHS
974 3330 dst.data() + dst_offset, // DST
975 3330 shape.n * sizeof(uint8_t), &requantization);
976
977 3330 return dst;
978 3330 }
979 } // namespace imatmul
980
981
8/16
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
16656 TEST_P(IndirectMatMulQuantizedTest, EndToEnd) {
982 /* This is a bit special, as shape.k must be k_chunk_len * k_chunk_count
983 * so instead of inventing a new special kind of shape, simply multiply
984 * with `k_chunk_len` here */
985 39960 const auto& [variant, shape_k_chunk, k_chunk_len, output_portion, clamp_keep_ratio] = GetParam();
986 19980 const KChunk k_chunk{shape_k_chunk.k, k_chunk_len};
987 19980 MatMulShape shape{shape_k_chunk.m, shape_k_chunk.n, k_chunk.count * k_chunk.length};
988
989
2/2
✓ Branch 0 taken 3330 times.
✓ Branch 1 taken 3330 times.
6660 if (!variant.is_supported()) {
990
3/6
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3330 times.
✗ Branch 5 not taken.
3330 GTEST_SKIP() << "Unsupported CPU feature";
991 }
992
993 // Toggle padding testst when LHS has more than one row
994 9990 TestDataId test_data_id{shape, variant.acc_pack, k_chunk.length, shape.m > 1, clamp_keep_ratio};
995 3330 const TestReference& reference = get_test_reference(test_data_id);
996 13320 const Rect portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n);
997
998 6660 Buffer packed_lhs = imatmul::lhs_pack(variant.lhs_pack, portion, reference, shape.m, k_chunk);
999
2/4
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
6660 Buffer packed_rhs = imatmul::rhs_pack(variant.rhs_pack, portion, reference, shape.n, k_chunk);
1000
2/4
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
6660 Buffer impl_result = imatmul::matmul(variant.matmul, portion, reference, packed_lhs, packed_rhs, shape, k_chunk);
1001
1/2
✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
3330 compare_matmul_result(shape, portion, impl_result, reference.dst_qsi8_clamped);
1002 6660 }
1003
1004 static constexpr std::array shapes{
1005 // clang-format off
1006 MatMulShape{ 1, 1, 1},
1007 MatMulShape{ 1, 16, 4},
1008 MatMulShape{ 1, 16, 16},
1009 MatMulShape{ 1, 17, 4},
1010 MatMulShape{ 1, 19, 24},
1011 MatMulShape{ 1, 32, 4},
1012 MatMulShape{ 1, 32, 32},
1013 MatMulShape{ 1, 33,200},
1014 MatMulShape{ 1, 49, 21},
1015 MatMulShape{ 1, 64, 4},
1016 MatMulShape{ 1, 65, 4},
1017 MatMulShape{ 1, 300, 10},
1018 MatMulShape{ 1, 512, 4},
1019 MatMulShape{ 1, 1523, 10},
1020 MatMulShape{ 2, 195, 50},
1021 MatMulShape{ 3, 6, 6},
1022 MatMulShape{ 3, 28, 25},
1023 MatMulShape{ 3, 184,177},
1024 MatMulShape{ 4, 16, 27},
1025 MatMulShape{ 5, 136, 23},
1026 MatMulShape{ 6, 18, 31},
1027 MatMulShape{ 6, 28, 1},
1028 MatMulShape{ 6, 29, 24},
1029 MatMulShape{ 16, 16, 4},
1030 MatMulShape{ 20, 30, 40},
1031 MatMulShape{ 23, 1, 43},
1032 MatMulShape{ 32, 14, 1},
1033 MatMulShape{ 32, 16, 27},
1034 MatMulShape{ 32, 32, 3},
1035 MatMulShape{ 32, 32, 4},
1036 MatMulShape{ 33, 29, 24},
1037 MatMulShape{ 64, 64, 3},
1038 MatMulShape{ 64, 64, 4},
1039 MatMulShape{ 96, 96, 3},
1040 MatMulShape{123, 85, 45},
1041 MatMulShape{128, 128, 3},
1042 MatMulShape{130, 130, 6},
1043 // clang-format on
1044 };
1045
1046 static constexpr std::array portions{
1047 // clang-format off
1048 // (Start row , start col , height , width)
1049 MatrixPortion( 0 , 0 , 1 , 1) , // Full matrix.
1050 MatrixPortion( 0 , 0 , 1 , 0.5) , // Left half
1051 MatrixPortion( 0 , 0 , 0.5 , 1) , // Upper half
1052 MatrixPortion( 0 , 0.5 , 1 , 0.5) , // Right half
1053 MatrixPortion( 0.5 , 0 , 0.5 , 1) , // Bottom half
1054 MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3) , // Center ninth
1055 // clang-format on
1056 };
1057
1058
18/56
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✓ Branch 20 taken 666 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✓ Branch 22 taken 1332 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
4002 INSTANTIATE_TEST_SUITE_P(
1059 matmul_clamp_qai8_qai8p_qsi8cxp, MatMulQuantizedTest,
1060 testing::Combine(
1061 testing::ValuesIn(get_gemm_variants()), //
1062 testing::ValuesIn(shapes), //
1063 testing::ValuesIn({
1064 // clang-format off
1065 MatrixPortion( 0, 0, 1, 1), // Full matrix.
1066 MatrixPortion( 0, 0, 0.25, 0.25), // Top-left corner.
1067 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
1068 // clang-format on
1069 }),
1070 testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})),
1071 [](const auto& info) -> std::string {
1072 return test_description(
1073 std::get<MatMulVariant>(info.param), //
1074 std::get<MatMulShape>(info.param), //
1075 std::get<MatrixPortion>(info.param), //
1076 std::get<float>(info.param));
1077 });
1078
1079
18/56
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✓ Branch 20 taken 168 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✓ Branch 22 taken 336 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
1014 INSTANTIATE_TEST_SUITE_P(
1080 matmul_clamp_qai8_qai8_qsi8cxp, MatMulQuantizedTest,
1081 testing::Combine(
1082 testing::ValuesIn(get_gemv_variants()),
1083 testing::ValuesIn({
1084 // clang-format off
1085 MatMulShape{ 1, 1, 1},
1086 MatMulShape{ 1, 16, 4},
1087 MatMulShape{ 1, 16, 16},
1088 MatMulShape{ 1, 17, 4},
1089 MatMulShape{ 1, 19, 24},
1090 MatMulShape{ 1, 32, 4},
1091 MatMulShape{ 1, 32, 32},
1092 MatMulShape{ 1, 33,200},
1093 MatMulShape{ 1, 49, 21},
1094 MatMulShape{ 1, 64, 4},
1095 MatMulShape{ 1, 65, 4},
1096 MatMulShape{ 1, 300, 10},
1097 MatMulShape{ 1, 512, 4},
1098 MatMulShape{ 1, 1523, 10},
1099 // clang-format on
1100 }),
1101 testing::ValuesIn({
1102 // clang-format off
1103 MatrixPortion(0, 0, 1, 1), // Full matrix.
1104 MatrixPortion(0, .5, 1, .5), // Right half
1105 MatrixPortion(0, 0, 1, .5), // Left half
1106 MatrixPortion(0, .25, 1, .5) // Middle half
1107 // clang-format on
1108 }),
1109 // Clamp range
1110 testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio
1111 [](const auto& info) -> std::string {
1112 return test_description(
1113 std::get<MatMulVariant>(info.param), //
1114 std::get<MatMulShape>(info.param), //
1115 std::get<MatrixPortion>(info.param), //
1116 std::get<float>(info.param));
1117 });
1118
1119
20/64
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 2664 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 5328 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
7998 INSTANTIATE_TEST_SUITE_P(
1120 ShapesSmallKC, IndirectMatMulQuantizedTest,
1121 testing::Combine(
1122 testing::ValuesIn(get_indirect_gemm_variants()), //
1123 testing::ValuesIn(shapes), //
1124 // k_chunk_len
1125 testing::ValuesIn(std::initializer_list<size_t>{1, 2, 3, 4, 8, 11}), //
1126 testing::ValuesIn(portions), //
1127 // Clamp range
1128 testing::Values(0.1F)),
1129 testing::PrintToStringParamName());
1130
1131
20/64
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 444 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 888 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
1338 INSTANTIATE_TEST_SUITE_P(
1132 ShapesKC32, IndirectMatMulQuantizedTest,
1133 testing::Combine(
1134 testing::ValuesIn(get_indirect_gemm_variants()), //
1135 testing::ValuesIn(shapes), //
1136 // k_chunk_len
1137 testing::ValuesIn(std::initializer_list<size_t>{32}), //
1138 testing::ValuesIn(portions), //
1139 // Clamp range
1140 testing::Values(0.1F)),
1141 testing::PrintToStringParamName());
1142
1143
22/72
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✓ Branch 24 taken 222 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✓ Branch 26 taken 444 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
672 INSTANTIATE_TEST_SUITE_P(
1144 Clamp, IndirectMatMulQuantizedTest,
1145 testing::Combine(
1146 testing::ValuesIn(get_indirect_gemm_variants()), //
1147 testing::ValuesIn(shapes), //
1148 // k_chunk_len
1149 testing::ValuesIn(std::initializer_list<size_t>{1}), //
1150 testing::Values(MatrixPortion(0, 0, 1, 1)), //
1151 // Clamp range
1152 testing::ValuesIn(std::initializer_list<float>{1.0f, 0.9f, 0.5f})), // clamp_keep_ratio
1153 testing::PrintToStringParamName());
1154
1155 } // namespace kai::test
1156