test/nextgen/operators/matmul/pack_rhs/matmul_pack_rhs_quant_wrapper.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <string> | ||
| 10 | #include <string_view> | ||
| 11 | |||
| 12 | #include "test/nextgen/common/poly.hpp" | ||
| 13 | #include "test/nextgen/format/format.hpp" | ||
| 14 | #include "test/nextgen/harness/kernel_wrapper.hpp" | ||
| 15 | #include "test/nextgen/operators/matmul/pack_rhs/matmul_pack_rhs_interface.hpp" | ||
| 16 | |||
| 17 | namespace kai::test { | ||
| 18 | |||
| 19 | /// Wrapper for RHS packing kernel with per-channel quantization. | ||
| 20 | class MatMulPackRhsQuantWrapper : public KernelWrapper { | ||
| 21 | public: | ||
| 22 | /// Creates a new wrapper. | ||
| 23 | 8 | MatMulPackRhsQuantWrapper( | |
| 24 | std::string_view name, const MatMulPackRhsQuantInterface& kernel, const Poly<Format>& src_data_format, | ||
| 25 | const Poly<Format>& src_scale_format, const Poly<Format>& src_bias_format, const Poly<Format>& src_sum_format, | ||
| 26 | const Poly<Format>& dst_format) : | ||
| 27 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | m_name(name), |
| 28 | 6 | m_kernel(kernel), | |
| 29 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | m_src_data_format(src_data_format), |
| 30 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | m_src_scale_format(src_scale_format), |
| 31 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | m_src_bias_format(src_bias_format), |
| 32 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | m_src_sum_format(src_sum_format), |
| 33 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
14 | m_dst_format(dst_format) { |
| 34 | 8 | } | |
| 35 | |||
| 36 | [[nodiscard]] std::string_view name() const override; | ||
| 37 | [[nodiscard]] std::vector<size_t> run_inputs(Span<const Tensor> tensors) const override; | ||
| 38 | [[nodiscard]] std::vector<size_t> ref_inputs(Span<const Tensor> tensors) const override; | ||
| 39 | [[nodiscard]] std::vector<size_t> steps(Span<const size_t> shape, Span<const Tensor> tensors) const override; | ||
| 40 | void populate_constant_info(Span<Tensor> tensors) const override; | ||
| 41 | void run( | ||
| 42 | Span<const size_t> full_shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape, | ||
| 43 | Span<Tensor> tensors) const override; | ||
| 44 | void compute_reference(Span<const size_t> shape, Span<Tensor> tensors) const override; | ||
| 45 | |||
| 46 | private: | ||
| 47 | std::string m_name; | ||
| 48 | MatMulPackRhsQuantInterface m_kernel; | ||
| 49 | Poly<Format> m_src_data_format; | ||
| 50 | Poly<Format> m_src_scale_format; | ||
| 51 | Poly<Format> m_src_bias_format; | ||
| 52 | Poly<Format> m_src_sum_format; | ||
| 53 | Poly<Format> m_dst_format; | ||
| 54 | }; | ||
| 55 | |||
| 56 | } // namespace kai::test | ||
| 57 |