test/nextgen/harness/kernel_wrapper.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | #include <string_view> | ||
| 11 | #include <vector> | ||
| 12 | |||
| 13 | #include "test/common/span.hpp" | ||
| 14 | #include "test/nextgen/harness/tensor.hpp" | ||
| 15 | |||
| 16 | namespace kai::test { | ||
| 17 | |||
| 18 | /// Wrapper to provide unified API for all micro-kernels. | ||
| 19 | class KernelWrapper { | ||
| 20 | public: | ||
| 21 | 18 | KernelWrapper() = default; ///< Default constructor. | |
| 22 | 18 | virtual ~KernelWrapper() = default; ///< Destructor. | |
| 23 | KernelWrapper(const KernelWrapper&) = delete; ///< No copy constructor. | ||
| 24 | KernelWrapper& operator=(const KernelWrapper&) = delete; ///< No copy assignment. | ||
| 25 | KernelWrapper(KernelWrapper&&) = default; ///< Move constructor. | ||
| 26 | KernelWrapper& operator=(KernelWrapper&&) = default; ///< Move assignment. | ||
| 27 | |||
| 28 | /// Gets the micro-kernel name. | ||
| 29 | [[nodiscard]] virtual std::string_view name() const = 0; | ||
| 30 | |||
| 31 | /// Gets the list of input tensors required to run the micro-kernel. | ||
| 32 | /// | ||
| 33 | /// @param[in] tensors The data pool. | ||
| 34 | /// | ||
| 35 | /// @return The list of tensor IDs. | ||
| 36 | [[nodiscard]] virtual std::vector<size_t> run_inputs(Span<const Tensor> tensors) const = 0; | ||
| 37 | |||
| 38 | /// Gets the list of input tensors required to run the reference implementation. | ||
| 39 | /// | ||
| 40 | /// @param[in] tensors The data pool. | ||
| 41 | /// | ||
| 42 | /// @return The list of tensor IDs. | ||
| 43 | [[nodiscard]] virtual std::vector<size_t> ref_inputs(Span<const Tensor> tensors) const = 0; | ||
| 44 | |||
| 45 | /// Gets the scheduling steps in each dimension. | ||
| 46 | /// | ||
| 47 | /// @param[in] shape The full problem shape. | ||
| 48 | /// @param[in] tensors The data pool. | ||
| 49 | /// | ||
| 50 | /// @return The step in each dimension. | ||
| 51 | [[nodiscard]] virtual std::vector<size_t> steps(Span<const size_t> shape, Span<const Tensor> tensors) const = 0; | ||
| 52 | |||
| 53 | /// Populates the data pool with constant information. | ||
| 54 | /// | ||
| 55 | /// @param[in, out] tensors The data pool. | ||
| 56 | virtual void populate_constant_info(Span<Tensor> tensors) const = 0; | ||
| 57 | |||
| 58 | /// Runs the micro-kernel to process a tile of the problem shape. | ||
| 59 | /// | ||
| 60 | /// @param[in] full_shape The full problem shape. | ||
| 61 | /// @param[in] tile_coords The starting coordinate of the tile to be processed by the kernel. | ||
| 62 | /// @param[in] tile_shape The size of the tile to be processed by the kernel. | ||
| 63 | /// @param[in, out] tensors The data pool. | ||
| 64 | virtual void run( | ||
| 65 | Span<const size_t> full_shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape, | ||
| 66 | Span<Tensor> tensors) const = 0; | ||
| 67 | |||
| 68 | /// Computes the reference data. | ||
| 69 | /// | ||
| 70 | /// @param[in] shape The problem shape. | ||
| 71 | /// @param[in, out] tensors The data pool. | ||
| 72 | virtual void compute_reference(Span<const size_t> shape, Span<Tensor> tensors) const = 0; | ||
| 73 | }; | ||
| 74 | |||
| 75 | } // namespace kai::test | ||
| 76 |