KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 2 / 0 / 2
Functions: 100.0% 2 / 0 / 2
Branches: -% 0 / 0 / 0

test/nextgen/harness/kernel_wrapper.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #pragma once
8
9 #include <cstddef>
10 #include <string_view>
11 #include <vector>
12
13 #include "test/common/span.hpp"
14 #include "test/nextgen/harness/tensor.hpp"
15
16 namespace kai::test {
17
18 /// Wrapper to provide unified API for all micro-kernels.
19 class KernelWrapper {
20 public:
21 18 KernelWrapper() = default; ///< Default constructor.
22 18 virtual ~KernelWrapper() = default; ///< Destructor.
23 KernelWrapper(const KernelWrapper&) = delete; ///< No copy constructor.
24 KernelWrapper& operator=(const KernelWrapper&) = delete; ///< No copy assignment.
25 KernelWrapper(KernelWrapper&&) = default; ///< Move constructor.
26 KernelWrapper& operator=(KernelWrapper&&) = default; ///< Move assignment.
27
28 /// Gets the micro-kernel name.
29 [[nodiscard]] virtual std::string_view name() const = 0;
30
31 /// Gets the list of input tensors required to run the micro-kernel.
32 ///
33 /// @param[in] tensors The data pool.
34 ///
35 /// @return The list of tensor IDs.
36 [[nodiscard]] virtual std::vector<size_t> run_inputs(Span<const Tensor> tensors) const = 0;
37
38 /// Gets the list of input tensors required to run the reference implementation.
39 ///
40 /// @param[in] tensors The data pool.
41 ///
42 /// @return The list of tensor IDs.
43 [[nodiscard]] virtual std::vector<size_t> ref_inputs(Span<const Tensor> tensors) const = 0;
44
45 /// Gets the scheduling steps in each dimension.
46 ///
47 /// @param[in] shape The full problem shape.
48 /// @param[in] tensors The data pool.
49 ///
50 /// @return The step in each dimension.
51 [[nodiscard]] virtual std::vector<size_t> steps(Span<const size_t> shape, Span<const Tensor> tensors) const = 0;
52
53 /// Populates the data pool with constant information.
54 ///
55 /// @param[in, out] tensors The data pool.
56 virtual void populate_constant_info(Span<Tensor> tensors) const = 0;
57
58 /// Runs the micro-kernel to process a tile of the problem shape.
59 ///
60 /// @param[in] full_shape The full problem shape.
61 /// @param[in] tile_coords The starting coordinate of the tile to be processed by the kernel.
62 /// @param[in] tile_shape The size of the tile to be processed by the kernel.
63 /// @param[in, out] tensors The data pool.
64 virtual void run(
65 Span<const size_t> full_shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape,
66 Span<Tensor> tensors) const = 0;
67
68 /// Computes the reference data.
69 ///
70 /// @param[in] shape The problem shape.
71 /// @param[in, out] tensors The data pool.
72 virtual void compute_reference(Span<const size_t> shape, Span<Tensor> tensors) const = 0;
73 };
74
75 } // namespace kai::test
76