test/common/compare.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | |||
| 11 | namespace kai::test { | ||
| 12 | |||
| 13 | class DataFormat; | ||
| 14 | class Rect; | ||
| 15 | class MismatchHandler; | ||
| 16 | |||
| 17 | /// Compares two matrices to check whether they are matched. | ||
| 18 | /// | ||
| 19 | /// @param[in] imp_data Data buffer of the actual implementation matrix. | ||
| 20 | /// @param[in] ref_data Data buffer of the reference implementation matrix. | ||
| 21 | /// @param[in] format Data format. | ||
| 22 | /// @param[in] full_height Height of the full matrix. | ||
| 23 | /// @param[in] full_width Width of the full matrix. | ||
| 24 | /// @param[in] rect Rectangular region of the matrix that is populated with data. | ||
| 25 | /// @param[in] handler Mismatch handler. | ||
| 26 | /// | ||
| 27 | /// @return `true` if the two matrices are considered matched. | ||
| 28 | bool compare( | ||
| 29 | const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, | ||
| 30 | const Rect& rect, MismatchHandler& handler); | ||
| 31 | |||
| 32 | /// Handles mismatches found by @ref validate function. | ||
| 33 | class MismatchHandler { | ||
| 34 | public: | ||
| 35 | /// Constructor. | ||
| 36 | 62070 | MismatchHandler() = default; | |
| 37 | |||
| 38 | /// Destructor. | ||
| 39 | 62070 | virtual ~MismatchHandler() = default; | |
| 40 | |||
| 41 | /// Copy constructor. | ||
| 42 | MismatchHandler(const MismatchHandler&) = default; | ||
| 43 | |||
| 44 | /// Copy assignment. | ||
| 45 | MismatchHandler& operator=(const MismatchHandler&) = default; | ||
| 46 | |||
| 47 | /// Move constructor. | ||
| 48 | MismatchHandler(MismatchHandler&&) noexcept = default; | ||
| 49 | |||
| 50 | /// Move assignment. | ||
| 51 | MismatchHandler& operator=(MismatchHandler&&) noexcept = default; | ||
| 52 | |||
| 53 | /// Handles new mismatch result. | ||
| 54 | /// | ||
| 55 | /// This method must be called even when no error is detected. | ||
| 56 | /// | ||
| 57 | /// @param[in] absolute_error Absolute error. | ||
| 58 | /// @param[in] relative_error Relative error. | ||
| 59 | /// | ||
| 60 | /// @return `true` if the mismatch is sufficiently large to be logged as real mismatch. | ||
| 61 | virtual bool handle_data(float absolute_error, float relative_error) = 0; | ||
| 62 | |||
| 63 | /// Marks the result as failed. | ||
| 64 | /// | ||
| 65 | /// It is zero tolerance if the data point is considered impossible to have mismatch | ||
| 66 | /// regardless of implementation method. | ||
| 67 | /// These normally include data point outside if the portion of interest (these must be 0) | ||
| 68 | /// and data point belongs to quantization information. | ||
| 69 | virtual void mark_as_failed() = 0; | ||
| 70 | |||
| 71 | /// Returns a value indicating whether the two matrices are considered matched. | ||
| 72 | /// | ||
| 73 | /// @param[in] num_checks Total number of data points that have been checked. | ||
| 74 | /// | ||
| 75 | /// @return `true` if the two matrices are considered matched. | ||
| 76 | [[nodiscard]] virtual bool success(size_t num_checks) const = 0; | ||
| 77 | }; | ||
| 78 | |||
| 79 | /// This mismatch handler considers two values being mismatched when both the relative error | ||
| 80 | /// and the absolute error exceed their respective thresholds. | ||
| 81 | /// | ||
| 82 | /// This mismatch handler considers two matrices being mismatched when the number of mismatches | ||
| 83 | /// exceed both the relative and absolute thresholds. | ||
| 84 | class DefaultMismatchHandler final : public MismatchHandler { | ||
| 85 | public: | ||
| 86 | /// Creates a new mismatch handler. | ||
| 87 | /// | ||
| 88 | /// @param[in] abs_error_threshold Threshold for absolute error | ||
| 89 | /// @param[in] rel_error_threshold Threshold for relative error. | ||
| 90 | /// @param[in] abs_mismatched_threshold Threshold for the number of mismatched data points. | ||
| 91 | /// @param[in] rel_mismatched_threshold Threshold for the ratio of mismatched data points. | ||
| 92 | DefaultMismatchHandler( | ||
| 93 | float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, | ||
| 94 | float rel_mismatched_threshold); | ||
| 95 | |||
| 96 | /// Destructur. | ||
| 97 | 84232 | ~DefaultMismatchHandler() = default; | |
| 98 | |||
| 99 | /// Copy constructor. | ||
| 100 | DefaultMismatchHandler(const DefaultMismatchHandler& rhs); | ||
| 101 | |||
| 102 | /// Copy assignment. | ||
| 103 | DefaultMismatchHandler& operator=(const DefaultMismatchHandler& rhs); | ||
| 104 | |||
| 105 | /// Move constructor. | ||
| 106 | DefaultMismatchHandler(DefaultMismatchHandler&& rhs) noexcept = default; | ||
| 107 | |||
| 108 | /// Move assignment. | ||
| 109 | DefaultMismatchHandler& operator=(DefaultMismatchHandler&& rhs) noexcept = default; | ||
| 110 | |||
| 111 | bool handle_data(float absolute_error, float relative_error) override; | ||
| 112 | void mark_as_failed() override; | ||
| 113 | [[nodiscard]] bool success(size_t num_checks) const override; | ||
| 114 | |||
| 115 | private: | ||
| 116 | float _abs_error_threshold; | ||
| 117 | float _rel_error_threshold; | ||
| 118 | size_t _abs_mismatched_threshold; | ||
| 119 | float _rel_mismatched_threshold; | ||
| 120 | |||
| 121 | size_t _num_mismatches; | ||
| 122 | bool _failed; | ||
| 123 | }; | ||
| 124 | |||
| 125 | } // namespace kai::test | ||
| 126 |