test/reference/matmul.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/reference/matmul.hpp" | ||
| 8 | |||
| 9 | #include <algorithm> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <type_traits> | ||
| 13 | |||
| 14 | #include "kai/kai_common.h" | ||
| 15 | #include "test/common/buffer.hpp" | ||
| 16 | #include "test/common/data_format.hpp" | ||
| 17 | #include "test/common/data_type.hpp" | ||
| 18 | #include "test/common/float16.hpp" | ||
| 19 | #include "test/common/int4.hpp" | ||
| 20 | #include "test/common/memory.hpp" | ||
| 21 | #include "test/common/round.hpp" | ||
| 22 | #include "test/reference/binary_elementwise.hpp" | ||
| 23 | #include "test/reference/cast.hpp" | ||
| 24 | #include "test/reference/pack.hpp" | ||
| 25 | #include "test/reference/reduce.hpp" | ||
| 26 | #include "test/reference/transpose.hpp" | ||
| 27 | |||
| 28 | namespace kai::test { | ||
| 29 | |||
| 30 | namespace { | ||
| 31 | |||
| 32 | /// Matrix multiplication. | ||
| 33 | /// | ||
| 34 | /// @tparam T Data type. | ||
| 35 | /// | ||
| 36 | /// @param[in] lhs LHS operand data buffer. | ||
| 37 | /// @param[in] rhs RHS operand data buffer. | ||
| 38 | /// @param[in] m Output height. | ||
| 39 | /// @param[in] n Output width. | ||
| 40 | /// @param[in] k Non-transposed LHS width and non-transposed RHS height. | ||
| 41 | /// @param[in] lhs_transposed `true` if LHS operand is transposed. | ||
| 42 | /// @param[in] rhs_transposed `true` if RHS operand is transposed. | ||
| 43 | /// | ||
| 44 | /// @return The result data buffer. | ||
| 45 | template <typename In, typename Acc> | ||
| 46 | 1655 | Buffer matmul_any_type( | |
| 47 | const void* lhs, const void* rhs, // | ||
| 48 | size_t m, size_t n, size_t k, // | ||
| 49 | bool lhs_transposed, bool rhs_transposed) { | ||
| 50 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 844 times.
|
1655 | const auto lhs_m_stride = lhs_transposed ? 1 : k; |
| 51 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 844 times.
|
1655 | const auto lhs_k_stride = lhs_transposed ? m : 1; |
| 52 | |||
| 53 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 844 times.
|
1655 | const auto rhs_n_stride = rhs_transposed ? k : 1; |
| 54 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 844 times.
|
1655 | const auto rhs_k_stride = rhs_transposed ? 1 : n; |
| 55 | |||
| 56 | 1655 | Buffer dst(m * n * size_in_bits<In> / 8); | |
| 57 | − | KAI_ASSUME_ALWAYS(n * size_in_bits<In> % 8 == 0); | |
| 58 | |||
| 59 |
4/4✓ Branch 0 taken 20034 times.
✓ Branch 1 taken 811 times.
✓ Branch 2 taken 22833 times.
✓ Branch 3 taken 844 times.
|
44522 | for (size_t im = 0; im < m; ++im) { |
| 60 |
4/4✓ Branch 0 taken 20034 times.
✓ Branch 1 taken 1524926 times.
✓ Branch 2 taken 22833 times.
✓ Branch 3 taken 1626893 times.
|
3194686 | for (size_t in = 0; in < n; ++in) { |
| 61 | 3151819 | Acc acc = Acc(0); | |
| 62 | |||
| 63 |
4/4✓ Branch 0 taken 221798379 times.
✓ Branch 1 taken 1524926 times.
✓ Branch 2 taken 229649811 times.
✓ Branch 3 taken 1626893 times.
|
454600009 | for (size_t ik = 0; ik < k; ++ik) { |
| 64 |
2/4✓ Branch 0 taken 221798379 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 229649811 times.
|
451448190 | const auto lhs_value = read_array<In>(lhs, im * lhs_m_stride + ik * lhs_k_stride); |
| 65 |
2/4✓ Branch 0 taken 221798379 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 229649811 times.
✗ Branch 3 not taken.
|
451448190 | const auto rhs_value = read_array<In>(rhs, in * rhs_n_stride + ik * rhs_k_stride); |
| 66 |
2/4✓ Branch 0 taken 229649811 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 229649811 times.
✗ Branch 3 not taken.
|
451448190 | acc += static_cast<Acc>(lhs_value) * static_cast<Acc>(rhs_value); |
| 67 | 451448190 | } | |
| 68 | |||
| 69 |
3/6✓ Branch 0 taken 1524926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1626893 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1626893 times.
✗ Branch 5 not taken.
|
3151819 | write_array<In>(dst.data(), im * n + in, static_cast<In>(acc)); |
| 70 | 3151819 | } | |
| 71 | 42867 | } | |
| 72 | |||
| 73 | 1655 | return dst; | |
| 74 | 1655 | } | |
| 75 | |||
| 76 | } // namespace | ||
| 77 | |||
| 78 | 663 | Buffer matmul_pack_rhs( | |
| 79 | const void* data, const void* scales, const void* zero_points, const DataFormat& src_format, | ||
| 80 | const DataFormat& dst_format, size_t n, size_t k, bool transposing) { | ||
| 81 | 663 | const auto src_dt = src_format.data_type(); | |
| 82 | 663 | const auto src_pf = src_format.pack_format(); | |
| 83 | |||
| 84 | 663 | const auto dst_dt = dst_format.data_type(); | |
| 85 | 663 | const auto dst_pf = dst_format.pack_format(); | |
| 86 | |||
| 87 | 663 | Buffer tmp_data; | |
| 88 | 663 | Buffer tmp_scales; | |
| 89 | 663 | Buffer tmp_zero_points; | |
| 90 | |||
| 91 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 663 times.
|
663 | if (transposing) { |
| 92 |
1/2✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
|
663 | tmp_data = transpose(data, src_dt, k, n); |
| 93 |
1/2✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
|
663 | data = tmp_data.data(); |
| 94 | 663 | } | |
| 95 | |||
| 96 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 663 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
663 | if (src_dt == DataType::QSU4 && src_pf == DataFormat::PackFormat::NONE && // |
| 97 | ✗ | dst_dt == DataType::QSI4 && dst_pf == DataFormat::PackFormat::QUANTIZE_PER_ROW) { | |
| 98 | // For this specific RHS format conversion: | ||
| 99 | // | ||
| 100 | // * 4-bit data is added by 8. | ||
| 101 | // * Scale is divided by 16. | ||
| 102 | // * Zero point is accumulation of all values in the same row. | ||
| 103 | |||
| 104 | − | KAI_ASSUME_ALWAYS(zero_points == nullptr); | |
| 105 | ✗ | const int32_t zero_point = 8; | |
| 106 | ✗ | const uint8_t zero_point_i4 = UInt4::pack_u8(UInt4(zero_point), UInt4(zero_point)); | |
| 107 | ✗ | const int32_t row_zero_point = zero_point * static_cast<int32_t>(k); | |
| 108 | |||
| 109 | − | KAI_ASSUME_ALWAYS(dst_format.subblock_width() > 0); | |
| 110 | ✗ | const auto subblock_width_i32 = static_cast<int32_t>(dst_format.subblock_width()); | |
| 111 | ✗ | const auto subblock_width_f = static_cast<float>(dst_format.subblock_width()); | |
| 112 | |||
| 113 | ✗ | tmp_zero_points = reduce_add(data, src_format, n, k, DataFormat(DataType::I32), 0); | |
| 114 | ✗ | tmp_zero_points = sub(tmp_zero_points.data(), DataType::I32, n, 1, &row_zero_point, DataType::I32, 1, 1); | |
| 115 | ✗ | tmp_zero_points = mul(tmp_zero_points.data(), DataType::I32, n, 1, &subblock_width_i32, DataType::I32, 1, 1); | |
| 116 | ✗ | zero_points = tmp_zero_points.data(); | |
| 117 | |||
| 118 | ✗ | tmp_data = add(data, DataType::QSU4, n, k, &zero_point_i4, DataType::QSU4, 1, 1); | |
| 119 | ✗ | data = tmp_data.data(); | |
| 120 | |||
| 121 | ✗ | tmp_scales = div(scales, DataType::FP32, n, 1, &subblock_width_f, DataType::FP32, 1, 1); | |
| 122 | ✗ | scales = tmp_scales.data(); | |
| 123 | ✗ | } | |
| 124 | |||
| 125 |
1/2✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
|
663 | return pack(dst_format, data, scales, zero_points, src_format, n, k); |
| 126 | 663 | } | |
| 127 | |||
| 128 | 1655 | Buffer matmul( | |
| 129 | const void* lhs, [[maybe_unused]] const void* lhs_scales, [[maybe_unused]] const void* lhs_zero_points, | ||
| 130 | DataType lhs_dt, // | ||
| 131 | const void* rhs, [[maybe_unused]] const void* rhs_scales, [[maybe_unused]] const void* rhs_zero_points, | ||
| 132 | DataType rhs_dt, // | ||
| 133 | const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // | ||
| 134 | DataType dst_dt, // | ||
| 135 | size_t m, size_t n, size_t k, // | ||
| 136 | bool lhs_transposed, bool rhs_transposed) { | ||
| 137 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1655 times.
|
1655 | const auto lhs_h = lhs_transposed ? k : m; |
| 138 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1655 times.
|
1655 | const auto lhs_w = lhs_transposed ? m : k; |
| 139 | |||
| 140 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1655 times.
|
1655 | const auto rhs_h = rhs_transposed ? n : k; |
| 141 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1655 times.
|
1655 | const auto rhs_w = rhs_transposed ? k : n; |
| 142 | |||
| 143 | 1655 | Buffer tmp_lhs; | |
| 144 | 1655 | Buffer tmp_rhs; | |
| 145 | 1655 | Buffer tmp_dst; | |
| 146 | 1655 | Buffer tmp_bias; | |
| 147 | |||
| 148 |
1/2✓ Branch 0 taken 1655 times.
✗ Branch 1 not taken.
|
1655 | if (lhs_dt != dst_dt) { |
| 149 | ✗ | tmp_lhs = cast(lhs, lhs_dt, dst_dt, lhs_h, lhs_w); | |
| 150 | ✗ | lhs = tmp_lhs.data(); | |
| 151 | ✗ | } | |
| 152 | |||
| 153 |
1/2✓ Branch 0 taken 1655 times.
✗ Branch 1 not taken.
|
1655 | if (rhs_dt != dst_dt) { |
| 154 | ✗ | tmp_rhs = cast(rhs, rhs_dt, dst_dt, rhs_h, rhs_w); | |
| 155 | ✗ | rhs = tmp_rhs.data(); | |
| 156 | ✗ | } | |
| 157 | |||
| 158 |
2/3✗ Branch 0 not taken.
✓ Branch 1 taken 811 times.
✓ Branch 2 taken 844 times.
|
1655 | switch (dst_dt) { |
| 159 | case DataType::FP32: | ||
| 160 |
1/2✓ Branch 0 taken 811 times.
✗ Branch 1 not taken.
|
811 | tmp_dst = matmul_any_type<float, float>(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed); |
| 161 | 811 | break; | |
| 162 | |||
| 163 | case DataType::FP16: | ||
| 164 |
1/2✓ Branch 0 taken 844 times.
✗ Branch 1 not taken.
|
844 | tmp_dst = matmul_any_type<Float16, float>(lhs, rhs, m, n, k, lhs_transposed, rhs_transposed); |
| 165 | 844 | break; | |
| 166 | |||
| 167 | default: | ||
| 168 | − | KAI_ERROR("Unknown data type!"); | |
| 169 | ✗ | } | |
| 170 | |||
| 171 |
2/2✓ Branch 0 taken 54 times.
✓ Branch 1 taken 1601 times.
|
1655 | if (bias != nullptr) { |
| 172 | − | KAI_ASSUME_ALWAYS(!data_type_is_quantized(bias_dt)); | |
| 173 | − | KAI_ASSUME_ALWAYS(bias_scales == nullptr); | |
| 174 | − | KAI_ASSUME_ALWAYS(bias_zero_points == nullptr); | |
| 175 | |||
| 176 | // Add bias in f32 to reduce precision loss. | ||
| 177 |
2/2✓ Branch 0 taken 811 times.
✓ Branch 1 taken 790 times.
|
1601 | if (dst_dt != DataType::FP32) { |
| 178 |
2/4✓ Branch 0 taken 790 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 790 times.
✗ Branch 3 not taken.
|
790 | tmp_dst = cast(tmp_dst.data(), dst_dt, DataType::FP32, m, n); |
| 179 | 790 | } | |
| 180 |
2/2✓ Branch 0 taken 811 times.
✓ Branch 1 taken 790 times.
|
1601 | if (bias_dt != DataType::FP32) { |
| 181 |
1/2✓ Branch 0 taken 790 times.
✗ Branch 1 not taken.
|
790 | tmp_bias = cast(bias, bias_dt, DataType::FP32, 1, n); |
| 182 |
1/2✓ Branch 0 taken 790 times.
✗ Branch 1 not taken.
|
790 | bias = tmp_bias.data(); |
| 183 | 790 | } | |
| 184 |
2/4✓ Branch 0 taken 1601 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1601 times.
✗ Branch 3 not taken.
|
1601 | tmp_dst = add(tmp_dst.data(), DataType::FP32, m, n, bias, DataType::FP32, 1, n); |
| 185 |
2/2✓ Branch 0 taken 811 times.
✓ Branch 1 taken 790 times.
|
1601 | if (dst_dt != DataType::FP32) { |
| 186 |
2/4✓ Branch 0 taken 790 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 790 times.
✗ Branch 3 not taken.
|
790 | tmp_dst = cast(tmp_dst.data(), DataType::FP32, dst_dt, m, n); |
| 187 | 790 | } | |
| 188 | 1601 | } | |
| 189 | |||
| 190 | 1655 | return tmp_dst; | |
| 191 | 1655 | } | |
| 192 | |||
| 193 | 884 | Buffer indirect_matmul( | |
| 194 | const void* const* lhs_idata, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, | ||
| 195 | const void* lhs_zero_points, | ||
| 196 | DataType lhs_dt, // | ||
| 197 | const void* rhs, const void* rhs_scales, const void* rhs_zero_points, | ||
| 198 | DataType rhs_dt, // | ||
| 199 | const void* bias, const void* bias_scales, const void* bias_zero_points, DataType bias_dt, // | ||
| 200 | DataType dst_dt, // | ||
| 201 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length) { | ||
| 202 | // This is inefficient, but allows code-reuse | ||
| 203 | 884 | const size_t chunk_bytes = k_chunk_length * round_up_division(data_type_size_in_bits(lhs_dt), 8); | |
| 204 | 884 | const size_t n_chunks = m * k_chunk_count; | |
| 205 | 884 | Buffer lhs(n_chunks * chunk_bytes); | |
| 206 | |||
| 207 | 884 | const uintptr_t lhs_padding_ptr_uint = reinterpret_cast<uintptr_t>(lhs_padding_ptr); | |
| 208 | |||
| 209 | // Copy all chunks to the created matrix | ||
| 210 |
2/2✓ Branch 0 taken 333242 times.
✓ Branch 1 taken 884 times.
|
334126 | for (size_t i = 0; i < n_chunks; i += 1) { |
| 211 | 333242 | uintptr_t src_pointer = reinterpret_cast<uintptr_t>(lhs_idata[i]); | |
| 212 |
2/2✓ Branch 0 taken 9516 times.
✓ Branch 1 taken 323726 times.
|
333242 | if (src_pointer != lhs_padding_ptr_uint) { |
| 213 | 323726 | src_pointer += lhs_offset; | |
| 214 | 323726 | } | |
| 215 | 333242 | memcpy( | |
| 216 |
1/2✓ Branch 0 taken 333242 times.
✗ Branch 1 not taken.
|
333242 | lhs.data() + i * chunk_bytes, reinterpret_cast<const void*>(src_pointer), |
| 217 | 333242 | chunk_bytes); // NOLINT(performance-no-int-to-ptr) | |
| 218 | 333242 | } | |
| 219 | |||
| 220 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
884 | return matmul( |
| 221 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
884 | lhs.data(), lhs_scales, lhs_zero_points, lhs_dt, // |
| 222 | 884 | rhs, rhs_scales, rhs_zero_points, rhs_dt, // | |
| 223 | 884 | bias, bias_scales, bias_zero_points, bias_dt, // | |
| 224 | 884 | dst_dt, m, n, k_chunk_count * k_chunk_length, false, false); | |
| 225 | 884 | } | |
| 226 | |||
| 227 | template < | ||
| 228 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
| 229 | typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> | ||
| 230 | 521 | Buffer indirect_matmul_nt_t_quantized( | |
| 231 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // | ||
| 232 | const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding_ptr, const void* lhs_scales, | ||
| 233 | const void* lhs_zero_points, size_t lhs_quant_height, | ||
| 234 | size_t lhs_quant_width, // | ||
| 235 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
| 236 | size_t rhs_quant_width, // | ||
| 237 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width) { | ||
| 238 | − | KAI_ASSUME_ALWAYS(lhs_quant_width != 0); | |
| 239 | − | KAI_ASSUME_ALWAYS(rhs_quant_width != 0); | |
| 240 | − | KAI_ASSUME_ALWAYS(lhs_quant_height != 0); | |
| 241 | − | KAI_ASSUME_ALWAYS(rhs_quant_height != 0); | |
| 242 | − | KAI_ASSUME_ALWAYS(bias_quant_width != 0); | |
| 243 | 521 | const auto lhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, lhs_quant_width); | |
| 244 | 521 | const auto rhs_num_quant_per_row = round_up_division(k_chunk_count * k_chunk_length, rhs_quant_width); | |
| 245 | |||
| 246 | 521 | Buffer dst(m * n * sizeof(DstData)); | |
| 247 | |||
| 248 |
2/2✓ Branch 0 taken 11441 times.
✓ Branch 1 taken 521 times.
|
11962 | for (size_t i_m = 0; i_m < m; ++i_m) { |
| 249 |
2/2✓ Branch 0 taken 11441 times.
✓ Branch 1 taken 929839 times.
|
941280 | for (size_t i_n = 0; i_n < n; ++i_n) { |
| 250 | 929839 | DstData acc = 0; | |
| 251 | |||
| 252 |
2/2✓ Branch 0 taken 9447359 times.
✓ Branch 1 taken 929839 times.
|
10377198 | for (size_t i_k_chunk = 0; i_k_chunk < k_chunk_count; ++i_k_chunk) { |
| 253 | // Calculate the K chunk pointer. Apply offset if this is not padding | ||
| 254 | 9447359 | const size_t k_chunk_idx = i_m * k_chunk_count + i_k_chunk; | |
| 255 | 9447359 | const void* k_chunk_ptr = lhs_ptrs[k_chunk_idx]; | |
| 256 |
2/2✓ Branch 0 taken 562940 times.
✓ Branch 1 taken 8884419 times.
|
9447359 | if (k_chunk_ptr != lhs_padding_ptr) { |
| 257 | 8884419 | k_chunk_ptr = reinterpret_cast<const void*>(reinterpret_cast<uintptr_t>(k_chunk_ptr) + lhs_offset); | |
| 258 | 8884419 | } | |
| 259 | |||
| 260 |
2/2✓ Branch 0 taken 61909478 times.
✓ Branch 1 taken 9447359 times.
|
71356837 | for (size_t i_k_chunk_len = 0; i_k_chunk_len < k_chunk_length; ++i_k_chunk_len) { |
| 261 | 61909478 | const size_t i = i_k_chunk * k_chunk_length + i_k_chunk_len; | |
| 262 | |||
| 263 | 61909478 | const auto lhs_data_index = i_k_chunk_len; | |
| 264 | 61909478 | const auto lhs_quant_index = (i_m / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; | |
| 265 |
1/2✓ Branch 0 taken 61909478 times.
✗ Branch 1 not taken.
|
61909478 | const auto lhs_value = read_array<LhsData>(k_chunk_ptr, lhs_data_index); |
| 266 |
2/4✓ Branch 0 taken 61909478 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 61909478 times.
|
61909478 | const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index) |
| 267 | : static_cast<LhsScale>(1); | ||
| 268 |
1/2✓ Branch 0 taken 61909478 times.
✗ Branch 1 not taken.
|
123818956 | const auto lhs_zero_point = lhs_zero_points != nullptr |
| 269 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 61909478 times.
|
61909478 | ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index) |
| 270 | : static_cast<LhsZeroPoint>(0); | ||
| 271 | |||
| 272 | 61909478 | const auto rhs_data_index = i_n * (k_chunk_count * k_chunk_length) + i; | |
| 273 | 61909478 | const auto rhs_quant_index = (i_n / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; | |
| 274 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 61909478 times.
|
61909478 | const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index); |
| 275 |
2/4✓ Branch 0 taken 61909478 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 61909478 times.
|
61909478 | const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index) |
| 276 | : static_cast<RhsScale>(1); | ||
| 277 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 61909478 times.
|
61909478 | const auto rhs_zero_point = rhs_zero_points != nullptr |
| 278 | ✗ | ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index) | |
| 279 | : static_cast<RhsZeroPoint>(0); | ||
| 280 | |||
| 281 | 185728434 | acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) * | |
| 282 | 123818956 | static_cast<DstData>(lhs_scale) * | |
| 283 | 61909478 | (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) * | |
| 284 | 61909478 | static_cast<DstData>(rhs_scale); | |
| 285 | 61909478 | } | |
| 286 | 9447359 | } | |
| 287 | |||
| 288 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 929839 times.
|
929839 | if (bias_data != nullptr) { |
| 289 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 929839 times.
|
929839 | const auto bias_value = read_array<BiasData>(bias_data, i_n); |
| 290 |
1/2✓ Branch 0 taken 929839 times.
✗ Branch 1 not taken.
|
1859678 | const auto bias_scale = bias_scales != nullptr |
| 291 |
1/2✓ Branch 0 taken 929839 times.
✗ Branch 1 not taken.
|
929839 | ? read_array<BiasScale>(bias_scales, i_n / bias_quant_width) |
| 292 | : static_cast<BiasScale>(1); | ||
| 293 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 929839 times.
|
929839 | const auto bias_zero_point = bias_zero_points != nullptr |
| 294 | ✗ | ? read_array<BiasZeroPoint>(bias_zero_points, i_n / bias_quant_width) | |
| 295 | : static_cast<BiasZeroPoint>(0); | ||
| 296 | |||
| 297 | 1859678 | acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) * | |
| 298 | 929839 | static_cast<DstData>(bias_scale); | |
| 299 | 929839 | } | |
| 300 | |||
| 301 |
2/4✓ Branch 0 taken 929839 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 929839 times.
✗ Branch 3 not taken.
|
929839 | write_array<DstData>(dst.data(), i_m * n + i_n, acc); |
| 302 | 929839 | } | |
| 303 | 11441 | } | |
| 304 | |||
| 305 | 521 | return dst; | |
| 306 | 521 | } | |
| 307 | |||
| 308 | template < | ||
| 309 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
| 310 | typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> | ||
| 311 | 26832 | Buffer matmul_nt_t_quantized( | |
| 312 | size_t m, size_t n, size_t k, // | ||
| 313 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, // | ||
| 314 | size_t lhs_quant_height, size_t lhs_quant_width, // | ||
| 315 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, // | ||
| 316 | size_t rhs_quant_height, size_t rhs_quant_width, // | ||
| 317 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, // | ||
| 318 | size_t bias_quant_width) { | ||
| 319 | − | KAI_ASSUME_ALWAYS(lhs_quant_width != 0); | |
| 320 | − | KAI_ASSUME_ALWAYS(rhs_quant_width != 0); | |
| 321 | − | KAI_ASSUME_ALWAYS(lhs_quant_height != 0); | |
| 322 | − | KAI_ASSUME_ALWAYS(rhs_quant_height != 0); | |
| 323 | − | KAI_ASSUME_ALWAYS(bias_quant_width != 0); | |
| 324 | |||
| 325 | 26832 | const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); | |
| 326 | 26832 | const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); | |
| 327 | |||
| 328 | 26832 | Buffer dst(m * n * sizeof(DstData)); | |
| 329 | |||
| 330 |
6/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 26064 times.
✓ Branch 3 taken 265116 times.
✓ Branch 4 taken 192 times.
✓ Branch 5 taken 2748 times.
✓ Branch 6 taken 576 times.
✓ Branch 7 taken 51051 times.
|
345747 | for (size_t row = 0; row < m; ++row) { |
| 331 |
6/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 19402236 times.
✓ Branch 3 taken 265116 times.
✓ Branch 4 taken 2748 times.
✓ Branch 5 taken 219744 times.
✓ Branch 6 taken 51051 times.
✓ Branch 7 taken 6371326 times.
|
26312221 | for (size_t col = 0; col < n; ++col) { |
| 332 | 25993306 | DstData acc = 0; | |
| 333 | |||
| 334 |
6/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 1174598280 times.
✓ Branch 3 taken 19402236 times.
✓ Branch 4 taken 10572468 times.
✓ Branch 5 taken 219744 times.
✓ Branch 6 taken 1984792064 times.
✓ Branch 7 taken 6371326 times.
|
3195956118 | for (size_t i = 0; i < k; ++i) { |
| 335 | 3169962812 | const auto lhs_data_index = row * k + i; | |
| 336 | 3169962812 | const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; | |
| 337 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 1174598280 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10572468 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1984792064 times.
✗ Branch 7 not taken.
|
3169962812 | const auto lhs_value = read_array<LhsData>(lhs_data, lhs_data_index); |
| 338 |
6/16✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1174598280 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1174598280 times.
✓ Branch 8 taken 10572468 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 10572468 times.
✓ Branch 12 taken 1984792064 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1984792064 times.
|
3169962812 | const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index) |
| 339 | : static_cast<LhsScale>(1); | ||
| 340 |
4/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 27671688 times.
✓ Branch 3 taken 1146926592 times.
✓ Branch 4 taken 10572468 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1984792064 times.
✗ Branch 7 not taken.
|
5192999032 | const auto lhs_zero_point = lhs_zero_points != nullptr |
| 341 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 27671688 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 10572468 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1984792064 times.
|
2023036220 | ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index) |
| 342 | : static_cast<LhsZeroPoint>(0); | ||
| 343 | |||
| 344 | 3169962812 | const auto rhs_data_index = col * k + i; | |
| 345 | 3169962812 | const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; | |
| 346 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1174598280 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 10572468 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1984792064 times.
|
3169962812 | const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index); |
| 347 |
6/16✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1174598280 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1174598280 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 10572468 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 10572468 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1984792064 times.
✓ Branch 14 taken 1984792064 times.
✗ Branch 15 not taken.
|
3169962812 | const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index) |
| 348 | ✗ | : static_cast<RhsScale>(1); | |
| 349 |
4/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 1146926592 times.
✓ Branch 3 taken 27671688 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 10572468 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1984792064 times.
|
4316889404 | const auto rhs_zero_point = rhs_zero_points != nullptr |
| 350 |
1/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 1146926592 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
1146926592 | ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index) |
| 351 | : static_cast<RhsZeroPoint>(0); | ||
| 352 | |||
| 353 | 3509139990 | acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) * | |
| 354 | 3180535280 | static_cast<DstData>(lhs_scale) * | |
| 355 |
2/4✓ Branch 0 taken 1174598280 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1984792064 times.
✗ Branch 3 not taken.
|
3169962812 | (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) * |
| 356 |
1/2✓ Branch 0 taken 1984792064 times.
✗ Branch 1 not taken.
|
3169962812 | static_cast<DstData>(rhs_scale); |
| 357 | 3169962812 | } | |
| 358 | |||
| 359 |
5/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 9701118 times.
✓ Branch 3 taken 9701118 times.
✓ Branch 4 taken 109872 times.
✓ Branch 5 taken 109872 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6371326 times.
|
25993306 | if (bias_data != nullptr) { |
| 360 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 9701118 times.
✓ Branch 4 taken 109872 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6371326 times.
|
16182316 | const auto bias_value = read_array<BiasData>(bias_data, col); |
| 361 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 9701118 times.
✓ Branch 4 taken 109872 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6371326 times.
|
16182316 | const auto bias_scale = bias_scales != nullptr |
| 362 | ✗ | ? read_array<BiasScale>(bias_scales, col / bias_quant_width) | |
| 363 | : static_cast<BiasScale>(1); | ||
| 364 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 9701118 times.
✓ Branch 4 taken 109872 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6371326 times.
|
16182316 | const auto bias_zero_point = bias_zero_points != nullptr |
| 365 | ✗ | ? read_array<BiasZeroPoint>(bias_zero_points, col / bias_quant_width) | |
| 366 | : static_cast<BiasZeroPoint>(0); | ||
| 367 | |||
| 368 | 32364632 | acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) * | |
| 369 | 16182316 | static_cast<DstData>(bias_scale); | |
| 370 | 16182316 | } | |
| 371 | |||
| 372 |
6/16✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 19402236 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 19402236 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 219744 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 219744 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 6371326 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 6371326 times.
✗ Branch 15 not taken.
|
25993306 | write_array<DstData>(dst.data(), row * n + col, acc); |
| 373 | 25993306 | } | |
| 374 | 318915 | } | |
| 375 | |||
| 376 | 26832 | return dst; | |
| 377 | 26832 | } | |
| 378 | |||
| 379 | template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>( | ||
| 380 | size_t m, size_t n, size_t k, // | ||
| 381 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
| 382 | size_t lhs_quant_width, // | ||
| 383 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
| 384 | size_t rhs_quant_width, // | ||
| 385 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
| 386 | |||
| 387 | template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>( | ||
| 388 | size_t m, size_t n, size_t k, // | ||
| 389 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
| 390 | size_t lhs_quant_width, // | ||
| 391 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
| 392 | size_t rhs_quant_width, // | ||
| 393 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
| 394 | |||
| 395 | template Buffer matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, float, float, int32_t, float>( | ||
| 396 | size_t m, size_t n, size_t k, // | ||
| 397 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
| 398 | size_t lhs_quant_width, // | ||
| 399 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
| 400 | size_t rhs_quant_width, // | ||
| 401 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
| 402 | |||
| 403 | template Buffer | ||
| 404 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
| 405 | size_t m, size_t n, size_t k, // | ||
| 406 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
| 407 | size_t lhs_quant_width, // | ||
| 408 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
| 409 | size_t rhs_quant_width, // | ||
| 410 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
| 411 | |||
| 412 | template < | ||
| 413 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
| 414 | typename RhsZeroPoint, typename BiasData, typename BiasScale, typename BiasZeroPoint, typename DstData> | ||
| 415 | 972 | Buffer matmul_nt_nt_quantized( | |
| 416 | size_t m, size_t n, size_t k, // | ||
| 417 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, // | ||
| 418 | size_t lhs_quant_height, size_t lhs_quant_width, // | ||
| 419 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, // | ||
| 420 | size_t rhs_quant_height, size_t rhs_quant_width, // | ||
| 421 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, // | ||
| 422 | size_t bias_quant_width) { | ||
| 423 | − | KAI_ASSUME_ALWAYS(lhs_quant_width != 0); | |
| 424 | − | KAI_ASSUME_ALWAYS(rhs_quant_width != 0); | |
| 425 | − | KAI_ASSUME_ALWAYS(lhs_quant_height != 0); | |
| 426 | − | KAI_ASSUME_ALWAYS(rhs_quant_height != 0); | |
| 427 | − | KAI_ASSUME_ALWAYS(bias_quant_width != 0); | |
| 428 | |||
| 429 | 972 | const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); | |
| 430 | 972 | const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); | |
| 431 | |||
| 432 | 972 | Buffer dst(m * n * sizeof(DstData)); | |
| 433 | |||
| 434 |
4/4✓ Branch 0 taken 558 times.
✓ Branch 1 taken 51033 times.
✓ Branch 2 taken 414 times.
✓ Branch 3 taken 44721 times.
|
96726 | for (size_t row = 0; row < m; ++row) { |
| 435 |
4/4✓ Branch 0 taken 51033 times.
✓ Branch 1 taken 6362191 times.
✓ Branch 2 taken 2838231 times.
✓ Branch 3 taken 44721 times.
|
9296176 | for (size_t col = 0; col < n; ++col) { |
| 436 | 9200422 | DstData acc = 0; | |
| 437 | |||
| 438 |
4/4✓ Branch 0 taken 1981913664 times.
✓ Branch 1 taken 6362191 times.
✓ Branch 2 taken 774427959 times.
✓ Branch 3 taken 2838231 times.
|
2765542045 | for (size_t i = 0; i < k; ++i) { |
| 439 | 2756341623 | const auto lhs_data_index = row * k + i; | |
| 440 | 2756341623 | const auto lhs_quant_index = (row / lhs_quant_height) * lhs_num_quant_per_row + i / lhs_quant_width; | |
| 441 |
2/4✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 774427959 times.
|
2756341623 | const auto lhs_value = read_array<LhsData>(lhs_data, lhs_data_index); |
| 442 |
3/8✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1981913664 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 774427959 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
2756341623 | const auto lhs_scale = lhs_scales != nullptr ? read_array<LhsScale>(lhs_scales, lhs_quant_index) |
| 443 | : static_cast<LhsScale>(1); | ||
| 444 |
2/4✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 774427959 times.
|
4738255287 | const auto lhs_zero_point = lhs_zero_points != nullptr |
| 445 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 1981913664 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1981913664 | ? read_array<LhsZeroPoint>(lhs_zero_points, lhs_quant_index) |
| 446 | : static_cast<LhsZeroPoint>(0); | ||
| 447 | |||
| 448 | 2756341623 | const auto rhs_data_index = col + i * n; | |
| 449 | 2756341623 | const auto rhs_quant_index = (col / rhs_quant_height) * rhs_num_quant_per_row + i / rhs_quant_width; | |
| 450 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1981913664 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 774427959 times.
|
2756341623 | const auto rhs_value = read_array<RhsData>(rhs_data, rhs_data_index); |
| 451 |
3/8✗ Branch 0 not taken.
✓ Branch 1 taken 1981913664 times.
✓ Branch 2 taken 1981913664 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 774427959 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
2756341623 | const auto rhs_scale = rhs_scales != nullptr ? read_array<RhsScale>(rhs_scales, rhs_quant_index) |
| 452 | ✗ | : static_cast<RhsScale>(1); | |
| 453 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1981913664 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 774427959 times.
|
2756341623 | const auto rhs_zero_point = rhs_zero_points != nullptr |
| 454 | ✗ | ? read_array<RhsZeroPoint>(rhs_zero_points, rhs_quant_index) | |
| 455 | : static_cast<RhsZeroPoint>(0); | ||
| 456 | |||
| 457 |
1/2✓ Branch 0 taken 774427959 times.
✗ Branch 1 not taken.
|
6720168951 | acc += (static_cast<DstData>(lhs_value) - static_cast<DstData>(lhs_zero_point)) * |
| 458 | 2756341623 | static_cast<DstData>(lhs_scale) * | |
| 459 |
2/4✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 774427959 times.
✗ Branch 3 not taken.
|
2756341623 | (static_cast<DstData>(rhs_value) - static_cast<DstData>(rhs_zero_point)) * |
| 460 |
1/2✓ Branch 0 taken 1981913664 times.
✗ Branch 1 not taken.
|
2756341623 | static_cast<DstData>(rhs_scale); |
| 461 | 2756341623 | } | |
| 462 | |||
| 463 |
3/4✗ Branch 0 not taken.
✓ Branch 1 taken 6362191 times.
✓ Branch 2 taken 1263918 times.
✓ Branch 3 taken 1574313 times.
|
9200422 | if (bias_data != nullptr) { |
| 464 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6362191 times.
✓ Branch 2 taken 1574313 times.
✗ Branch 3 not taken.
|
7936504 | const auto bias_value = read_array<BiasData>(bias_data, col); |
| 465 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6362191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1574313 times.
|
7936504 | const auto bias_scale = bias_scales != nullptr |
| 466 | ✗ | ? read_array<BiasScale>(bias_scales, col / bias_quant_width) | |
| 467 | : static_cast<BiasScale>(1); | ||
| 468 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6362191 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1574313 times.
|
7936504 | const auto bias_zero_point = bias_zero_points != nullptr |
| 469 | ✗ | ? read_array<BiasZeroPoint>(bias_zero_points, col / bias_quant_width) | |
| 470 | : static_cast<BiasZeroPoint>(0); | ||
| 471 | |||
| 472 | 15873008 | acc += (static_cast<DstData>(bias_value) - static_cast<DstData>(bias_zero_point)) * | |
| 473 | 7936504 | static_cast<DstData>(bias_scale); | |
| 474 | 7936504 | } | |
| 475 | |||
| 476 |
4/8✓ Branch 0 taken 6362191 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6362191 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2838231 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2838231 times.
✗ Branch 7 not taken.
|
9200422 | write_array<DstData>(dst.data(), row * n + col, acc); |
| 477 | 9200422 | } | |
| 478 | 95754 | } | |
| 479 | |||
| 480 | 972 | return dst; | |
| 481 | 972 | } | |
| 482 | |||
| 483 | template Buffer | ||
| 484 | matmul_nt_nt_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
| 485 | size_t m, size_t n, size_t k, // | ||
| 486 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
| 487 | size_t lhs_quant_width, // | ||
| 488 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
| 489 | size_t rhs_quant_width, // | ||
| 490 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
| 491 | |||
| 492 | template Buffer matmul_nt_nt_quantized<BFloat16<>, float, float, BFloat16<>, float, float, float, float, float, float>( | ||
| 493 | size_t m, size_t n, size_t k, // | ||
| 494 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_height, | ||
| 495 | size_t lhs_quant_width, // | ||
| 496 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
| 497 | size_t rhs_quant_width, // | ||
| 498 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
| 499 | |||
| 500 | template Buffer | ||
| 501 | indirect_matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>( | ||
| 502 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, // | ||
| 503 | const void* const* lhs_ptrs, uintptr_t lhs_offset, const void* lhs_padding, const void* lhs_scales, | ||
| 504 | const void* lhs_zero_points, size_t lhs_quant_height, | ||
| 505 | size_t lhs_quant_width, // | ||
| 506 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_height, | ||
| 507 | size_t rhs_quant_width, // | ||
| 508 | const void* bias_data, const void* bias_scales, const void* bias_zero_points, size_t bias_quant_width); | ||
| 509 | |||
| 510 | template < | ||
| 511 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
| 512 | typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> | ||
| 513 | 5786 | Buffer matmul_clamp_nt_t( | |
| 514 | size_t m, size_t n, size_t k, // | ||
| 515 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 516 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 517 | const void* biases, // | ||
| 518 | DstData min_value, DstData max_value) { | ||
| 519 | − | KAI_ASSUME_ALWAYS(lhs_quant_width != 0); | |
| 520 | − | KAI_ASSUME_ALWAYS(rhs_quant_width != 0); | |
| 521 | 5786 | const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); | |
| 522 | 5786 | const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); | |
| 523 | |||
| 524 | 5786 | Buffer dst(m * n * sizeof(DstData)); | |
| 525 | |||
| 526 | 5786 | const auto* lhs_scales_ptr = reinterpret_cast<const LhsScale*>(lhs_scales); | |
| 527 | 5786 | const auto* rhs_scales_ptr = reinterpret_cast<const RhsScale*>(rhs_scales); | |
| 528 | 5786 | const auto* lhs_zero_points_ptr = reinterpret_cast<const LhsZeroPoint*>(lhs_zero_points); | |
| 529 | 5786 | const auto* rhs_zero_points_ptr = reinterpret_cast<const RhsZeroPoint*>(rhs_zero_points); | |
| 530 | 5786 | const auto* biases_ptr = reinterpret_cast<const Bias*>(biases); | |
| 531 |
3/8✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
|
5786 | auto* dst_ptr = reinterpret_cast<DstData*>(dst.data()); |
| 532 | |||
| 533 |
6/8✓ Branch 0 taken 4800 times.
✓ Branch 1 taken 125760 times.
✓ Branch 2 taken 16972 times.
✓ Branch 3 taken 920 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✓ Branch 7 taken 654 times.
|
149172 | for (size_t y = 0; y < m; ++y) { |
| 534 |
6/8✓ Branch 0 taken 125760 times.
✓ Branch 1 taken 17208000 times.
✓ Branch 2 taken 16972 times.
✓ Branch 3 taken 1304540 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 654 times.
✓ Branch 7 taken 43470 times.
|
18699396 | for (size_t x = 0; x < n; ++x) { |
| 535 | 18556010 | DstData acc = 0; | |
| 536 | |||
| 537 |
6/8✓ Branch 0 taken 865691520 times.
✓ Branch 1 taken 17208000 times.
✓ Branch 2 taken 82565120 times.
✓ Branch 3 taken 1304540 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1472394 times.
✓ Branch 7 taken 43470 times.
|
968285044 | for (size_t i = 0; i < k; ++i) { |
| 538 |
3/8✓ Branch 0 taken 865691520 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 82565120 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1472394 times.
✗ Branch 7 not taken.
|
949729034 | const auto lhs_value = read_array<LhsData>(lhs_data, y * k + i); |
| 539 | 949729034 | const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]; | |
| 540 |
3/8✓ Branch 0 taken 865691520 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 82565120 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1472394 times.
✗ Branch 7 not taken.
|
949729034 | const auto lhs_zero_point = lhs_zero_points_ptr != nullptr |
| 541 | 867163914 | ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width] | |
| 542 | : 0; | ||
| 543 | |||
| 544 |
3/8✗ Branch 0 not taken.
✓ Branch 1 taken 865691520 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 82565120 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1472394 times.
✗ Branch 7 not taken.
|
949729034 | const auto rhs_value = read_array<RhsData>(rhs_data, x * k + i); |
| 545 | 949729034 | const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width]; | |
| 546 |
3/8✗ Branch 0 not taken.
✓ Branch 1 taken 865691520 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 82565120 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1472394 times.
|
949729034 | const auto rhs_zero_point = rhs_zero_points_ptr != nullptr |
| 547 | ✗ | ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width] | |
| 548 | : 0; | ||
| 549 | |||
| 550 | 949729034 | acc += static_cast<DstData>( | |
| 551 | 951201428 | (static_cast<IntAcc>(lhs_value) - static_cast<IntAcc>(lhs_zero_point)) * | |
| 552 |
2/6✗ Branch 0 not taken.
✓ Branch 1 taken 865691520 times.
✓ Branch 2 taken 82565120 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
951201428 | (static_cast<IntAcc>(rhs_value) - static_cast<IntAcc>(rhs_zero_point))) * |
| 553 |
2/8✓ Branch 0 taken 82565120 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 82565120 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
1816892948 | static_cast<DstData>(lhs_scale) * static_cast<DstData>(rhs_scale); |
| 554 | 949729034 | } | |
| 555 | |||
| 556 |
3/8✗ Branch 0 not taken.
✓ Branch 1 taken 17208000 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1304540 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 43470 times.
✗ Branch 7 not taken.
|
18556010 | if (biases_ptr != nullptr) { |
| 557 | 17251470 | acc += static_cast<DstData>(biases_ptr[x]); | |
| 558 | 17251470 | } | |
| 559 | |||
| 560 |
3/8✓ Branch 0 taken 17208000 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1304540 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 43470 times.
✗ Branch 7 not taken.
|
18556010 | acc = std::clamp(acc, min_value, max_value); |
| 561 | 18556010 | dst_ptr[y * n + x] = acc; | |
| 562 | 18556010 | } | |
| 563 | 143386 | } | |
| 564 | |||
| 565 | 5786 | return dst; | |
| 566 | 5786 | } | |
| 567 | |||
| 568 | template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( | ||
| 569 | size_t m, size_t n, size_t k, // | ||
| 570 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 571 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 572 | const void* biases, // | ||
| 573 | float min_value, float max_value); | ||
| 574 | |||
| 575 | template Buffer matmul_clamp_nt_t<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>( | ||
| 576 | size_t m, size_t n, size_t k, // | ||
| 577 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 578 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 579 | const void* biases, // | ||
| 580 | float min_value, float max_value); | ||
| 581 | |||
| 582 | template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, int32_t, float>( | ||
| 583 | size_t m, size_t n, size_t k, // | ||
| 584 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 585 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 586 | const void* biases, // | ||
| 587 | float min_value, float max_value); | ||
| 588 | |||
| 589 | template Buffer matmul_clamp_nt_t<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>( | ||
| 590 | size_t m, size_t n, size_t k, // | ||
| 591 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 592 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 593 | const void* biases, // | ||
| 594 | float min_value, float max_value); | ||
| 595 | |||
| 596 | template < | ||
| 597 | typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, | ||
| 598 | typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> | ||
| 599 | 4554 | Buffer matmul_clamp_nt_nt( | |
| 600 | size_t m, size_t n, size_t k, // | ||
| 601 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 602 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 603 | const void* biases, // | ||
| 604 | DstData min_value, DstData max_value) { | ||
| 605 | − | KAI_ASSUME_ALWAYS(lhs_quant_width != 0); | |
| 606 | − | KAI_ASSUME_ALWAYS(rhs_quant_width != 0); | |
| 607 | 4554 | const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); | |
| 608 | 4554 | const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); | |
| 609 | |||
| 610 | 4554 | Buffer dst(m * n * sizeof(DstData)); | |
| 611 | |||
| 612 | 4554 | const auto* lhs_scales_ptr = reinterpret_cast<const LhsScale*>(lhs_scales); | |
| 613 | 4554 | const auto* rhs_scales_ptr = reinterpret_cast<const RhsScale*>(rhs_scales); | |
| 614 | 4554 | const auto* lhs_zero_points_ptr = reinterpret_cast<const LhsZeroPoint*>(lhs_zero_points); | |
| 615 | 4554 | const auto* rhs_zero_points_ptr = reinterpret_cast<const RhsZeroPoint*>(rhs_zero_points); | |
| 616 | 4554 | const auto* biases_ptr = reinterpret_cast<const Bias*>(biases); | |
| 617 |
2/8✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4488 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
4554 | auto* dst_ptr = reinterpret_cast<DstData*>(dst.data()); |
| 618 | |||
| 619 |
4/8✓ Branch 0 taken 66 times.
✓ Branch 1 taken 654 times.
✓ Branch 2 taken 4488 times.
✓ Branch 3 taken 117000 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
122208 | for (size_t y = 0; y < m; ++y) { |
| 620 |
4/8✓ Branch 0 taken 654 times.
✓ Branch 1 taken 43470 times.
✓ Branch 2 taken 117000 times.
✓ Branch 3 taken 15801948 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
15963072 | for (size_t x = 0; x < n; ++x) { |
| 621 | 15845418 | DstData acc = 0; | |
| 622 | |||
| 623 |
4/8✓ Branch 0 taken 1472394 times.
✓ Branch 1 taken 43470 times.
✓ Branch 2 taken 796221588 times.
✓ Branch 3 taken 15801948 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
813539400 | for (size_t i = 0; i < k; ++i) { |
| 624 |
2/8✓ Branch 0 taken 1472394 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 796221588 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
797693982 | const auto lhs_value = read_array<LhsData>(lhs_data, y * k + i); |
| 625 | 797693982 | const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]; | |
| 626 |
2/8✓ Branch 0 taken 1472394 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 796221588 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
797693982 | const auto lhs_zero_point = lhs_zero_points_ptr != nullptr |
| 627 | 797693982 | ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width] | |
| 628 | : 0; | ||
| 629 | |||
| 630 |
2/8✓ Branch 0 taken 1472394 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 796221588 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
797693982 | const auto rhs_value = read_array<RhsData>(rhs_data, x + i * n); |
| 631 | 797693982 | const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width]; | |
| 632 |
2/8✗ Branch 0 not taken.
✓ Branch 1 taken 1472394 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 796221588 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
797693982 | const auto rhs_zero_point = rhs_zero_points_ptr != nullptr |
| 633 | ✗ | ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width] | |
| 634 | : 0; | ||
| 635 | |||
| 636 | 797693982 | acc += static_cast<DstData>( | |
| 637 | 799166376 | (static_cast<IntAcc>(lhs_value) - static_cast<IntAcc>(lhs_zero_point)) * | |
| 638 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 796221588 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
799166376 | (static_cast<IntAcc>(rhs_value) - static_cast<IntAcc>(rhs_zero_point))) * |
| 639 |
0/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
1595387964 | static_cast<DstData>(lhs_scale) * static_cast<DstData>(rhs_scale); |
| 640 | 797693982 | } | |
| 641 | |||
| 642 |
3/8✓ Branch 0 taken 43470 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 157374 times.
✓ Branch 3 taken 15644574 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
15845418 | if (biases_ptr != nullptr) { |
| 643 | 15688044 | acc += static_cast<DstData>(biases_ptr[x]); | |
| 644 | 15688044 | } | |
| 645 | |||
| 646 |
2/8✓ Branch 0 taken 43470 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15801948 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
15845418 | acc = std::clamp(acc, min_value, max_value); |
| 647 | 15845418 | dst_ptr[y * n + x] = acc; | |
| 648 | 15845418 | } | |
| 649 | 117654 | } | |
| 650 | |||
| 651 | 4554 | return dst; | |
| 652 | 4554 | } | |
| 653 | |||
| 654 | template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>( | ||
| 655 | size_t m, size_t n, size_t k, // | ||
| 656 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 657 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 658 | const void* biases, // | ||
| 659 | float min_value, float max_value); | ||
| 660 | template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( | ||
| 661 | size_t m, size_t n, size_t k, // | ||
| 662 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 663 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 664 | const void* biases, // | ||
| 665 | float min_value, float max_value); | ||
| 666 | |||
| 667 | template Buffer matmul_clamp_nt_nt<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>( | ||
| 668 | size_t m, size_t n, size_t k, // | ||
| 669 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 670 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 671 | const void* biases, // | ||
| 672 | float min_value, float max_value); | ||
| 673 | |||
| 674 | template Buffer matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, int32_t, float>( | ||
| 675 | size_t m, size_t n, size_t k, // | ||
| 676 | const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // | ||
| 677 | const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // | ||
| 678 | const void* biases, // | ||
| 679 | float min_value, float max_value); | ||
| 680 | |||
| 681 | } // namespace kai::test | ||
| 682 |