KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 89.2% 273 / 26 / 332
Functions: 80.0% 16 / 0 / 20
Branches: 42.5% 279 / 120 / 777

test/reference/matmul.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/matmul.hpp"
8
9 #include <algorithm>
10 #include <cstddef>
11 #include <cstdint>
12 #include <type_traits>
13
14 #include "kai/kai_common.h"
15 #include "test/common/buffer.hpp"
16 #include "test/common/data_format.hpp"
17 #include "test/common/data_type.hpp"
18 #include "test/common/float16.hpp"
19 #include "test/common/int4.hpp"
20 #include "test/common/memory.hpp"
21 #include "test/common/round.hpp"
22 #include "test/reference/binary_elementwise.hpp"
23 #include "test/reference/cast.hpp"
24 #include "test/reference/pack.hpp"
25 #include "test/reference/reduce.hpp"
26 #include "test/reference/transpose.hpp"
27
28 namespace kai::test {
29
30 namespace {
31
32 /// Matrix multiplication.
33 ///
34 /// @tparam T Data type.
35 ///
36 /// @param[in] lhs LHS operand data buffer.
37 /// @param[in] rhs RHS operand data buffer.
38 /// @param[in] m Output height.
39 /// @param[in] n Output width.
40 /// @param[in] k Non-transposed LHS width and non-transposed RHS height.
41 /// @param[in] lhs_transposed `true` if LHS operand is transposed.
42 /// @param[in] rhs_transposed `true` if RHS operand is transposed.
43 ///
44 /// @return The result data buffer.
45 template <typename In, typename Acc>
46 1655 Buffer matmul_any_type(
47 const void* lhs, const void* rhs, //
48 size_t m, size_t n, size_t k, //
49 bool lhs_transposed, bool rhs_transposed) {
50
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 844 times.
1655 const auto lhs_m_stride = lhs_transposed ? 1 : k;
51
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 844 times.
1655 const auto lhs_k_stride = lhs_transposed ? m : 1;
52
53
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 844 times.
1655 const auto rhs_n_stride = rhs_transposed ? k : 1;
54
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 844 times.
1655 const auto rhs_k_stride = rhs_transposed ? 1 : n;
55
56 1655 Buffer dst(m * n * size_in_bits<In> / 8);
57 KAI_ASSUME_ALWAYS(n * size_in_bits<In> % 8 == 0);
58
59
4/4
✓ Branch 0 taken 20034 times.
✓ Branch 1 taken 811 times.
✓ Branch 2 taken 22833 times.
✓ Branch 3 taken 844 times.
44522 for (size_t im = 0; im < m; ++im) {
60
4/4
✓ Branch 0 taken 20034 times.
✓ Branch 1 taken 1524926 times.
✓ Branch 2 taken 22833 times.
✓ Branch 3 taken 1626893 times.
3194686 for (size_t in = 0; in < n; ++in) {
61 3151819 Acc acc = Acc(0);
62
63
4/4
✓ Branch 0 taken 221798379 times.
✓ Branch 1 taken 1524926 times.
✓ Branch 2 taken 229649811 times.
✓ Branch 3 taken 1626893 times.
454600009 for (size_t ik = 0; ik < k; ++ik) {
64
2/4
✓ Branch 0 taken 221798379 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 229649811 times.
451448190 const auto lhs_value = read_array<In>(lhs, im * lhs_m_stride + ik * lhs_k_stride);
65
2/4
✓ Branch 0 taken 221798379 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 229649811 times.
✗ Branch 3 not taken.
451448190 const auto rhs_value = read_array<In>(rhs, in * rhs_n_stride + ik * rhs_k_stride);
66
2/4
✓ Branch 0 taken 229649811 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 229649811 times.
✗ Branch 3 not taken.
451448190 acc += static_cast<Acc>(lhs_value) * static_cast<Acc>(rhs_value);
67 451448190 }
68
69
3/6
✓ Branch 0 taken 1524926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1626893 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1626893 times.
✗ Branch 5 not taken.
3151819 write_array<In>(dst.data(), im * n + in, static_cast<In>(acc));
70 3151819 }
71 42867 }
72
73 1655 return dst;
74 1655 }
75
76 } // namespace
77
78 663 Buffer matmul_pack_rhs(
79 const void* data, const void* scales, const void* zero_points, const DataFormat& src_format,
80 const DataFormat& dst_format, size_t n, size_t k, bool transposing) {
81 663 const auto src_dt = src_format.data_type();
82 663 const auto src_pf = src_format.pack_format();
83
84 663 const auto dst_dt = dst_format.data_type();
85 663 const auto dst_pf = dst_format.pack_format();
86
87 663 Buffer tmp_data;
88 663 Buffer tmp_scales;
89 663 Buffer tmp_zero_points;
90
91
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 663 times.
663 if (transposing) {
92
1/2
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
663 tmp_data = transpose(data, src_dt, k, n);
93
1/2
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
663 data = tmp_data.data();
94 663 }
95
96
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 663 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
663 if (src_dt == DataType::QSU4 && src_pf == DataFormat::PackFormat::NONE && //
97 dst_dt == DataType::QSI4 && dst_pf == DataFormat::PackFormat::QUANTIZE_PER_ROW) {
98 // For this specific RHS format conversion:
99 //
100 // * 4-bit data is added by 8.
101 // * Scale is divided by 16.
102 // * Zero point is accumulation of all values in the same row.
103
104 KAI_ASSUME_ALWAYS(zero_points == nullptr);
105 const int32_t zero_point = 8;
106 const uint8_t zero_point_i4 = UInt4::pack_u8(UInt4(zero_point), UInt4(zero_point));
107 const int32_t row_zero_point = zero_point * static_cast<int32_t>(k);
108
109 KAI_ASSUME_ALWAYS(dst_format.subblock_width() > 0);
110 const auto subblock_width_i32 = static_cast<int32_t>(dst_format.subblock_width());
111 const auto subblock_width_f = static_cast<float>(dst_format.subblock_width());
112
113 tmp_zero_points = reduce_add(data, src_format, n, k, DataFormat(DataType::I32), 0);
114 tmp_zero_points = sub(tmp_zero_points.data(), DataType::I32, n, 1, &row_zero_point, DataType::I32, 1, 1);
115 tmp_zero_points = mul(tmp_zero_points.data(), DataType::I32, n, 1, &subblock_width_i32, DataType::I32, 1, 1);
116 zero_points = tmp_zero_points.data();
117
118 tmp_data = add(data, DataType::QSU4, n, k, &zero_point_i4, DataType::QSU4, 1, 1);
119 data = tmp_data.data();
120
121 tmp_scales = div(scales, DataType::FP32, n, 1, &subblock_width_f, DataType::FP32, 1, 1);
122 scales = tmp_scales.data();
123 }
124
125
1/2
✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
663 return pack(dst_format, data, scales, zero_points, src_format, n, k);
126 663 }
127
128 1655 Buffer matmul(
129 const void* lhs, [[maybe_unused]] const void* lhs_scales, [[maybe_unused]] const void* lhs_zero_points,
130 DataType lhs_dt, //
131 const void* rhs, [[maybe_unused]] const void* rhs_scales, [[maybe_unused]] const void* rhs_zero_points,
132 DataType rhs_dt, //
133 const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, //
134 DataType dst_dt, //
135 size_t m, size_t n, size_t k, //
136 bool lhs_transposed, bool rhs_transposed) {
137
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1655 times.
1655 const auto lhs_h = lhs_transposed ? k : m;
138
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1655 times.
1655 const auto lhs_w = lhs_transposed ? m : k;
139
140
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1655 times.
1655 const auto rhs_h = rhs_transposed ? n : k;
141
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1655 times.
1655 const auto rhs_w = rhs_transposed ? k : n;
142
143 1655 Buffer tmp_lhs;
144 1655 Buffer tmp_rhs;
145 1655 Buffer tmp_dst;
146 1655 Buffer tmp_bias;
147
148
1/2
✓ Branch 0 taken 1655 times.
✗ Branch 1 not taken.
1655 if (lhs_dt != dst_dt) {
149 tmp_lhs = cast(lhs, lhs_dt, dst_dt, lhs_h, lhs_w);
150 lhs = tmp_lhs.data();
151 }
152
153
1/2
✓ Branch 0 taken 1655 times.
✗ Branch 1 not taken.
1655 if (rhs_dt != dst_dt) {
154 tmp_rhs = cast(rhs, rhs_dt, dst_dt, rhs_h, rhs_w);
155 rhs = tmp_rhs.data();
156 }
157
158
2/3
✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✓ Branch 2 taken 844 times.
1655 switch (dst_dt) {
159 case DataType::FP32:
160
1/2
✓ Branch 0 taken 811 times.
✗ Branch 1 not taken.
811 tmp_dst = matmul_any_type<float, float>(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed);
161 811 break;
162
163 case DataType::FP16:
164
1/2
✓ Branch 0 taken 844 times.
✗ Branch 1 not taken.
844 tmp_dst = matmul_any_type<Float16, float>(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed);
165 844 break;
166
167 default:
168 KAI_ERROR("Unknown data type!");
169 }
170
171
2/2
✓ Branch 0 taken 54 times.
✓ Branch 1 taken 1601 times.
1655 if (bias != nullptr) {
172 KAI_ASSUME_ALWAYS(!data_type_is_quantized(bias_dt));
173 KAI_ASSUME_ALWAYS(bias_scales == nullptr);
174 KAI_ASSUME_ALWAYS(bias_zero_points == nullptr);
175
176 // Add bias in f32 to reduce precision loss.
177
2/2
✓ Branch 0 taken 811 times.
✓ Branch 1 taken 790 times.
1601 if (dst_dt != DataType::FP32) {
178
2/4
✓ Branch 0 taken 790 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 790 times.
✗ Branch 3 not taken.
790 tmp_dst = cast(tmp_dst.data(), dst_dt, DataType::FP32, m, n);
179 790 }
180
2/2
✓ Branch 0 taken 811 times.
✓ Branch 1 taken 790 times.
1601 if (bias_dt != DataType::FP32) {
181
1/2
✓ Branch 0 taken 790 times.
✗ Branch 1 not taken.
790 tmp_bias = cast(bias, bias_dt, DataType::FP32, 1, n);
182
1/2
✓ Branch 0 taken 790 times.
✗ Branch 1 not taken.
790 bias = tmp_bias.data();
183 790 }
184
2/4
✓ Branch 0 taken 1601 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1601 times.
✗ Branch 3 not taken.
1601 tmp_dst = add(tmp_dst.data(), DataType::FP32, m, n, bias, DataType::FP32, 1, n);
185
2/2
✓ Branch 0 taken 811 times.
✓ Branch 1 taken 790 times.
1601 if (dst_dt != DataType::FP32) {
186
2/4
✓ Branch 0 taken 790 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 790 times.
✗ Branch 3 not taken.
790 tmp_dst = cast(tmp_dst.data(), DataType::FP32, dst_dt, m, n);
187 790 }
188 1601 }
189
190 1655 return tmp_dst;
191 1655 }
192
193 884 Buffer indirect_matmul(
194 const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales,
195 const void* lhs_zero_points,
196 DataType lhs_dt, //
197 const void* rhs, const void* rhs_scales, const void* rhs_zero_points,
198 DataType rhs_dt, //
199 const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, //
200 DataType dst_dt, //
201 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length) {
202 // This is inefficient, but allows code-reuse
203 884 const size_t chunk_bytes = k_chunk_length * round_up_division(data_type_size_in_bits(lhs_dt), 8);
204 884 const size_t n_chunks = m * k_chunk_count;
205 884 Buffer lhs(n_chunks * chunk_bytes);
206
207 884 const uintptr_t lhs_padding_ptr_uint = reinterpret_cast<uintptr_t>(lhs_padding_ptr);
208
209 // Copy all chunks to the created matrix
210
2/2
✓ Branch 0 taken 333242 times.
✓ Branch 1 taken 884 times.
334126 for (size_t i = 0; i < n_chunks; i += 1) {
211 333242 uintptr_t src_pointer = reinterpret_cast<uintptr_t>(lhs_idata[i]);
212
2/2
✓ Branch 0 taken 9516 times.
✓ Branch 1 taken 323726 times.
333242 if (src_pointer != lhs_padding_ptr_uint) {
213 323726 src_pointer += lhs_offset;
214 323726 }
215 333242 memcpy(
216
1/2
✓ Branch 0 taken 333242 times.
✗ Branch 1 not taken.
333242 lhs.data() + i * chunk_bytes, reinterpret_cast<const void*>(src_pointer),
217 333242 chunk_bytes); // NOLINT(performance-no-int-to-ptr)
218 333242 }
219
220
1/2
✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
884 return matmul(
221
1/2
✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
884 lhs.data(), lhs_scales, lhs_zero_points, lhs_dt, //
222 884 rhs, rhs_scales, rhs_zero_points, rhs_dt, //
223 884 bias, bias_scales, bias_zero_points, bias_dt, //
224 884 dst_dt, m, n, k_chunk_count * k_chunk_length, false, false);
225 884 }
226
227 template <
228 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
229 typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData>
230 521 Buffer indirect_matmul_nt_t_quantized(
231 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, //
232 const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales,
233 const void* lhs_zero_points, size_t lhs_quant_height,
234 size_t lhs_quant_width, //
235 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
236 size_t rhs_quant_width, //
237 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width) {
238 KAI_ASSUME_ALWAYS(lhs_quant_width != 0);
239 KAI_ASSUME_ALWAYS(rhs_quant_width != 0);
240 KAI_ASSUME_ALWAYS(lhs_quant_height != 0);
241 KAI_ASSUME_ALWAYS(rhs_quant_height != 0);
242 KAI_ASSUME_ALWAYS(bias_quant_width != 0);
243 521 const auto lhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, lhs_quant_width);
244 521 const auto rhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, rhs_quant_width);
245
246 521 Buffer dst(m * n * sizeof(DstData));
247
248
2/2
✓ Branch 0 taken 11441 times.
✓ Branch 1 taken 521 times.
11962 for (size_t i_m = 0; i_m < m; ++i_m) {
249
2/2
✓ Branch 0 taken 11441 times.
✓ Branch 1 taken 929839 times.
941280 for (size_t i_n = 0; i_n < n; ++i_n) {
250 929839 DstData acc = 0;
251
252
2/2
✓ Branch 0 taken 9447359 times.
✓ Branch 1 taken 929839 times.
10377198 for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; ++i_k_chunk) {
253 // Calculate the K chunk pointer. Apply offset if this is not padding
254 9447359 const size_t k_chunk_idx = i_m * k_chunk_count + i_k_chunk;
255 9447359 const void* k_chunk_ptr = lhs_ptrs[k_chunk_idx];
256
2/2
✓ Branch 0 taken 562940 times.
✓ Branch 1 taken 8884419 times.
9447359 if (k_chunk_ptr != lhs_padding_ptr) {
257 8884419 k_chunk_ptr = reinterpret_cast<const void*>(reinterpret_cast<uintptr_t>(k_chunk_ptr) + lhs_offset);
258 8884419 }
259
260
2/2
✓ Branch 0 taken 61909478 times.
✓ Branch 1 taken 9447359 times.
71356837 for (size_t i_k_chunk_len = 0; i_k_chunk_len < k_chunk_length; ++i_k_chunk_len) {
261 61909478 const size_t i = i_k_chunk * k_chunk_length + i_k_chunk_len;
262
263 61909478 const auto lhs_data_index = i_k_chunk_len;
264 61909478 const auto lhs_quant_index = (i_m / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width;
265
1/2
✓ Branch 0 taken 61909478 times.
✗ Branch 1 not taken.
61909478 const auto lhs_value = read_array<LhsData>(k_chunk_ptr, lhs_data_index);
266
2/4
✓ Branch 0 taken 61909478 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 61909478 times.
61909478 const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index)
267 : static_cast<LhsScale>(1);
268
1/2
✓ Branch 0 taken 61909478 times.
✗ Branch 1 not taken.
123818956 const auto lhs_zero_point = lhs_zero_points != nullptr
269
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 61909478 times.
61909478 ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index)
270 : static_cast<LhsZeroPoint>(0);
271
272 61909478 const auto rhs_data_index = i_n * (k_chunk_count * k_chunk_length) + i;
273 61909478 const auto rhs_quant_index = (i_n / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width;
274
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 61909478 times.
61909478 const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index);
275
2/4
✓ Branch 0 taken 61909478 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 61909478 times.
61909478 const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index)
276 : static_cast<RhsScale>(1);
277
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 61909478 times.
61909478 const auto rhs_zero_point = rhs_zero_points != nullptr
278 ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index)
279 : static_cast<RhsZeroPoint>(0);
280
281 185728434 acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) *
282 123818956 static_cast<DstData>(lhs_scale) *
283 61909478 (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) *
284 61909478 static_cast<DstData>(rhs_scale);
285 61909478 }
286 9447359 }
287
288
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 929839 times.
929839 if (bias_data != nullptr) {
289
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 929839 times.
929839 const auto bias_value = read_array<BiasData>(bias_data, i_n);
290
1/2
✓ Branch 0 taken 929839 times.
✗ Branch 1 not taken.
1859678 const auto bias_scale = bias_scales != nullptr
291
1/2
✓ Branch 0 taken 929839 times.
✗ Branch 1 not taken.
929839 ? read_array<BiasScale>(bias_scales, i_n / bias_quant_width)
292 : static_cast<BiasScale>(1);
293
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 929839 times.
929839 const auto bias_zero_point = bias_zero_points != nullptr
294 ? read_array<BiasZeroPoint>(bias_zero_points, i_n / bias_quant_width)
295 : static_cast<BiasZeroPoint>(0);
296
297 1859678 acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) *
298 929839 static_cast<DstData>(bias_scale);
299 929839 }
300
301
2/4
✓ Branch 0 taken 929839 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 929839 times.
✗ Branch 3 not taken.
929839 write_array<DstData>(dst.data(), i_m * n + i_n, acc);
302 929839 }
303 11441 }
304
305 521 return dst;
306 521 }
307
308 template <
309 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
310 typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData>
311 26832 Buffer matmul_nt_t_quantized(
312 size_t m, size_t n, size_t k, //
313 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, //
314 size_t lhs_quant_height, size_t lhs_quant_width, //
315 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, //
316 size_t rhs_quant_height, size_t rhs_quant_width, //
317 const void* bias_data, const void* bias_scales, const void* bias_zero_points, //
318 size_t bias_quant_width) {
319 KAI_ASSUME_ALWAYS(lhs_quant_width != 0);
320 KAI_ASSUME_ALWAYS(rhs_quant_width != 0);
321 KAI_ASSUME_ALWAYS(lhs_quant_height != 0);
322 KAI_ASSUME_ALWAYS(rhs_quant_height != 0);
323 KAI_ASSUME_ALWAYS(bias_quant_width != 0);
324
325 26832 const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width);
326 26832 const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width);
327
328 26832 Buffer dst(m * n * sizeof(DstData));
329
330
6/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 26064 times.
✓ Branch 3 taken 265116 times.
✓ Branch 4 taken 192 times.
✓ Branch 5 taken 2748 times.
✓ Branch 6 taken 576 times.
✓ Branch 7 taken 51051 times.
345747 for (size_t row = 0; row < m; ++row) {
331
6/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 19402236 times.
✓ Branch 3 taken 265116 times.
✓ Branch 4 taken 2748 times.
✓ Branch 5 taken 219744 times.
✓ Branch 6 taken 51051 times.
✓ Branch 7 taken 6371326 times.
26312221 for (size_t col = 0; col < n; ++col) {
332 25993306 DstData acc = 0;
333
334
6/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 1174598280 times.
✓ Branch 3 taken 19402236 times.
✓ Branch 4 taken 10572468 times.
✓ Branch 5 taken 219744 times.
✓ Branch 6 taken 1984792064 times.
✓ Branch 7 taken 6371326 times.
3195956118 for (size_t i = 0; i < k; ++i) {
335 3169962812 const auto lhs_data_index = row * k + i;
336 3169962812 const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width;
337
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 1174598280 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10572468 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1984792064 times.
✗ Branch 7 not taken.
3169962812 const auto lhs_value = read_array<LhsData>(lhs_data, lhs_data_index);
338
6/16
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1174598280 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1174598280 times.
✓ Branch 8 taken 10572468 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 10572468 times.
✓ Branch 12 taken 1984792064 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1984792064 times.
3169962812 const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index)
339 : static_cast<LhsScale>(1);
340
4/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 27671688 times.
✓ Branch 3 taken 1146926592 times.
✓ Branch 4 taken 10572468 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1984792064 times.
✗ Branch 7 not taken.
5192999032 const auto lhs_zero_point = lhs_zero_points != nullptr
341
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 27671688 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 10572468 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1984792064 times.
2023036220 ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index)
342 : static_cast<LhsZeroPoint>(0);
343
344 3169962812 const auto rhs_data_index = col * k + i;
345 3169962812 const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width;
346
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1174598280 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 10572468 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1984792064 times.
3169962812 const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index);
347
6/16
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1174598280 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1174598280 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 10572468 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 10572468 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1984792064 times.
✓ Branch 14 taken 1984792064 times.
✗ Branch 15 not taken.
3169962812 const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index)
348 : static_cast<RhsScale>(1);
349
4/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 1146926592 times.
✓ Branch 3 taken 27671688 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 10572468 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1984792064 times.
4316889404 const auto rhs_zero_point = rhs_zero_points != nullptr
350
1/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 1146926592 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
1146926592 ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index)
351 : static_cast<RhsZeroPoint>(0);
352
353 3509139990 acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) *
354 3180535280 static_cast<DstData>(lhs_scale) *
355
2/4
✓ Branch 0 taken 1174598280 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1984792064 times.
✗ Branch 3 not taken.
3169962812 (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) *
356
1/2
✓ Branch 0 taken 1984792064 times.
✗ Branch 1 not taken.
3169962812 static_cast<DstData>(rhs_scale);
357 3169962812 }
358
359
5/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 9701118 times.
✓ Branch 3 taken 9701118 times.
✓ Branch 4 taken 109872 times.
✓ Branch 5 taken 109872 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6371326 times.
25993306 if (bias_data != nullptr) {
360
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 9701118 times.
✓ Branch 4 taken 109872 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6371326 times.
16182316 const auto bias_value = read_array<BiasData>(bias_data, col);
361
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 9701118 times.
✓ Branch 4 taken 109872 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6371326 times.
16182316 const auto bias_scale = bias_scales != nullptr
362 ? read_array<BiasScale>(bias_scales, col / bias_quant_width)
363 : static_cast<BiasScale>(1);
364
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 9701118 times.
✓ Branch 4 taken 109872 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6371326 times.
16182316 const auto bias_zero_point = bias_zero_points != nullptr
365 ? read_array<BiasZeroPoint>(bias_zero_points, col / bias_quant_width)
366 : static_cast<BiasZeroPoint>(0);
367
368 32364632 acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) *
369 16182316 static_cast<DstData>(bias_scale);
370 16182316 }
371
372
6/16
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 19402236 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19402236 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 219744 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 219744 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 6371326 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 6371326 times.
✗ Branch 15 not taken.
25993306 write_array<DstData>(dst.data(), row * n + col, acc);
373 25993306 }
374 318915 }
375
376 26832 return dst;
377 26832 }
378
379 template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>(
380 size_t m, size_t n, size_t k, //
381 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
382 size_t lhs_quant_width, //
383 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
384 size_t rhs_quant_width, //
385 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
386
387 template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>(
388 size_t m, size_t n, size_t k, //
389 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
390 size_t lhs_quant_width, //
391 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
392 size_t rhs_quant_width, //
393 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
394
395 template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, float, float, int32_t, float>(
396 size_t m, size_t n, size_t k, //
397 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
398 size_t lhs_quant_width, //
399 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
400 size_t rhs_quant_width, //
401 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
402
403 template Buffer
404 matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
405 size_t m, size_t n, size_t k, //
406 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
407 size_t lhs_quant_width, //
408 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
409 size_t rhs_quant_width, //
410 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
411
412 template <
413 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
414 typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData>
415 972 Buffer matmul_nt_nt_quantized(
416 size_t m, size_t n, size_t k, //
417 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, //
418 size_t lhs_quant_height, size_t lhs_quant_width, //
419 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, //
420 size_t rhs_quant_height, size_t rhs_quant_width, //
421 const void* bias_data, const void* bias_scales, const void* bias_zero_points, //
422 size_t bias_quant_width) {
423 KAI_ASSUME_ALWAYS(lhs_quant_width != 0);
424 KAI_ASSUME_ALWAYS(rhs_quant_width != 0);
425 KAI_ASSUME_ALWAYS(lhs_quant_height != 0);
426 KAI_ASSUME_ALWAYS(rhs_quant_height != 0);
427 KAI_ASSUME_ALWAYS(bias_quant_width != 0);
428
429 972 const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width);
430 972 const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width);
431
432 972 Buffer dst(m * n * sizeof(DstData));
433
434
4/4
✓ Branch 0 taken 558 times.
✓ Branch 1 taken 51033 times.
✓ Branch 2 taken 414 times.
✓ Branch 3 taken 44721 times.
96726 for (size_t row = 0; row < m; ++row) {
435
4/4
✓ Branch 0 taken 51033 times.
✓ Branch 1 taken 6362191 times.
✓ Branch 2 taken 2838231 times.
✓ Branch 3 taken 44721 times.
9296176 for (size_t col = 0; col < n; ++col) {
436 9200422 DstData acc = 0;
437
438
4/4
✓ Branch 0 taken 1981913664 times.
✓ Branch 1 taken 6362191 times.
✓ Branch 2 taken 774427959 times.
✓ Branch 3 taken 2838231 times.
2765542045 for (size_t i = 0; i < k; ++i) {
439 2756341623 const auto lhs_data_index = row * k + i;
440 2756341623 const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width;
441
2/4
✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 774427959 times.
2756341623 const auto lhs_value = read_array<LhsData>(lhs_data, lhs_data_index);
442
3/8
✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1981913664 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 774427959 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
2756341623 const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index)
443 : static_cast<LhsScale>(1);
444
2/4
✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 774427959 times.
4738255287 const auto lhs_zero_point = lhs_zero_points != nullptr
445
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1981913664 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1981913664 ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index)
446 : static_cast<LhsZeroPoint>(0);
447
448 2756341623 const auto rhs_data_index = col + i * n;
449 2756341623 const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width;
450
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1981913664 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 774427959 times.
2756341623 const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index);
451
3/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1981913664 times.
✓ Branch 2 taken 1981913664 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 774427959 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
2756341623 const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index)
452 : static_cast<RhsScale>(1);
453
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1981913664 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 774427959 times.
2756341623 const auto rhs_zero_point = rhs_zero_points != nullptr
454 ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index)
455 : static_cast<RhsZeroPoint>(0);
456
457
1/2
✓ Branch 0 taken 774427959 times.
✗ Branch 1 not taken.
6720168951 acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) *
458 2756341623 static_cast<DstData>(lhs_scale) *
459
2/4
✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 774427959 times.
✗ Branch 3 not taken.
2756341623 (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) *
460
1/2
✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
2756341623 static_cast<DstData>(rhs_scale);
461 2756341623 }
462
463
3/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6362191 times.
✓ Branch 2 taken 1263918 times.
✓ Branch 3 taken 1574313 times.
9200422 if (bias_data != nullptr) {
464
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6362191 times.
✓ Branch 2 taken 1574313 times.
✗ Branch 3 not taken.
7936504 const auto bias_value = read_array<BiasData>(bias_data, col);
465
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6362191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1574313 times.
7936504 const auto bias_scale = bias_scales != nullptr
466 ? read_array<BiasScale>(bias_scales, col / bias_quant_width)
467 : static_cast<BiasScale>(1);
468
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6362191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1574313 times.
7936504 const auto bias_zero_point = bias_zero_points != nullptr
469 ? read_array<BiasZeroPoint>(bias_zero_points, col / bias_quant_width)
470 : static_cast<BiasZeroPoint>(0);
471
472 15873008 acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) *
473 7936504 static_cast<DstData>(bias_scale);
474 7936504 }
475
476
4/8
✓ Branch 0 taken 6362191 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6362191 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2838231 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2838231 times.
✗ Branch 7 not taken.
9200422 write_array<DstData>(dst.data(), row * n + col, acc);
477 9200422 }
478 95754 }
479
480 972 return dst;
481 972 }
482
483 template Buffer
484 matmul_nt_nt_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>(
485 size_t m, size_t n, size_t k, //
486 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
487 size_t lhs_quant_width, //
488 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
489 size_t rhs_quant_width, //
490 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
491
492 template Buffer matmul_nt_nt_quantized<BFloat16<>, float, float, BFloat16<>, float, float, float, float, float, float>(
493 size_t m, size_t n, size_t k, //
494 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height,
495 size_t lhs_quant_width, //
496 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
497 size_t rhs_quant_width, //
498 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
499
500 template Buffer
501 indirect_matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>(
502 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, //
503 const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales,
504 const void* lhs_zero_points, size_t lhs_quant_height,
505 size_t lhs_quant_width, //
506 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height,
507 size_t rhs_quant_width, //
508 const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width);
509
510 template <
511 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
512 typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData>
513 5786 Buffer matmul_clamp_nt_t(
514 size_t m, size_t n, size_t k, //
515 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
516 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
517 const void* biases, //
518 DstData min_value, DstData max_value) {
519 KAI_ASSUME_ALWAYS(lhs_quant_width != 0);
520 KAI_ASSUME_ALWAYS(rhs_quant_width != 0);
521 5786 const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width);
522 5786 const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width);
523
524 5786 Buffer dst(m * n * sizeof(DstData));
525
526 5786 const auto* lhs_scales_ptr = reinterpret_cast<const LhsScale*>(lhs_scales);
527 5786 const auto* rhs_scales_ptr = reinterpret_cast<const RhsScale*>(rhs_scales);
528 5786 const auto* lhs_zero_points_ptr = reinterpret_cast<const LhsZeroPoint*>(lhs_zero_points);
529 5786 const auto* rhs_zero_points_ptr = reinterpret_cast<const RhsZeroPoint*>(rhs_zero_points);
530 5786 const auto* biases_ptr = reinterpret_cast<const Bias*>(biases);
531
3/8
✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
5786 auto* dst_ptr = reinterpret_cast<DstData*>(dst.data());
532
533
6/8
✓ Branch 0 taken 4800 times.
✓ Branch 1 taken 125760 times.
✓ Branch 2 taken 16972 times.
✓ Branch 3 taken 920 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✓ Branch 7 taken 654 times.
149172 for (size_t y = 0; y < m; ++y) {
534
6/8
✓ Branch 0 taken 125760 times.
✓ Branch 1 taken 17208000 times.
✓ Branch 2 taken 16972 times.
✓ Branch 3 taken 1304540 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 654 times.
✓ Branch 7 taken 43470 times.
18699396 for (size_t x = 0; x < n; ++x) {
535 18556010 DstData acc = 0;
536
537
6/8
✓ Branch 0 taken 865691520 times.
✓ Branch 1 taken 17208000 times.
✓ Branch 2 taken 82565120 times.
✓ Branch 3 taken 1304540 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1472394 times.
✓ Branch 7 taken 43470 times.
968285044 for (size_t i = 0; i < k; ++i) {
538
3/8
✓ Branch 0 taken 865691520 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 82565120 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1472394 times.
✗ Branch 7 not taken.
949729034 const auto lhs_value = read_array<LhsData>(lhs_data, y * k + i);
539 949729034 const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width];
540
3/8
✓ Branch 0 taken 865691520 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 82565120 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1472394 times.
✗ Branch 7 not taken.
949729034 const auto lhs_zero_point = lhs_zero_points_ptr != nullptr
541 867163914 ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]
542 : 0;
543
544
3/8
✗ Branch 0 not taken.
✓ Branch 1 taken 865691520 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 82565120 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1472394 times.
✗ Branch 7 not taken.
949729034 const auto rhs_value = read_array<RhsData>(rhs_data, x * k + i);
545 949729034 const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width];
546
3/8
✗ Branch 0 not taken.
✓ Branch 1 taken 865691520 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 82565120 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1472394 times.
949729034 const auto rhs_zero_point = rhs_zero_points_ptr != nullptr
547 ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width]
548 : 0;
549
550 949729034 acc += static_cast<DstData>(
551 951201428 (static_cast<IntAcc>(lhs_value) - static_cast<IntAcc>(lhs_zero_point)) *
552
2/6
✗ Branch 0 not taken.
✓ Branch 1 taken 865691520 times.
✓ Branch 2 taken 82565120 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
951201428 (static_cast<IntAcc>(rhs_value) - static_cast<IntAcc>(rhs_zero_point))) *
553
2/8
✓ Branch 0 taken 82565120 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 82565120 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
1816892948 static_cast<DstData>(lhs_scale) * static_cast<DstData>(rhs_scale);
554 949729034 }
555
556
3/8
✗ Branch 0 not taken.
✓ Branch 1 taken 17208000 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1304540 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 43470 times.
✗ Branch 7 not taken.
18556010 if (biases_ptr != nullptr) {
557 17251470 acc += static_cast<DstData>(biases_ptr[x]);
558 17251470 }
559
560
3/8
✓ Branch 0 taken 17208000 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1304540 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 43470 times.
✗ Branch 7 not taken.
18556010 acc = std::clamp(acc, min_value, max_value);
561 18556010 dst_ptr[y * n + x] = acc;
562 18556010 }
563 143386 }
564
565 5786 return dst;
566 5786 }
567
568 template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
569 size_t m, size_t n, size_t k, //
570 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
571 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
572 const void* biases, //
573 float min_value, float max_value);
574
575 template Buffer matmul_clamp_nt_t<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>(
576 size_t m, size_t n, size_t k, //
577 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
578 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
579 const void* biases, //
580 float min_value, float max_value);
581
582 template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, int32_t, float>(
583 size_t m, size_t n, size_t k, //
584 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
585 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
586 const void* biases, //
587 float min_value, float max_value);
588
589 template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>(
590 size_t m, size_t n, size_t k, //
591 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
592 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
593 const void* biases, //
594 float min_value, float max_value);
595
596 template <
597 typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale,
598 typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData>
599 4554 Buffer matmul_clamp_nt_nt(
600 size_t m, size_t n, size_t k, //
601 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
602 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
603 const void* biases, //
604 DstData min_value, DstData max_value) {
605 KAI_ASSUME_ALWAYS(lhs_quant_width != 0);
606 KAI_ASSUME_ALWAYS(rhs_quant_width != 0);
607 4554 const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width);
608 4554 const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width);
609
610 4554 Buffer dst(m * n * sizeof(DstData));
611
612 4554 const auto* lhs_scales_ptr = reinterpret_cast<const LhsScale*>(lhs_scales);
613 4554 const auto* rhs_scales_ptr = reinterpret_cast<const RhsScale*>(rhs_scales);
614 4554 const auto* lhs_zero_points_ptr = reinterpret_cast<const LhsZeroPoint*>(lhs_zero_points);
615 4554 const auto* rhs_zero_points_ptr = reinterpret_cast<const RhsZeroPoint*>(rhs_zero_points);
616 4554 const auto* biases_ptr = reinterpret_cast<const Bias*>(biases);
617
2/8
✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4488 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
4554 auto* dst_ptr = reinterpret_cast<DstData*>(dst.data());
618
619
4/8
✓ Branch 0 taken 66 times.
✓ Branch 1 taken 654 times.
✓ Branch 2 taken 4488 times.
✓ Branch 3 taken 117000 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
122208 for (size_t y = 0; y < m; ++y) {
620
4/8
✓ Branch 0 taken 654 times.
✓ Branch 1 taken 43470 times.
✓ Branch 2 taken 117000 times.
✓ Branch 3 taken 15801948 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
15963072 for (size_t x = 0; x < n; ++x) {
621 15845418 DstData acc = 0;
622
623
4/8
✓ Branch 0 taken 1472394 times.
✓ Branch 1 taken 43470 times.
✓ Branch 2 taken 796221588 times.
✓ Branch 3 taken 15801948 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
813539400 for (size_t i = 0; i < k; ++i) {
624
2/8
✓ Branch 0 taken 1472394 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 796221588 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
797693982 const auto lhs_value = read_array<LhsData>(lhs_data, y * k + i);
625 797693982 const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width];
626
2/8
✓ Branch 0 taken 1472394 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 796221588 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
797693982 const auto lhs_zero_point = lhs_zero_points_ptr != nullptr
627 797693982 ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]
628 : 0;
629
630
2/8
✓ Branch 0 taken 1472394 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 796221588 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
797693982 const auto rhs_value = read_array<RhsData>(rhs_data, x + i * n);
631 797693982 const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width];
632
2/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1472394 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 796221588 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
797693982 const auto rhs_zero_point = rhs_zero_points_ptr != nullptr
633 ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width]
634 : 0;
635
636 797693982 acc += static_cast<DstData>(
637 799166376 (static_cast<IntAcc>(lhs_value) - static_cast<IntAcc>(lhs_zero_point)) *
638
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 796221588 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
799166376 (static_cast<IntAcc>(rhs_value) - static_cast<IntAcc>(rhs_zero_point))) *
639
0/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
1595387964 static_cast<DstData>(lhs_scale) * static_cast<DstData>(rhs_scale);
640 797693982 }
641
642
3/8
✓ Branch 0 taken 43470 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 157374 times.
✓ Branch 3 taken 15644574 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
15845418 if (biases_ptr != nullptr) {
643 15688044 acc += static_cast<DstData>(biases_ptr[x]);
644 15688044 }
645
646
2/8
✓ Branch 0 taken 43470 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15801948 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
15845418 acc = std::clamp(acc, min_value, max_value);
647 15845418 dst_ptr[y * n + x] = acc;
648 15845418 }
649 117654 }
650
651 4554 return dst;
652 4554 }
653
654 template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>(
655 size_t m, size_t n, size_t k, //
656 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
657 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
658 const void* biases, //
659 float min_value, float max_value);
660 template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>(
661 size_t m, size_t n, size_t k, //
662 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
663 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
664 const void* biases, //
665 float min_value, float max_value);
666
667 template Buffer matmul_clamp_nt_nt<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>(
668 size_t m, size_t n, size_t k, //
669 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
670 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
671 const void* biases, //
672 float min_value, float max_value);
673
674 template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, int32_t, float>(
675 size_t m, size_t n, size_t k, //
676 const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, //
677 const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, //
678 const void* biases, //
679 float min_value, float max_value);
680
681 } // namespace kai::test
682