Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/matmul.hpp" | ||
8 | |||
9 | #include <algorithm> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | |||
13 | #include "kai/kai_common.h" | ||
14 | #include "test/common/buffer.hpp" | ||
15 | #include "test/common/data_format.hpp" | ||
16 | #include "test/common/data_type.hpp" | ||
17 | #include "test/common/float16.hpp" | ||
18 | #include "test/common/int4.hpp" | ||
19 | #include "test/common/memory.hpp" | ||
20 | #include "test/common/round.hpp" | ||
21 | #include "test/reference/binary_elementwise.hpp" | ||
22 | #include "test/reference/cast.hpp" | ||
23 | #include "test/reference/pack.hpp" | ||
24 | #include "test/reference/reduce.hpp" | ||
25 | #include "test/reference/transpose.hpp" | ||
26 | |||
27 | namespace kai::test { | ||
28 | |||
29 | namespace { | ||
30 | |||
31 | /// Matrix multiplication. | ||
32 | /// | ||
33 | /// @tparam T Data type. | ||
34 | /// | ||
35 | /// @param[in] lhs LHS operand data buffer. | ||
36 | /// @param[in] rhs RHS operand data buffer. | ||
37 | /// @param[in] m Output height. | ||
38 | /// @param[in] n Output width. | ||
39 | /// @param[in] k Non-transposed LHS width and non-transposed RHS height. | ||
40 | /// @param[in] lhs_transposed `true` if LHS operand is transposed. | ||
41 | /// @param[in] rhs_transposed `true` if RHS operand is transposed. | ||
42 | /// | ||
43 | /// @return The result data buffer. | ||
44 | template <typename T> | ||
45 | 2387 | Buffer matmul_any_type( | |
46 | const void* lhs, const void* rhs, // | ||
47 | size_t m, size_t n, size_t k, // | ||
48 | bool lhs_transposed, bool rhs_transposed) { | ||
49 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1196 times.
|
2387 | const auto lhs_m_stride = lhs_transposed ? 1 : k; |
50 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1196 times.
|
2387 | const auto lhs_k_stride = lhs_transposed ? m : 1; |
51 | |||
52 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1196 times.
|
2387 | const auto rhs_n_stride = rhs_transposed ? k : 1; |
53 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1196 times.
|
2387 | const auto rhs_k_stride = rhs_transposed ? 1 : n; |
54 | |||
55 | 2387 | Buffer dst(m * n * size_in_bits<T> / 8); | |
56 | − | KAI_ASSUME(n * size_in_bits<T> % 8 == 0); | |
57 | |||
58 |
4/4✓ Branch 0 taken 36473 times.
✓ Branch 1 taken 1191 times.
✓ Branch 2 taken 37072 times.
✓ Branch 3 taken 1196 times.
|
75932 | for (size_t im = 0; im < m; ++im) { |
59 |
4/4✓ Branch 0 taken 36473 times.
✓ Branch 1 taken 2892060 times.
✓ Branch 2 taken 37072 times.
✓ Branch 3 taken 2916561 times.
|
5882166 | for (size_t in = 0; in < n; ++in) { |
60 |
1/2✓ Branch 0 taken 2916561 times.
✗ Branch 1 not taken.
|
5808621 | T acc{0}; |
61 | |||
62 |
4/4✓ Branch 0 taken 599017351 times.
✓ Branch 1 taken 2892060 times.
✓ Branch 2 taken 600720321 times.
✓ Branch 3 taken 2916561 times.
|
1205546293 | for (size_t ik = 0; ik < k; ++ik) { |
63 |
2/4✓ Branch 0 taken 599017351 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 600720321 times.
|
1199737672 | const auto lhs_value = read_array<T>(lhs, im * lhs_m_stride + ik * lhs_k_stride); |
64 |
2/4✓ Branch 0 taken 599017351 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600720321 times.
✗ Branch 3 not taken.
|
1199737672 | const auto rhs_value = read_array<T>(rhs, in * rhs_n_stride + ik * rhs_k_stride); |
65 |
2/4✓ Branch 0 taken 600720321 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600720321 times.
✗ Branch 3 not taken.
|
1199737672 | acc += lhs_value * rhs_value; |
66 | 1199737672 | } | |
67 | |||
68 |
2/4✓ Branch 0 taken 2892060 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2916561 times.
✗ Branch 3 not taken.
|
5808621 | write_array<T>(dst.data(), im * n + in, acc); |
69 | 5808621 | } | |
70 | 73545 | } | |
71 | |||
72 | 2387 | return dst; | |
73 | 2387 | } | |
74 | |||
75 | } // namespace | ||
76 | |||
77 | 125 | Buffer matmul_pack_rhs( | |
78 | const void* data, const void* scales, const void* zero_points, const DataFormat& src_format, | ||
79 | const DataFormat& dst_format, size_t n, size_t k, bool transposing) { | ||
80 | 125 | const auto src_dt = src_format.data_type(); | |
81 | 125 | const auto src_pf = src_format.pack_format(); | |
82 | |||
83 | 125 | const auto dst_dt = dst_format.data_type(); | |
84 | 125 | const auto dst_pf = dst_format.pack_format(); | |
85 | |||
86 | 125 | Buffer tmp_data; | |
87 | 125 | Buffer tmp_scales; | |
88 | 125 | Buffer tmp_zero_points; | |
89 | |||
90 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 125 times.
|
125 | if (transposing) { |
91 |
1/2✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
|
125 | tmp_data = transpose(data, src_dt, k, n); |
92 |
1/2✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
|
125 | data = tmp_data.data(); |
93 | 125 | } | |
94 | |||
95 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 125 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
125 | if (src_dt == DataType::QSU4 && src_pf == DataFormat::PackFormat::NONE && // |
96 | ✗ | dst_dt == DataType::QSI4 && dst_pf == DataFormat::PackFormat::QUANTIZE_PER_ROW) { | |
97 | // For this specific RHS format conversion: | ||
98 | // | ||
99 | // * 4-bit data is added by 8. | ||
100 | // * Scale is divided by 16. | ||
101 | // * Zero point is accumulation of all values in the same row. | ||
102 | |||
103 | − | KAI_ASSUME(zero_points == nullptr); | |
104 | ✗ | const int32_t zero_point = 8; | |
105 | ✗ | const uint8_t zero_point_i4 = UInt4::pack_u8(UInt4(zero_point), UInt4(zero_point)); | |
106 | ✗ | const int32_t row_zero_point = zero_point * static_cast<int32_t>(k); | |
107 | |||
108 | − | KAI_ASSUME(dst_format.subblock_width() > 0); | |
109 | ✗ | const auto subblock_width_i32 = static_cast<int32_t>(dst_format.subblock_width()); | |
110 | ✗ | const auto subblock_width_f = static_cast<float>(dst_format.subblock_width()); | |
111 | |||
112 | ✗ | tmp_zero_points = reduce_add(data, src_format, n, k, DataFormat(DataType::I32), 0); | |
113 | ✗ | tmp_zero_points = sub(tmp_zero_points.data(), DataType::I32, n, 1, &row_zero_point, DataType::I32, 1, 1); | |
114 | ✗ | tmp_zero_points = mul(tmp_zero_points.data(), DataType::I32, n, 1, &subblock_width_i32, DataType::I32, 1, 1); | |
115 | ✗ | zero_points = tmp_zero_points.data(); | |
116 | |||
117 | ✗ | tmp_data = add(data, DataType::QSU4, n, k, &zero_point_i4, DataType::QSU4, 1, 1); | |
118 | ✗ | data = tmp_data.data(); | |
119 | |||
120 | ✗ | tmp_scales = div(scales, DataType::FP32, n, 1, &subblock_width_f, DataType::FP32, 1, 1); | |
121 | ✗ | scales = tmp_scales.data(); | |
122 | ✗ | } | |
123 | |||
124 |
1/2✓ Branch 0 taken 125 times.
✗ Branch 1 not taken.
|
125 | return pack(dst_format, data, scales, zero_points, src_format, n, k); |
125 | 125 | } | |
126 | |||
127 | 2387 | Buffer matmul( | |
128 | const void* lhs, [[maybe_unused]] const void* lhs_scales, [[maybe_unused]] const void* lhs_zero_points, | ||
129 | DataType lhs_dt, // | ||
130 | const void* rhs, [[maybe_unused]] const void* rhs_scales, [[maybe_unused]] const void* rhs_zero_points, | ||
131 | DataType rhs_dt, // | ||
132 | const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // | ||
133 | DataType dst_dt, // | ||
134 | size_t m, size_t n, size_t k, // | ||
135 | bool lhs_transposed, bool rhs_transposed) { | ||
136 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2387 times.
|
2387 | const auto lhs_h = lhs_transposed ? k : m; |
137 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2387 times.
|
2387 | const auto lhs_w = lhs_transposed ? m : k; |
138 | |||
139 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2387 times.
|
2387 | const auto rhs_h = rhs_transposed ? n : k; |
140 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2387 times.
|
2387 | const auto rhs_w = rhs_transposed ? k : n; |
141 | |||
142 | 2387 | Buffer tmp_lhs; | |
143 | 2387 | Buffer tmp_rhs; | |
144 | 2387 | Buffer tmp_dst; | |
145 | 2387 | Buffer tmp_bias; | |
146 | |||
147 |
1/2✓ Branch 0 taken 2387 times.
✗ Branch 1 not taken.
|
2387 | if (lhs_dt != dst_dt) { |
148 | ✗ | tmp_lhs = cast(lhs, lhs_dt, dst_dt, lhs_h, lhs_w); | |
149 | ✗ | lhs = tmp_lhs.data(); | |
150 | ✗ | } | |
151 | |||
152 |
1/2✓ Branch 0 taken 2387 times.
✗ Branch 1 not taken.
|
2387 | if (rhs_dt != dst_dt) { |
153 | ✗ | tmp_rhs = cast(rhs, rhs_dt, dst_dt, rhs_h, rhs_w); | |
154 | ✗ | rhs = tmp_rhs.data(); | |
155 | ✗ | } | |
156 | |||
157 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 1191 times.
✓ Branch 2 taken 1196 times.
|
2387 | switch (dst_dt) { |
158 | case DataType::FP32: | ||
159 |
1/2✓ Branch 0 taken 1191 times.
✗ Branch 1 not taken.
|
1191 | tmp_dst = matmul_any_type<float>(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed); |
160 | 1191 | break; | |
161 | |||
162 | case DataType::FP16: | ||
163 |
1/2✓ Branch 0 taken 1196 times.
✗ Branch 1 not taken.
|
1196 | tmp_dst = matmul_any_type<Float16>(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed); |
164 | 1196 | break; | |
165 | |||
166 | default: | ||
167 | − | KAI_ERROR("Unknown data type!"); | |
168 | ✗ | } | |
169 | |||
170 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 2378 times.
|
2387 | if (bias != nullptr) { |
171 |
1/2✓ Branch 0 taken 2378 times.
✗ Branch 1 not taken.
|
2378 | if (bias_dt != dst_dt) { |
172 | ✗ | tmp_bias = cast(bias, bias_dt, dst_dt, 1, n); | |
173 | ✗ | bias = tmp_bias.data(); | |
174 | ✗ | } | |
175 | |||
176 | − | KAI_ASSUME(!data_type_is_quantized(bias_dt)); | |
177 | − | KAI_ASSUME(bias_scales == nullptr); | |
178 | − | KAI_ASSUME(bias_zero_points == nullptr); | |
179 | |||
180 |
2/4✓ Branch 0 taken 2378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2378 times.
✗ Branch 3 not taken.
|
2378 | tmp_dst = add(tmp_dst.data(), dst_dt, m, n, bias, bias_dt, 1, n); |
181 | 2378 | } | |
182 | |||
183 | 2387 | return tmp_dst; | |
184 | 2387 | } | |
185 | |||
186 | 2244 | Buffer indirect_matmul( | |
187 | const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, | ||
188 | const void* lhs_zero_points, | ||
189 | DataType lhs_dt, // | ||
190 | const void* rhs, const void* rhs_scales, const void* rhs_zero_points, | ||
191 | DataType rhs_dt, // | ||
192 | const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // | ||
193 | DataType dst_dt, // | ||
194 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length) { | ||
195 | // This is inefficient, but allows code-reuse | ||
196 | 2244 | const size_t chunk_bytes = k_chunk_length * round_up_division(data_type_size_in_bits(lhs_dt), 8); | |
197 | 2244 | const size_t n_chunks = m * k_chunk_count; | |
198 | 2244 | Buffer lhs(n_chunks * chunk_bytes); | |
199 | |||
200 | // Copy all chunks to the created matrix | ||
201 |
2/2✓ Branch 0 taken 858660 times.
✓ Branch 1 taken 2244 times.
|
860904 | for (size_t i = 0; i < n_chunks; i += 1) { |
202 | 858660 | const uint8_t* src_pointer = static_cast<const uint8_t*>(lhs_idata[i]); | |
203 |
2/2✓ Branch 0 taken 24156 times.
✓ Branch 1 taken 834504 times.
|
858660 | if (src_pointer != lhs_padding_ptr) { |
204 | 834504 | src_pointer += lhs_offset; | |
205 | 834504 | } | |
206 |
1/2✓ Branch 0 taken 858660 times.
✗ Branch 1 not taken.
|
858660 | memcpy(lhs.data() + i * chunk_bytes, src_pointer, chunk_bytes); |
207 | 858660 | } | |
208 | |||
209 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
2244 | return matmul( |
210 |
1/2✓ Branch 0 taken 2244 times.
✗ Branch 1 not taken.
|
2244 | lhs.data(), lhs_scales, lhs_zero_points, lhs_dt, // |
211 | 2244 | rhs, rhs_scales, rhs_zero_points, rhs_dt, // | |
212 | 2244 | bias, bias_scales, bias_zero_points, bias_dt, // | |
213 | 2244 | dst_dt, m, n, k_chunk_count * k_chunk_length, false, false); | |
214 | 2244 | } | |
215 | |||
216 | template < | ||
217 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
218 | typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> | ||
219 | 927 | Buffer indirect_matmul_nt_t_quantized( | |
220 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // | ||
221 | const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, | ||
222 | const void* lhs_zero_points, size_t lhs_quant_height, | ||
223 | size_t lhs_quant_width, // | ||
224 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
225 | size_t rhs_quant_width, // | ||
226 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width) { | ||
227 | − | KAI_ASSUME(lhs_quant_width != 0); | |
228 | − | KAI_ASSUME(rhs_quant_width != 0); | |
229 | − | KAI_ASSUME(lhs_quant_height != 0); | |
230 | − | KAI_ASSUME(rhs_quant_height != 0); | |
231 | − | KAI_ASSUME(bias_quant_width != 0); | |
232 | 927 | const auto lhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, lhs_quant_width); | |
233 | 927 | const auto rhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, rhs_quant_width); | |
234 | |||
235 | 927 | Buffer dst(m * n * sizeof(DstData)); | |
236 | |||
237 |
2/2✓ Branch 0 taken 21087 times.
✓ Branch 1 taken 927 times.
|
22014 | for (size_t i_m = 0; i_m < m; ++i_m) { |
238 |
2/2✓ Branch 0 taken 21087 times.
✓ Branch 1 taken 1709826 times.
|
1730913 | for (size_t i_n = 0; i_n < n; ++i_n) { |
239 | 1709826 | DstData acc = 0; | |
240 | |||
241 |
2/2✓ Branch 0 taken 19596618 times.
✓ Branch 1 taken 1709826 times.
|
21306444 | for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; ++i_k_chunk) { |
242 | // Calculate the K chunk pointer. Apply offset if this is not padding | ||
243 | 19596618 | const size_t k_chunk_idx = i_m * k_chunk_count + i_k_chunk; | |
244 | 19596618 | const void* k_chunk_ptr = lhs_ptrs[k_chunk_idx]; | |
245 |
2/2✓ Branch 0 taken 1182174 times.
✓ Branch 1 taken 18414444 times.
|
19596618 | if (k_chunk_ptr != lhs_padding_ptr) { |
246 | 18414444 | k_chunk_ptr = reinterpret_cast<const void*>(reinterpret_cast<uintptr_t>(k_chunk_ptr) + lhs_offset); | |
247 | 18414444 | } | |
248 | |||
249 |
2/2✓ Branch 0 taken 171706017 times.
✓ Branch 1 taken 19596618 times.
|
191302635 | for (size_t i_k_chunk_len = 0; i_k_chunk_len < k_chunk_length; ++i_k_chunk_len) { |
250 | 171706017 | const size_t i = i_k_chunk * k_chunk_length + i_k_chunk_len; | |
251 | |||
252 | 171706017 | const auto lhs_data_index = i_k_chunk_len; | |
253 | 171706017 | const auto lhs_quant_index = (i_m / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; | |
254 |
1/2✓ Branch 0 taken 171706017 times.
✗ Branch 1 not taken.
|
171706017 | const auto lhs_value = read_array<LhsData>(k_chunk_ptr, lhs_data_index); |
255 |
2/4✓ Branch 0 taken 171706017 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 171706017 times.
|
171706017 | const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index) |
256 | : static_cast<LhsScale>(1); | ||
257 |
1/2✓ Branch 0 taken 171706017 times.
✗ Branch 1 not taken.
|
343412034 | const auto lhs_zero_point = lhs_zero_points != nullptr |
258 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 171706017 times.
|
171706017 | ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index) |
259 | : static_cast<LhsZeroPoint>(0); | ||
260 | |||
261 | 171706017 | const auto rhs_data_index = i_n * (k_chunk_count * k_chunk_length) + i; | |
262 | 171706017 | const auto rhs_quant_index = (i_n / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; | |
263 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 171706017 times.
|
171706017 | const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index); |
264 |
2/4✓ Branch 0 taken 171706017 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 171706017 times.
|
171706017 | const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index) |
265 | : static_cast<RhsScale>(1); | ||
266 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 171706017 times.
|
171706017 | const auto rhs_zero_point = rhs_zero_points != nullptr |
267 | ✗ | ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index) | |
268 | : static_cast<RhsZeroPoint>(0); | ||
269 | |||
270 | 515118051 | acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) * | |
271 | 343412034 | static_cast<DstData>(lhs_scale) * | |
272 | 171706017 | (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) * | |
273 | 171706017 | static_cast<DstData>(rhs_scale); | |
274 | 171706017 | } | |
275 | 19596618 | } | |
276 | |||
277 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1709826 times.
|
1709826 | if (bias_data != nullptr) { |
278 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1709826 times.
|
1709826 | const auto bias_value = read_array<BiasData>(bias_data, i_n); |
279 |
1/2✓ Branch 0 taken 1709826 times.
✗ Branch 1 not taken.
|
3419652 | const auto bias_scale = bias_scales != nullptr |
280 |
1/2✓ Branch 0 taken 1709826 times.
✗ Branch 1 not taken.
|
1709826 | ? read_array<BiasScale>(bias_scales, i_n / bias_quant_width) |
281 | : static_cast<BiasScale>(1); | ||
282 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1709826 times.
|
1709826 | const auto bias_zero_point = bias_zero_points != nullptr |
283 | ✗ | ? read_array<BiasZeroPoint>(bias_zero_points, i_n / bias_quant_width) | |
284 | : static_cast<BiasZeroPoint>(0); | ||
285 | |||
286 | 3419652 | acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) * | |
287 | 1709826 | static_cast<DstData>(bias_scale); | |
288 | 1709826 | } | |
289 | |||
290 |
2/4✓ Branch 0 taken 1709826 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1709826 times.
✗ Branch 3 not taken.
|
1709826 | write_array<DstData>(dst.data(), i_m * n + i_n, acc); |
291 | 1709826 | } | |
292 | 21087 | } | |
293 | |||
294 | 927 | return dst; | |
295 | 927 | } | |
296 | |||
297 | template < | ||
298 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
299 | typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> | ||
300 | 7426 | Buffer matmul_nt_t_quantized( | |
301 | size_t m, size_t n, size_t k, // | ||
302 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, // | ||
303 | size_t lhs_quant_height, size_t lhs_quant_width, // | ||
304 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, // | ||
305 | size_t rhs_quant_height, size_t rhs_quant_width, // | ||
306 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, // | ||
307 | size_t bias_quant_width) { | ||
308 | − | KAI_ASSUME(lhs_quant_width != 0); | |
309 | − | KAI_ASSUME(rhs_quant_width != 0); | |
310 | − | KAI_ASSUME(lhs_quant_height != 0); | |
311 | − | KAI_ASSUME(rhs_quant_height != 0); | |
312 | − | KAI_ASSUME(bias_quant_width != 0); | |
313 | |||
314 | 7426 | const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); | |
315 | 7426 | const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); | |
316 | |||
317 | 7426 | Buffer dst(m * n * sizeof(DstData)); | |
318 | |||
319 |
6/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 6164 times.
✓ Branch 3 taken 72384 times.
✓ Branch 4 taken 560 times.
✓ Branch 5 taken 6524 times.
✓ Branch 6 taken 702 times.
✓ Branch 7 taken 16510 times.
|
102844 | for (size_t row = 0; row < m; ++row) { |
320 |
6/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 5416304 times.
✓ Branch 3 taken 72384 times.
✓ Branch 4 taken 6524 times.
✓ Branch 5 taken 516260 times.
✓ Branch 6 taken 16510 times.
✓ Branch 7 taken 2832622 times.
|
8860604 | for (size_t col = 0; col < n; ++col) { |
321 | 8765186 | DstData acc = 0; | |
322 | |||
323 |
6/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 317394496 times.
✓ Branch 3 taken 5416304 times.
✓ Branch 4 taken 24804860 times.
✓ Branch 5 taken 516260 times.
✓ Branch 6 taken 696329088 times.
✓ Branch 7 taken 2832622 times.
|
1047293630 | for (size_t i = 0; i < k; ++i) { |
324 | 1038528444 | const auto lhs_data_index = row * k + i; | |
325 | 1038528444 | const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; | |
326 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 317394496 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24804860 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 696329088 times.
✗ Branch 7 not taken.
|
1038528444 | const auto lhs_value = read_array<LhsData>(lhs_data, lhs_data_index); |
327 |
6/16✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 317394496 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 317394496 times.
✓ Branch 8 taken 24804860 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 24804860 times.
✓ Branch 12 taken 696329088 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 696329088 times.
|
1038528444 | const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index) |
328 | : static_cast<LhsScale>(1); | ||
329 |
4/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 64700992 times.
✓ Branch 3 taken 252693504 times.
✓ Branch 4 taken 24804860 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 696329088 times.
✗ Branch 7 not taken.
|
1824363384 | const auto lhs_zero_point = lhs_zero_points != nullptr |
330 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 64700992 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 24804860 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 696329088 times.
|
785834940 | ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index) |
331 | : static_cast<LhsZeroPoint>(0); | ||
332 | |||
333 | 1038528444 | const auto rhs_data_index = col * k + i; | |
334 | 1038528444 | const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; | |
335 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 317394496 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 24804860 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 696329088 times.
|
1038528444 | const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index); |
336 |
6/16✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 317394496 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 317394496 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 24804860 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 24804860 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 696329088 times.
✓ Branch 14 taken 696329088 times.
✗ Branch 15 not taken.
|
1038528444 | const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index) |
337 | ✗ | : static_cast<RhsScale>(1); | |
338 |
4/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 252693504 times.
✓ Branch 3 taken 64700992 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 24804860 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 696329088 times.
|
1291221948 | const auto rhs_zero_point = rhs_zero_points != nullptr |
339 |
1/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 252693504 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
252693504 | ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index) |
340 | : static_cast<RhsZeroPoint>(0); | ||
341 | |||
342 | 2798190836 | acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) * | |
343 | 1063333304 | static_cast<DstData>(lhs_scale) * | |
344 |
2/4✓ Branch 0 taken 317394496 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 696329088 times.
✗ Branch 3 not taken.
|
1038528444 | (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) * |
345 |
1/2✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
|
1038528444 | static_cast<DstData>(rhs_scale); |
346 | 1038528444 | } | |
347 | |||
348 |
5/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 2708152 times.
✓ Branch 3 taken 2708152 times.
✓ Branch 4 taken 258130 times.
✓ Branch 5 taken 258130 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2832622 times.
|
8765186 | if (bias_data != nullptr) { |
349 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2708152 times.
✓ Branch 4 taken 258130 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2832622 times.
|
5798904 | const auto bias_value = read_array<BiasData>(bias_data, col); |
350 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2708152 times.
✓ Branch 4 taken 258130 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2832622 times.
|
5798904 | const auto bias_scale = bias_scales != nullptr |
351 | ✗ | ? read_array<BiasScale>(bias_scales, col / bias_quant_width) | |
352 | : static_cast<BiasScale>(1); | ||
353 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2708152 times.
✓ Branch 4 taken 258130 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2832622 times.
|
5798904 | const auto bias_zero_point = bias_zero_points != nullptr |
354 | ✗ | ? read_array<BiasZeroPoint>(bias_zero_points, col / bias_quant_width) | |
355 | : static_cast<BiasZeroPoint>(0); | ||
356 | |||
357 | 11597808 | acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) * | |
358 | 5798904 | static_cast<DstData>(bias_scale); | |
359 | 5798904 | } | |
360 | |||
361 |
6/16✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 5416304 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 5416304 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 516260 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 516260 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2832622 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2832622 times.
✗ Branch 15 not taken.
|
8765186 | write_array<DstData>(dst.data(), row * n + col, acc); |
362 | 8765186 | } | |
363 | 95418 | } | |
364 | |||
365 | 7426 | return dst; | |
366 | 7426 | } | |
367 | |||
368 | template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>( | ||
369 | size_t m, size_t n, size_t k, // | ||
370 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
371 | size_t lhs_quant_width, // | ||
372 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
373 | size_t rhs_quant_width, // | ||
374 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
375 | |||
376 | template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>( | ||
377 | size_t m, size_t n, size_t k, // | ||
378 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
379 | size_t lhs_quant_width, // | ||
380 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
381 | size_t rhs_quant_width, // | ||
382 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
383 | |||
384 | template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, float, float, int32_t, float>( | ||
385 | size_t m, size_t n, size_t k, // | ||
386 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
387 | size_t lhs_quant_width, // | ||
388 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
389 | size_t rhs_quant_width, // | ||
390 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
391 | |||
392 | template Buffer | ||
393 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
394 | size_t m, size_t n, size_t k, // | ||
395 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
396 | size_t lhs_quant_width, // | ||
397 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
398 | size_t rhs_quant_width, // | ||
399 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
400 | |||
401 | template < | ||
402 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
403 | typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> | ||
404 | 776 | Buffer matmul_nt_nt_quantized( | |
405 | size_t m, size_t n, size_t k, // | ||
406 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, // | ||
407 | size_t lhs_quant_height, size_t lhs_quant_width, // | ||
408 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, // | ||
409 | size_t rhs_quant_height, size_t rhs_quant_width, // | ||
410 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, // | ||
411 | size_t bias_quant_width) { | ||
412 | − | KAI_ASSUME(lhs_quant_width != 0); | |
413 | − | KAI_ASSUME(rhs_quant_width != 0); | |
414 | − | KAI_ASSUME(lhs_quant_height != 0); | |
415 | − | KAI_ASSUME(rhs_quant_height != 0); | |
416 | − | KAI_ASSUME(bias_quant_width != 0); | |
417 | |||
418 | 776 | const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); | |
419 | 776 | const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); | |
420 | |||
421 | 776 | Buffer dst(m * n * sizeof(DstData)); | |
422 | |||
423 |
4/4✓ Branch 0 taken 702 times.
✓ Branch 1 taken 16510 times.
✓ Branch 2 taken 74 times.
✓ Branch 3 taken 8279 times.
|
25565 | for (size_t row = 0; row < m; ++row) { |
424 |
4/4✓ Branch 0 taken 16510 times.
✓ Branch 1 taken 2832622 times.
✓ Branch 2 taken 524771 times.
✓ Branch 3 taken 8279 times.
|
3382182 | for (size_t col = 0; col < n; ++col) { |
425 | 3357393 | DstData acc = 0; | |
426 | |||
427 |
4/4✓ Branch 0 taken 696329088 times.
✓ Branch 1 taken 2832622 times.
✓ Branch 2 taken 143126255 times.
✓ Branch 3 taken 524771 times.
|
842812736 | for (size_t i = 0; i < k; ++i) { |
428 | 839455343 | const auto lhs_data_index = row * k + i; | |
429 | 839455343 | const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; | |
430 |
2/4✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 143126255 times.
|
839455343 | const auto lhs_value = read_array<LhsData>(lhs_data, lhs_data_index); |
431 |
3/8✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 696329088 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 143126255 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
839455343 | const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index) |
432 | : static_cast<LhsScale>(1); | ||
433 |
2/4✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 143126255 times.
|
1535784431 | const auto lhs_zero_point = lhs_zero_points != nullptr |
434 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 696329088 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
696329088 | ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index) |
435 | : static_cast<LhsZeroPoint>(0); | ||
436 | |||
437 | 839455343 | const auto rhs_data_index = col + i * n; | |
438 | 839455343 | const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; | |
439 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 696329088 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 143126255 times.
|
839455343 | const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index); |
440 |
3/8✗ Branch 0 not taken.
✓ Branch 1 taken 696329088 times.
✓ Branch 2 taken 696329088 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 143126255 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
839455343 | const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index) |
441 | ✗ | : static_cast<RhsScale>(1); | |
442 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 696329088 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 143126255 times.
|
839455343 | const auto rhs_zero_point = rhs_zero_points != nullptr |
443 | ✗ | ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index) | |
444 | : static_cast<RhsZeroPoint>(0); | ||
445 | |||
446 |
1/2✓ Branch 0 taken 143126255 times.
✗ Branch 1 not taken.
|
2232113519 | acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) * |
447 | 839455343 | static_cast<DstData>(lhs_scale) * | |
448 |
2/4✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 143126255 times.
✗ Branch 3 not taken.
|
839455343 | (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) * |
449 |
1/2✓ Branch 0 taken 696329088 times.
✗ Branch 1 not taken.
|
839455343 | static_cast<DstData>(rhs_scale); |
450 | 839455343 | } | |
451 | |||
452 |
3/4✗ Branch 0 not taken.
✓ Branch 1 taken 2832622 times.
✓ Branch 2 taken 210653 times.
✓ Branch 3 taken 314118 times.
|
3357393 | if (bias_data != nullptr) { |
453 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2832622 times.
✓ Branch 2 taken 314118 times.
✗ Branch 3 not taken.
|
3146740 | const auto bias_value = read_array<BiasData>(bias_data, col); |
454 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2832622 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 314118 times.
|
3146740 | const auto bias_scale = bias_scales != nullptr |
455 | ✗ | ? read_array<BiasScale>(bias_scales, col / bias_quant_width) | |
456 | : static_cast<BiasScale>(1); | ||
457 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2832622 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 314118 times.
|
3146740 | const auto bias_zero_point = bias_zero_points != nullptr |
458 | ✗ | ? read_array<BiasZeroPoint>(bias_zero_points, col / bias_quant_width) | |
459 | : static_cast<BiasZeroPoint>(0); | ||
460 | |||
461 | 6293480 | acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) * | |
462 | 3146740 | static_cast<DstData>(bias_scale); | |
463 | 3146740 | } | |
464 | |||
465 |
4/8✓ Branch 0 taken 2832622 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2832622 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 524771 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 524771 times.
✗ Branch 7 not taken.
|
3357393 | write_array<DstData>(dst.data(), row * n + col, acc); |
466 | 3357393 | } | |
467 | 24789 | } | |
468 | |||
469 | 776 | return dst; | |
470 | 776 | } | |
471 | |||
472 | template Buffer | ||
473 | matmul_nt_nt_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
474 | size_t m, size_t n, size_t k, // | ||
475 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
476 | size_t lhs_quant_width, // | ||
477 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
478 | size_t rhs_quant_width, // | ||
479 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
480 | |||
481 | template Buffer matmul_nt_nt_quantized<BFloat16<>, float, float, BFloat16<>, float, float, float, float, float, float>( | ||
482 | size_t m, size_t n, size_t k, // | ||
483 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
484 | size_t lhs_quant_width, // | ||
485 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
486 | size_t rhs_quant_width, // | ||
487 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
488 | |||
489 | template Buffer | ||
490 | indirect_matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>( | ||
491 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // | ||
492 | const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, | ||
493 | const void* lhs_zero_points, size_t lhs_quant_height, | ||
494 | size_t lhs_quant_width, // | ||
495 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
496 | size_t rhs_quant_width, // | ||
497 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
498 | |||
499 | template < | ||
500 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
501 | typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> | ||
502 | 1620 | Buffer matmul_clamp_nt_t( | |
503 | size_t m, size_t n, size_t k, // | ||
504 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
505 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
506 | const void* biases, // | ||
507 | DstData min_value, DstData max_value) { | ||
508 | − | KAI_ASSUME(lhs_quant_width != 0); | |
509 | − | KAI_ASSUME(rhs_quant_width != 0); | |
510 | 1620 | const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); | |
511 | 1620 | const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); | |
512 | |||
513 | 1620 | Buffer dst(m * n * sizeof(DstData)); | |
514 | |||
515 | 1620 | const auto* lhs_scales_ptr = reinterpret_cast<const LhsScale*>(lhs_scales); | |
516 | 1620 | const auto* rhs_scales_ptr = reinterpret_cast<const RhsScale*>(rhs_scales); | |
517 | 1620 | const auto* lhs_zero_points_ptr = reinterpret_cast<const LhsZeroPoint*>(lhs_zero_points); | |
518 | 1620 | const auto* rhs_zero_points_ptr = reinterpret_cast<const RhsZeroPoint*>(rhs_zero_points); | |
519 | 1620 | const auto* biases_ptr = reinterpret_cast<const Bias*>(biases); | |
520 |
3/8✓ Branch 0 taken 880 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 124 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 616 times.
✗ Branch 7 not taken.
|
1620 | auto* dst_ptr = reinterpret_cast<DstData*>(dst.data()); |
521 | |||
522 |
6/8✓ Branch 0 taken 880 times.
✓ Branch 1 taken 23056 times.
✓ Branch 2 taken 2984 times.
✓ Branch 3 taken 124 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 616 times.
✓ Branch 7 taken 6104 times.
|
33764 | for (size_t y = 0; y < m; ++y) { |
523 |
6/8✓ Branch 0 taken 23056 times.
✓ Branch 1 taken 3154800 times.
✓ Branch 2 taken 2984 times.
✓ Branch 3 taken 218428 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 6104 times.
✓ Branch 7 taken 405720 times.
|
3811092 | for (size_t x = 0; x < n; ++x) { |
524 | 3778948 | DstData acc = 0; | |
525 | |||
526 |
6/8✓ Branch 0 taken 158710112 times.
✓ Branch 1 taken 3154800 times.
✓ Branch 2 taken 13670656 times.
✓ Branch 3 taken 218428 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 13742344 times.
✓ Branch 7 taken 405720 times.
|
189902060 | for (size_t i = 0; i < k; ++i) { |
527 |
3/8✓ Branch 0 taken 158710112 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13670656 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 13742344 times.
✗ Branch 7 not taken.
|
186123112 | const auto lhs_value = read_array<LhsData>(lhs_data, y * k + i); |
528 | 186123112 | const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]; | |
529 |
3/8✓ Branch 0 taken 158710112 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 13670656 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 13742344 times.
✗ Branch 7 not taken.
|
186123112 | const auto lhs_zero_point = lhs_zero_points_ptr != nullptr |
530 | 172452456 | ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width] | |
531 | : 0; | ||
532 | |||
533 |
3/8✗ Branch 0 not taken.
✓ Branch 1 taken 158710112 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 13670656 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 13742344 times.
✗ Branch 7 not taken.
|
186123112 | const auto rhs_value = read_array<RhsData>(rhs_data, x * k + i); |
534 | 186123112 | const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width]; | |
535 |
3/8✗ Branch 0 not taken.
✓ Branch 1 taken 158710112 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 13670656 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 13742344 times.
|
186123112 | const auto rhs_zero_point = rhs_zero_points_ptr != nullptr |
536 | ✗ | ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width] | |
537 | : 0; | ||
538 | |||
539 | 186123112 | acc += static_cast<DstData>( | |
540 | 199865456 | (static_cast<IntAcc>(lhs_value) - static_cast<IntAcc>(lhs_zero_point)) * | |
541 |
2/6✗ Branch 0 not taken.
✓ Branch 1 taken 158710112 times.
✓ Branch 2 taken 13670656 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
199865456 | (static_cast<IntAcc>(rhs_value) - static_cast<IntAcc>(rhs_zero_point))) * |
542 |
2/8✓ Branch 0 taken 13670656 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13670656 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
358575568 | static_cast<DstData>(lhs_scale) * static_cast<DstData>(rhs_scale); |
543 | 186123112 | } | |
544 | |||
545 |
3/8✗ Branch 0 not taken.
✓ Branch 1 taken 3154800 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 218428 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 405720 times.
✗ Branch 7 not taken.
|
3778948 | if (biases_ptr != nullptr) { |
546 | 3560520 | acc += static_cast<DstData>(biases_ptr[x]); | |
547 | 3560520 | } | |
548 | |||
549 |
3/8✓ Branch 0 taken 3154800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 218428 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 405720 times.
✗ Branch 7 not taken.
|
3778948 | acc = std::clamp(acc, min_value, max_value); |
550 | 3778948 | dst_ptr[y * n + x] = acc; | |
551 | 3778948 | } | |
552 | 32144 | } | |
553 | |||
554 | 1620 | return dst; | |
555 | 1620 | } | |
556 | |||
557 | template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( | ||
558 | size_t m, size_t n, size_t k, // | ||
559 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
560 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
561 | const void* biases, // | ||
562 | float min_value, float max_value); | ||
563 | |||
564 | template Buffer matmul_clamp_nt_t<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>( | ||
565 | size_t m, size_t n, size_t k, // | ||
566 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
567 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
568 | const void* biases, // | ||
569 | float min_value, float max_value); | ||
570 | |||
571 | template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, int32_t, float>( | ||
572 | size_t m, size_t n, size_t k, // | ||
573 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
574 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
575 | const void* biases, // | ||
576 | float min_value, float max_value); | ||
577 | |||
578 | template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>( | ||
579 | size_t m, size_t n, size_t k, // | ||
580 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
581 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
582 | const void* biases, // | ||
583 | float min_value, float max_value); | ||
584 | |||
585 | template < | ||
586 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
587 | typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> | ||
588 | 1728 | Buffer matmul_clamp_nt_nt( | |
589 | size_t m, size_t n, size_t k, // | ||
590 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
591 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
592 | const void* biases, // | ||
593 | DstData min_value, DstData max_value) { | ||
594 | − | KAI_ASSUME(lhs_quant_width != 0); | |
595 | − | KAI_ASSUME(rhs_quant_width != 0); | |
596 | 1728 | const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); | |
597 | 1728 | const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); | |
598 | |||
599 | 1728 | Buffer dst(m * n * sizeof(DstData)); | |
600 | |||
601 | 1728 | const auto* lhs_scales_ptr = reinterpret_cast<const LhsScale*>(lhs_scales); | |
602 | 1728 | const auto* rhs_scales_ptr = reinterpret_cast<const RhsScale*>(rhs_scales); | |
603 | 1728 | const auto* lhs_zero_points_ptr = reinterpret_cast<const LhsZeroPoint*>(lhs_zero_points); | |
604 | 1728 | const auto* rhs_zero_points_ptr = reinterpret_cast<const RhsZeroPoint*>(rhs_zero_points); | |
605 | 1728 | const auto* biases_ptr = reinterpret_cast<const Bias*>(biases); | |
606 |
2/8✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1112 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
1728 | auto* dst_ptr = reinterpret_cast<DstData*>(dst.data()); |
607 | |||
608 |
4/8✓ Branch 0 taken 616 times.
✓ Branch 1 taken 6104 times.
✓ Branch 2 taken 1112 times.
✓ Branch 3 taken 27768 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
35600 | for (size_t y = 0; y < m; ++y) { |
609 |
4/8✓ Branch 0 taken 6104 times.
✓ Branch 1 taken 405720 times.
✓ Branch 2 taken 27768 times.
✓ Branch 3 taken 3315612 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
3755204 | for (size_t x = 0; x < n; ++x) { |
610 | 3721332 | DstData acc = 0; | |
611 | |||
612 |
4/8✓ Branch 0 taken 13742344 times.
✓ Branch 1 taken 405720 times.
✓ Branch 2 taken 169751908 times.
✓ Branch 3 taken 3315612 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
187215584 | for (size_t i = 0; i < k; ++i) { |
613 |
2/8✓ Branch 0 taken 13742344 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 169751908 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
183494252 | const auto lhs_value = read_array<LhsData>(lhs_data, y * k + i); |
614 | 183494252 | const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]; | |
615 |
2/8✓ Branch 0 taken 13742344 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 169751908 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
183494252 | const auto lhs_zero_point = lhs_zero_points_ptr != nullptr |
616 | 183494252 | ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width] | |
617 | : 0; | ||
618 | |||
619 |
2/8✓ Branch 0 taken 13742344 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 169751908 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
183494252 | const auto rhs_value = read_array<RhsData>(rhs_data, x + i * n); |
620 | 183494252 | const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width]; | |
621 |
2/8✗ Branch 0 not taken.
✓ Branch 1 taken 13742344 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 169751908 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
183494252 | const auto rhs_zero_point = rhs_zero_points_ptr != nullptr |
622 | ✗ | ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width] | |
623 | : 0; | ||
624 | |||
625 | 183494252 | acc += static_cast<DstData>( | |
626 | 197236596 | (static_cast<IntAcc>(lhs_value) - static_cast<IntAcc>(lhs_zero_point)) * | |
627 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 169751908 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
197236596 | (static_cast<IntAcc>(rhs_value) - static_cast<IntAcc>(rhs_zero_point))) * |
628 |
0/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
366988504 | static_cast<DstData>(lhs_scale) * static_cast<DstData>(rhs_scale); |
629 | 183494252 | } | |
630 | |||
631 |
3/8✓ Branch 0 taken 405720 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 367206 times.
✓ Branch 3 taken 2948406 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
3721332 | if (biases_ptr != nullptr) { |
632 | 3354126 | acc += static_cast<DstData>(biases_ptr[x]); | |
633 | 3354126 | } | |
634 | |||
635 |
2/8✓ Branch 0 taken 405720 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3315612 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
3721332 | acc = std::clamp(acc, min_value, max_value); |
636 | 3721332 | dst_ptr[y * n + x] = acc; | |
637 | 3721332 | } | |
638 | 33872 | } | |
639 | |||
640 | 1728 | return dst; | |
641 | 1728 | } | |
642 | |||
643 | template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>( | ||
644 | size_t m, size_t n, size_t k, // | ||
645 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
646 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
647 | const void* biases, // | ||
648 | float min_value, float max_value); | ||
649 | template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( | ||
650 | size_t m, size_t n, size_t k, // | ||
651 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
652 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
653 | const void* biases, // | ||
654 | float min_value, float max_value); | ||
655 | |||
656 | template Buffer matmul_clamp_nt_nt<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>( | ||
657 | size_t m, size_t n, size_t k, // | ||
658 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
659 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
660 | const void* biases, // | ||
661 | float min_value, float max_value); | ||
662 | |||
663 | template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, int32_t, float>( | ||
664 | size_t m, size_t n, size_t k, // | ||
665 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
666 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
667 | const void* biases, // | ||
668 | float min_value, float max_value); | ||
669 | |||
670 | } // namespace kai::test | ||
671 |