test/nextgen/operators/matmul/matmul_tb.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/nextgen/operators/matmul/matmul_tb.hpp" | ||
| 8 | |||
| 9 | #include <algorithm> | ||
| 10 | #include <array> | ||
| 11 | #include <cstddef> | ||
| 12 | #include <tuple> | ||
| 13 | #include <utility> | ||
| 14 | #include <vector> | ||
| 15 | |||
| 16 | #include "test/common/assert.hpp" | ||
| 17 | #include "test/common/buffer.hpp" | ||
| 18 | #include "test/common/compare.hpp" | ||
| 19 | #include "test/common/data_type.hpp" | ||
| 20 | #include "test/nextgen/common/poly.hpp" | ||
| 21 | #include "test/nextgen/common/random.hpp" | ||
| 22 | #include "test/nextgen/format/format.hpp" | ||
| 23 | #include "test/nextgen/format/plain_format.hpp" | ||
| 24 | #include "test/nextgen/harness/kernel_wrapper.hpp" | ||
| 25 | #include "test/nextgen/operators/matmul/matmul_config.hpp" | ||
| 26 | #include "test/nextgen/operators/matmul/matmul_slots.hpp" | ||
| 27 | #include "test/nextgen/quantization/quantizer.hpp" | ||
| 28 | #include "test/nextgen/reference/binary_elementwise.hpp" | ||
| 29 | #include "test/nextgen/reference/clamp.hpp" | ||
| 30 | #include "test/nextgen/reference/matmul.hpp" | ||
| 31 | #include "test/nextgen/reference/reduce.hpp" | ||
| 32 | #include "test/nextgen/reference/unary_elementwise.hpp" | ||
| 33 | #include "test/reference/transpose.hpp" | ||
| 34 | |||
| 35 | namespace kai::test { | ||
| 36 | |||
| 37 | 600 | MatMulTb::MatMulTb( | |
| 38 | size_t shape_m, size_t shape_n, size_t shape_k, MatMulBiasMode bias_mode, float clamp_ratio, | ||
| 39 | const MatMulOperator* op) : | ||
| 40 | 200 | m_shape_m(shape_m), | |
| 41 | 200 | m_shape_n(shape_n), | |
| 42 | 200 | m_shape_k(shape_k), | |
| 43 | 200 | m_bias_mode(bias_mode), | |
| 44 | 200 | m_clamp_ratio(clamp_ratio), | |
| 45 | 200 | m_op(op), | |
| 46 | 400 | m_tensors_required() { | |
| 47 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | std::fill(m_tensors_required.begin(), m_tensors_required.end(), false); |
| 48 | 400 | } | |
| 49 | |||
| 50 | 200 | void MatMulTb::generate_test_data(Rng& rng) { | |
| 51 | 200 | populate_config(); | |
| 52 | 200 | determine_required_tensors(); | |
| 53 | |||
| 54 | // Populates the constant information. | ||
| 55 | 200 | m_op->matmul->populate_constant_info(m_tensors); | |
| 56 | |||
| 57 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_op->pack_lhs.has_value()) { |
| 58 | 200 | const KernelWrapper& pack_lhs = *m_op->pack_lhs.value(); | |
| 59 | 200 | pack_lhs.populate_constant_info(m_tensors); | |
| 60 | 200 | } | |
| 61 | |||
| 62 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_op->pack_rhs.has_value()) { |
| 63 | 200 | const KernelWrapper& pack_rhs = *m_op->pack_rhs.value(); | |
| 64 | 200 | pack_rhs.populate_constant_info(m_tensors); | |
| 65 | 200 | } | |
| 66 | |||
| 67 | // Generates the raw test data. | ||
| 68 | 200 | generate_lhs_raw(rng); | |
| 69 | 200 | generate_rhs_raw(rng); | |
| 70 | 200 | generate_bias_raw(rng); | |
| 71 | |||
| 72 | 200 | compute_rhs_t_raw(); // The transposed RHS data is always needed for reference packing. | |
| 73 | |||
| 74 | // Quantizes the input data. | ||
| 75 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_op->lhs_quant.has_value()) { |
| 76 | 200 | quantize_lhs(); | |
| 77 | 200 | } | |
| 78 | |||
| 79 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_op->rhs_quant.has_value()) { |
| 80 | 200 | quantize_rhs_t(); | |
| 81 | 200 | } | |
| 82 | |||
| 83 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | if (m_op->bias_quant.has_value()) { |
| 84 | ✗ | quantize_bias(); | |
| 85 | ✗ | } | |
| 86 | |||
| 87 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_tensors_required.at(MATMUL_SLOT_LHS_QZP_NEG)) { |
| 88 | 200 | compute_lhs_qzp_neg(); | |
| 89 | 200 | } | |
| 90 | |||
| 91 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_tensors_required.at(MATMUL_SLOT_RHS_T_QDATA_SIGN)) { |
| 92 | 200 | compute_rhs_t_qdata_sign(); | |
| 93 | 200 | } | |
| 94 | |||
| 95 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_tensors_required.at(MATMUL_SLOT_RHS_T_QDATA_SIGN_SUM)) { |
| 96 | 200 | compute_rhs_t_qdata_sign_sum(); | |
| 97 | 200 | } | |
| 98 | |||
| 99 | // Generates reference output. | ||
| 100 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_op->pack_lhs.has_value()) { |
| 101 | 200 | compute_ref_packed_lhs(); | |
| 102 | 200 | } | |
| 103 | |||
| 104 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_op->pack_rhs.has_value()) { |
| 105 | 200 | compute_ref_packed_rhs(); | |
| 106 | 200 | } | |
| 107 | |||
| 108 | 200 | compute_ref_matmul(); | |
| 109 | 200 | } | |
| 110 | |||
| 111 | 200 | void MatMulTb::populate_config() { | |
| 112 | 200 | m_tensors.at(MATMUL_SLOT_CONFIG).set_value(MatMulConfig{m_bias_mode}); | |
| 113 | 200 | } | |
| 114 | |||
| 115 | 200 | void MatMulTb::determine_required_tensors() { | |
| 116 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
200 | std::vector<const KernelWrapper*> kernels{m_op->matmul.get()}; |
| 117 | |||
| 118 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_op->pack_lhs.has_value()) { |
| 119 |
2/4✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
|
200 | kernels.emplace_back(m_op->pack_lhs.value().get()); |
| 120 | 200 | } | |
| 121 | |||
| 122 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 200 times.
|
200 | if (m_op->pack_rhs.has_value()) { |
| 123 |
2/4✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
|
200 | kernels.emplace_back(m_op->pack_rhs.value().get()); |
| 124 | 200 | } | |
| 125 | |||
| 126 |
2/2✓ Branch 0 taken 200 times.
✓ Branch 1 taken 600 times.
|
800 | for (const KernelWrapper* kernel : kernels) { |
| 127 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
|
600 | if (kernel != nullptr) { |
| 128 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
|
600 | const std::vector<size_t> run_inputs = kernel->run_inputs(m_tensors); |
| 129 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 600 times.
|
600 | const std::vector<size_t> ref_inputs = kernel->ref_inputs(m_tensors); |
| 130 | |||
| 131 |
2/2✓ Branch 0 taken 1345 times.
✓ Branch 1 taken 600 times.
|
1945 | for (const size_t id : run_inputs) { |
| 132 |
1/2✓ Branch 0 taken 1345 times.
✗ Branch 1 not taken.
|
1345 | m_tensors_required.at(id) = true; |
| 133 | 1345 | } | |
| 134 | |||
| 135 |
2/2✓ Branch 0 taken 600 times.
✓ Branch 1 taken 2545 times.
|
3145 | for (const size_t id : ref_inputs) { |
| 136 |
1/2✓ Branch 0 taken 2545 times.
✗ Branch 1 not taken.
|
2545 | m_tensors_required.at(id) = true; |
| 137 | 2545 | } | |
| 138 | 600 | } | |
| 139 | 600 | } | |
| 140 | 200 | } | |
| 141 | |||
| 142 | 200 | void MatMulTb::generate_lhs_raw(Rng& rng) { | |
| 143 | 200 | const std::array shape{m_shape_m, m_shape_k}; | |
| 144 | 200 | const Poly<Format> format(std::in_place_type<PlainFormat>, DataType::FP32); | |
| 145 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | Tensor& tensor = m_tensors.at(MATMUL_SLOT_LHS_RAW); |
| 146 | |||
| 147 |
5/10✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
|
200 | tensor.set_shape(shape).set_format(format).set_data(format->generate_random(shape, rng)); |
| 148 | 200 | } | |
| 149 | |||
| 150 | 200 | void MatMulTb::generate_rhs_raw(Rng& rng) { | |
| 151 | 200 | const std::array shape{m_shape_k, m_shape_n}; | |
| 152 | 200 | const Poly<Format> format(std::in_place_type<PlainFormat>, DataType::FP32); | |
| 153 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | Tensor& tensor = m_tensors.at(MATMUL_SLOT_RHS_RAW); |
| 154 | |||
| 155 |
5/10✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
|
200 | tensor.set_shape(shape).set_format(format).set_data(format->generate_random(shape, rng)); |
| 156 | 200 | } | |
| 157 | |||
| 158 | 200 | void MatMulTb::generate_bias_raw(Rng& rng) { | |
| 159 | 200 | const std::array shape{m_shape_n}; | |
| 160 | 200 | const Poly<Format> format(std::in_place_type<PlainFormat>, DataType::FP32); | |
| 161 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | Tensor& tensor = m_tensors.at(MATMUL_SLOT_BIAS_RAW); |
| 162 | |||
| 163 |
5/10✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
|
200 | tensor.set_shape(shape).set_format(format).set_data(format->generate_random(shape, rng)); |
| 164 | 200 | } | |
| 165 | |||
| 166 | 200 | void MatMulTb::compute_rhs_t_raw() { | |
| 167 | 200 | const std::array shape{m_shape_n, m_shape_k}; | |
| 168 | 200 | const Poly<Format> format(std::in_place_type<PlainFormat>, DataType::FP32); | |
| 169 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | Tensor& rhs_t_raw = m_tensors.at(MATMUL_SLOT_RHS_T_RAW); |
| 170 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | const Tensor& rhs_raw = m_tensors.at(MATMUL_SLOT_RHS_RAW); |
| 171 | |||
| 172 |
5/10✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
|
200 | rhs_t_raw.set_shape(shape).set_format(format).set_data(transpose<float>(rhs_raw.data_ptr(), m_shape_k, m_shape_n)); |
| 173 | 200 | } | |
| 174 | |||
| 175 | 200 | void MatMulTb::quantize_lhs() { | |
| 176 | 200 | const Quantizer& lhs_quant = *m_op->lhs_quant.value(); | |
| 177 | |||
| 178 | 200 | const std::array lhs_shape{m_shape_m, m_shape_k}; | |
| 179 | 200 | const Tensor& lhs_raw = m_tensors.at(MATMUL_SLOT_LHS_RAW); | |
| 180 | 200 | Tensor& lhs_qdata = m_tensors.at(MATMUL_SLOT_LHS_QDATA); | |
| 181 | 200 | Tensor& lhs_qscale = m_tensors.at(MATMUL_SLOT_LHS_QSCALE); | |
| 182 | 200 | Tensor& lhs_qzp = m_tensors.at(MATMUL_SLOT_LHS_QZP); | |
| 183 | |||
| 184 | 200 | lhs_quant.dynamic_quantize(DataType::FP32, lhs_shape, lhs_raw.data(), lhs_qdata, lhs_qscale, lhs_qzp); | |
| 185 | 200 | } | |
| 186 | |||
| 187 | 200 | void MatMulTb::quantize_rhs_t() { | |
| 188 | 200 | const Quantizer& rhs_quant = *m_op->rhs_quant.value(); | |
| 189 | |||
| 190 | 200 | const std::array rhs_t_shape{m_shape_n, m_shape_k}; | |
| 191 | 200 | const Tensor& rhs_t_raw = m_tensors.at(MATMUL_SLOT_RHS_T_RAW); | |
| 192 | 200 | Tensor& rhs_t_qdata = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA); | |
| 193 | 200 | Tensor& rhs_t_qscale = m_tensors.at(MATMUL_SLOT_RHS_T_QSCALE); | |
| 194 | 200 | Tensor& rhs_t_qzp = m_tensors.at(MATMUL_SLOT_RHS_T_QZP); | |
| 195 | |||
| 196 | 200 | rhs_quant.dynamic_quantize(DataType::FP32, rhs_t_shape, rhs_t_raw.data(), rhs_t_qdata, rhs_t_qscale, rhs_t_qzp); | |
| 197 | 200 | } | |
| 198 | |||
| 199 | ✗ | void MatMulTb::quantize_bias() { | |
| 200 | ✗ | KAI_TEST_ERROR("Not supported."); | |
| 201 | ✗ | } | |
| 202 | |||
| 203 | 200 | void MatMulTb::compute_lhs_qzp_neg() { | |
| 204 | 200 | const Tensor& lhs_qzp = m_tensors.at(MATMUL_SLOT_LHS_QZP); | |
| 205 | 200 | Tensor& lhs_qzp_neg = m_tensors.at(MATMUL_SLOT_LHS_QZP_NEG); | |
| 206 | |||
| 207 | 200 | const Span<const size_t> shape = lhs_qzp.shape(); | |
| 208 | 200 | const Poly<Format>& format = lhs_qzp.format(); | |
| 209 | |||
| 210 | 200 | const UnaryElementwiseFn fn = make_negate(format->dtype()); | |
| 211 | 200 | Buffer data = fn(shape, lhs_qzp.data()); | |
| 212 | |||
| 213 |
3/6✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
|
200 | lhs_qzp_neg.set_shape(shape).set_format(format).set_data(std::move(data)); |
| 214 | 200 | } | |
| 215 | |||
| 216 | 200 | void MatMulTb::compute_rhs_t_qdata_sign() { | |
| 217 | 200 | const Tensor& rhs_t_qdata = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA); | |
| 218 | 200 | Tensor& rhs_t_qdata_sign = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA_SIGN); | |
| 219 | |||
| 220 | 200 | const Span<const size_t> shape = rhs_t_qdata.shape(); | |
| 221 | 200 | const Poly<Format>& format = rhs_t_qdata.format(); | |
| 222 | |||
| 223 | 200 | const UnaryElementwiseFn fn = make_change_signedness(format->dtype()); | |
| 224 | 200 | Buffer data = fn(shape, rhs_t_qdata.data()); | |
| 225 | |||
| 226 |
3/6✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
|
200 | rhs_t_qdata_sign.set_shape(shape).set_format(format).set_data(std::move(data)); |
| 227 | 200 | } | |
| 228 | |||
| 229 | 200 | void MatMulTb::compute_rhs_t_qdata_sign_sum() { | |
| 230 | 200 | const Tensor& rhs_t_qdata_sign = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA_SIGN); | |
| 231 | 200 | Tensor& rhs_t_qdata_sign_sum = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA_SIGN_SUM); | |
| 232 | |||
| 233 | 200 | const std::array rhs_t_shape = {m_shape_n, m_shape_k}; | |
| 234 | 200 | const std::array rhs_t_rowsum_shape = {m_shape_n}; | |
| 235 | 200 | const DataType src_dtype = rhs_t_qdata_sign.format()->dtype(); | |
| 236 | 200 | const DataType dst_dtype = rhs_t_qdata_sign_sum.format()->dtype(); | |
| 237 | |||
| 238 | 200 | const ReduceFn fn = make_reduce_add(src_dtype, dst_dtype); | |
| 239 | 200 | Buffer data = fn(0, rhs_t_shape, rhs_t_qdata_sign.data()); | |
| 240 | |||
| 241 |
2/4✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
|
200 | rhs_t_qdata_sign_sum.set_shape(rhs_t_rowsum_shape).set_data(std::move(data)); |
| 242 | 200 | } | |
| 243 | |||
| 244 | 200 | void MatMulTb::compute_ref_packed_lhs() { | |
| 245 | 200 | const KernelWrapper& pack_lhs = *m_op->pack_lhs.value(); | |
| 246 | 200 | const std::array lhs_shape{m_shape_m, m_shape_k}; | |
| 247 | 200 | pack_lhs.compute_reference(lhs_shape, m_tensors); | |
| 248 | 200 | } | |
| 249 | |||
| 250 | 200 | void MatMulTb::compute_ref_packed_rhs() { | |
| 251 | 200 | const KernelWrapper& pack_rhs = *m_op->pack_rhs.value(); | |
| 252 | 200 | const std::array rhs_t_shape{m_shape_n, m_shape_k}; | |
| 253 | 200 | pack_rhs.compute_reference(rhs_t_shape, m_tensors); | |
| 254 | 200 | } | |
| 255 | |||
| 256 | 200 | void MatMulTb::compute_ref_matmul() { | |
| 257 | 200 | const MatMulConfig& config = m_tensors.at(MATMUL_SLOT_CONFIG).value<MatMulConfig>(); | |
| 258 | 200 | const Tensor& lhs_qdata = m_tensors.at(MATMUL_SLOT_LHS_QDATA); | |
| 259 | 200 | const Tensor& lhs_qscale = m_tensors.at(MATMUL_SLOT_LHS_QSCALE); | |
| 260 | 200 | const Tensor& lhs_qzp = m_tensors.at(MATMUL_SLOT_LHS_QZP); | |
| 261 | 200 | const Tensor& rhs_t_qdata = m_tensors.at(MATMUL_SLOT_RHS_T_QDATA); | |
| 262 | 200 | const Tensor& rhs_t_qscale = m_tensors.at(MATMUL_SLOT_RHS_T_QSCALE); | |
| 263 | 200 | const Tensor& bias_raw = m_tensors.at(MATMUL_SLOT_BIAS_RAW); | |
| 264 | 200 | Tensor& kernel_args = m_tensors.at(MATMUL_SLOT_MATMUL_ARGS); | |
| 265 | 200 | Tensor& ref_dst_data = m_tensors.at(MATMUL_SLOT_REF_DST_DATA); | |
| 266 | |||
| 267 |
2/4✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
|
200 | ref_dst_data.set_shape({m_shape_m, m_shape_n}).set_format(make_poly<PlainFormat>(m_op->dst_dtype)); |
| 268 | |||
| 269 | // REVISIT: Assumes that the LHS and RHS are both quantized. | ||
| 270 | 200 | const Quantizer& lhs_quant = *m_op->lhs_quant.value(); | |
| 271 | 200 | const Quantizer& rhs_quant = *m_op->rhs_quant.value(); | |
| 272 | |||
| 273 | 1200 | const Buffer lhs_data = lhs_quant.dequantize( | |
| 274 | 1000 | m_op->acc_dtype, {m_shape_m, m_shape_k}, lhs_qdata.data(), lhs_qscale.data(), lhs_qzp.data()); | |
| 275 | 200 | const Buffer rhs_t_data = | |
| 276 |
3/6✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
|
200 | rhs_quant.dequantize(m_op->acc_dtype, {m_shape_n, m_shape_k}, rhs_t_qdata.data(), rhs_t_qscale.data(), {}); |
| 277 | |||
| 278 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | const MatMulFn matmul_fn = make_matmul_nt_t(m_op->acc_dtype); |
| 279 |
3/6✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
|
200 | Buffer dst = matmul_fn(m_shape_m, m_shape_n, m_shape_k, lhs_data, rhs_t_data); |
| 280 | |||
| 281 |
2/3✓ Branch 0 taken 55 times.
✓ Branch 1 taken 145 times.
✗ Branch 2 not taken.
|
200 | switch (config.bias_mode) { |
| 282 | case MatMulBiasMode::NO_BIAS: | ||
| 283 | break; | ||
| 284 | |||
| 285 | case MatMulBiasMode::PER_N: { | ||
| 286 |
1/2✓ Branch 0 taken 145 times.
✗ Branch 1 not taken.
|
145 | const BinaryElementwiseFn add_fn = make_add_2d(m_op->acc_dtype); |
| 287 |
3/6✓ Branch 0 taken 145 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 145 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 145 times.
✗ Branch 5 not taken.
|
145 | dst = add_fn(m_shape_m, m_shape_n, dst, 1, m_shape_n, bias_raw.data()); |
| 288 | break; | ||
| 289 | 145 | } | |
| 290 | |||
| 291 | default: | ||
| 292 | ✗ | KAI_TEST_ERROR("Not supported."); | |
| 293 | } | ||
| 294 | |||
| 295 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | const DynamicClampFn dynamic_clamp_fn = make_dynamic_clamp(m_op->acc_dtype); |
| 296 |
2/4✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
|
600 | auto [clamp_args, clampped_dst] = dynamic_clamp_fn(m_clamp_ratio, {m_shape_m, m_shape_n}, dst); |
| 297 | |||
| 298 |
5/10✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 200 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 200 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 200 times.
✗ Branch 9 not taken.
|
400 | kernel_args.set_shape({clamp_args.size()}).set_data(std::move(clamp_args)); |
| 299 | |||
| 300 |
1/4✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
200 | KAI_TEST_ASSERT_MSG( |
| 301 | m_op->dst_dtype == m_op->acc_dtype, "Only support the accumulator and output type being the same."); | ||
| 302 |
2/4✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
|
400 | ref_dst_data.set_data(std::move(clampped_dst)); |
| 303 | 200 | } | |
| 304 | |||
| 305 | ✗ | bool MatMulTb::has_lhs_packing() const { | |
| 306 | ✗ | return m_op->pack_lhs != nullptr; | |
| 307 | } | ||
| 308 | |||
| 309 | 600 | std::tuple<size_t, size_t> MatMulTb::lhs_packing_steps() const { | |
| 310 | 600 | const KernelWrapper& pack_lhs = *m_op->pack_lhs.value(); | |
| 311 | 600 | const std::vector<size_t> steps = pack_lhs.steps({m_shape_m, m_shape_k}, m_tensors); | |
| 312 |
2/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
|
600 | return {steps.at(0), steps.at(1)}; |
| 313 | 600 | } | |
| 314 | |||
| 315 | 600 | void MatMulTb::test_lhs_packing(size_t start_m, size_t start_k, size_t size_m, size_t size_k) { | |
| 316 | 600 | const KernelWrapper& pack_lhs = *m_op->pack_lhs.value(); | |
| 317 | |||
| 318 | 600 | const std::array full_shape{m_shape_m, m_shape_k}; | |
| 319 | 600 | const std::array tile_coords{start_m, start_k}; | |
| 320 | 600 | const std::array tile_shape{size_m, size_k}; | |
| 321 | |||
| 322 | 600 | pack_lhs.run(full_shape, tile_coords, tile_shape, m_tensors); | |
| 323 | |||
| 324 | 600 | const Tensor& ref_packed_lhs = m_tensors.at(MATMUL_SLOT_REF_LHS_PACKED); | |
| 325 | 600 | const Tensor& imp_packed_lhs = m_tensors.at(MATMUL_SLOT_IMP_LHS_PACKED); | |
| 326 | 600 | const Format& format = *ref_packed_lhs.format(); | |
| 327 | |||
| 328 | 600 | DefaultMismatchHandler handler(0.0F, 0.0F, 0, 0.0F); | |
| 329 | 1200 | const bool ok = | |
| 330 |
3/6✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
|
600 | format.compare(full_shape, tile_coords, tile_shape, imp_packed_lhs.data(), ref_packed_lhs.data(), handler); |
| 331 |
1/6✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
600 | KAI_TEST_ASSERT(ok); |
| 332 | 600 | } | |
| 333 | |||
| 334 | ✗ | bool MatMulTb::has_rhs_packing() const { | |
| 335 | ✗ | return m_op->pack_rhs.has_value(); | |
| 336 | } | ||
| 337 | |||
| 338 | 600 | std::tuple<size_t, size_t> MatMulTb::rhs_packing_steps() const { | |
| 339 | 600 | const KernelWrapper& pack_rhs = *m_op->pack_rhs.value(); | |
| 340 | 600 | const std::vector<size_t> steps = pack_rhs.steps({m_shape_n, m_shape_k}, m_tensors); | |
| 341 |
2/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
|
600 | return {steps.at(0), steps.at(1)}; |
| 342 | 600 | } | |
| 343 | |||
| 344 | 600 | void MatMulTb::test_rhs_packing(size_t start_n, size_t start_k, size_t size_n, size_t size_k) { | |
| 345 | 600 | const KernelWrapper& pack_rhs = *m_op->pack_rhs.value(); | |
| 346 | |||
| 347 | 600 | const std::array full_shape{m_shape_n, m_shape_k}; | |
| 348 | 600 | const std::array tile_coords{start_n, start_k}; | |
| 349 | 600 | const std::array tile_shape{size_n, size_k}; | |
| 350 | |||
| 351 | 600 | pack_rhs.run(full_shape, tile_coords, tile_shape, m_tensors); | |
| 352 | |||
| 353 | 600 | const Tensor& ref_packed_rhs = m_tensors.at(MATMUL_SLOT_REF_RHS_PACKED); | |
| 354 | 600 | const Tensor& imp_packed_rhs = m_tensors.at(MATMUL_SLOT_IMP_RHS_PACKED); | |
| 355 | 600 | const Format& format = *ref_packed_rhs.format(); | |
| 356 | |||
| 357 | 600 | DefaultMismatchHandler handler(0.0F, 0.0F, 0, 0.0F); | |
| 358 | 1200 | const bool ok = | |
| 359 |
3/6✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
|
600 | format.compare(full_shape, tile_coords, tile_shape, imp_packed_rhs.data(), ref_packed_rhs.data(), handler); |
| 360 |
1/6✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
600 | KAI_TEST_ASSERT(ok); |
| 361 | 600 | } | |
| 362 | |||
| 363 | 600 | std::tuple<size_t, size_t> MatMulTb::matmul_steps() const { | |
| 364 | 600 | const std::vector<size_t> steps = m_op->matmul->steps({m_shape_m, m_shape_n, m_shape_k}, m_tensors); | |
| 365 |
2/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
|
600 | return {steps.at(0), steps.at(1)}; |
| 366 | 600 | } | |
| 367 | |||
| 368 | 600 | void MatMulTb::test_matmul(size_t start_m, size_t start_n, size_t size_m, size_t size_n) { | |
| 369 | 600 | const std::array matmul_full_shape{m_shape_m, m_shape_n, m_shape_k}; | |
| 370 | 600 | const std::array matmul_tile_coords{start_m, start_n, static_cast<size_t>(0)}; | |
| 371 | 600 | const std::array matmul_tile_shape{size_m, size_n, m_shape_k}; | |
| 372 | |||
| 373 | 600 | const std::array dst_full_shape{m_shape_m, m_shape_n}; | |
| 374 | 600 | const std::array dst_tile_coords{start_m, start_n}; | |
| 375 | 600 | const std::array dst_tile_shape{size_m, size_n}; | |
| 376 | |||
| 377 | 600 | m_op->matmul->run(matmul_full_shape, matmul_tile_coords, matmul_tile_shape, m_tensors); | |
| 378 | |||
| 379 | 600 | const Tensor& ref_dst_data = m_tensors.at(MATMUL_SLOT_REF_DST_DATA); | |
| 380 | 600 | const Tensor& imp_dst_data = m_tensors.at(MATMUL_SLOT_IMP_DST_DATA); | |
| 381 | 600 | const Format& format = *ref_dst_data.format(); | |
| 382 | |||
| 383 | 600 | DefaultMismatchHandler handler(1e-3, 1e-3, 0, 0.0F); | |
| 384 |
1/2✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
|
2400 | const bool ok = format.compare( |
| 385 |
6/12✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 600 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 600 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 600 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 600 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 600 times.
✗ Branch 11 not taken.
|
2400 | dst_full_shape, dst_tile_coords, dst_tile_shape, imp_dst_data.data(), ref_dst_data.data(), handler); |
| 386 |
1/6✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
600 | KAI_TEST_ASSERT(ok); |
| 387 | 600 | } | |
| 388 | |||
| 389 | } // namespace kai::test | ||
| 390 |