KleidiAI Coverage Report


Directory: ./
File: test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.9% 529 1 547
Functions: 97.1% 67 0 69
Branches: 41.7% 384 2 922

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <gtest/gtest.h>
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <cstdlib>
13 #include <functional>
14 #include <optional>
15 #include <sstream>
16 #include <string>
17 #include <string_view>
18 #include <tuple>
19 #include <unordered_map>
20
21 #include "kai/kai_common.h"
22 #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h"
23 #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h"
24 #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h"
25 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h"
26 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h"
27 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h"
28 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h"
29 #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_interface.h"
30 #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h"
31 #include "kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h"
32 #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h"
33 #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h"
34 #include "test/common/buffer.hpp"
35 #include "test/common/cpu_info.hpp"
36 #include "test/common/matmul_test_common.hpp"
37 #include "test/common/matrix_portion.hpp"
38 #include "test/common/memory.hpp"
39 #include "test/common/rect.hpp"
40 #include "test/common/sme.hpp"
41 #include "test/reference/binary_elementwise.hpp"
42 #include "test/reference/clamp.hpp"
43 #include "test/reference/fill.hpp"
44 #include "test/reference/matmul.hpp"
45 #include "test/reference/matmul_pack.hpp"
46 #include "test/reference/quantize.hpp"
47 #include "test/reference/reduce.hpp"
48 #include "test/reference/reorder.hpp"
49 #include "test/reference/transpose.hpp"
50
51 namespace kai::test {
52
53 // Ensure static linkage for all functionality local to this test file
54 namespace {
55
56 struct KChunk {
57 size_t count;
58 size_t length;
59 };
60
61
3/6
✓ Branch 0 taken 3336 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3336 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3336 times.
✗ Branch 5 not taken.
3336 struct LhsPackKernel {
62 std::function<size_t(size_t mr)> get_m_step;
63 std::function<size_t(size_t m_idx, size_t lhs_stride)> get_lhs_offset;
64 std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)> get_packed_lhs_offset;
65 std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)> get_packed_lhs_size;
66 std::function<void(
67 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
68 void* lhs_packed)>
69 pack;
70 };
71
72
2/4
✓ Branch 0 taken 46624 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 46624 times.
✗ Branch 3 not taken.
46624 struct LhsPackIndirectKernel {
73 std::function<size_t()> get_m_step;
74 std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_packed_lhs_offset;
75 std::function<size_t(size_t m, size_t k_chunk_count, size_t k_chunk_length)> get_packed_lhs_size;
76 std::function<void(
77 size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset,
78 const void* zero, void* packed_lhs)>
79 pack;
80 };
81
82
5/10
✓ Branch 0 taken 4176 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4176 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4176 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4176 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4176 times.
✗ Branch 9 not taken.
4176 struct RhsPackKernel {
83 std::function<size_t()> get_n_step;
84 std::function<size_t(size_t n_idx)> get_rhs_offset;
85 std::function<size_t(size_t n_idx)> get_bias_offset;
86 std::function<size_t(size_t n_idx)> get_scale_offset;
87 std::function<size_t(size_t n_idx, size_t k)> get_packed_rhs_offset;
88 std::function<size_t(size_t n, size_t k)> get_packed_rhs_size;
89 std::function<void(
90 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
91 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes,
92 const struct kai_rhs_pack_qsi8cx_params* params)>
93 pack;
94 };
95
96
5/10
✓ Branch 0 taken 46624 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 46624 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 46624 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 46624 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 46624 times.
✗ Branch 9 not taken.
46624 struct RhsPackIndirectKernel {
97 std::function<size_t()> get_n_step;
98 std::function<size_t(size_t n_idx)> get_rhs_offset;
99 std::function<size_t(size_t n_idx)> get_bias_offset;
100 std::function<size_t(size_t n_idx)> get_scale_offset;
101 std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_packed_rhs_offset;
102 std::function<size_t(size_t n, size_t k_chunk_count, size_t k_chunk_length)> get_packed_rhs_size;
103 std::function<void(
104 size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride, const void* rhs, const void* bias,
105 const void* scale, void* rhs_packed, const kai_rhs_pack_qsi8cx_params* params)>
106 pack;
107 };
108
109
9/18
✓ Branch 0 taken 4176 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4176 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4176 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4176 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4176 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4176 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4176 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4176 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4176 times.
✗ Branch 17 not taken.
4176 struct MatMulKernel {
110 std::function<size_t(void)> get_m_step;
111 std::function<size_t(void)> get_n_step;
112 std::function<size_t(void)> get_mr;
113 std::function<size_t(void)> get_nr;
114 std::function<size_t(void)> get_kr;
115 std::function<size_t(void)> get_sr;
116 std::function<size_t(size_t m_idx, size_t k)> get_packed_lhs_offset;
117 std::function<size_t(size_t n_idx, size_t k)> get_packed_rhs_offset;
118 std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride)> get_dst_offset;
119 std::function<size_t(size_t m, size_t n)> get_dst_size;
120 std::function<void(
121 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
122 size_t dst_stride_col, const kai_matmul_requantize32_params* params)>
123 matmul;
124 };
125
126
5/10
✓ Branch 0 taken 46624 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 46624 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 46624 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 46624 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 46624 times.
✗ Branch 9 not taken.
46624 struct MatMulIndirectKernel {
127 std::function<size_t(void)> get_m_step;
128 std::function<size_t(void)> get_n_step;
129 std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset;
130 std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset;
131 std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride_row)> get_dst_offset;
132 std::function<size_t(size_t m, size_t n)> get_dst_size;
133 std::function<void(
134 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_lenght, const void* lhs_packed, const void* rhs_packed,
135 void* dst, size_t dst_stride_row, const kai_matmul_requantize32_params* params)>
136 imatmul;
137 };
138
139 /// Make sure that interface matches for qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa
140 const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel&
141 1 get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface() {
142 static kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel ukernel;
143
144 1 ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
145 1 ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
146 1 ukernel.get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
147 1 ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
148 1 ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
149 1 ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
150 1 ukernel.get_lhs_packed_offset =
151 kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
152 1 ukernel.get_rhs_packed_offset =
153 kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
154 1 ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
155 1 ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
156 1 ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
157
158 1 return ukernel;
159 }
160
161 /// Make sure that interface matches for qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme_mopa
162 const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel&
163 1 get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface() {
164 static kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel ukernel;
165
166 1 ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
167 1 ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
168 1 ukernel.get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
169 1 ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
170 1 ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
171 1 ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
172 1 ukernel.get_lhs_packed_offset =
173 kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
174 1 ukernel.get_rhs_packed_offset =
175 kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
176 1 ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
177 1 ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
178 1 ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
179
180 1 return ukernel;
181 }
182
183 /// Make sure that interface matches for qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot
184 const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel&
185 1 get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() {
186 static kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel;
187
188 1 ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
189 1 ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
190 1 ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
191 1 ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
192 1 ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
193 1 ukernel.get_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
194 1 ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
195 1 ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
196 1 ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
197 1 ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot;
198
199 1 return ukernel;
200 };
201
202 /// Make sure that interface matches qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa
203 const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel&
204 1 get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface() {
205 static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel;
206
207 1 ukernel.get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
208 1 ukernel.get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
209 1 ukernel.get_lhs_packed_offset =
210 kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
211 1 ukernel.get_rhs_packed_offset =
212 kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
213 1 ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
214 1 ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
215 1 ukernel.run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa;
216
217 1 return ukernel;
218 };
219
220 /// Make sure that interface matches qai8_qai8p2vlx4_qsi8cxps2vlx4b_2vlx2vl_sme_mopa
221 const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel&
222 1 get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface() {
223 static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel;
224
225 1 ukernel.get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
226 1 ukernel.get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
227 1 ukernel.get_lhs_packed_offset =
228 kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
229 1 ukernel.get_rhs_packed_offset =
230 kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
231 1 ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
232 1 ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
233 1 ukernel.run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa;
234
235 1 return ukernel;
236 };
237
238 3 const RhsPackKernel& get_rhs_pack() {
239
3/4
✓ Branch 0 taken 1 time.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
3 static RhsPackKernel ukernel;
240
241 3 ukernel.get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
242 3 ukernel.get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
243 3 ukernel.get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
244 3 ukernel.get_scale_offset = kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
245 3 ukernel.get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
246 3 ukernel.get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
247 3 ukernel.pack = kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
248
249 3 return ukernel;
250 }
251
252 2 const LhsPackKernel& get_lhs_pack() {
253
3/4
✓ Branch 0 taken 1 time.
✓ Branch 1 taken 1 time.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
2 static LhsPackKernel ukernel;
254
255 2 ukernel.get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme;
256 2 ukernel.get_lhs_offset = kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme;
257 2 ukernel.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme;
258 2 ukernel.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme;
259 2 ukernel.pack = kai_run_lhs_pack_x8p2vlx4_x8_sme;
260
261 2 return ukernel;
262 }
263
264
2/4
✓ Branch 0 taken 4176 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4176 times.
✗ Branch 3 not taken.
4176 struct MatMulVariant {
265 std::string_view name; ///< Test identification
266 MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr)
267 MatMulShape acc_step; ///< Accumulator shape for matmul (stepping)
268
269 std::function<bool(void)> is_supported; ///< HW support check
270
271 std::optional<LhsPackKernel> lhs_pack; ///< LHS packing micro-kernel interface
272 RhsPackKernel rhs_pack; ///< RHS packing micro-kernel interface
273 MatMulKernel matmul; ///< Matmul kernel interface
274 };
275
276
2/4
✓ Branch 0 taken 46624 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 46624 times.
✗ Branch 3 not taken.
46624 struct IndirectMatMulVariant {
277 std::string_view name; ///< Test identification
278 MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr)
279 MatMulShape acc_step; ///< Accumulator shape for matmul (stepping)
280
281 std::function<bool(void)> is_supported; ///< HW support check
282
283 LhsPackIndirectKernel lhs_pack; ///< LHS packing micro-kernel interface
284 RhsPackIndirectKernel rhs_pack; ///< RHS packing micro-kernel interface
285 MatMulIndirectKernel matmul; ///< Matmul kernel interface
286 };
287
288 1 const std::array<MatMulVariant, 2>& get_gemm_variants() {
289
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
1 static std::array<MatMulVariant, 2> variants;
290
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
2 static const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& ukernel_sme2 =
291
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface();
292
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
2 static const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& ukernel_sme =
293
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 time.
1 get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface();
294
295 1 variants[0].name = "matmul_qai8_qai8p_qsi8cxp_sme";
296 1 variants[0].acc_pack.m = 2 * get_sme_vector_length<int32_t>();
297 1 variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
298 1 variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
299 1 variants[0].acc_step.m = 2 * get_sme_vector_length<int32_t>();
300 1 variants[0].acc_step.n = 2 * get_sme_vector_length<int32_t>();
301 1 variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
302 1 variants[0].is_supported = cpu_has_sme;
303 1 variants[0].lhs_pack = get_lhs_pack();
304 1 variants[0].rhs_pack = get_rhs_pack();
305 1 variants[0].matmul.get_m_step = ukernel_sme.get_m_step;
306 1 variants[0].matmul.get_n_step = ukernel_sme.get_n_step;
307 1 variants[0].matmul.get_mr = ukernel_sme.get_mr;
308 1 variants[0].matmul.get_nr = ukernel_sme.get_nr;
309 1 variants[0].matmul.get_kr = ukernel_sme.get_kr;
310 1 variants[0].matmul.get_sr = ukernel_sme.get_sr;
311 1 variants[0].matmul.get_packed_lhs_offset = ukernel_sme.get_lhs_packed_offset;
312 1 variants[0].matmul.get_packed_rhs_offset = ukernel_sme.get_rhs_packed_offset;
313 1 variants[0].matmul.get_dst_offset = ukernel_sme.get_dst_offset;
314 1 variants[0].matmul.get_dst_size = ukernel_sme.get_dst_size;
315 1 variants[0].matmul.matmul = ukernel_sme.run_matmul;
316
317 1 variants[1].name = "matmul_qai8_qai8p_qsi8cxp_sme2";
318 1 variants[1].acc_pack.m = 2 * get_sme_vector_length<int32_t>();
319 1 variants[1].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
320 1 variants[1].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
321 1 variants[1].acc_step.m = 2 * get_sme_vector_length<int32_t>();
322 1 variants[1].acc_step.n = 2 * get_sme_vector_length<int32_t>();
323 1 variants[1].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
324 1 variants[1].is_supported = cpu_has_sme2;
325 1 variants[1].lhs_pack = get_lhs_pack();
326 1 variants[1].rhs_pack = get_rhs_pack();
327 1 variants[1].matmul.get_m_step = ukernel_sme2.get_m_step;
328 1 variants[1].matmul.get_n_step = ukernel_sme2.get_n_step;
329 1 variants[1].matmul.get_mr = ukernel_sme2.get_mr;
330 1 variants[1].matmul.get_nr = ukernel_sme2.get_nr;
331 1 variants[1].matmul.get_kr = ukernel_sme2.get_kr;
332 1 variants[1].matmul.get_sr = ukernel_sme2.get_sr;
333 1 variants[1].matmul.get_packed_lhs_offset = ukernel_sme2.get_lhs_packed_offset;
334 1 variants[1].matmul.get_packed_rhs_offset = ukernel_sme2.get_rhs_packed_offset;
335 1 variants[1].matmul.get_dst_offset = ukernel_sme2.get_dst_offset;
336 1 variants[1].matmul.get_dst_size = ukernel_sme2.get_dst_size;
337 1 variants[1].matmul.matmul = ukernel_sme2.run_matmul;
338
339 1 return variants;
340 }
341
342 1 const std::array<IndirectMatMulVariant, 2>& get_indirect_gemm_variants() {
343
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
1 static std::array<IndirectMatMulVariant, 2> variants;
344
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
2 static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel_sme =
345
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface();
346
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
2 static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel_sme2 =
347
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 time.
1 get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface();
348
349 1 variants[0].name = "imatmul_qai8_qai8p_qsi8cxp_sme";
350 1 variants[0].acc_pack.m = 2 * get_sme_vector_length<int32_t>();
351 1 variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
352 1 variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
353 1 variants[0].acc_step.m = 2 * get_sme_vector_length<int32_t>();
354 1 variants[0].acc_step.n = 2 * get_sme_vector_length<int32_t>();
355 1 variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
356 1 variants[0].is_supported = cpu_has_sme;
357 1 variants[0].lhs_pack.get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
358 1 variants[0].lhs_pack.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
359 1 variants[0].lhs_pack.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
360 1 variants[0].lhs_pack.pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
361 1 variants[0].rhs_pack.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
362 1 variants[0].rhs_pack.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
363 1 variants[0].rhs_pack.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
364 1 variants[0].rhs_pack.get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
365 1 variants[0].rhs_pack.get_packed_rhs_offset =
366 kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
367 1 variants[0].rhs_pack.get_packed_rhs_size =
368 kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
369 1 variants[0].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
370 1 variants[0].matmul.get_m_step = ukernel_sme.get_m_step;
371 1 variants[0].matmul.get_n_step = ukernel_sme.get_n_step;
372 1 variants[0].matmul.get_lhs_packed_offset = ukernel_sme.get_lhs_packed_offset;
373 1 variants[0].matmul.get_rhs_packed_offset = ukernel_sme.get_rhs_packed_offset;
374 1 variants[0].matmul.get_dst_offset = ukernel_sme.get_dst_offset;
375 1 variants[0].matmul.get_dst_size = ukernel_sme.get_dst_size;
376 1 variants[0].matmul.imatmul = ukernel_sme.run_imatmul;
377
378 1 variants[1].name = "imatmul_qai8_qai8p_qsi8cxp_sme2";
379 1 variants[1].acc_pack.m = 2 * get_sme_vector_length<int32_t>();
380 1 variants[1].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
381 1 variants[1].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
382 1 variants[1].acc_step.m = 2 * get_sme_vector_length<int32_t>();
383 1 variants[1].acc_step.n = 2 * get_sme_vector_length<int32_t>();
384 1 variants[1].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
385 1 variants[1].is_supported = cpu_has_sme2;
386 1 variants[1].lhs_pack.get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
387 1 variants[1].lhs_pack.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
388 1 variants[1].lhs_pack.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
389 1 variants[1].lhs_pack.pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme;
390 1 variants[1].rhs_pack.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
391 1 variants[1].rhs_pack.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
392 1 variants[1].rhs_pack.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
393 1 variants[1].rhs_pack.get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
394 1 variants[1].rhs_pack.get_packed_rhs_offset =
395 kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
396 1 variants[1].rhs_pack.get_packed_rhs_size =
397 kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
398 1 variants[1].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme;
399 1 variants[1].matmul.get_m_step = ukernel_sme2.get_m_step;
400 1 variants[1].matmul.get_n_step = ukernel_sme2.get_n_step;
401 1 variants[1].matmul.get_lhs_packed_offset = ukernel_sme2.get_lhs_packed_offset;
402 1 variants[1].matmul.get_rhs_packed_offset = ukernel_sme2.get_rhs_packed_offset;
403 1 variants[1].matmul.get_dst_offset = ukernel_sme2.get_dst_offset;
404 1 variants[1].matmul.get_dst_size = ukernel_sme2.get_dst_size;
405 1 variants[1].matmul.imatmul = ukernel_sme2.run_imatmul;
406
407 1 return variants;
408 }
409
410 1 const std::array<MatMulVariant, 1>& get_gemv_variants() {
411
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
1 static std::array<MatMulVariant, 1> variants;
412
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
2 static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel =
413
1/2
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
1 get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface();
414
415 1 variants[0].name = "matmul_qai8_qai8_qsi8cxp";
416 1 variants[0].acc_pack.m = 1;
417 1 variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>();
418 1 variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t);
419 1 variants[0].acc_step.m = 1;
420 1 variants[0].acc_step.n = 16 * get_sme_vector_length<int32_t>();
421 1 variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t);
422 1 variants[0].is_supported = cpu_has_sme2;
423 1 variants[0].lhs_pack = std::nullopt;
424 1 variants[0].rhs_pack = get_rhs_pack();
425 1 variants[0].matmul.get_m_step = ukernel.get_m_step;
426 1 variants[0].matmul.get_n_step = ukernel.get_n_step;
427 169 variants[0].matmul.get_mr = []() -> size_t { return 1; };
428 1 variants[0].matmul.get_nr = ukernel.get_nr;
429 1 variants[0].matmul.get_kr = ukernel.get_kr;
430 1 variants[0].matmul.get_sr = ukernel.get_sr;
431 1 variants[0].matmul.get_packed_lhs_offset = nullptr;
432 1 variants[0].matmul.get_packed_rhs_offset = ukernel.get_rhs_packed_offset;
433 1 variants[0].matmul.get_dst_offset = ukernel.get_dst_offset;
434 1 variants[0].matmul.get_dst_size = ukernel.get_dst_size;
435 1 variants[0].matmul.matmul = ukernel.run_matmul;
436
437 1 return variants;
438 }
439
440 constexpr uint32_t seed = 0; ///< Random seed used for tests
441
442 /// Quantization parameters
443 struct Quant {
444 float scale;
445 int32_t zero_point;
446 };
447
448 /// Reference test data
449 struct TestReference {
450 Range<int8_t> clamp;
451
452 Quant qa_lhs;
453 Quant qa_dst;
454
455 Buffer lhs_qai8;
456 Buffer lhs_qai8_scales;
457 Buffer lhs_qai8_zero_points;
458 Buffer lhs_qai8_indirect;
459 Buffer lhs_qai8_indirect_packed;
460 Buffer lhs_qai8_indirect_padding;
461 size_t lhs_qai8_indirect_offset;
462
463 Buffer rhs_qsi8;
464 Buffer rhs_scales;
465
466 Buffer bias_qsi32;
467
468 Buffer dst_qsi8_clamped;
469
470 Buffer packed_lhs;
471 Buffer packed_rhs;
472 };
473
474 constexpr int8_t padding_value = 0;
475
476 // Functionality for hashing generated test data.
477 // This is particularly useful for portion testing
478 // which reuses the exact same data for all portions
479 struct TestDataId {
480 MatMulShape shape;
481 MatMulShape shape_pack;
482 size_t chunk_len;
483 bool pad_testing;
484 float clamp_ratio;
485
486 struct Hash {
487 11085 size_t operator()(const TestDataId& id) const {
488 11085 return //
489 22170 (MatMulShape::Hash{}(id.shape) << 0) ^ //
490 22170 (MatMulShape::Hash{}(id.shape_pack) << 1) ^ //
491 22170 (std::hash<size_t>{}(id.chunk_len) << 2) ^ //
492 22170 (std::hash<bool>{}(id.pad_testing) << 3) ^ //
493 11085 (std::hash<float>{}(id.clamp_ratio) << 4);
494 }
495 };
496
497 private:
498 13342 friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) {
499 13342 return //
500
2/2
✓ Branch 0 taken 9267 times.
✓ Branch 1 taken 4075 times.
13342 lhs.shape == rhs.shape && //
501
1/2
✓ Branch 0 taken 9267 times.
✗ Branch 1 not taken.
9267 lhs.shape_pack == rhs.shape_pack && //
502
2/2
✓ Branch 0 taken 9231 times.
✓ Branch 1 taken 36 times.
9267 lhs.chunk_len == rhs.chunk_len && //
503
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 9231 times.
9231 lhs.pad_testing == rhs.pad_testing && //
504 9231 lhs.clamp_ratio == rhs.clamp_ratio;
505 }
506 };
507
508 // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
509 1 std::unordered_map<TestDataId, TestReference, TestDataId::Hash> g_data;
510 // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
511
512 /// Generate test reference data
513 10158 const TestReference& get_test_reference(const TestDataId& test_data_id) {
514 // ============================================================
515 // Generates input and reference output data
516 // ============================================================
517
518 // Attempt to find test data in cache
519 10158 const auto data_it = g_data.find(test_data_id);
520
2/2
✓ Branch 0 taken 9231 times.
✓ Branch 1 taken 927 times.
10158 if (data_it != g_data.end()) {
521 9231 return data_it->second;
522 }
523
524 839526 const auto& [shape, pack_shape, k_chunk_len, pad_testing, clamp_ratio] = test_data_id;
525
526 // Generates the input data in floating-point.
527 2781 Buffer lhs_f32 = fill_random<float>(shape.m * shape.k, seed);
528
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
2781 const Buffer rhs_f32 = fill_random<float>(shape.k * shape.n, seed);
529
2/4
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
1854 const Buffer bias_f32 = fill_random<float>(shape.n, seed);
530
531 // Quantizes the input data.
532 // * LHS: 8-bit asymmetric per-matrix quantization.
533 // * RHS: 8-bit symmetric per-channel quantization.
534 // * Bias: 32-bit symmetric per-channel quantization.
535 //
536 // Treat entire LHS as one row vector, to calculate one single pair of <scale, zero_point>
537 8343 auto [lhs_qai8, lhs_qai8_scales, lhs_qai8_zero_points] =
538
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(
539
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 lhs_f32.data(), 1, shape.m * shape.k, shape.m * shape.k);
540
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
1854 const auto lhs_scale = read_array<float>(lhs_qai8_scales.data(), 0);
541
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
1854 const auto lhs_zero_point = read_array<int32_t>(lhs_qai8_zero_points.data(), 0);
542
543 2781 const size_t k_chunk_count = shape.k / k_chunk_len;
544 assert(k_chunk_count * k_chunk_len == shape.k);
545
546 // Setup an indirection buffer, where each "row" contains `k_chunk_count`
547 // pointers to chunks of length `k_chunk_len` in the input_buffer
548
2/4
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
1854 Buffer lhs_qai8_indirect(shape.m * k_chunk_count * sizeof(void*));
549
2/4
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
1854 Buffer lhs_padding(k_chunk_len, padding_value);
550
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 auto* lhs_qai8_indirect_ptr = reinterpret_cast<uint8_t**>(lhs_qai8_indirect.data());
551
4/4
✓ Branch 0 taken 21087 times.
✓ Branch 1 taken 927 times.
✓ Branch 2 taken 21087 times.
✓ Branch 3 taken 927 times.
22014 for (size_t m_i = 0; m_i < shape.m; ++m_i) {
552
2/2
✓ Branch 0 taken 271911 times.
✓ Branch 1 taken 21087 times.
292998 for (size_t k_chunk_idx = 0; k_chunk_idx < k_chunk_count; ++k_chunk_idx) {
553 271911 const size_t idx = m_i * k_chunk_count + k_chunk_idx;
554
4/4
✓ Branch 0 taken 262143 times.
✓ Branch 1 taken 9768 times.
✓ Branch 2 taken 250089 times.
✓ Branch 3 taken 12054 times.
271911 if (pad_testing and m_i == 0) {
555 // Push padding pointers for first row
556
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 12054 times.
12054 lhs_qai8_indirect_ptr[idx] = reinterpret_cast<uint8_t*>(lhs_padding.data());
557 12054 } else {
558 779571 uintptr_t offset = m_i * shape.k + k_chunk_idx * k_chunk_len;
559 259857 lhs_qai8_indirect_ptr[idx] = reinterpret_cast<uint8_t*>(offset);
560 259857 }
561 271911 }
562 21087 }
563
2/4
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
1854 const auto indirection_base = reinterpret_cast<uintptr_t>(lhs_qai8.data());
564
565 // Reorder indirection pointers to layout the packing micro-kernel expects
566
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 Buffer lhs_qai8_indirect_packed = reorder_block<const void*>(
567
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 reinterpret_cast<const void*>(lhs_qai8_indirect.data()), shape.m, k_chunk_count, pack_shape.m, 1);
568
569 // Transpose, then quantize symmetrically, then transpose back. This will give one
570 // quantization value for each column
571
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
927 const auto rhs_f32_t = transpose<float>(rhs_f32.data(), shape.k, shape.n);
572 6489 auto [rhs_qsi8_t, rhs_scales] =
573
4/8
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 927 times.
✗ Branch 7 not taken.
927 quantize_symmetric_per_block_dynamic<float, int8_t, float>(rhs_f32_t.data(), shape.n, shape.k, shape.k);
574
4/8
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 927 times.
✗ Branch 7 not taken.
1854 auto rhs_qsi8 = transpose<int8_t>(rhs_qsi8_t.data(), shape.n, shape.k);
575
576 // Multiply all bias values with the LHS scale
577
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
1854 const auto bias_scales = mul<float>(&lhs_scale, 1, 1, rhs_scales.data(), 1, shape.n);
578 // Calculate quantized bias values, by treating bias as column, and
579 // scale using RHS scales. This will scale each bias value indiviually
580 927 auto bias_qsi32 =
581
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
927 quantize_symmetric_per_block<float, int32_t, float>(bias_f32.data(), bias_scales.data(), shape.n, 1, 1);
582
583 // Runs the reference implementation of matmul to produce floating-point result.
584
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 const void* const* lhs_iptr = reinterpret_cast<const void* const*>(lhs_qai8_indirect.data());
585 927 const auto ref_dst_f32 =
586
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 indirect_matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>(
587 2781 shape.m, shape.n, k_chunk_count, k_chunk_len, // matmul shape
588
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 lhs_iptr, indirection_base, lhs_padding.data(), // LHS indirection, offset and padding
589 &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point
590 1854 shape.m, shape.k, // LHS quantization window shape
591
2/4
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
927 rhs_qsi8_t.data(), rhs_scales.data(), nullptr, // RHS scaling factors
592 927 1, shape.k, // RHS quantization window shape
593
2/4
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
927 bias_qsi32.data(), bias_scales.data(), nullptr, // Bias, scaling and zero points
594 1 // Bias quantization window shape
595 );
596
597 // Computes the output quantization information and clamping limits.
598 //
599 // To get a realistic value for the output quantization information and clamping limits
600 // and avoid uncontrolled saturation problem, these information will be calculated
601 // based on the reference floating-point output.
602 //
603 // The clamping limits will be slightly narrower than the actual range of the output
604 // so that a portion of the output will be clampped.
605 3708 const auto [dst_scales, dst_zero_points] =
606
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 compute_asymmetric_per_block_quantization_info<float, int8_t, float, int32_t>(
607
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 ref_dst_f32.data(), 1, shape.m * shape.n, shape.m * shape.n);
608
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
1854 const auto dst_scale = read_array<float>(dst_scales.data(), 0);
609
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
1854 const auto dst_zero_point = read_array<int32_t>(dst_zero_points.data(), 0);
610
611
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
927 const auto ref_dst_f32_min = reduce_min<float>(ref_dst_f32.data(), shape.m * shape.n);
612
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
927 const auto ref_dst_f32_max = reduce_max<float>(ref_dst_f32.data(), shape.m * shape.n);
613 927 const auto ref_dst_f32_range = ref_dst_f32_max - ref_dst_f32_min;
614
615 1854 const auto ref_dst_f32_clamp_min = ref_dst_f32_min + ref_dst_f32_range * clamp_ratio / 2;
616 1854 const auto ref_dst_f32_clamp_max = ref_dst_f32_max - ref_dst_f32_range * clamp_ratio / 2;
617 927 const auto dst_qai8_clamp_min =
618
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 quantize_asymmetric<float, int8_t, int32_t>(ref_dst_f32_clamp_min, dst_scale, dst_zero_point);
619 927 const auto dst_qai8_clamp_max =
620
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 quantize_asymmetric<float, int8_t, int32_t>(ref_dst_f32_clamp_max, dst_scale, dst_zero_point);
621
622 // Clamps and quantizes the reference output matrix.
623 927 const auto ref_dst_f32_clamped =
624
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
927 clamp<float>(ref_dst_f32.data(), shape.m * shape.n, ref_dst_f32_clamp_min, ref_dst_f32_clamp_max);
625
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 auto ref_dst_qsi8_clamped = quantize_asymmetric_per_block<float, int8_t, float, int32_t>(
626
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, // values, scales, zero point
627 1854 1, shape.m * shape.n, // data shape
628 1854 shape.m * shape.n // quantization window width
629 );
630
631 // Runs the reference implementation of the packing micro-kernels.
632 //
633 // The reference packing micro-kernels cannot be executed earlier
634 // because we need the reference floating-point output first to have
635 // the quantization information.
636
6/12
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 927 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 927 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 927 times.
✗ Branch 11 not taken.
1854 auto packed_lhs = reorder_block<int8_t>(lhs_qai8.data(), shape.m, shape.k, pack_shape.m, pack_shape.k);
637
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
1854 auto packed_rhs = matmul_pack_rhs_nxk_static_quantized<int8_t, float, int32_t>(
638
3/6
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
927 rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k,
639 1854 pack_shape.n, pack_shape.k);
640
641
1/2
✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
927 TestReference& reference = g_data[test_data_id];
642 927 reference.clamp.min = dst_qai8_clamp_min;
643 927 reference.clamp.max = dst_qai8_clamp_max;
644 927 reference.qa_lhs.scale = lhs_scale;
645 927 reference.qa_lhs.zero_point = lhs_zero_point;
646 927 reference.qa_dst.scale = dst_scale;
647 927 reference.qa_dst.zero_point = dst_zero_point;
648 927 reference.lhs_qai8 = std::move(lhs_qai8);
649 927 reference.lhs_qai8_scales = std::move(lhs_qai8_scales);
650 927 reference.lhs_qai8_zero_points = std::move(lhs_qai8_zero_points);
651 927 reference.lhs_qai8_indirect = std::move(lhs_qai8_indirect);
652 927 reference.lhs_qai8_indirect_packed = std::move(lhs_qai8_indirect_packed);
653 927 reference.lhs_qai8_indirect_padding = std::move(lhs_padding);
654 927 reference.lhs_qai8_indirect_offset = indirection_base;
655 927 reference.rhs_qsi8 = std::move(rhs_qsi8);
656 927 reference.rhs_scales = std::move(rhs_scales);
657 927 reference.bias_qsi32 = std::move(bias_qsi32);
658 927 reference.dst_qsi8_clamped = std::move(ref_dst_qsi8_clamped);
659 927 reference.packed_lhs = std::move(packed_lhs);
660 927 reference.packed_rhs = std::move(packed_rhs);
661
662 927 return reference;
663 10158 }
664
665 /// Test LHS packing
666 666 void test_lhs_pack(
667 const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) {
668 KAI_ASSUME(variant.lhs_pack.has_value());
669
670 1332 const auto imp_packed_lhs_size =
671 666 variant.lhs_pack->get_packed_lhs_size(shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k, 1);
672
2/12
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 666 times.
666 ASSERT_EQ(imp_packed_lhs_size, reference.packed_lhs.size());
673
674 666 Buffer imp_packed_lhs(imp_packed_lhs_size, 0);
675
2/4
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 666 times.
✗ Branch 3 not taken.
666 const auto imp_lhs_offset = variant.lhs_pack->get_lhs_offset(output_area.start_row(), shape.k * sizeof(int8_t));
676
1/2
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
666 const auto imp_packed_lhs_offset = variant.lhs_pack->get_packed_lhs_offset(
677
1/2
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
666 output_area.start_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1);
678
679
1/2
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
666 variant.lhs_pack->pack(
680
1/2
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
666 output_area.height(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1, 0,
681 666 reference.lhs_qai8.data() + imp_lhs_offset, shape.k * sizeof(int8_t),
682 666 imp_packed_lhs.data() + imp_packed_lhs_offset);
683
684
3/4
✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 132 times.
✓ Branch 3 taken 534 times.
798 const auto imp_packed_lhs_end_offset = output_area.end_row() < shape.m
685
1/2
✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
132 ? variant.lhs_pack->get_packed_lhs_offset(
686
1/2
✓ Branch 0 taken 132 times.
✗ Branch 1 not taken.
132 output_area.end_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1)
687 534 : imp_packed_lhs_size;
688
689 666 const auto* imp_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs.data());
690 666 const auto* ref_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(reference.packed_lhs.data());
691
692
4/6
✓ Branch 0 taken 680346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 679680 times.
✓ Branch 3 taken 666 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 666 times.
680346 for (size_t i = 0; i < reference.packed_lhs.size(); ++i) {
693
4/4
✓ Branch 0 taken 651264 times.
✓ Branch 1 taken 28416 times.
✓ Branch 2 taken 120576 times.
✓ Branch 3 taken 530688 times.
679680 if (i >= imp_packed_lhs_offset && i < imp_packed_lhs_end_offset) {
694
3/14
✗ Branch 0 not taken.
✓ Branch 1 taken 530688 times.
✓ Branch 2 taken 530688 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 530688 times.
530688 ASSERT_EQ(imp_packed_lhs_ptr[i], ref_packed_lhs_ptr[i]);
695 530688 } else {
696
3/14
✗ Branch 0 not taken.
✓ Branch 1 taken 148992 times.
✓ Branch 2 taken 148992 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 148992 times.
148992 ASSERT_EQ(imp_packed_lhs_ptr[i], 0);
697 }
698 679680 }
699 666 }
700
701 /// Test RHS packing
702 834 void test_rhs_pack(
703 const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) {
704 834 const auto imp_packed_rhs_size = variant.rhs_pack.get_packed_rhs_size(shape.n, shape.k);
705
2/12
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 834 times.
834 ASSERT_EQ(imp_packed_rhs_size, reference.packed_rhs.size());
706 834 Buffer imp_packed_rhs(imp_packed_rhs_size, 0);
707
708
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const auto imp_rhs_offset = variant.rhs_pack.get_rhs_offset(output_area.start_col());
709
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const auto imp_bias_offset = variant.rhs_pack.get_bias_offset(output_area.start_col());
710
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const auto imp_scale_offset = variant.rhs_pack.get_scale_offset(output_area.start_col());
711
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const auto imp_packed_rhs_offset = variant.rhs_pack.get_packed_rhs_offset(output_area.start_col(), shape.k);
712
713 834 kai_rhs_pack_qsi8cx_params imp_pack_rhs_params{};
714 834 imp_pack_rhs_params.lhs_zero_point = reference.qa_lhs.zero_point;
715 834 imp_pack_rhs_params.scale_multiplier = reference.qa_lhs.scale / reference.qa_dst.scale;
716
717
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
834 variant.rhs_pack.pack(
718
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
834 1, output_area.width(), shape.k, variant.acc_pack.n, variant.acc_pack.k, 1, shape.n * sizeof(int8_t),
719 834 reference.rhs_qsi8.data() + imp_rhs_offset, reference.bias_qsi32.data() + imp_bias_offset,
720 834 reference.rhs_scales.data() + imp_scale_offset, imp_packed_rhs.data() + imp_packed_rhs_offset, 0,
721 &imp_pack_rhs_params);
722
723
3/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 174 times.
✓ Branch 3 taken 660 times.
1008 const auto imp_packed_rhs_end_offset = output_area.end_col() < shape.n
724
2/4
✓ Branch 0 taken 174 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 174 times.
✗ Branch 3 not taken.
174 ? variant.rhs_pack.get_packed_rhs_offset(output_area.end_col(), shape.k)
725 660 : imp_packed_rhs_size;
726
727 834 size_t mismatches = 0;
728 834 const auto* imp_packed_rhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_rhs.data());
729 834 const auto* ref_packed_rhs_ptr = reinterpret_cast<const uint8_t*>(reference.packed_rhs.data());
730
731
2/2
✓ Branch 0 taken 3380736 times.
✓ Branch 1 taken 834 times.
3381570 for (size_t i = 0; i < reference.packed_rhs.size(); ++i) {
732
4/4
✓ Branch 0 taken 2832384 times.
✓ Branch 1 taken 548352 times.
✓ Branch 2 taken 713088 times.
✓ Branch 3 taken 2119296 times.
3380736 if (i >= imp_packed_rhs_offset && i < imp_packed_rhs_end_offset) {
733
1/2
✓ Branch 0 taken 2119296 times.
✗ Branch 1 not taken.
2119296 if (imp_packed_rhs_ptr[i] != ref_packed_rhs_ptr[i]) {
734 mismatches += 1;
735 }
736 2119296 } else {
737
1/2
✓ Branch 0 taken 1261440 times.
✗ Branch 1 not taken.
1261440 if (imp_packed_rhs_ptr[i] != 0) {
738 mismatches += 1;
739 }
740 }
741 3380736 }
742
3/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
834 ASSERT_EQ(mismatches, 0) << "There are an unexpected amount of mismatches in RHS packing";
743 834 }
744
745 10158 void compare_matmul_result(
746 const MatMulShape& shape, const Rect& output_area, const Buffer& actual, const Buffer& reference) {
747 10158 size_t mismatches = 0;
748 10158 bool printed_row = false;
749 10158 std::ostringstream sstream;
750
2/2
✓ Branch 0 taken 236958 times.
✓ Branch 1 taken 10158 times.
247116 for (size_t m_i = 0; m_i < shape.m; ++m_i) {
751
2/2
✓ Branch 0 taken 19177308 times.
✓ Branch 1 taken 236958 times.
19414266 for (size_t n_i = 0; n_i < shape.n; ++n_i) {
752 19177308 const auto i = m_i * shape.n + n_i;
753
6/8
✓ Branch 0 taken 19177308 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 17235804 times.
✓ Branch 3 taken 1941504 times.
✓ Branch 4 taken 17235804 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 14725092 times.
✓ Branch 7 taken 2510712 times.
33902400 const auto in_area = m_i >= output_area.start_row() && m_i < output_area.end_row() &&
754
4/6
✓ Branch 0 taken 14725092 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13058724 times.
✓ Branch 3 taken 1666368 times.
✓ Branch 4 taken 13058724 times.
✗ Branch 5 not taken.
14725092 n_i >= output_area.start_col() && n_i < output_area.end_col();
755
756
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 19177308 times.
19177308 const auto imp_value = read_array<int8_t>(actual.data(), i);
757
3/4
✓ Branch 0 taken 11567574 times.
✓ Branch 1 taken 7609734 times.
✓ Branch 2 taken 11567574 times.
✗ Branch 3 not taken.
19177308 const auto ref_value = in_area ? read_array<int8_t>(reference.data(), i) : 0;
758 19177308 const auto error = std::abs(imp_value - ref_value);
759 19177308 const auto threshold = in_area ? 1 : 0;
760 19177308 const bool mismatch = error > threshold;
761
1/2
✓ Branch 0 taken 19177308 times.
✗ Branch 1 not taken.
19177308 if (mismatch) {
762 if (not printed_row) {
763 sstream << " row=" << m_i << ", columns: ";
764 printed_row = true;
765 }
766 sstream << n_i << ", ";
767 }
768 19177308 mismatches += static_cast<size_t>(mismatch);
769 19177308 }
770
1/2
✓ Branch 0 taken 236958 times.
✗ Branch 1 not taken.
236958 if (printed_row) {
771 sstream << "\n";
772 }
773 236958 printed_row = false;
774 236958 }
775
3/20
✓ Branch 0 taken 10158 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 10158 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✓ Branch 19 taken 10158 times.
10158 ASSERT_EQ(mismatches, 0) << "Mismatches between reference result and actual result:\n" << sstream.str();
776 10158 }
777
778 /// Test MatMul of GEMM/GEMV like kernel
779 834 void test_matmul(
780 const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) {
781 834 const auto imp_dst_size = variant.matmul.get_dst_size(shape.m, shape.n);
782
2/12
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 834 times.
834 ASSERT_EQ(imp_dst_size, reference.dst_qsi8_clamped.size());
783
784 834 Buffer imp_dst(imp_dst_size, 0);
785
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
3336 const auto [imp_lhs_offset, lhs_data] = [&]() -> std::tuple<size_t, const Buffer&> {
786
2/2
✓ Branch 0 taken 666 times.
✓ Branch 1 taken 168 times.
834 if (variant.lhs_pack.has_value()) {
787 666 return {variant.matmul.get_packed_lhs_offset(output_area.start_row(), shape.k), reference.packed_lhs};
788 }
789 168 return {output_area.start_row() * shape.k, reference.lhs_qai8};
790 834 }();
791
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 const size_t imp_packed_rhs_offset = variant.matmul.get_packed_rhs_offset(output_area.start_col(), shape.k);
792 1668 const size_t imp_dst_offset =
793
3/6
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
834 variant.matmul.get_dst_offset(output_area.start_row(), output_area.start_col(), shape.n * sizeof(int8_t));
794
5/18
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 834 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 834 times.
834 ASSERT_EQ(imp_dst_offset, output_area.start_row() * shape.n + output_area.start_col());
795
796 834 kai_matmul_requantize32_params imp_main_params{};
797 834 imp_main_params.min_value = reference.clamp.min;
798 834 imp_main_params.max_value = reference.clamp.max;
799 834 imp_main_params.output_zero_point = reference.qa_dst.zero_point;
800
801
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
1668 variant.matmul.matmul(
802
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 output_area.height(), output_area.width(), shape.k, lhs_data.data() + imp_lhs_offset,
803 834 reference.packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t),
804 sizeof(int8_t), &imp_main_params);
805
806
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
834 compare_matmul_result(shape, output_area, imp_dst, reference.dst_qsi8_clamped);
807 834 }
808
809 } // namespace
810
811 using MatMulQuantizedTest = testing::TestWithParam<std::tuple<MatMulVariant, MatMulShape, MatrixPortion, float>>;
812 using IndirectMatMulQuantizedTest =
813 testing::TestWithParam<std::tuple<IndirectMatMulVariant, MatMulShape, MatrixPortion, size_t, float>>;
814
815 834 static std::string test_description(
816 const MatMulVariant& variant, //
817 const MatMulShape& shape, //
818 const MatrixPortion& portion, float clamp_ratio) {
819 834 std::ostringstream sstream;
820
821
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
1668 sstream << test_description(variant.name, shape, portion, true) //
822
2/4
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
834 << "__clamp_ratio_" << static_cast<int>(clamp_ratio * 100);
823
824
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
834 return sstream.str();
825 834 };
826
827 9324 static std::string test_description(
828 const IndirectMatMulVariant& variant, //
829 const MatMulShape& shape, //
830 const MatrixPortion& portion, size_t k_chunk_len, float clamp_ratio) {
831 9324 std::ostringstream sstream;
832
833
8/16
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9324 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9324 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 9324 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9324 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 9324 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 9324 times.
✗ Branch 15 not taken.
9324 sstream << variant.name << "__M_" << shape.m << "__N_" << shape.n << "__k_chunk_count_" << shape.k << "__";
834
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
9324 PrintTo(portion, &sstream);
835
4/8
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9324 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9324 times.
✗ Branch 7 not taken.
9324 sstream << "__k_chunk_len_" << k_chunk_len << "__clamp_ratio_" << static_cast<int>(clamp_ratio * 100);
836
837
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
9324 return sstream.str();
838 9324 };
839
840
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
2504 TEST_P(MatMulQuantizedTest, EndToEnd) {
841 13344 const auto& [variant, shape, output_portion, clamp_ratio] = GetParam();
842
843
1/2
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
834 if (!variant.is_supported()) {
844 GTEST_SKIP() << "Unsupported CPU feature";
845 }
846
847 4170 TestDataId test_data_id{shape, variant.acc_pack, shape.k, false, clamp_ratio};
848 834 const TestReference& reference = get_test_reference(test_data_id);
849
850 // Check scheduling parameters
851 1668 const auto imp_mr = variant.matmul.get_mr();
852 1668 const auto imp_nr = variant.matmul.get_nr();
853 1668 const auto imp_kr = variant.matmul.get_kr();
854 1668 const auto imp_sr = variant.matmul.get_sr();
855
856
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_mr, variant.acc_pack.m);
857
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_nr, variant.acc_pack.n);
858
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_kr, variant.acc_pack.k);
859
3/14
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 834 times.
834 ASSERT_EQ(imp_sr, 1);
860
861 // Check that stepping is a multiple of accumulation
862 1668 const auto imp_m_step = variant.matmul.get_m_step();
863 1668 const auto imp_n_step = variant.matmul.get_n_step();
864
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_m_step, variant.acc_step.m);
865
4/16
✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
1668 ASSERT_EQ(imp_n_step, variant.acc_step.n);
866
867 // Test kernels. Note that packing and actual stepping might not be the same
868 4170 const auto pack_portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_pack.m, variant.acc_pack.n);
869 834 const auto matmul_portion =
870 3336 output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n);
871
2/2
✓ Branch 0 taken 168 times.
✓ Branch 1 taken 666 times.
834 if (variant.lhs_pack.has_value()) {
872 666 test_lhs_pack(shape, variant, pack_portion, reference);
873 666 }
874 834 test_rhs_pack(shape, variant, pack_portion, reference);
875 834 test_matmul(shape, variant, matmul_portion, reference);
876 834 }
877
878 namespace imatmul {
879
880 /// Perform LHS IMATMUL packing
881 9324 static Buffer lhs_pack(
882 const LhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t m,
883 const KChunk& k_chunk) {
884 18648 const void* const* indirection_pointer =
885 9324 reinterpret_cast<const void* const*>(reference.lhs_qai8_indirect_packed.data());
886
887 // Allocate buffer
888 9324 const size_t dst_size = variant.get_packed_lhs_size(m, k_chunk.count, k_chunk.length);
889 9324 Buffer packed(dst_size);
890
891 // Calculate offsets
892
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
9324 const size_t input_offset = portion.start_row() * k_chunk.count;
893
2/4
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
9324 const size_t dst_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length);
894
895
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
9324 variant.pack(
896
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
9324 portion.height(), k_chunk.count, k_chunk.length, // Dimensions
897 9324 indirection_pointer + input_offset, // Indirection input
898 9324 reference.lhs_qai8_indirect_offset, // chunk offset
899 9324 reference.lhs_qai8_indirect_padding.data(), // padding pointer
900 9324 packed.data() + dst_offset);
901
902 9324 return packed;
903 9324 }
904
905 /// Perform RHS IMATMUL packing
906 9324 static Buffer rhs_pack(
907 const RhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t n,
908 const KChunk& k_chunk) {
909 // Allocate output buffer
910 9324 const size_t dst_size = variant.get_packed_rhs_size(n, k_chunk.count, k_chunk.length);
911 9324 Buffer packed(dst_size);
912
913 // Caluclate effective quantization parameters
914 27972 const kai_rhs_pack_qsi8cx_params quantization{
915 9324 reference.qa_lhs.zero_point,
916 9324 reference.qa_lhs.scale / reference.qa_dst.scale,
917 };
918
919 // Calculate offsets
920
2/4
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
9324 const size_t rhs_offset = variant.get_rhs_offset(portion.start_col());
921
2/4
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
9324 const size_t bias_offset = variant.get_bias_offset(portion.start_col());
922
2/4
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
9324 const size_t scale_offset = variant.get_scale_offset(portion.start_col());
923
2/4
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
9324 const size_t dst_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length);
924
925 // Pack
926
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
9324 variant.pack(
927
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
9324 portion.width(), k_chunk.count, k_chunk.length,
928 9324 n * sizeof(uint8_t), // Dimensions, row stride
929 9324 reference.rhs_qsi8.data() + rhs_offset, // RHS matrix
930 9324 reference.bias_qsi32.data() + bias_offset, // Bias
931 9324 reference.rhs_scales.data() + scale_offset, // Scales
932 9324 packed.data() + dst_offset, // Output
933 &quantization);
934
935 9324 return packed;
936 9324 }
937
938 /// Calculate the matmul result from IMATMUL kernels
939 9324 static Buffer matmul(
940 const MatMulIndirectKernel& variant, const Rect& portion, const TestReference& reference, const Buffer& packed_lhs,
941 const Buffer& packed_rhs, const MatMulShape& shape, const KChunk& k_chunk) {
942 // Calculate portion offsets.
943 9324 size_t dst_offset = variant.get_dst_offset(portion.start_row(), portion.start_col(), shape.n);
944 9324 size_t lhs_offset = variant.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length);
945 9324 size_t rhs_offset = variant.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length);
946
947 // Allocate output buffer
948 9324 const size_t dst_size = variant.get_dst_size(shape.m, shape.n);
949 9324 Buffer dst(dst_size, 0);
950
951 // Calculate geffective uantization parameters
952 9324 kai_matmul_requantize32_params requantization{};
953 9324 requantization.min_value = reference.clamp.min;
954 9324 requantization.max_value = reference.clamp.max;
955 9324 requantization.output_zero_point = reference.qa_dst.zero_point;
956
957 // Call matmul kernel
958
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
18648 variant.imatmul(
959
2/4
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
9324 portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions
960 9324 packed_lhs.data() + lhs_offset, // LHS
961 9324 packed_rhs.data() + rhs_offset, // RHS
962 9324 dst.data() + dst_offset, // DST
963 9324 shape.n * sizeof(uint8_t), &requantization);
964
965 9324 return dst;
966 9324 }
967 } // namespace imatmul
968
969
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
27974 TEST_P(IndirectMatMulQuantizedTest, EndToEnd) {
970 /* This is a bit special, as shape.k must be k_chunk_len * k_chunk_count
971 * so instead of inventing a new special kind of shape, simply multiply
972 * with `k_chunk_len` here */
973 55944 const auto& [variant, shape_k_chunk, output_portion, k_chunk_len, clamp_ratio] = GetParam();
974 27972 const KChunk k_chunk{shape_k_chunk.k, k_chunk_len};
975 27972 MatMulShape shape{shape_k_chunk.m, shape_k_chunk.n, k_chunk.count * k_chunk.length};
976
977
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
9324 if (!variant.is_supported()) {
978 GTEST_SKIP() << "Unsupported CPU feature";
979 }
980
981 // Toggle padding testst when LHS has more than one row
982 27972 TestDataId test_data_id{shape, variant.acc_pack, k_chunk.length, shape.m > 1, clamp_ratio};
983 9324 const TestReference& reference = get_test_reference(test_data_id);
984 37296 const Rect portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n);
985
986 18648 Buffer packed_lhs = imatmul::lhs_pack(variant.lhs_pack, portion, reference, shape.m, k_chunk);
987
2/4
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
18648 Buffer packed_rhs = imatmul::rhs_pack(variant.rhs_pack, portion, reference, shape.n, k_chunk);
988
2/4
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9324 times.
✗ Branch 3 not taken.
18648 Buffer impl_result = imatmul::matmul(variant.matmul, portion, reference, packed_lhs, packed_rhs, shape, k_chunk);
989
1/2
✓ Branch 0 taken 9324 times.
✗ Branch 1 not taken.
9324 compare_matmul_result(shape, portion, impl_result, reference.dst_qsi8_clamped);
990 9324 }
991
992 static constexpr std::array shapes{
993 // clang-format off
994 MatMulShape{ 1, 1, 1},
995 MatMulShape{ 1, 16, 4},
996 MatMulShape{ 1, 16, 16},
997 MatMulShape{ 1, 17, 4},
998 MatMulShape{ 1, 19, 24},
999 MatMulShape{ 1, 32, 4},
1000 MatMulShape{ 1, 32, 32},
1001 MatMulShape{ 1, 33,200},
1002 MatMulShape{ 1, 49, 21},
1003 MatMulShape{ 1, 64, 4},
1004 MatMulShape{ 1, 65, 4},
1005 MatMulShape{ 1, 300, 10},
1006 MatMulShape{ 1, 512, 4},
1007 MatMulShape{ 1, 1523, 10},
1008 MatMulShape{ 2, 195, 50},
1009 MatMulShape{ 3, 6, 6},
1010 MatMulShape{ 3, 28, 25},
1011 MatMulShape{ 3, 184,177},
1012 MatMulShape{ 4, 16, 27},
1013 MatMulShape{ 5, 136, 23},
1014 MatMulShape{ 6, 18, 31},
1015 MatMulShape{ 6, 28, 1},
1016 MatMulShape{ 6, 29, 24},
1017 MatMulShape{ 16, 16, 4},
1018 MatMulShape{ 20, 30, 40},
1019 MatMulShape{ 23, 1, 43},
1020 MatMulShape{ 32, 14, 1},
1021 MatMulShape{ 32, 16, 27},
1022 MatMulShape{ 32, 32, 3},
1023 MatMulShape{ 32, 32, 4},
1024 MatMulShape{ 33, 29, 24},
1025 MatMulShape{ 64, 64, 3},
1026 MatMulShape{ 64, 64, 4},
1027 MatMulShape{ 96, 96, 3},
1028 MatMulShape{123, 85, 45},
1029 MatMulShape{128, 128, 3},
1030 MatMulShape{130, 130, 6},
1031 // clang-format on
1032 };
1033
1034
14/44
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 666 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
1334 INSTANTIATE_TEST_SUITE_P(
1035 matmul_clamp_qai8_qai8p_qsi8cxp, MatMulQuantizedTest,
1036 testing::Combine(
1037 testing::ValuesIn(get_gemm_variants()), //
1038 testing::ValuesIn(shapes), //
1039 testing::ValuesIn({
1040 // clang-format off
1041 MatrixPortion( 0, 0, 1, 1), // Full matrix.
1042 MatrixPortion( 0, 0, 0.25, 0.25), // Top-left corner.
1043 MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner.
1044 // clang-format on
1045 }),
1046 testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})),
1047 [](const auto& info) -> std::string {
1048 return test_description(
1049 std::get<MatMulVariant>(info.param), //
1050 std::get<MatMulShape>(info.param), //
1051 std::get<MatrixPortion>(info.param), //
1052 std::get<float>(info.param));
1053 });
1054
1055
15/48
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 168 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
338 INSTANTIATE_TEST_SUITE_P(
1056 matmul_clamp_qai8_qai8_qsi8cxp, MatMulQuantizedTest,
1057 testing::Combine(
1058 testing::ValuesIn(get_gemv_variants()),
1059 testing::ValuesIn({
1060 // clang-format off
1061 MatMulShape{ 1, 1, 1},
1062 MatMulShape{ 1, 16, 4},
1063 MatMulShape{ 1, 16, 16},
1064 MatMulShape{ 1, 17, 4},
1065 MatMulShape{ 1, 19, 24},
1066 MatMulShape{ 1, 32, 4},
1067 MatMulShape{ 1, 32, 32},
1068 MatMulShape{ 1, 33,200},
1069 MatMulShape{ 1, 49, 21},
1070 MatMulShape{ 1, 64, 4},
1071 MatMulShape{ 1, 65, 4},
1072 MatMulShape{ 1, 300, 10},
1073 MatMulShape{ 1, 512, 4},
1074 MatMulShape{ 1, 1523, 10},
1075 // clang-format on
1076 }),
1077 testing::ValuesIn({
1078 // clang-format off
1079 MatrixPortion(0, 0, 1, 1), // Full matrix.
1080 MatrixPortion(0, .5, 1, .5), // Right half
1081 MatrixPortion(0, 0, 1, .5), // Left half
1082 MatrixPortion(0, .25, 1, .5) // Middle half
1083 // clang-format on
1084 }),
1085 // Clamp range
1086 testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})),
1087 [](const auto& info) -> std::string {
1088 return test_description(
1089 std::get<MatMulVariant>(info.param), //
1090 std::get<MatMulShape>(info.param), //
1091 std::get<MatrixPortion>(info.param), //
1092 std::get<float>(info.param));
1093 });
1094
1095
18/60
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 time.
✗ Branch 33 not taken.
✓ Branch 34 taken 9324 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
18650 INSTANTIATE_TEST_SUITE_P(
1096 indirect_matmul_clamp_qai8_qai8p_qsi8cxp, IndirectMatMulQuantizedTest,
1097 testing::Combine(
1098 testing::ValuesIn(get_indirect_gemm_variants()), //
1099 testing::ValuesIn(shapes), //
1100 testing::ValuesIn({
1101 // clang-format off
1102 // (Start row , start col , height , width)
1103 MatrixPortion( 0 , 0 , 1 , 1) , // Full matrix.
1104 MatrixPortion( 0 , 0 , 1 , 0.5) , // Left half
1105 MatrixPortion( 0 , 0 , 0.5 , 1) , // Upper half
1106 MatrixPortion( 0 , 0.5 , 1 , 0.5) , // Right half
1107 MatrixPortion( 0.5 , 0 , 0.5 , 1) , // Bottom half
1108 MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3) , // Center ninth
1109 // clang-format on
1110 }),
1111 // k_chunk_len
1112 testing::ValuesIn(std::initializer_list<size_t>{1, 2, 3, 4, 8, 11, 32}),
1113 // Clamp range
1114 testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})),
1115 [](const auto& info) -> std::string {
1116 return test_description(
1117 std::get<IndirectMatMulVariant>(info.param), //
1118 std::get<MatMulShape>(info.param), //
1119 std::get<MatrixPortion>(info.param), //
1120 std::get<size_t>(info.param), //
1121 std::get<float>(info.param));
1122 });
1123 } // namespace kai::test
1124