test/nextgen/operators/matmul/matmul/matmul_dq_wrapper.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/nextgen/operators/matmul/matmul/matmul_dq_wrapper.hpp" | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | #include <string_view> | ||
| 11 | #include <utility> | ||
| 12 | #include <vector> | ||
| 13 | |||
| 14 | #include "test/common/abi_checker.hpp" | ||
| 15 | #include "test/common/assert.hpp" | ||
| 16 | #include "test/common/buffer.hpp" | ||
| 17 | #include "test/common/span.hpp" | ||
| 18 | #include "test/nextgen/harness/tensor.hpp" | ||
| 19 | #include "test/nextgen/operators/matmul/matmul_bias_mode.hpp" | ||
| 20 | #include "test/nextgen/operators/matmul/matmul_config.hpp" | ||
| 21 | #include "test/nextgen/operators/matmul/matmul_main_args.hpp" | ||
| 22 | #include "test/nextgen/operators/matmul/matmul_pack_args.hpp" | ||
| 23 | #include "test/nextgen/operators/matmul/matmul_slots.hpp" | ||
| 24 | #include "test/nextgen/reference/binary_elementwise.hpp" | ||
| 25 | #include "test/nextgen/reference/clamp.hpp" | ||
| 26 | #include "test/nextgen/reference/matmul.hpp" | ||
| 27 | |||
| 28 | namespace kai::test { | ||
| 29 | |||
| 30 | ✗ | std::string_view MatMulDqWrapper::name() const { | |
| 31 | ✗ | return m_name; | |
| 32 | } | ||
| 33 | |||
| 34 | 200 | std::vector<size_t> MatMulDqWrapper::run_inputs([[maybe_unused]] Span<const Tensor> tensors) const { | |
| 35 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
200 | return {MATMUL_SLOT_REF_LHS_PACKED, MATMUL_SLOT_REF_RHS_PACKED, MATMUL_SLOT_MATMUL_ARGS}; |
| 36 | ✗ | } | |
| 37 | |||
| 38 | 200 | std::vector<size_t> MatMulDqWrapper::ref_inputs([[maybe_unused]] Span<const Tensor> tensors) const { | |
| 39 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
200 | return {MATMUL_SLOT_LHS_QDATA, MATMUL_SLOT_LHS_QSCALE, MATMUL_SLOT_LHS_QZP, |
| 40 | MATMUL_SLOT_RHS_T_QDATA, MATMUL_SLOT_RHS_T_QSCALE, MATMUL_SLOT_BIAS_RAW}; | ||
| 41 | ✗ | } | |
| 42 | |||
| 43 | 600 | std::vector<size_t> MatMulDqWrapper::steps( | |
| 44 | Span<const size_t> shape, [[maybe_unused]] Span<const Tensor> tensorsf) const { | ||
| 45 | 600 | const size_t step_m = m_kernel.get_m_step(); | |
| 46 | 600 | const size_t step_n = m_kernel.get_n_step(); | |
| 47 | 600 | const size_t shape_k = shape.at(2); | |
| 48 | |||
| 49 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
600 | return {step_m, step_n, shape_k}; |
| 50 | 600 | } | |
| 51 | |||
| 52 | 200 | void MatMulDqWrapper::populate_constant_info(Span<Tensor> tensors) const { | |
| 53 | // Populates the packing arguments. | ||
| 54 | 200 | Tensor& pack_args_tensor = tensors.at(MATMUL_SLOT_PACK_ARGS); | |
| 55 | 200 | pack_args_tensor.set_shape({sizeof(MatMulPackArgs)}).allocate(); | |
| 56 | 200 | auto& pack_args = pack_args_tensor.value<MatMulPackArgs>(); | |
| 57 | |||
| 58 | 200 | pack_args.mr = m_kernel.get_mr(); | |
| 59 | 200 | pack_args.nr = m_kernel.get_nr(); | |
| 60 | 200 | pack_args.kr = m_kernel.get_kr(); | |
| 61 | 200 | pack_args.sr = m_kernel.get_sr(); | |
| 62 | 200 | pack_args.bl = 0; | |
| 63 | |||
| 64 | // Setups data format. | ||
| 65 | 200 | } | |
| 66 | |||
| 67 | 600 | void MatMulDqWrapper::run( | |
| 68 | Span<const size_t> full_shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape, | ||
| 69 | Span<Tensor> tensors) const { | ||
| 70 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(tile_coords.size() == full_shape.size()); |
| 71 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(tile_shape.size() == full_shape.size()); |
| 72 | |||
| 73 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT_MSG(full_shape.size() == 3, "Only M, N and K dimensions are expected."); |
| 74 | |||
| 75 | 600 | const size_t full_m = full_shape.at(0); | |
| 76 | 600 | const size_t full_n = full_shape.at(1); | |
| 77 | 600 | const size_t full_k = full_shape.at(2); | |
| 78 | |||
| 79 | 600 | const size_t start_m = tile_coords.at(0); | |
| 80 | 600 | const size_t start_n = tile_coords.at(1); | |
| 81 | 600 | const size_t start_k = tile_coords.at(2); | |
| 82 | |||
| 83 | 600 | const size_t size_m = tile_shape.at(0); | |
| 84 | 600 | const size_t size_n = tile_shape.at(1); | |
| 85 | 600 | const size_t size_k = tile_shape.at(2); | |
| 86 | |||
| 87 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT_MSG(start_k == 0, "Only full K is supported."); |
| 88 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT_MSG(size_k == full_k, "Only full K is supported."); |
| 89 | |||
| 90 | 600 | const Tensor& ref_packed_lhs = tensors.at(MATMUL_SLOT_REF_LHS_PACKED); | |
| 91 | 600 | const Tensor& ref_packed_rhs = tensors.at(MATMUL_SLOT_REF_RHS_PACKED); | |
| 92 | 600 | const Tensor& kernel_args = tensors.at(MATMUL_SLOT_MATMUL_ARGS); | |
| 93 | 600 | Tensor& imp_dst_data = tensors.at(MATMUL_SLOT_IMP_DST_DATA); | |
| 94 | |||
| 95 | 600 | const size_t ref_packed_lhs_offset = m_lhs_format->compute_offset({full_m, full_k}, {start_m, start_k}); | |
| 96 | 600 | const size_t imp_packed_lhs_offset = m_kernel.get_lhs_packed_offset(start_m, full_k); | |
| 97 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(imp_packed_lhs_offset == ref_packed_lhs_offset); |
| 98 | |||
| 99 | 600 | const size_t ref_packed_rhs_offset = m_rhs_format->compute_offset({full_n, full_k}, {start_n, start_k}); | |
| 100 | 600 | const size_t imp_packed_rhs_offset = m_kernel.get_rhs_packed_offset(start_n, full_k); | |
| 101 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(imp_packed_rhs_offset == ref_packed_rhs_offset); |
| 102 | |||
| 103 | 600 | const size_t ref_dst_stride_row = m_dst_format->compute_size({full_n}); | |
| 104 | 600 | const size_t ref_dst_stride_col = m_dst_format->compute_size({1}); | |
| 105 | 600 | const size_t ref_dst_offset = m_dst_format->compute_offset({full_m, full_n}, {start_m, start_n}); | |
| 106 | 600 | const size_t imp_dst_offset = m_kernel.get_dst_offset(start_m, start_n, ref_dst_stride_row); | |
| 107 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(imp_dst_offset == ref_dst_offset); |
| 108 | |||
| 109 | 600 | imp_dst_data.set_shape({full_m, full_n}).set_format(m_dst_format).allocate(); | |
| 110 | 600 | const size_t imp_dst_size = m_kernel.get_dst_size(full_m, full_n); | |
| 111 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(imp_dst_size == imp_dst_data.data().size()); |
| 112 | |||
| 113 | 600 | const Span<const std::byte> packed_lhs_tile = ref_packed_lhs.data().subspan(ref_packed_lhs_offset); | |
| 114 | 600 | const Span<const std::byte> packed_rhs_tile = ref_packed_rhs.data().subspan(ref_packed_rhs_offset); | |
| 115 | 600 | const Span<std::byte> dst_tile = imp_dst_data.data().subspan(ref_dst_offset); | |
| 116 | |||
| 117 | 600 | const auto& clamp_args = kernel_args.value<MatMulClampArgsF32>(); | |
| 118 | |||
| 119 | 1200 | abi_check([&] { | |
| 120 | 1200 | m_kernel.run( | |
| 121 | 600 | size_m, size_n, size_k, packed_lhs_tile.data(), packed_rhs_tile.data(), | |
| 122 | 600 | reinterpret_cast<float*>(dst_tile.data()), ref_dst_stride_row, ref_dst_stride_col, clamp_args.clamp_min, | |
| 123 | 600 | clamp_args.clamp_max); | |
| 124 | 600 | }); | |
| 125 | 600 | } | |
| 126 | |||
| 127 | ✗ | void MatMulDqWrapper::compute_reference( | |
| 128 | [[maybe_unused]] Span<const size_t> shape, [[maybe_unused]] Span<Tensor> tensors) const { | ||
| 129 | ✗ | } | |
| 130 | |||
| 131 | } // namespace kai::test | ||
| 132 |