test/nextgen/operators/matmul/pack_lhs/matmul_pack_lhs_dq_wrapper.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | #include <string> | ||
| 11 | #include <string_view> | ||
| 12 | #include <utility> | ||
| 13 | #include <vector> | ||
| 14 | |||
| 15 | #include "test/common/span.hpp" | ||
| 16 | #include "test/nextgen/common/poly.hpp" | ||
| 17 | #include "test/nextgen/format/format.hpp" | ||
| 18 | #include "test/nextgen/harness/kernel_wrapper.hpp" | ||
| 19 | #include "test/nextgen/harness/tensor.hpp" | ||
| 20 | #include "test/nextgen/operators/matmul/pack_lhs/matmul_pack_lhs_interface.hpp" | ||
| 21 | |||
| 22 | namespace kai::test { | ||
| 23 | |||
| 24 | /// Wrapper for LHS packing kernel with dynamic quantization. | ||
| 25 | class MatMulPackLhsDqWrapper final : public KernelWrapper { | ||
| 26 | public: | ||
| 27 | /// Creates a new wrapper. | ||
| 28 | /// | ||
| 29 | /// @param[in] name The kernel name. | ||
| 30 | /// @param[in] kernel The kernel interface. | ||
| 31 | /// @param[in] src_format The input data format. | ||
| 32 | /// @param[in] dst_format The output data format. | ||
| 33 | 8 | MatMulPackLhsDqWrapper( | |
| 34 | std::string_view name, const MatMulPackLhsDqInterface& kernel, Poly<Format>&& src_format, | ||
| 35 | Poly<Format>&& dst_format) : | ||
| 36 |
3/6✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
|
8 | m_name(name), m_kernel(kernel), m_src_format(std::move(src_format)), m_dst_format(std::move(dst_format)) { |
| 37 | 8 | } | |
| 38 | |||
| 39 | [[nodiscard]] std::string_view name() const override; | ||
| 40 | [[nodiscard]] std::vector<size_t> run_inputs(Span<const Tensor> tensors) const override; | ||
| 41 | [[nodiscard]] std::vector<size_t> ref_inputs(Span<const Tensor> tensors) const override; | ||
| 42 | [[nodiscard]] std::vector<size_t> steps(Span<const size_t> shape, Span<const Tensor> tensors) const override; | ||
| 43 | void populate_constant_info(Span<Tensor> tensors) const override; | ||
| 44 | void run( | ||
| 45 | Span<const size_t> full_shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape, | ||
| 46 | Span<Tensor> tensors) const override; | ||
| 47 | void compute_reference(Span<const size_t> shape, Span<Tensor> tensors) const override; | ||
| 48 | |||
| 49 | private: | ||
| 50 | [[nodiscard]] size_t src_tensor_id() const; ///< Determines the tensor ID containing the input data. | ||
| 51 | |||
| 52 | std::string m_name; | ||
| 53 | MatMulPackLhsDqInterface m_kernel; | ||
| 54 | Poly<Format> m_src_format; | ||
| 55 | Poly<Format> m_dst_format; | ||
| 56 | }; | ||
| 57 | |||
| 58 | } // namespace kai::test | ||
| 59 |