KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 9 / 0 / 9
Functions: 100.0% 2 / 0 / 2
Branches: 50.0% 4 / 0 / 8

test/nextgen/operators/matmul/matmul/matmul_dq_wrapper.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #pragma once
8
9 #include <memory>
10 #include <string>
11 #include <string_view>
12 #include <utility>
13
14 #include "test/nextgen/common/poly.hpp"
15 #include "test/nextgen/format/format.hpp"
16 #include "test/nextgen/harness/kernel_wrapper.hpp"
17 #include "test/nextgen/operators/matmul/matmul/matmul_interface.hpp"
18 #include "test/nextgen/quantization/quantizer.hpp"
19
20 namespace kai::test {
21
22 /// Wrapper for matrix multiplication kernel with dynamic quantization.
23 class MatMulDqWrapper : public KernelWrapper {
24 public:
25 /// Creates a new wrapper.
26 8 MatMulDqWrapper(
27 std::string_view name, const MatMulDqInterface& kernel, std::unique_ptr<Quantizer> lhs_quant,
28 std::unique_ptr<Quantizer> rhs_quant, const Poly<Format>& lhs_format, const Poly<Format>& rhs_format,
29 const Poly<Format>& dst_format) :
30
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 m_name(name),
31 6 m_kernel(kernel),
32 6 m_lhs_quant(std::move(lhs_quant)),
33 6 m_rhs_quant(std::move(rhs_quant)),
34
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 m_lhs_format(lhs_format),
35
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 m_rhs_format(rhs_format),
36
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
14 m_dst_format(dst_format) {
37 8 }
38
39 [[nodiscard]] std::string_view name() const override;
40 [[nodiscard]] std::vector<size_t> run_inputs(Span<const Tensor> tensors) const override;
41 [[nodiscard]] std::vector<size_t> ref_inputs(Span<const Tensor> tensors) const override;
42 [[nodiscard]] std::vector<size_t> steps(Span<const size_t> shape, Span<const Tensor> tensors) const override;
43 void populate_constant_info(Span<Tensor> tensors) const override;
44 void run(
45 Span<const size_t> full_shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape,
46 Span<Tensor> tensors) const override;
47 void compute_reference(Span<const size_t> shape, Span<Tensor> tensors) const override;
48
49 private:
50 std::string m_name;
51 MatMulDqInterface m_kernel;
52 std::unique_ptr<Quantizer> m_lhs_quant;
53 std::unique_ptr<Quantizer> m_rhs_quant;
54 Poly<Format> m_lhs_format;
55 Poly<Format> m_rhs_format;
56 Poly<Format> m_dst_format;
57 };
58
59 } // namespace kai::test
60