KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 84.8% 39 / 0 / 46
Functions: 52.9% 9 / 0 / 17
Branches: 62.1% 64 / 0 / 103

test/nextgen/reference/compare.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/nextgen/reference/compare.hpp"
8
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <functional>
13 #include <iostream>
14 #include <ostream>
15
16 #include "test/common/assert.hpp"
17 #include "test/common/data_type.hpp"
18 #include "test/common/int4.hpp"
19 #include "test/common/memory.hpp"
20 #include "test/common/span.hpp"
21 #include "test/common/type_traits.hpp"
22
23 namespace kai::test {
24
25 namespace {
26
27 /// Calculates the absolute and relative errors.
28 ///
29 /// @param[in] imp Value under test.
30 /// @param[in] ref Reference value.
31 ///
32 /// @return The absolute error and relative error.
33 template <typename T>
34 15498495 std::tuple<float, float> calculate_error(T imp, T ref) {
35 15498495 const auto imp_f = static_cast<float>(imp);
36 15498495 const auto ref_f = static_cast<float>(ref);
37
38 15498495 const auto abs_error = std::abs(imp_f - ref_f);
39
8/8
✓ Branch 0 taken 1801132 times.
✓ Branch 1 taken 2126849 times.
✓ Branch 2 taken 60812 times.
✓ Branch 3 taken 56998 times.
✓ Branch 4 taken 2440186 times.
✓ Branch 5 taken 2340134 times.
✓ Branch 6 taken 2647617 times.
✓ Branch 7 taken 4024767 times.
15498495 const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0F;
40
41 15498495 return {abs_error, rel_error};
42 15498495 }
43
44 template <typename T>
45 84141 size_t compare_plain_2d(
46 Span<const size_t> shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape,
47 Span<const std::byte> imp_buffer, Span<const std::byte> ref_buffer,
48 const std::function<void(std::ostream& os, Span<const size_t> coords)>& report_fn, MismatchHandler& handler) {
49 84141 const size_t height = shape.at(0);
50 84141 const size_t width = shape.at(1);
51
52 84141 const size_t start_row = tile_coords.at(0);
53 84141 const size_t start_col = tile_coords.at(1);
54
55 84141 const size_t tile_height = tile_shape.at(0);
56 84141 const size_t tile_width = tile_shape.at(1);
57
58 84141 const size_t end_row = start_row + tile_height;
59 84141 const size_t end_col = start_col + tile_width;
60
61
8/8
✓ Branch 0 taken 29169 times.
✓ Branch 1 taken 74925 times.
✓ Branch 2 taken 27486 times.
✓ Branch 3 taken 27486 times.
✓ Branch 4 taken 26403 times.
✓ Branch 5 taken 654720 times.
✓ Branch 6 taken 1083 times.
✓ Branch 7 taken 26064 times.
867336 for (size_t row = 0; row < height; ++row) {
62
8/8
✓ Branch 0 taken 3927981 times.
✓ Branch 1 taken 74925 times.
✓ Branch 2 taken 117810 times.
✓ Branch 3 taken 27486 times.
✓ Branch 4 taken 4780320 times.
✓ Branch 5 taken 654720 times.
✓ Branch 6 taken 6672384 times.
✓ Branch 7 taken 26064 times.
16281690 for (size_t col = 0; col < width; ++col) {
63
16/24
✓ Branch 0 taken 3063577 times.
✓ Branch 1 taken 864404 times.
✓ Branch 2 taken 2194339 times.
✓ Branch 3 taken 869238 times.
✓ Branch 4 taken 152832 times.
✓ Branch 5 taken 2041507 times.
✓ Branch 6 taken 117810 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 117810 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 117810 times.
✓ Branch 12 taken 4780320 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4780320 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 4780320 times.
✓ Branch 18 taken 6672384 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 5667840 times.
✓ Branch 21 taken 1004544 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 5667840 times.
15498495 const bool in_tile = row >= start_row && row < end_row && col >= start_col && col < end_col;
64 15498495 const size_t index = row * width + col;
65
66 15498495 const T imp_value = read_array<T>(imp_buffer, index);
67
8/8
✓ Branch 0 taken 1847117 times.
✓ Branch 1 taken 2080864 times.
✓ Branch 2 taken 79220 times.
✓ Branch 3 taken 38590 times.
✓ Branch 4 taken 2594656 times.
✓ Branch 5 taken 2185664 times.
✓ Branch 6 taken 4304896 times.
✓ Branch 7 taken 2367488 times.
15498495 const T ref_value = in_tile ? read_array<T>(ref_buffer, index) : static_cast<T>(0);
68
69 17555725 const auto [abs_err, rel_err] = calculate_error<T>(imp_value, ref_value);
70
71
9/16
✓ Branch 0 taken 2899366 times.
✓ Branch 1 taken 1028615 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2899366 times.
✓ Branch 4 taken 117810 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 117810 times.
✓ Branch 8 taken 4780320 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 4780320 times.
✓ Branch 12 taken 6672384 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 6672384 times.
15498495 if (abs_err != 0 || rel_err != 0) {
72 // If the mismatch happens outside the tile, it's an error straightaway
73 // since these are expected to be 0 and the kernel is likely to write out-of-bound.
74 // If the mismatch happens inside the tile, the mismatch handler makes the decision
75 // based on the absolute error and relative error.
76
77
1/8
✓ Branch 0 taken 1028615 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
1028615 if (!in_tile) {
78 handler.mark_as_failed();
79 }
80
81
1/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1028615 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
1028615 const auto notifying = !in_tile || handler.handle_data(abs_err, rel_err);
82
83
1/8
✓ Branch 0 taken 1028615 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
1028615 if (notifying) {
84 report_fn(std::cerr, std::array{row, col});
85 std::cerr << ": actual = " << displayable(imp_value) << ", expected = " << displayable(ref_value)
86 << "\n";
87 }
88 1028615 }
89 15498495 }
90 783195 }
91
92 168282 return tile_height * tile_width;
93 84141 }
94
95 } // namespace
96
97 4800 CompareFn make_compare_plain_2d(DataType dtype) {
98
4/5
✓ Branch 0 taken 2400 times.
✓ Branch 1 taken 1200 times.
✓ Branch 2 taken 600 times.
✓ Branch 3 taken 600 times.
✗ Branch 4 not taken.
4800 switch (dtype) {
99 case DataType::FP32:
100 2400 return compare_plain_2d<float>;
101
102 case DataType::I32:
103 1200 return compare_plain_2d<int32_t>;
104
105 case DataType::I8:
106 600 return compare_plain_2d<int8_t>;
107
108 case DataType::I4:
109 600 return compare_plain_2d<Int4>;
110
111 default:
112 KAI_TEST_ERROR("Not implemented.");
113 }
114 4800 }
115
116 } // namespace kai::test
117