test/nextgen/operators/matmul/matmul/matmul_dq_wrapper.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <memory> | ||
| 10 | #include <string> | ||
| 11 | #include <string_view> | ||
| 12 | #include <utility> | ||
| 13 | |||
| 14 | #include "test/nextgen/common/poly.hpp" | ||
| 15 | #include "test/nextgen/format/format.hpp" | ||
| 16 | #include "test/nextgen/harness/kernel_wrapper.hpp" | ||
| 17 | #include "test/nextgen/operators/matmul/matmul/matmul_interface.hpp" | ||
| 18 | #include "test/nextgen/quantization/quantizer.hpp" | ||
| 19 | |||
| 20 | namespace kai::test { | ||
| 21 | |||
| 22 | /// Wrapper for matrix multiplication kernel with dynamic quantization. | ||
| 23 | class MatMulDqWrapper : public KernelWrapper { | ||
| 24 | public: | ||
| 25 | /// Creates a new wrapper. | ||
| 26 | 8 | MatMulDqWrapper( | |
| 27 | std::string_view name, const MatMulDqInterface& kernel, std::unique_ptr<Quantizer> lhs_quant, | ||
| 28 | std::unique_ptr<Quantizer> rhs_quant, const Poly<Format>& lhs_format, const Poly<Format>& rhs_format, | ||
| 29 | const Poly<Format>& dst_format) : | ||
| 30 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | m_name(name), |
| 31 | 6 | m_kernel(kernel), | |
| 32 | 6 | m_lhs_quant(std::move(lhs_quant)), | |
| 33 | 6 | m_rhs_quant(std::move(rhs_quant)), | |
| 34 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | m_lhs_format(lhs_format), |
| 35 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | m_rhs_format(rhs_format), |
| 36 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
14 | m_dst_format(dst_format) { |
| 37 | 8 | } | |
| 38 | |||
| 39 | [[nodiscard]] std::string_view name() const override; | ||
| 40 | [[nodiscard]] std::vector<size_t> run_inputs(Span<const Tensor> tensors) const override; | ||
| 41 | [[nodiscard]] std::vector<size_t> ref_inputs(Span<const Tensor> tensors) const override; | ||
| 42 | [[nodiscard]] std::vector<size_t> steps(Span<const size_t> shape, Span<const Tensor> tensors) const override; | ||
| 43 | void populate_constant_info(Span<Tensor> tensors) const override; | ||
| 44 | void run( | ||
| 45 | Span<const size_t> full_shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape, | ||
| 46 | Span<Tensor> tensors) const override; | ||
| 47 | void compute_reference(Span<const size_t> shape, Span<Tensor> tensors) const override; | ||
| 48 | |||
| 49 | private: | ||
| 50 | std::string m_name; | ||
| 51 | MatMulDqInterface m_kernel; | ||
| 52 | std::unique_ptr<Quantizer> m_lhs_quant; | ||
| 53 | std::unique_ptr<Quantizer> m_rhs_quant; | ||
| 54 | Poly<Format> m_lhs_format; | ||
| 55 | Poly<Format> m_rhs_format; | ||
| 56 | Poly<Format> m_dst_format; | ||
| 57 | }; | ||
| 58 | |||
| 59 | } // namespace kai::test | ||
| 60 |