Line |
Branch |
Exec |
Source |
1 |
|
|
// |
2 |
|
|
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> |
3 |
|
|
// |
4 |
|
|
// SPDX-License-Identifier: Apache-2.0 |
5 |
|
|
// |
6 |
|
|
|
7 |
|
|
#pragma once |
8 |
|
|
|
9 |
|
|
#include <cstddef> |
10 |
|
|
|
11 |
|
|
namespace kai::test { |
12 |
|
|
|
13 |
|
|
class DataFormat; |
14 |
|
|
class Rect; |
15 |
|
|
class MismatchHandler; |
16 |
|
|
|
17 |
|
|
/// Compares two matrices to check whether they are matched. |
18 |
|
|
/// |
19 |
|
|
/// @param[in] imp_data Data buffer of the actual implementation matrix. |
20 |
|
|
/// @param[in] ref_data Data buffer of the reference implementation matrix. |
21 |
|
|
/// @param[in] format Data format. |
22 |
|
|
/// @param[in] full_height Height of the full matrix. |
23 |
|
|
/// @param[in] full_width Width of the full matrix. |
24 |
|
|
/// @param[in] rect Rectangular region of the matrix that is populated with data. |
25 |
|
|
/// @param[in] handler Mismatch handler. |
26 |
|
|
/// |
27 |
|
|
/// @return `true` if the two matrices are considered matched. |
28 |
|
|
bool compare( |
29 |
|
|
const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, |
30 |
|
|
const Rect& rect, MismatchHandler& handler); |
31 |
|
|
|
32 |
|
|
/// Handles mismatches found by @ref validate function. |
33 |
|
|
class MismatchHandler { |
34 |
|
|
public: |
35 |
|
|
/// Constructor. |
36 |
|
32956 |
MismatchHandler() = default; |
37 |
|
|
|
38 |
|
|
/// Destructor. |
39 |
|
32956 |
virtual ~MismatchHandler() = default; |
40 |
|
|
|
41 |
|
|
/// Copy constructor. |
42 |
|
|
MismatchHandler(const MismatchHandler&) = default; |
43 |
|
|
|
44 |
|
|
/// Copy assignment. |
45 |
|
|
MismatchHandler& operator=(const MismatchHandler&) = default; |
46 |
|
|
|
47 |
|
|
/// Move constructor. |
48 |
|
|
MismatchHandler(MismatchHandler&&) noexcept = default; |
49 |
|
|
|
50 |
|
|
/// Move assignment. |
51 |
|
|
MismatchHandler& operator=(MismatchHandler&&) noexcept = default; |
52 |
|
|
|
53 |
|
|
/// Handles new mismatch result. |
54 |
|
|
/// |
55 |
|
|
/// This method must be called even when no error is detected. |
56 |
|
|
/// |
57 |
|
|
/// @param[in] absolute_error Absolute error. |
58 |
|
|
/// @param[in] relative_error Relative error. |
59 |
|
|
/// |
60 |
|
|
/// @return `true` if the mismatch is sufficiently large to be logged as real mismatch. |
61 |
|
|
virtual bool handle_data(float absolute_error, float relative_error) = 0; |
62 |
|
|
|
63 |
|
|
/// Marks the result as failed. |
64 |
|
|
/// |
65 |
|
|
/// It is zero tolerance if the data point is considered impossible to have mismatch |
66 |
|
|
/// regardless of implementation method. |
67 |
|
|
/// These normally include data point outside if the portion of interest (these must be 0) |
68 |
|
|
/// and data point belongs to quantization information. |
69 |
|
|
virtual void mark_as_failed() = 0; |
70 |
|
|
|
71 |
|
|
/// Returns a value indicating whether the two matrices are considered matched. |
72 |
|
|
/// |
73 |
|
|
/// @param[in] num_checks Total number of data points that have been checked. |
74 |
|
|
/// |
75 |
|
|
/// @return `true` if the two matrices are considered matched. |
76 |
|
|
[[nodiscard]] virtual bool success(size_t num_checks) const = 0; |
77 |
|
|
}; |
78 |
|
|
|
79 |
|
|
/// This mismatch handler considers two values being mismatched when both the relative error |
80 |
|
|
/// and the absolute error exceed their respective thresholds. |
81 |
|
|
/// |
82 |
|
|
/// This mismatch handler considers two matrices being mismatched when the number of mismatches |
83 |
|
|
/// exceed both the relative and absolute thresholds. |
84 |
|
|
class DefaultMismatchHandler final : public MismatchHandler { |
85 |
|
|
public: |
86 |
|
|
/// Creates a new mismatch handler. |
87 |
|
|
/// |
88 |
|
|
/// @param[in] abs_error_threshold Threshold for absolute error |
89 |
|
|
/// @param[in] rel_error_threshold Threshold for relative error. |
90 |
|
|
/// @param[in] abs_mismatched_threshold Threshold for the number of mismatched data points. |
91 |
|
|
/// @param[in] rel_mismatched_threshold Threshold for the ratio of mismatched data points. |
92 |
|
|
DefaultMismatchHandler( |
93 |
|
|
float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, |
94 |
|
|
float rel_mismatched_threshold); |
95 |
|
|
|
96 |
|
|
/// Destructur. |
97 |
|
65912 |
~DefaultMismatchHandler() = default; |
98 |
|
|
|
99 |
|
|
/// Copy constructor. |
100 |
|
|
DefaultMismatchHandler(const DefaultMismatchHandler& rhs); |
101 |
|
|
|
102 |
|
|
/// Copy assignment. |
103 |
|
|
DefaultMismatchHandler& operator=(const DefaultMismatchHandler& rhs); |
104 |
|
|
|
105 |
|
|
/// Move constructor. |
106 |
|
|
DefaultMismatchHandler(DefaultMismatchHandler&& rhs) noexcept = default; |
107 |
|
|
|
108 |
|
|
/// Move assignment. |
109 |
|
|
DefaultMismatchHandler& operator=(DefaultMismatchHandler&& rhs) noexcept = default; |
110 |
|
|
|
111 |
|
|
bool handle_data(float absolute_error, float relative_error) override; |
112 |
|
|
void mark_as_failed() override; |
113 |
|
|
[[nodiscard]] bool success(size_t num_checks) const override; |
114 |
|
|
|
115 |
|
|
private: |
116 |
|
|
float _abs_error_threshold; |
117 |
|
|
float _rel_error_threshold; |
118 |
|
|
size_t _abs_mismatched_threshold; |
119 |
|
|
float _rel_mismatched_threshold; |
120 |
|
|
|
121 |
|
|
size_t _num_mismatches; |
122 |
|
|
bool _failed; |
123 |
|
|
}; |
124 |
|
|
|
125 |
|
|
} // namespace kai::test |
126 |
|
|
|