test/nextgen/reference/compare.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/nextgen/reference/compare.hpp" | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <functional> | ||
| 13 | #include <iostream> | ||
| 14 | #include <ostream> | ||
| 15 | |||
| 16 | #include "test/common/assert.hpp" | ||
| 17 | #include "test/common/data_type.hpp" | ||
| 18 | #include "test/common/int4.hpp" | ||
| 19 | #include "test/common/memory.hpp" | ||
| 20 | #include "test/common/span.hpp" | ||
| 21 | #include "test/common/type_traits.hpp" | ||
| 22 | |||
| 23 | namespace kai::test { | ||
| 24 | |||
| 25 | namespace { | ||
| 26 | |||
| 27 | /// Calculates the absolute and relative errors. | ||
| 28 | /// | ||
| 29 | /// @param[in] imp Value under test. | ||
| 30 | /// @param[in] ref Reference value. | ||
| 31 | /// | ||
| 32 | /// @return The absolute error and relative error. | ||
| 33 | template <typename T> | ||
| 34 | 15498495 | std::tuple<float, float> calculate_error(T imp, T ref) { | |
| 35 | 15498495 | const auto imp_f = static_cast<float>(imp); | |
| 36 | 15498495 | const auto ref_f = static_cast<float>(ref); | |
| 37 | |||
| 38 | 15498495 | const auto abs_error = std::abs(imp_f - ref_f); | |
| 39 |
8/8✓ Branch 0 taken 1801132 times.
✓ Branch 1 taken 2126849 times.
✓ Branch 2 taken 60812 times.
✓ Branch 3 taken 56998 times.
✓ Branch 4 taken 2440186 times.
✓ Branch 5 taken 2340134 times.
✓ Branch 6 taken 2647617 times.
✓ Branch 7 taken 4024767 times.
|
15498495 | const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0F; |
| 40 | |||
| 41 | 15498495 | return {abs_error, rel_error}; | |
| 42 | 15498495 | } | |
| 43 | |||
| 44 | template <typename T> | ||
| 45 | 84141 | size_t compare_plain_2d( | |
| 46 | Span<const size_t> shape, Span<const size_t> tile_coords, Span<const size_t> tile_shape, | ||
| 47 | Span<const std::byte> imp_buffer, Span<const std::byte> ref_buffer, | ||
| 48 | const std::function<void(std::ostream& os, Span<const size_t> coords)>& report_fn, MismatchHandler& handler) { | ||
| 49 | 84141 | const size_t height = shape.at(0); | |
| 50 | 84141 | const size_t width = shape.at(1); | |
| 51 | |||
| 52 | 84141 | const size_t start_row = tile_coords.at(0); | |
| 53 | 84141 | const size_t start_col = tile_coords.at(1); | |
| 54 | |||
| 55 | 84141 | const size_t tile_height = tile_shape.at(0); | |
| 56 | 84141 | const size_t tile_width = tile_shape.at(1); | |
| 57 | |||
| 58 | 84141 | const size_t end_row = start_row + tile_height; | |
| 59 | 84141 | const size_t end_col = start_col + tile_width; | |
| 60 | |||
| 61 |
8/8✓ Branch 0 taken 29169 times.
✓ Branch 1 taken 74925 times.
✓ Branch 2 taken 27486 times.
✓ Branch 3 taken 27486 times.
✓ Branch 4 taken 26403 times.
✓ Branch 5 taken 654720 times.
✓ Branch 6 taken 1083 times.
✓ Branch 7 taken 26064 times.
|
867336 | for (size_t row = 0; row < height; ++row) { |
| 62 |
8/8✓ Branch 0 taken 3927981 times.
✓ Branch 1 taken 74925 times.
✓ Branch 2 taken 117810 times.
✓ Branch 3 taken 27486 times.
✓ Branch 4 taken 4780320 times.
✓ Branch 5 taken 654720 times.
✓ Branch 6 taken 6672384 times.
✓ Branch 7 taken 26064 times.
|
16281690 | for (size_t col = 0; col < width; ++col) { |
| 63 |
16/24✓ Branch 0 taken 3063577 times.
✓ Branch 1 taken 864404 times.
✓ Branch 2 taken 2194339 times.
✓ Branch 3 taken 869238 times.
✓ Branch 4 taken 152832 times.
✓ Branch 5 taken 2041507 times.
✓ Branch 6 taken 117810 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 117810 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 117810 times.
✓ Branch 12 taken 4780320 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4780320 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 4780320 times.
✓ Branch 18 taken 6672384 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 5667840 times.
✓ Branch 21 taken 1004544 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 5667840 times.
|
15498495 | const bool in_tile = row >= start_row && row < end_row && col >= start_col && col < end_col; |
| 64 | 15498495 | const size_t index = row * width + col; | |
| 65 | |||
| 66 | 15498495 | const T imp_value = read_array<T>(imp_buffer, index); | |
| 67 |
8/8✓ Branch 0 taken 1847117 times.
✓ Branch 1 taken 2080864 times.
✓ Branch 2 taken 79220 times.
✓ Branch 3 taken 38590 times.
✓ Branch 4 taken 2594656 times.
✓ Branch 5 taken 2185664 times.
✓ Branch 6 taken 4304896 times.
✓ Branch 7 taken 2367488 times.
|
15498495 | const T ref_value = in_tile ? read_array<T>(ref_buffer, index) : static_cast<T>(0); |
| 68 | |||
| 69 | 17555725 | const auto [abs_err, rel_err] = calculate_error<T>(imp_value, ref_value); | |
| 70 | |||
| 71 |
9/16✓ Branch 0 taken 2899366 times.
✓ Branch 1 taken 1028615 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2899366 times.
✓ Branch 4 taken 117810 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 117810 times.
✓ Branch 8 taken 4780320 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 4780320 times.
✓ Branch 12 taken 6672384 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 6672384 times.
|
15498495 | if (abs_err != 0 || rel_err != 0) { |
| 72 | // If the mismatch happens outside the tile, it's an error straightaway | ||
| 73 | // since these are expected to be 0 and the kernel is likely to write out-of-bound. | ||
| 74 | // If the mismatch happens inside the tile, the mismatch handler makes the decision | ||
| 75 | // based on the absolute error and relative error. | ||
| 76 | |||
| 77 |
1/8✓ Branch 0 taken 1028615 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
1028615 | if (!in_tile) { |
| 78 | ✗ | handler.mark_as_failed(); | |
| 79 | ✗ | } | |
| 80 | |||
| 81 |
1/8✗ Branch 0 not taken.
✓ Branch 1 taken 1028615 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
1028615 | const auto notifying = !in_tile || handler.handle_data(abs_err, rel_err); |
| 82 | |||
| 83 |
1/8✓ Branch 0 taken 1028615 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
1028615 | if (notifying) { |
| 84 | ✗ | report_fn(std::cerr, std::array{row, col}); | |
| 85 | ✗ | std::cerr << ": actual = " << displayable(imp_value) << ", expected = " << displayable(ref_value) | |
| 86 | ✗ | << "\n"; | |
| 87 | ✗ | } | |
| 88 | 1028615 | } | |
| 89 | 15498495 | } | |
| 90 | 783195 | } | |
| 91 | |||
| 92 | 168282 | return tile_height * tile_width; | |
| 93 | 84141 | } | |
| 94 | |||
| 95 | } // namespace | ||
| 96 | |||
| 97 | 4800 | CompareFn make_compare_plain_2d(DataType dtype) { | |
| 98 |
4/5✓ Branch 0 taken 2400 times.
✓ Branch 1 taken 1200 times.
✓ Branch 2 taken 600 times.
✓ Branch 3 taken 600 times.
✗ Branch 4 not taken.
|
4800 | switch (dtype) { |
| 99 | case DataType::FP32: | ||
| 100 | 2400 | return compare_plain_2d<float>; | |
| 101 | |||
| 102 | case DataType::I32: | ||
| 103 | 1200 | return compare_plain_2d<int32_t>; | |
| 104 | |||
| 105 | case DataType::I8: | ||
| 106 | 600 | return compare_plain_2d<int8_t>; | |
| 107 | |||
| 108 | case DataType::I4: | ||
| 109 | 600 | return compare_plain_2d<Int4>; | |
| 110 | |||
| 111 | default: | ||
| 112 | ✗ | KAI_TEST_ERROR("Not implemented."); | |
| 113 | } | ||
| 114 | 4800 | } | |
| 115 | |||
| 116 | } // namespace kai::test | ||
| 117 |