KleidiAI Coverage Report


Directory: ./
File: test/reference/matmul.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 87.9% 261 26 323
Functions: 80.0% 16 0 20
Branches: 41.6% 268 136 781

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/matmul.hpp"
8
9 #include <algorithm>
10 #include <cstddef>
11 #include <cstdint>
12
13 #include "kai/kai_common.h"
14 #include "test/common/buffer.hpp"
15 #include "test/common/data_format.hpp"
16 #include "test/common/data_type.hpp"
17 #include "test/common/float16.hpp"
18 #include "test/common/int4.hpp"
19 #include "test/common/memory.hpp"
20 #include "test/common/round.hpp"
21 #include "test/reference/binary_elementwise.hpp"
22 #include "test/reference/cast.hpp"
23 #include "test/reference/pack.hpp"
24 #include "test/reference/reduce.hpp"
25 #include "test/reference/transpose.hpp"
26
27 namespace kai::test {
28
29 namespace {
30
31 /// Matrix multiplication.
32 ///
33 /// @tparam T Data type.
34 ///
35 /// @param[in] lhs LHS operand data buffer.
36 /// @param[in] rhs RHS operand data buffer.
37 /// @param[in] m Output height.
38 /// @param[in] n Output width.
39 /// @param[in] k Non-transposed LHS width and non-transposed RHS height.
40 /// @param[in] lhs_transposed `true` if LHS operand is transposed.
41 /// @param[in] rhs_transposed `true` if RHS operand is transposed.
42 ///
43 /// @return The result data buffer.
44 template <typename T>
45 2387 Buffer matmul_any_type(
46 const void* lhs, const void* rhs, //
47 size_t m, size_t n, size_t k, //
48 bool lhs_transposed, bool rhs_transposed) {
49
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1196 times.
2387 const auto lhs_m_stride = lhs_transposed ? 1 : k;
50
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1196 times.
2387 const auto lhs_k_stride = lhs_transposed ? m : 1;
51
52
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1196 times.
2387 const auto rhs_n_stride = rhs_transposed ? k : 1;
53
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1196 times.
2387 const auto rhs_k_stride = rhs_transposed ? 1 : n;
54
55 2387 Buffer dst(m * n * size_in_bits<T> / 8);
56 KAI_ASSUME(n * size_in_bits<T> % 8 == 0);
57
58
4/4
✓ Branch 0 taken 36473 times.
✓ Branch 1 taken 1191 times.
✓ Branch 2 taken 37072 times.
✓ Branch 3 taken 1196 times.
75932 for (size_t im = 0; im < m; ++im) {
59
4/4
✓ Branch 0 taken 36473 times.
✓ Branch 1 taken 2892060 times.
✓ Branch 2 taken 37072 times.
✓ Branch 3 taken 2916561 times.
5882166 for (size_t in = 0; in < n; ++in) {
60
1/2
✓ Branch 0 taken 2916561 times.
✗ Branch 1 not taken.
5808621 T acc{0};
61
62
4/4
✓ Branch 0 taken 599017351 times.
✓ Branch 1 taken 2892060 times.
✓ Branch 2 taken 600720321 times.
✓ Branch 3 taken 2916561 times.
1205546293 for (size_t ik = 0; ik < k; ++ik) {
63
2/4
✓ Branch 0 taken 599017351 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 600720321 times.
1199737672 const auto lhs_value = read_array<T>(lhs, im * lhs_m_stride + ik * lhs_k_stride);
64
2/4
✓ Branch 0 taken 599017351 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600720321 times.
✗ Branch 3 not taken.
1199737672 const auto rhs_value = read_array<T>(rhs, in * rhs_n_stride + ik * rhs_k_stride);
65
2/4
✓ Branch 0 taken 600720321 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600720321 times.
✗ Branch 3 not taken.
1199737672 acc += lhs_value * rhs_value;
66 1199737672 }
67
68
2/4
✓ Branch 0 taken 2892060 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2916561 times.
✗ Branch 3 not taken.
5808621 write_array<T>(dst.data(), im * n + in, acc);
69 5808621 }
70 73545 }
71
72 2387 return dst;
73 2387 }
74
75 } // namespace
76
77 125 Buffer matmul_pack_rhs(
78 const void* data, const void* scales, const void* zero_points, const DataFormat& src_format,
79 const DataFormat& dst_format, size_t n, size_t k, bool transposing) {
80 125 const auto src_dt = src_format.data_type();
81 125 const auto src_pf = src_format.pack_format();
82
83 125 const auto dst_dt = dst_format.data_type();
84 125 const auto dst_pf = dst_format.pack_format();
85
86 125 Buffer tmp_data;
87 125 Buffer tmp_scales;
88 125 Buffer tmp_zero_points;
89
90
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 125 times.
125 if (transposing) {
91
1/2
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
125 tmp_data = transpose(data, src_dt, k, n);
92
1/2
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
125 data = tmp_data.data();
93 125 }
94
95
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 125 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
125 if (src_dt == DataType::QSU4 && src_pf == DataFormat::PackFormat::NONE && //
96 dst_dt == DataType::QSI4 && dst_pf == DataFormat::PackFormat::QUANTIZE_PER_ROW) {
97 // For this specific RHS format conversion:
98 //
99 // * 4-bit data is added by 8.
100 // * Scale is divided by 16.
101 // * Zero point is accumulation of all values in the same row.
102
103 KAI_ASSUME(zero_points == nullptr);
104 const int32_t zero_point = 8;
105 const uint8_t zero_point_i4 = UInt4::pack_u8(UInt4(zero_point), UInt4(zero_point));
106 const int32_t row_zero_point = zero_point * static_cast<int32_t>(k);
107
108 KAI_ASSUME(dst_format.subblock_width() > 0);
109 const auto subblock_width_i32 = static_cast<int32_t>(dst_format.subblock_width());
110 const auto subblock_width_f = static_cast<float>(dst_format.subblock_width());
111
112 tmp_zero_points = reduce_add(data, src_format, n, k, DataFormat(DataType::I32), 0);
113 tmp_zero_points = sub(tmp_zero_points.data(), DataType::I32, n, 1, &row_zero_point, DataType::I32, 1, 1);
114 tmp_zero_points = mul(tmp_zero_points.data(), DataType::I32, n, 1, &subblock_width_i32, DataType::I32, 1, 1);
115 zero_points = tmp_zero_points.data();
116
117 tmp_data = add(data, DataType::QSU4, n, k, &zero_point_i4, DataType::QSU4, 1, 1);
118 data = tmp_data.data();
119
120 tmp_scales = div(scales, DataType::FP32, n, 1, &subblock_width_f, DataType::FP32, 1, 1);
121 scales = tmp_scales.data();
122 }
123
124
1/2
✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
125 return pack(dst_format, data, scales, zero_points, src_format, n, k);
125 125 }
126
127 2387 Buffer matmul(
128 const void* lhs, [[maybe_unused]] const void* lhs_scales, [[maybe_unused]] const void* lhs_zero_points,
129 DataType lhs_dt, //
130 const void* rhs, [[maybe_unused]] const void* rhs_scales, [[maybe_unused]] const void* rhs_zero_points,
131 DataType rhs_dt, //
132 const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, //
133 DataType dst_dt, //
134 size_t m, size_t n, size_t k, //
135 bool lhs_transposed, bool rhs_transposed) {
136
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2387 times.
2387 const auto lhs_h = lhs_transposed ? k : m;
137
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2387 times.
2387 const auto lhs_w = lhs_transposed ? m : k;
138
139
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2387 times.
2387 const auto rhs_h = rhs_transposed ? n : k;
140
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2387 times.
2387 const auto rhs_w = rhs_transposed ? k : n;
141
142 2387 Buffer tmp_lhs;
143 2387 Buffer tmp_rhs;
144 2387 Buffer tmp_dst;
145 2387 Buffer tmp_bias;
146
147
1/2
✓ Branch 0 taken 2387 times.
✗ Branch 1 not taken.
2387 if (lhs_dt != dst_dt) {
148 tmp_lhs = cast(lhs, lhs_dt, dst_dt, lhs_h, lhs_w);
149 lhs = tmp_lhs.data();
150 }
151
152
1/2
✓ Branch 0 taken 2387 times.
✗ Branch 1 not taken.
2387 if (rhs_dt != dst_dt) {
153 tmp_rhs = cast(rhs, rhs_dt, dst_dt, rhs_h, rhs_w);
154 rhs = tmp_rhs.data();
155 }
156
157
2/3
✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✓ Branch 2 taken 1196 times.
2387 switch (dst_dt) {
158 case DataType::FP32:
159
1/2
✓ Branch 0 taken 1191 times.
✗ Branch 1 not taken.
1191 tmp_dst = matmul_any_type<float>(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed);
160 1191 break;
161
162 case DataType::FP16:
163
1/2
✓ Branch 0 taken 1196 times.
✗ Branch 1 not taken.
1196 tmp_dst = matmul_any_type<Float16>(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed);
164 1196 break;
165
166 default:
167 KAI_ERROR("Unknown data type!");
168 }
169
170
2/2
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 2378 times.
2387 if (bias != nullptr) {
171
1/2
✓ Branch 0 taken 2378 times.
✗ Branch 1 not taken.
2378 if (bias_dt != dst_dt) {
172 tmp_bias = cast(bias, bias_dt, dst_dt, 1, n);
173 bias = tmp_bias.data();
174 }
175
176 KAI_ASSUME(!data_type_is_quantized(bias_dt));
177 KAI_ASSUME(bias_scales == nullptr);
178 KAI_ASSUME(bias_zero_points == nullptr);
179
180
2/4
✓ Branch 0 taken 2378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2378 times.
✗ Branch 3 not taken.
2378 tmp_dst = add(tmp_dst.data(), dst_dt, m, n, bias, bias_dt, 1, n);
181 2378 }
182
183 2387 return tmp_dst;
184 2387 }
185
186 2244 Buffer indirect_matmul(
187 const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales,
188 const void* lhs_zero_points,
189 DataType lhs_dt, //
190 const void* rhs, const void* rhs_scales, const void* rhs_zero_points,
191 DataType rhs_dt, //
192 const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, //
193 DataType dst_dt, //
194 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length) {
195 // This is inefficient, but allows code-reuse
196 2244 const size_t chunk_bytes = k_chunk_length * round_up_division(data_type_size_in_bits(lhs_dt), 8);
197 2244 const size_t n_chunks = m * k_chunk_count;
198 2244 Buffer lhs(n_chunks * chunk_bytes);
199
200 // Copy all chunks to the created matrix
201
2/2
✓ Branch 0 taken 858660 times.
✓ Branch 1 taken 2244 times.
860904 for (size_t i = 0; i < n_chunks; i += 1) {
202 858660 const uint8_t* src_pointer = static_cast<const uint8_t*>(lhs_idata[i]);
203
2/2
✓ Branch 0 taken 24156 times.
✓ Branch 1 taken 834504 times.
858660 if (src_pointer != lhs_padding_ptr) {
204 834504 src_pointer += lhs_offset;
205 834504 }
206
1/2
✓ Branch 0 taken 858660 times.
✗ Branch 1 not taken.
858660 memcpy(lhs.data() + i * chunk_bytes, src_pointer, chunk_bytes);
207 858660 }
208
209
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
2244 return matmul(
210
1/2
✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
2244 lhs.data(), lhs_scales, lhs_zero_points, lhs_dt, //
211 2244 rhs, rhs_scales, rhs_zero_points, rhs_dt, //
212 2244 bias, bias_scales, bias_zero_points, bias_dt, //
213 2244 dst_dt, m, n, k_chunk_count * k_chunk_length, false, false);
214 2244 }
215
216 template <
217 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
218 typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData>
219 927 Buffer indirect_matmul_nt_t_quantized(
220 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, //
221 const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales,
222 const void* lhs_zero_points, size_t lhs_quant_height,
223 size_t lhs_quant_width, //
224 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
225 size_t rhs_quant_width, //
226 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width) {
227 KAI_ASSUME(lhs_quant_width != 0);
228 KAI_ASSUME(rhs_quant_width != 0);
229 KAI_ASSUME(lhs_quant_height != 0);
230 KAI_ASSUME(rhs_quant_height != 0);
231 KAI_ASSUME(bias_quant_width != 0);
232 927 const auto lhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, lhs_quant_width);
233 927 const auto rhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, rhs_quant_width);
234
235 927 Buffer dst(m * n * sizeof(DstData));
236
237
2/2
✓ Branch 0 taken 21087 times.
✓ Branch 1 taken 927 times.
22014 for (size_t i_m = 0; i_m < m; ++i_m) {
238
2/2
✓ Branch 0 taken 21087 times.
✓ Branch 1 taken 1709826 times.
1730913 for (size_t i_n = 0; i_n < n; ++i_n) {
239 1709826 DstData acc = 0;
240
241
2/2
✓ Branch 0 taken 19596618 times.
✓ Branch 1 taken 1709826 times.
21306444 for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; ++i_k_chunk) {
242 // Calculate the K chunk pointer. Apply offset if this is not padding
243 19596618 const size_t k_chunk_idx = i_m * k_chunk_count + i_k_chunk;
244 19596618 const void* k_chunk_ptr = lhs_ptrs[k_chunk_idx];
245
2/2
✓ Branch 0 taken 1182174 times.
✓ Branch 1 taken 18414444 times.
19596618 if (k_chunk_ptr != lhs_padding_ptr) {
246 18414444 k_chunk_ptr = reinterpret_cast<const void*>(reinterpret_cast<uintptr_t>(k_chunk_ptr) + lhs_offset);
247 18414444 }
248
249
2/2
✓ Branch 0 taken 171706017 times.
✓ Branch 1 taken 19596618 times.
191302635 for (size_t i_k_chunk_len = 0; i_k_chunk_len < k_chunk_length; ++i_k_chunk_len) {
250 171706017 const size_t i = i_k_chunk * k_chunk_length + i_k_chunk_len;
251
252 171706017 const auto lhs_data_index = i_k_chunk_len;
253 171706017 const auto lhs_quant_index = (i_m / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width;
254
1/2
✓ Branch 0 taken 171706017 times.
✗ Branch 1 not taken.
171706017 const auto lhs_value = read_array<LhsData>(k_chunk_ptr, lhs_data_index);
255
2/4
✓ Branch 0 taken 171706017 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 171706017 times.
171706017 const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index)
256 : static_cast<LhsScale>(1);
257
1/2
✓ Branch 0 taken 171706017 times.
✗ Branch 1 not taken.
343412034 const auto lhs_zero_point = lhs_zero_points != nullptr
258
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 171706017 times.
171706017 ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index)
259 : static_cast<LhsZeroPoint>(0);
260
261 171706017 const auto rhs_data_index = i_n * (k_chunk_count * k_chunk_length) + i;
262 171706017 const auto rhs_quant_index = (i_n / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width;
263
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 171706017 times.
171706017 const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index);
264
2/4
✓ Branch 0 taken 171706017 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 171706017 times.
171706017 const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index)
265 : static_cast<RhsScale>(1);
266
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 171706017 times.
171706017 const auto rhs_zero_point = rhs_zero_points != nullptr
267 ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index)
268 : static_cast<RhsZeroPoint>(0);
269
270 515118051 acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) *
271 343412034 static_cast<DstData>(lhs_scale) *
272 171706017 (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) *
273 171706017 static_cast<DstData>(rhs_scale);
274 171706017 }
275 19596618 }
276
277
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1709826 times.
1709826 if (bias_data != nullptr) {
278
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1709826 times.
1709826 const auto bias_value = read_array<BiasData>(bias_data, i_n);
279
1/2
✓ Branch 0 taken 1709826 times.
✗ Branch 1 not taken.
3419652 const auto bias_scale = bias_scales != nullptr
280
1/2
✓ Branch 0 taken 1709826 times.
✗ Branch 1 not taken.
1709826 ? read_array<BiasScale>(bias_scales, i_n / bias_quant_width)
281 : static_cast<BiasScale>(1);
282
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1709826 times.
1709826 const auto bias_zero_point = bias_zero_points != nullptr
283 ? read_array<BiasZeroPoint>(bias_zero_points, i_n / bias_quant_width)
284 : static_cast<BiasZeroPoint>(0);
285
286 3419652 acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) *
287 1709826 static_cast<DstData>(bias_scale);
288 1709826 }
289
290
2/4
✓ Branch 0 taken 1709826 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1709826 times.
✗ Branch 3 not taken.
1709826 write_array<DstData>(dst.data(), i_m * n + i_n, acc);
291 1709826 }
292 21087 }
293
294 927 return dst;
295 927 }
296
297 template <
298 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
299 typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData>
300 7426 Buffer matmul_nt_t_quantized(
301 size_t m, size_t n, size_t k, //
302 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, //
303 size_t lhs_quant_height, size_t lhs_quant_width, //
304 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, //
305 size_t rhs_quant_height, size_t rhs_quant_width, //
306 const void* bias_data, const void* bias_scales, const void* bias_zero_points, //
307 size_t bias_quant_width) {
308 KAI_ASSUME(lhs_quant_width != 0);
309 KAI_ASSUME(rhs_quant_width != 0);
310 KAI_ASSUME(lhs_quant_height != 0);
311 KAI_ASSUME(rhs_quant_height != 0);
312 KAI_ASSUME(bias_quant_width != 0);
313
314 7426 const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width);
315 7426 const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width);
316
317 7426 Buffer dst(m * n * sizeof(DstData));
318
319
6/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 6164 times.
✓ Branch 3 taken 72384 times.
✓ Branch 4 taken 560 times.
✓ Branch 5 taken 6524 times.
✓ Branch 6 taken 702 times.
✓ Branch 7 taken 16510 times.
102844 for (size_t row = 0; row < m; ++row) {
320
6/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 5416304 times.
✓ Branch 3 taken 72384 times.
✓ Branch 4 taken 6524 times.
✓ Branch 5 taken 516260 times.
✓ Branch 6 taken 16510 times.
✓ Branch 7 taken 2832622 times.
8860604 for (size_t col = 0; col < n; ++col) {
321 8765186 DstData acc = 0;
322
323
6/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 317394496 times.
✓ Branch 3 taken 5416304 times.
✓ Branch 4 taken 24804860 times.
✓ Branch 5 taken 516260 times.
✓ Branch 6 taken 696329088 times.
✓ Branch 7 taken 2832622 times.
1047293630 for (size_t i = 0; i < k; ++i) {
324 1038528444 const auto lhs_data_index = row * k + i;
325 1038528444 const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width;
326
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 317394496 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24804860 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 696329088 times.
✗ Branch 7 not taken.
1038528444 const auto lhs_value = read_array<LhsData>(lhs_data, lhs_data_index);
327
6/16
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 317394496 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 317394496 times.
✓ Branch 8 taken 24804860 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 24804860 times.
✓ Branch 12 taken 696329088 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 696329088 times.
1038528444 const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index)
328 : static_cast<LhsScale>(1);
329
4/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 64700992 times.
✓ Branch 3 taken 252693504 times.
✓ Branch 4 taken 24804860 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 696329088 times.
✗ Branch 7 not taken.
1824363384 const auto lhs_zero_point = lhs_zero_points != nullptr
330
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 64700992 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 24804860 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 696329088 times.
785834940 ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index)
331 : static_cast<LhsZeroPoint>(0);
332
333 1038528444 const auto rhs_data_index = col * k + i;
334 1038528444 const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width;
335
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 317394496 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 24804860 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 696329088 times.
1038528444 const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index);
336
6/16
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 317394496 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 317394496 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 24804860 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 24804860 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 696329088 times.
✓ Branch 14 taken 696329088 times.
✗ Branch 15 not taken.
1038528444 const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index)
337 : static_cast<RhsScale>(1);
338
4/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 252693504 times.
✓ Branch 3 taken 64700992 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 24804860 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 696329088 times.
1291221948 const auto rhs_zero_point = rhs_zero_points != nullptr
339
1/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 252693504 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
252693504 ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index)
340 : static_cast<RhsZeroPoint>(0);
341
342 2798190836 acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) *
343 1063333304 static_cast<DstData>(lhs_scale) *
344
2/4
✓ Branch 0 taken 317394496 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 696329088 times.
✗ Branch 3 not taken.
1038528444 (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) *
345
1/2
✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
1038528444 static_cast<DstData>(rhs_scale);
346 1038528444 }
347
348
5/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 2708152 times.
✓ Branch 3 taken 2708152 times.
✓ Branch 4 taken 258130 times.
✓ Branch 5 taken 258130 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2832622 times.
8765186 if (bias_data != nullptr) {
349
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2708152 times.
✓ Branch 4 taken 258130 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2832622 times.
5798904 const auto bias_value = read_array<BiasData>(bias_data, col);
350
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2708152 times.
✓ Branch 4 taken 258130 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2832622 times.
5798904 const auto bias_scale = bias_scales != nullptr
351 ? read_array<BiasScale>(bias_scales, col / bias_quant_width)
352 : static_cast<BiasScale>(1);
353
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2708152 times.
✓ Branch 4 taken 258130 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2832622 times.
5798904 const auto bias_zero_point = bias_zero_points != nullptr
354 ? read_array<BiasZeroPoint>(bias_zero_points, col / bias_quant_width)
355 : static_cast<BiasZeroPoint>(0);
356
357 11597808 acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) *
358 5798904 static_cast<DstData>(bias_scale);
359 5798904 }
360
361
6/16
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 5416304 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 5416304 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 516260 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 516260 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2832622 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2832622 times.
✗ Branch 15 not taken.
8765186 write_array<DstData>(dst.data(), row * n + col, acc);
362 8765186 }
363 95418 }
364
365 7426 return dst;
366 7426 }
367
368 template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>(
369 size_t m, size_t n, size_t k, //
370 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
371 size_t lhs_quant_width, //
372 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
373 size_t rhs_quant_width, //
374 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
375
376 template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>(
377 size_t m, size_t n, size_t k, //
378 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
379 size_t lhs_quant_width, //
380 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
381 size_t rhs_quant_width, //
382 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
383
384 template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, float, float, int32_t, float>(
385 size_t m, size_t n, size_t k, //
386 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
387 size_t lhs_quant_width, //
388 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
389 size_t rhs_quant_width, //
390 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
391
392 template Buffer
393 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
394 size_t m, size_t n, size_t k, //
395 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
396 size_t lhs_quant_width, //
397 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
398 size_t rhs_quant_width, //
399 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
400
401 template <
402 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
403 typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData>
404 776 Buffer matmul_nt_nt_quantized(
405 size_t m, size_t n, size_t k, //
406 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, //
407 size_t lhs_quant_height, size_t lhs_quant_width, //
408 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, //
409 size_t rhs_quant_height, size_t rhs_quant_width, //
410 const void* bias_data, const void* bias_scales, const void* bias_zero_points, //
411 size_t bias_quant_width) {
412 KAI_ASSUME(lhs_quant_width != 0);
413 KAI_ASSUME(rhs_quant_width != 0);
414 KAI_ASSUME(lhs_quant_height != 0);
415 KAI_ASSUME(rhs_quant_height != 0);
416 KAI_ASSUME(bias_quant_width != 0);
417
418 776 const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width);
419 776 const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width);
420
421 776 Buffer dst(m * n * sizeof(DstData));
422
423
4/4
✓ Branch 0 taken 702 times.
✓ Branch 1 taken 16510 times.
✓ Branch 2 taken 74 times.
✓ Branch 3 taken 8279 times.
25565 for (size_t row = 0; row < m; ++row) {
424
4/4
✓ Branch 0 taken 16510 times.
✓ Branch 1 taken 2832622 times.
✓ Branch 2 taken 524771 times.
✓ Branch 3 taken 8279 times.
3382182 for (size_t col = 0; col < n; ++col) {
425 3357393 DstData acc = 0;
426
427
4/4
✓ Branch 0 taken 696329088 times.
✓ Branch 1 taken 2832622 times.
✓ Branch 2 taken 143126255 times.
✓ Branch 3 taken 524771 times.
842812736 for (size_t i = 0; i < k; ++i) {
428 839455343 const auto lhs_data_index = row * k + i;
429 839455343 const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width;
430
2/4
✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 143126255 times.
839455343 const auto lhs_value = read_array<LhsData>(lhs_data, lhs_data_index);
431
3/8
✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 696329088 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 143126255 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
839455343 const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index)
432 : static_cast<LhsScale>(1);
433
2/4
✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 143126255 times.
1535784431 const auto lhs_zero_point = lhs_zero_points != nullptr
434
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 696329088 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
696329088 ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index)
435 : static_cast<LhsZeroPoint>(0);
436
437 839455343 const auto rhs_data_index = col + i * n;
438 839455343 const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width;
439
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 696329088 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 143126255 times.
839455343 const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index);
440
3/8
✗ Branch 0 not taken.
✓ Branch 1 taken 696329088 times.
✓ Branch 2 taken 696329088 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 143126255 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
839455343 const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index)
441 : static_cast<RhsScale>(1);
442
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 696329088 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 143126255 times.
839455343 const auto rhs_zero_point = rhs_zero_points != nullptr
443 ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index)
444 : static_cast<RhsZeroPoint>(0);
445
446
1/2
✓ Branch 0 taken 143126255 times.
✗ Branch 1 not taken.
2232113519 acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) *
447 839455343 static_cast<DstData>(lhs_scale) *
448
2/4
✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 143126255 times.
✗ Branch 3 not taken.
839455343 (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) *
449
1/2
✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
839455343 static_cast<DstData>(rhs_scale);
450 839455343 }
451
452
3/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2832622 times.
✓ Branch 2 taken 210653 times.
✓ Branch 3 taken 314118 times.
3357393 if (bias_data != nullptr) {
453
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2832622 times.
✓ Branch 2 taken 314118 times.
✗ Branch 3 not taken.
3146740 const auto bias_value = read_array<BiasData>(bias_data, col);
454
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2832622 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 314118 times.
3146740 const auto bias_scale = bias_scales != nullptr
455 ? read_array<BiasScale>(bias_scales, col / bias_quant_width)
456 : static_cast<BiasScale>(1);
457
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2832622 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 314118 times.
3146740 const auto bias_zero_point = bias_zero_points != nullptr
458 ? read_array<BiasZeroPoint>(bias_zero_points, col / bias_quant_width)
459 : static_cast<BiasZeroPoint>(0);
460
461 6293480 acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) *
462 3146740 static_cast<DstData>(bias_scale);
463 3146740 }
464
465
4/8
✓ Branch 0 taken 2832622 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2832622 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 524771 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 524771 times.
✗ Branch 7 not taken.
3357393 write_array<DstData>(dst.data(), row * n + col, acc);
466 3357393 }
467 24789 }
468
469 776 return dst;
470 776 }
471
472 template Buffer
473 matmul_nt_nt_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
474 size_t m, size_t n, size_t k, //
475 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
476 size_t lhs_quant_width, //
477 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
478 size_t rhs_quant_width, //
479 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
480
481 template Buffer matmul_nt_nt_quantized<BFloat16<>, float, float, BFloat16<>, float, float, float, float, float, float>(
482 size_t m, size_t n, size_t k, //
483 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
484 size_t lhs_quant_width, //
485 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
486 size_t rhs_quant_width, //
487 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
488
489 template Buffer
490 indirect_matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>(
491 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, //
492 const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales,
493 const void* lhs_zero_points, size_t lhs_quant_height,
494 size_t lhs_quant_width, //
495 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
496 size_t rhs_quant_width, //
497 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
498
499 template <
500 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
501 typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData>
502 1620 Buffer matmul_clamp_nt_t(
503 size_t m, size_t n, size_t k, //
504 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
505 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
506 const void* biases, //
507 DstData min_value, DstData max_value) {
508 KAI_ASSUME(lhs_quant_width != 0);
509 KAI_ASSUME(rhs_quant_width != 0);
510 1620 const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width);
511 1620 const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width);
512
513 1620 Buffer dst(m * n * sizeof(DstData));
514
515 1620 const auto* lhs_scales_ptr = reinterpret_cast<const LhsScale*>(lhs_scales);
516 1620 const auto* rhs_scales_ptr = reinterpret_cast<const RhsScale*>(rhs_scales);
517 1620 const auto* lhs_zero_points_ptr = reinterpret_cast<const LhsZeroPoint*>(lhs_zero_points);
518 1620 const auto* rhs_zero_points_ptr = reinterpret_cast<const RhsZeroPoint*>(rhs_zero_points);
519 1620 const auto* biases_ptr = reinterpret_cast<const Bias*>(biases);
520
3/8
✓ Branch 0 taken 880 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 616 times.
✗ Branch 7 not taken.
1620 auto* dst_ptr = reinterpret_cast<DstData*>(dst.data());
521
522
6/8
✓ Branch 0 taken 880 times.
✓ Branch 1 taken 23056 times.
✓ Branch 2 taken 2984 times.
✓ Branch 3 taken 124 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 616 times.
✓ Branch 7 taken 6104 times.
33764 for (size_t y = 0; y < m; ++y) {
523
6/8
✓ Branch 0 taken 23056 times.
✓ Branch 1 taken 3154800 times.
✓ Branch 2 taken 2984 times.
✓ Branch 3 taken 218428 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 6104 times.
✓ Branch 7 taken 405720 times.
3811092 for (size_t x = 0; x < n; ++x) {
524 3778948 DstData acc = 0;
525
526
6/8
✓ Branch 0 taken 158710112 times.
✓ Branch 1 taken 3154800 times.
✓ Branch 2 taken 13670656 times.
✓ Branch 3 taken 218428 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 13742344 times.
✓ Branch 7 taken 405720 times.
189902060 for (size_t i = 0; i < k; ++i) {
527
3/8
✓ Branch 0 taken 158710112 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13670656 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 13742344 times.
✗ Branch 7 not taken.
186123112 const auto lhs_value = read_array<LhsData>(lhs_data, y * k + i);
528 186123112 const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width];
529
3/8
✓ Branch 0 taken 158710112 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 13670656 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 13742344 times.
✗ Branch 7 not taken.
186123112 const auto lhs_zero_point = lhs_zero_points_ptr != nullptr
530 172452456 ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]
531 : 0;
532
533
3/8
✗ Branch 0 not taken.
✓ Branch 1 taken 158710112 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 13670656 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 13742344 times.
✗ Branch 7 not taken.
186123112 const auto rhs_value = read_array<RhsData>(rhs_data, x * k + i);
534 186123112 const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width];
535
3/8
✗ Branch 0 not taken.
✓ Branch 1 taken 158710112 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 13670656 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 13742344 times.
186123112 const auto rhs_zero_point = rhs_zero_points_ptr != nullptr
536 ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width]
537 : 0;
538
539 186123112 acc += static_cast<DstData>(
540 199865456 (static_cast<IntAcc>(lhs_value) - static_cast<IntAcc>(lhs_zero_point)) *
541
2/6
✗ Branch 0 not taken.
✓ Branch 1 taken 158710112 times.
✓ Branch 2 taken 13670656 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
199865456 (static_cast<IntAcc>(rhs_value) - static_cast<IntAcc>(rhs_zero_point))) *
542
2/8
✓ Branch 0 taken 13670656 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13670656 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
358575568 static_cast<DstData>(lhs_scale) * static_cast<DstData>(rhs_scale);
543 186123112 }
544
545
3/8
✗ Branch 0 not taken.
✓ Branch 1 taken 3154800 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 218428 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 405720 times.
✗ Branch 7 not taken.
3778948 if (biases_ptr != nullptr) {
546 3560520 acc += static_cast<DstData>(biases_ptr[x]);
547 3560520 }
548
549
3/8
✓ Branch 0 taken 3154800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 218428 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 405720 times.
✗ Branch 7 not taken.
3778948 acc = std::clamp(acc, min_value, max_value);
550 3778948 dst_ptr[y * n + x] = acc;
551 3778948 }
552 32144 }
553
554 1620 return dst;
555 1620 }
556
557 template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
558 size_t m, size_t n, size_t k, //
559 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
560 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
561 const void* biases, //
562 float min_value, float max_value);
563
564 template Buffer matmul_clamp_nt_t<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>(
565 size_t m, size_t n, size_t k, //
566 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
567 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
568 const void* biases, //
569 float min_value, float max_value);
570
571 template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, int32_t, float>(
572 size_t m, size_t n, size_t k, //
573 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
574 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
575 const void* biases, //
576 float min_value, float max_value);
577
578 template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>(
579 size_t m, size_t n, size_t k, //
580 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
581 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
582 const void* biases, //
583 float min_value, float max_value);
584
585 template <
586 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
587 typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData>
588 1728 Buffer matmul_clamp_nt_nt(
589 size_t m, size_t n, size_t k, //
590 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
591 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
592 const void* biases, //
593 DstData min_value, DstData max_value) {
594 KAI_ASSUME(lhs_quant_width != 0);
595 KAI_ASSUME(rhs_quant_width != 0);
596 1728 const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width);
597 1728 const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width);
598
599 1728 Buffer dst(m * n * sizeof(DstData));
600
601 1728 const auto* lhs_scales_ptr = reinterpret_cast<const LhsScale*>(lhs_scales);
602 1728 const auto* rhs_scales_ptr = reinterpret_cast<const RhsScale*>(rhs_scales);
603 1728 const auto* lhs_zero_points_ptr = reinterpret_cast<const LhsZeroPoint*>(lhs_zero_points);
604 1728 const auto* rhs_zero_points_ptr = reinterpret_cast<const RhsZeroPoint*>(rhs_zero_points);
605 1728 const auto* biases_ptr = reinterpret_cast<const Bias*>(biases);
606
2/8
✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1112 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
1728 auto* dst_ptr = reinterpret_cast<DstData*>(dst.data());
607
608
4/8
✓ Branch 0 taken 616 times.
✓ Branch 1 taken 6104 times.
✓ Branch 2 taken 1112 times.
✓ Branch 3 taken 27768 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
35600 for (size_t y = 0; y < m; ++y) {
609
4/8
✓ Branch 0 taken 6104 times.
✓ Branch 1 taken 405720 times.
✓ Branch 2 taken 27768 times.
✓ Branch 3 taken 3315612 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
3755204 for (size_t x = 0; x < n; ++x) {
610 3721332 DstData acc = 0;
611
612
4/8
✓ Branch 0 taken 13742344 times.
✓ Branch 1 taken 405720 times.
✓ Branch 2 taken 169751908 times.
✓ Branch 3 taken 3315612 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
187215584 for (size_t i = 0; i < k; ++i) {
613
2/8
✓ Branch 0 taken 13742344 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 169751908 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
183494252 const auto lhs_value = read_array<LhsData>(lhs_data, y * k + i);
614 183494252 const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width];
615
2/8
✓ Branch 0 taken 13742344 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 169751908 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
183494252 const auto lhs_zero_point = lhs_zero_points_ptr != nullptr
616 183494252 ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]
617 : 0;
618
619
2/8
✓ Branch 0 taken 13742344 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 169751908 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
183494252 const auto rhs_value = read_array<RhsData>(rhs_data, x + i * n);
620 183494252 const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width];
621
2/8
✗ Branch 0 not taken.
✓ Branch 1 taken 13742344 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 169751908 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
183494252 const auto rhs_zero_point = rhs_zero_points_ptr != nullptr
622 ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width]
623 : 0;
624
625 183494252 acc += static_cast<DstData>(
626 197236596 (static_cast<IntAcc>(lhs_value) - static_cast<IntAcc>(lhs_zero_point)) *
627
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 169751908 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
197236596 (static_cast<IntAcc>(rhs_value) - static_cast<IntAcc>(rhs_zero_point))) *
628
0/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
366988504 static_cast<DstData>(lhs_scale) * static_cast<DstData>(rhs_scale);
629 183494252 }
630
631
3/8
✓ Branch 0 taken 405720 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 367206 times.
✓ Branch 3 taken 2948406 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
3721332 if (biases_ptr != nullptr) {
632 3354126 acc += static_cast<DstData>(biases_ptr[x]);
633 3354126 }
634
635
2/8
✓ Branch 0 taken 405720 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3315612 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
3721332 acc = std::clamp(acc, min_value, max_value);
636 3721332 dst_ptr[y * n + x] = acc;
637 3721332 }
638 33872 }
639
640 1728 return dst;
641 1728 }
642
643 template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>(
644 size_t m, size_t n, size_t k, //
645 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
646 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
647 const void* biases, //
648 float min_value, float max_value);
649 template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
650 size_t m, size_t n, size_t k, //
651 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
652 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
653 const void* biases, //
654 float min_value, float max_value);
655
656 template Buffer matmul_clamp_nt_nt<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>(
657 size_t m, size_t n, size_t k, //
658 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
659 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
660 const void* biases, //
661 float min_value, float max_value);
662
663 template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, int32_t, float>(
664 size_t m, size_t n, size_t k, //
665 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
666 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
667 const void* biases, //
668 float min_value, float max_value);
669
670 } // namespace kai::test
671