test/nextgen/operators/matmul/pack_lhs/matmul_pack_lhs_dq_wrapper.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/nextgen/operators/matmul/pack_lhs/matmul_pack_lhs_dq_wrapper.hpp" | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <string_view> | ||
| 12 | #include <vector> | ||
| 13 | |||
| 14 | #include "test/common/abi_checker.hpp" | ||
| 15 | #include "test/common/assert.hpp" | ||
| 16 | #include "test/common/data_type.hpp" | ||
| 17 | #include "test/common/span.hpp" | ||
| 18 | #include "test/nextgen/format/plain_format.hpp" | ||
| 19 | #include "test/nextgen/harness/tensor.hpp" | ||
| 20 | #include "test/nextgen/operators/matmul/matmul_pack_args.hpp" | ||
| 21 | #include "test/nextgen/operators/matmul/matmul_slots.hpp" | ||
| 22 | |||
| 23 | namespace kai::test { | ||
| 24 | |||
| 25 | ✗ | std::string_view MatMulPackLhsDqWrapper::name() const { | |
| 26 | ✗ | return m_name; | |
| 27 | } | ||
| 28 | |||
| 29 | 800 | size_t MatMulPackLhsDqWrapper::src_tensor_id() const { | |
| 30 |
1/2✓ Branch 0 taken 800 times.
✗ Branch 1 not taken.
|
800 | const size_t tensor_id = *m_src_format == PlainFormat(DataType::FP32) ? MATMUL_SLOT_LHS_RAW : MATMUL_SLOT_LHS_DATA; |
| 31 | 1600 | return tensor_id; | |
| 32 | 800 | } | |
| 33 | |||
| 34 | 200 | std::vector<size_t> MatMulPackLhsDqWrapper::run_inputs([[maybe_unused]] Span<const Tensor> tensors) const { | |
| 35 | 200 | const size_t src_id = src_tensor_id(); | |
| 36 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
200 | return {src_id}; |
| 37 | 200 | } | |
| 38 | |||
| 39 | 200 | std::vector<size_t> MatMulPackLhsDqWrapper::ref_inputs([[maybe_unused]] Span<const Tensor> tensors) const { | |
| 40 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
200 | return {MATMUL_SLOT_LHS_QDATA, MATMUL_SLOT_LHS_QSCALE, MATMUL_SLOT_LHS_QZP_NEG}; |
| 41 | ✗ | } | |
| 42 | |||
| 43 | 600 | std::vector<size_t> MatMulPackLhsDqWrapper::steps(Span<const size_t> shape, Span<const Tensor> tensors) const { | |
| 44 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT_MSG(shape.size() == 2, "Only M and K dimensions are expected."); |
| 45 | |||
| 46 | 600 | const auto& pack_args = tensors.at(MATMUL_SLOT_PACK_ARGS).value<MatMulPackArgs>(); | |
| 47 | |||
| 48 | 600 | const size_t m_step = m_kernel.get_m_step(pack_args.mr); | |
| 49 | 600 | const size_t shape_k = shape.at(1); | |
| 50 | |||
| 51 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
600 | return {m_step, shape_k}; |
| 52 | 600 | } | |
| 53 | |||
| 54 | 200 | void MatMulPackLhsDqWrapper::populate_constant_info(Span<Tensor> tensors) const { | |
| 55 | 200 | Tensor& lhs_raw = tensors.at(MATMUL_SLOT_LHS_RAW); | |
| 56 | 200 | Tensor& packed_lhs = tensors.at(MATMUL_SLOT_IMP_LHS_PACKED); | |
| 57 | |||
| 58 | 200 | lhs_raw.set_format(m_src_format); | |
| 59 | 200 | packed_lhs.set_format(m_dst_format); | |
| 60 | 200 | } | |
| 61 | |||
| 62 | 600 | void MatMulPackLhsDqWrapper::run( | |
| 63 | Span<const size_t> full_shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape, | ||
| 64 | Span<Tensor> tensors) const { | ||
| 65 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT_MSG(full_shape.size() == 2, "Only M and K dimensions are expected."); |
| 66 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT_MSG(tile_coords.size() == 2, "Only M and K dimensions are expected."); |
| 67 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT_MSG(tile_shape.size() == 2, "Only M and K dimensions are expected."); |
| 68 | |||
| 69 | 600 | const size_t full_m = full_shape.at(0); | |
| 70 | 600 | const size_t full_k = full_shape.at(1); | |
| 71 | |||
| 72 | 600 | const size_t start_m = tile_coords.at(0); | |
| 73 | 600 | const size_t start_k = tile_coords.at(1); | |
| 74 | |||
| 75 | 600 | const size_t size_m = tile_shape.at(0); | |
| 76 | 600 | const size_t size_k = tile_shape.at(1); | |
| 77 | |||
| 78 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(start_k == 0); |
| 79 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(size_k == full_k); |
| 80 | |||
| 81 | 600 | const size_t lhs_tensor_id = src_tensor_id(); | |
| 82 | 600 | const Tensor& lhs_data = tensors.at(lhs_tensor_id); | |
| 83 | 600 | Tensor& packed_lhs = tensors.at(MATMUL_SLOT_IMP_LHS_PACKED); | |
| 84 | |||
| 85 | 600 | const auto& pack_args = tensors.at(MATMUL_SLOT_PACK_ARGS).value<MatMulPackArgs>(); | |
| 86 | |||
| 87 | 600 | packed_lhs.set_shape({full_m, full_k}).allocate(); | |
| 88 | |||
| 89 | 600 | const size_t lhs_stride = m_src_format->compute_size({1, full_k}); | |
| 90 | |||
| 91 | 600 | const size_t lhs_offset = m_src_format->compute_offset(full_shape, tile_coords); | |
| 92 | 600 | const size_t imp_lhs_offset = m_kernel.get_lhs_offset(start_m, lhs_stride); | |
| 93 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(imp_lhs_offset == lhs_offset); |
| 94 | |||
| 95 | 600 | const size_t packed_lhs_offset = m_dst_format->compute_offset(full_shape, tile_coords); | |
| 96 | 1200 | const size_t imp_packed_lhs_offset = | |
| 97 | 600 | m_kernel.get_lhs_packed_offset(start_m, full_k, pack_args.mr, pack_args.kr, pack_args.sr); | |
| 98 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(imp_packed_lhs_offset == packed_lhs_offset); |
| 99 | |||
| 100 | 600 | const size_t packed_lhs_size = packed_lhs.data().size(); | |
| 101 | 1200 | const size_t imp_packed_lhs_size = | |
| 102 | 600 | m_kernel.get_lhs_packed_size(full_m, full_k, pack_args.mr, pack_args.kr, pack_args.sr); | |
| 103 |
1/4✓ Branch 0 taken 600 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
600 | KAI_TEST_ASSERT(imp_packed_lhs_size == packed_lhs_size); |
| 104 | |||
| 105 | 600 | const Span<const std::byte> lhs_tile = lhs_data.data().subspan(lhs_offset); | |
| 106 | 600 | const Span<std::byte> packed_lhs_tile = packed_lhs.data().subspan(packed_lhs_offset); | |
| 107 | |||
| 108 | 1200 | abi_check([&] { | |
| 109 | 1200 | m_kernel.run( | |
| 110 | 600 | size_m, size_k, pack_args.mr, pack_args.kr, pack_args.sr, 0, | |
| 111 | 600 | reinterpret_cast<const float*>(lhs_tile.data()), lhs_stride, packed_lhs_tile.data()); | |
| 112 | 600 | }); | |
| 113 | 600 | } | |
| 114 | |||
| 115 | 200 | void MatMulPackLhsDqWrapper::compute_reference(Span<const size_t> shape, Span<Tensor> tensors) const { | |
| 116 | 200 | const Tensor& lhs_qdata = tensors.at(MATMUL_SLOT_LHS_QDATA); | |
| 117 | 200 | const Tensor& lhs_qscale = tensors.at(MATMUL_SLOT_LHS_QSCALE); | |
| 118 | 200 | const Tensor& lhs_qzp_neg = tensors.at(MATMUL_SLOT_LHS_QZP_NEG); | |
| 119 | 200 | Tensor& ref_packed_lhs = tensors.at(MATMUL_SLOT_REF_LHS_PACKED); | |
| 120 | |||
| 121 | 200 | ref_packed_lhs.set_shape(shape) | |
| 122 | 200 | .set_format(m_dst_format) | |
| 123 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | .set_data(m_dst_format->pack(shape, std::array{lhs_qdata.data(), lhs_qzp_neg.data(), lhs_qscale.data()})); |
| 124 | 200 | } | |
| 125 | |||
| 126 | } // namespace kai::test | ||
| 127 |