KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 0.0% 0 / 0 / 34
Functions: 0.0% 0 / 0 / 21
Branches: 0.0% 0 / 0 / 230

test/nextgen/reference/print.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/nextgen/reference/print.hpp"
8
9 #include <cstddef>
10 #include <cstdint>
11 #include <functional>
12 #include <numeric>
13 #include <ostream>
14 #include <string>
15
16 #include "test/common/assert.hpp"
17 #include "test/common/data_type.hpp"
18 #include "test/common/int4.hpp"
19 #include "test/common/memory.hpp"
20 #include "test/common/round.hpp"
21 #include "test/common/span.hpp"
22
23 namespace kai::test {
24
25 namespace {
26
27 template <typename T>
28 void print_impl(std::ostream& os, Span<const size_t> shape, Span<const std::byte> data, size_t level = 0) {
29 const std::string indent(level * 2, ' ');
30 const size_t len = shape.at(0);
31
32 if (shape.size() == 1) {
33 os << indent << "[";
34
35 for (size_t i = 0; i < len; ++i) {
36 const T value = read_array<T>(data, i);
37 os << displayable(value) << ", ";
38 }
39
40 os << "]";
41 } else {
42 const size_t row_size = round_up_division(shape.at(shape.size() - 1) * size_in_bits<T>, 8);
43 const size_t num_rows = std::accumulate(shape.begin() + 1, shape.end() - 1, 1, std::multiplies<>());
44 const size_t stride = num_rows * row_size;
45
46 os << indent << "[\n";
47
48 for (size_t i = 0; i < len; ++i) {
49 print_impl<T>(os, shape.subspan(1), data.subspan(i * stride), level + 1);
50 os << ",\n";
51 }
52
53 os << indent << "]";
54 }
55 }
56
57 template <typename T>
58 void print_array(std::ostream& os, Span<const size_t> shape, Span<const std::byte> data, size_t level) {
59 print_impl<T>(os, shape, data, level);
60 }
61
62 } // namespace
63
64 PrintFn make_print_array(DataType dtype) {
65 switch (dtype) {
66 case DataType::FP32:
67 return print_array<float>;
68
69 case DataType::I32:
70 return print_array<int32_t>;
71
72 case DataType::I8:
73 return print_array<int8_t>;
74
75 case DataType::U4:
76 return print_array<UInt4>;
77
78 case DataType::I4:
79 return print_array<Int4>;
80
81 default:
82 KAI_TEST_ERROR("Not supported.");
83 }
84 }
85
86 } // namespace kai::test
87