test/common/compare.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/common/compare.hpp" | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | #include <cstdint> | ||
| 11 | #include <cstdlib> | ||
| 12 | #include <sstream> | ||
| 13 | #include <tuple> | ||
| 14 | #include <type_traits> | ||
| 15 | |||
| 16 | #include "kai/kai_common.h" | ||
| 17 | #include "test/common/bfloat16.hpp" | ||
| 18 | #include "test/common/data_format.hpp" | ||
| 19 | #include "test/common/data_type.hpp" | ||
| 20 | #include "test/common/float16.hpp" | ||
| 21 | #include "test/common/int4.hpp" | ||
| 22 | #include "test/common/logging.hpp" | ||
| 23 | #include "test/common/memory.hpp" | ||
| 24 | #include "test/common/printer.hpp" | ||
| 25 | #include "test/common/rect.hpp" | ||
| 26 | #include "test/common/round.hpp" | ||
| 27 | |||
| 28 | namespace kai::test { | ||
| 29 | |||
| 30 | namespace { | ||
| 31 | |||
| 32 | /// Calculates the absolute and relative errors. | ||
| 33 | /// | ||
| 34 | /// @param[in] imp Value under test. | ||
| 35 | /// @param[in] ref Reference value. | ||
| 36 | /// | ||
| 37 | /// @return The absolute error and relative error. | ||
| 38 | template <typename T> | ||
| 39 | 156982263 | std::tuple<float, float> calculate_error(T imp, T ref) { | |
| 40 | 156982263 | const auto imp_f = static_cast<float>(imp); | |
| 41 | 156982263 | const auto ref_f = static_cast<float>(ref); | |
| 42 | |||
| 43 | 156982263 | const auto abs_error = std::abs(imp_f - ref_f); | |
| 44 |
6/12✓ Branch 0 taken 78018717 times.
✓ Branch 1 taken 34277676 times.
✓ Branch 2 taken 20604111 times.
✓ Branch 3 taken 14129727 times.
✓ Branch 4 taken 4633584 times.
✓ Branch 5 taken 5318448 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
156982263 | const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0F; |
| 45 | |||
| 46 | 156982263 | return {abs_error, rel_error}; | |
| 47 | 156982263 | } | |
| 48 | |||
| 49 | /// Compares matrices with per-row quantization. | ||
| 50 | template <typename Data> | ||
| 51 | 57870 | bool compare_raw( | |
| 52 | const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, | ||
| 53 | const Rect& rect, MismatchHandler& handler) { | ||
| 54 | 57870 | const size_t block_height = format.actual_block_height(full_height); | |
| 55 | 57870 | const size_t block_width = format.actual_block_width(full_width); | |
| 56 | 57870 | const size_t subblock_height = format.actual_subblock_height(full_height); | |
| 57 | 57870 | const size_t subblock_width = format.actual_subblock_width(full_width); | |
| 58 | |||
| 59 | 57870 | size_t idx = 0; | |
| 60 | |||
| 61 | 57870 | bool block_heading_written = false; | |
| 62 | 57870 | bool subblock_heading_written = false; | |
| 63 | 57870 | bool row_heading_written = false; | |
| 64 | 57870 | std::ostringstream sstream; | |
| 65 |
6/6✓ Branch 0 taken 26466 times.
✓ Branch 1 taken 26430 times.
✓ Branch 2 taken 26580 times.
✓ Branch 3 taken 26544 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 4896 times.
|
115812 | for (size_t y_block = 0; y_block < full_height; y_block += block_height) { |
| 66 |
6/6✓ Branch 0 taken 31290 times.
✓ Branch 1 taken 26466 times.
✓ Branch 2 taken 26580 times.
✓ Branch 3 taken 28938 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 4896 times.
|
123066 | for (size_t x_block = 0; x_block < full_width; x_block += block_width) { |
| 67 |
6/6✓ Branch 0 taken 31290 times.
✓ Branch 1 taken 31290 times.
✓ Branch 2 taken 28938 times.
✓ Branch 3 taken 28938 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 4896 times.
|
130248 | for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { |
| 68 |
6/6✓ Branch 0 taken 31290 times.
✓ Branch 1 taken 31290 times.
✓ Branch 2 taken 28938 times.
✓ Branch 3 taken 28938 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 4896 times.
|
130248 | for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { |
| 69 |
6/6✓ Branch 0 taken 31290 times.
✓ Branch 1 taken 989456 times.
✓ Branch 2 taken 28938 times.
✓ Branch 3 taken 471438 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 119952 times.
|
1645970 | for (size_t y_element = 0; y_element < subblock_height; ++y_element) { |
| 70 |
6/6✓ Branch 0 taken 104115945 times.
✓ Branch 1 taken 989456 times.
✓ Branch 2 taken 29549166 times.
✓ Branch 3 taken 471438 times.
✓ Branch 4 taken 9952032 times.
✓ Branch 5 taken 119952 times.
|
145197989 | for (size_t x_element = 0; x_element < subblock_width; ++x_element) { |
| 71 | 143617143 | const auto y = y_block + y_subblock + y_element; | |
| 72 | 143617143 | const auto x = x_block + x_subblock + x_element; | |
| 73 | |||
| 74 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 104115945 times.
✓ Branch 2 taken 29549166 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9952032 times.
✗ Branch 5 not taken.
|
143617143 | const auto in_roi = rect.contains(y, x); |
| 75 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 104115945 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 29549166 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 9952032 times.
|
143617143 | const auto imp_value = read_array<Data>(imp_data, idx); |
| 76 |
11/16✓ Branch 0 taken 71557437 times.
✓ Branch 1 taken 32558508 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 71557437 times.
✓ Branch 4 taken 13434080 times.
✓ Branch 5 taken 16115086 times.
✓ Branch 6 taken 16115086 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 13434080 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5318448 times.
✓ Branch 11 taken 4633584 times.
✓ Branch 12 taken 4633584 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 5318448 times.
✗ Branch 15 not taken.
|
143617143 | const auto ref_value = in_roi ? read_array<Data>(ref_data, idx) : static_cast<Data>(0); |
| 77 | |||
| 78 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 104115945 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 29549166 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 9952032 times.
|
237706947 | const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); |
| 79 | |||
| 80 |
9/12✓ Branch 0 taken 63584591 times.
✓ Branch 1 taken 40531354 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 63584591 times.
✓ Branch 4 taken 25347998 times.
✓ Branch 5 taken 4201168 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 25347998 times.
✓ Branch 8 taken 7639652 times.
✓ Branch 9 taken 2312380 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7639652 times.
|
143617143 | if (abs_err != 0 || rel_err != 0) { |
| 81 |
3/6✓ Branch 0 taken 40531354 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4201168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2312380 times.
✗ Branch 5 not taken.
|
47044902 | if (!in_roi) { |
| 82 | ✗ | handler.mark_as_failed(); | |
| 83 | ✗ | } | |
| 84 | |||
| 85 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 40531354 times.
✓ Branch 2 taken 40531354 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40531354 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 40531354 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 4201168 times.
✓ Branch 10 taken 4201168 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4201168 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4201168 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2312380 times.
✓ Branch 18 taken 2312380 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2312380 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2312380 times.
✗ Branch 23 not taken.
|
47044902 | const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); |
| 86 | |||
| 87 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 40531354 times.
✓ Branch 2 taken 995 times.
✓ Branch 3 taken 4200173 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2312380 times.
|
47044902 | if (notifying) { |
| 88 |
2/6✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 172 times.
✓ Branch 3 taken 823 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
995 | if (!block_heading_written) { |
| 89 |
5/30✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 172 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 172 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 172 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 172 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 172 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
|
172 | sstream << "block @ (" << y_block << ", " << x_block << "):\n"; |
| 90 | 172 | block_heading_written = true; | |
| 91 | 172 | } | |
| 92 |
2/6✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 172 times.
✓ Branch 3 taken 823 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
995 | if (!subblock_heading_written) { |
| 93 |
5/30✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 172 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 172 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 172 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 172 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 172 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
|
172 | sstream << " sub-block @ (" << y_subblock << ", " << x_subblock << "):\n"; |
| 94 | 172 | subblock_heading_written = true; | |
| 95 | 172 | } | |
| 96 |
2/6✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 877 times.
✓ Branch 3 taken 118 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
995 | if (!row_heading_written) { |
| 97 |
3/18✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 877 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 877 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 877 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
|
877 | sstream << " row=" << y_element << ": "; |
| 98 | 877 | row_heading_written = true; | |
| 99 | 877 | } | |
| 100 |
2/12✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 995 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 995 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
995 | sstream << x_element << ", "; |
| 101 | 995 | } | |
| 102 | 47044902 | } | |
| 103 | |||
| 104 | 143617143 | ++idx; | |
| 105 | 143617143 | } | |
| 106 |
4/6✓ Branch 0 taken 989456 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 470561 times.
✓ Branch 3 taken 877 times.
✓ Branch 4 taken 119952 times.
✗ Branch 5 not taken.
|
1580846 | if (row_heading_written) { |
| 107 |
1/6✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 877 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
877 | sstream << "\n"; |
| 108 | 877 | } | |
| 109 | 1580846 | row_heading_written = false; | |
| 110 | 1580846 | } | |
| 111 | 65124 | subblock_heading_written = false; | |
| 112 | 65124 | } | |
| 113 | 65124 | } | |
| 114 | 65124 | block_heading_written = false; | |
| 115 | 65124 | } | |
| 116 | 57942 | } | |
| 117 | |||
| 118 |
3/6✓ Branch 0 taken 26430 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 26544 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4896 times.
✗ Branch 5 not taken.
|
57870 | const bool success = handler.success(full_height * full_width); |
| 119 |
6/6✓ Branch 0 taken 8338 times.
✓ Branch 1 taken 18092 times.
✓ Branch 2 taken 8520 times.
✓ Branch 3 taken 18024 times.
✓ Branch 4 taken 2448 times.
✓ Branch 5 taken 2448 times.
|
57870 | if (!success) { |
| 120 | ✗ | KAI_LOGE("mismatches:\n", sstream.str()); | |
| 121 | ✗ | } | |
| 122 | 57870 | return success; | |
| 123 | 57870 | } | |
| 124 | |||
| 125 | /// Compares matrices with per-row bias or per-row quantization. | ||
| 126 | template <typename Data, typename Scale, typename Offset> | ||
| 127 | 2400 | bool compare_per_row( | |
| 128 | const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, | ||
| 129 | const Rect& rect, MismatchHandler& handler) { | ||
| 130 | static constexpr auto has_scale = !std::is_null_pointer_v<Scale>; | ||
| 131 | |||
| 132 | 2400 | const auto block_height = format.actual_block_height(full_height); | |
| 133 | 2400 | const auto block_width = format.actual_block_width(full_width); | |
| 134 | 2400 | const auto subblock_height = format.actual_subblock_height(full_height); | |
| 135 | 2400 | const auto subblock_width = format.actual_subblock_width(full_width); | |
| 136 | |||
| 137 | − | KAI_ASSUME_ALWAYS(format.scheduler_block_height(full_height) == block_height); | |
| 138 | − | KAI_ASSUME_ALWAYS(format.scheduler_block_width(full_width) == full_width); | |
| 139 | − | KAI_ASSUME_ALWAYS(rect.start_col() == 0); | |
| 140 | − | KAI_ASSUME_ALWAYS(rect.width() == full_width); | |
| 141 | |||
| 142 | 2400 | const size_t row_block_zero_points_bytes = block_height * sizeof(Offset); | |
| 143 | 2400 | const size_t row_block_scales_bytes = has_scale ? block_height * sizeof(Scale) : 0; | |
| 144 | 2400 | const size_t row_block_data_bytes = block_height * round_up_multiple(full_width, block_width) * sizeof(Data); | |
| 145 | |||
| 146 | 2400 | const auto* imp_ptr = reinterpret_cast<const uint8_t*>(imp_data); | |
| 147 | 2400 | const auto* ref_ptr = reinterpret_cast<const uint8_t*>(ref_data); | |
| 148 | |||
| 149 |
4/10✓ Branch 0 taken 1068 times.
✓ Branch 1 taken 3648 times.
✓ Branch 2 taken 1332 times.
✓ Branch 3 taken 4695 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
10743 | for (size_t y_block = 0; y_block < full_height; y_block += block_height) { |
| 150 |
4/10✓ Branch 0 taken 108 times.
✓ Branch 1 taken 3540 times.
✓ Branch 2 taken 288 times.
✓ Branch 3 taken 4407 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
8343 | const auto in_roi = y_block >= rect.start_row() && y_block < rect.end_row(); |
| 151 | |||
| 152 | // Checks the zero points. | ||
| 153 |
4/10✓ Branch 0 taken 112704 times.
✓ Branch 1 taken 3648 times.
✓ Branch 2 taken 180000 times.
✓ Branch 3 taken 4695 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
301047 | for (size_t i = 0; i < block_height; ++i) { |
| 154 | 292704 | const auto imp_zero_point = reinterpret_cast<const Offset*>(imp_ptr)[i]; | |
| 155 |
4/10✓ Branch 0 taken 105696 times.
✓ Branch 1 taken 7008 times.
✓ Branch 2 taken 170400 times.
✓ Branch 3 taken 9600 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
292704 | const Offset ref_zero_point = in_roi ? reinterpret_cast<const Offset*>(ref_ptr)[i] : static_cast<Offset>(0); |
| 156 | 292704 | const auto [abs_err, rel_err] = calculate_error(imp_zero_point, ref_zero_point); | |
| 157 | |||
| 158 |
4/20✓ Branch 0 taken 112704 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 112704 times.
✓ Branch 4 taken 180000 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 180000 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
|
292704 | if (abs_err != 0 || rel_err != 0) { |
| 159 | ✗ | handler.mark_as_failed(); | |
| 160 | |||
| 161 | ✗ | const auto raw_row = y_block + i; | |
| 162 | ✗ | KAI_LOGE( | |
| 163 | "Mismatched zero point ", raw_row, ": actual = ", imp_zero_point, ", expected: ", ref_zero_point); | ||
| 164 | ✗ | } | |
| 165 | 292704 | } | |
| 166 | |||
| 167 | 8343 | imp_ptr += row_block_zero_points_bytes; | |
| 168 | 8343 | ref_ptr += row_block_zero_points_bytes; | |
| 169 | |||
| 170 | // Checks the data. | ||
| 171 |
4/10✓ Branch 0 taken 67104 times.
✓ Branch 1 taken 3648 times.
✓ Branch 2 taken 134703 times.
✓ Branch 3 taken 4695 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
210150 | for (size_t x_block = 0; x_block < full_width; x_block += block_width) { |
| 172 |
4/10✓ Branch 0 taken 67104 times.
✓ Branch 1 taken 67104 times.
✓ Branch 2 taken 134703 times.
✓ Branch 3 taken 134703 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
403614 | for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { |
| 173 |
4/10✓ Branch 0 taken 96858 times.
✓ Branch 1 taken 67104 times.
✓ Branch 2 taken 204696 times.
✓ Branch 3 taken 134703 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
503361 | for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { |
| 174 |
4/10✓ Branch 0 taken 2953440 times.
✓ Branch 1 taken 96858 times.
✓ Branch 2 taken 8000448 times.
✓ Branch 3 taken 204696 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
11255442 | for (size_t y = 0; y < subblock_height; ++y) { |
| 175 |
4/10✓ Branch 0 taken 5071968 times.
✓ Branch 1 taken 2953440 times.
✓ Branch 2 taken 8000448 times.
✓ Branch 3 taken 8000448 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
24026304 | for (size_t x = 0; x < subblock_width; ++x) { |
| 176 | 13072416 | const auto offset = (y_subblock + y) * full_width + x_block + x_subblock + x; | |
| 177 | 13072416 | const auto imp_value = read_array<Data>(imp_ptr, offset); | |
| 178 |
4/10✓ Branch 0 taken 4734144 times.
✓ Branch 1 taken 337824 times.
✓ Branch 2 taken 7582032 times.
✓ Branch 3 taken 418416 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
13072416 | const Data ref_value = in_roi ? read_array<Data>(ref_ptr, offset) : static_cast<Data>(0); |
| 179 | 13072416 | const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); | |
| 180 | |||
| 181 |
4/20✓ Branch 0 taken 5071968 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5071968 times.
✓ Branch 4 taken 8000448 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 8000448 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
|
13072416 | if (abs_err != 0 || rel_err != 0) { |
| 182 | ✗ | if (!in_roi) { | |
| 183 | ✗ | handler.mark_as_failed(); | |
| 184 | ✗ | } | |
| 185 | |||
| 186 | ✗ | const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); | |
| 187 | |||
| 188 | ✗ | if (notifying) { | |
| 189 | ✗ | const auto raw_index = y_block * block_height * block_width + offset; | |
| 190 | ✗ | KAI_LOGE( | |
| 191 | "Mismatched data ", raw_index, ": actual = ", imp_value, | ||
| 192 | ", expected: ", ref_value); | ||
| 193 | ✗ | } | |
| 194 | ✗ | } | |
| 195 | 13072416 | } | |
| 196 | 10953888 | } | |
| 197 | 301554 | } | |
| 198 | 201807 | } | |
| 199 | 201807 | } | |
| 200 | |||
| 201 | 8343 | imp_ptr += row_block_data_bytes; | |
| 202 | 8343 | ref_ptr += row_block_data_bytes; | |
| 203 | |||
| 204 | // Checks the scales (if exists). | ||
| 205 | if constexpr (has_scale) { | ||
| 206 | ✗ | for (size_t i = 0; i < block_height; ++i) { | |
| 207 | ✗ | const auto imp_scale = reinterpret_cast<const Scale*>(imp_ptr)[i]; | |
| 208 | ✗ | const Scale ref_scale = in_roi ? reinterpret_cast<const Scale*>(ref_ptr)[i] : 0; | |
| 209 | ✗ | const auto [abs_err, rel_err] = calculate_error(imp_scale, ref_scale); | |
| 210 | |||
| 211 | ✗ | if (abs_err != 0 || rel_err != 0) { | |
| 212 | ✗ | handler.mark_as_failed(); | |
| 213 | |||
| 214 | ✗ | const auto raw_row = y_block + i; | |
| 215 | ✗ | KAI_LOGE( | |
| 216 | "Mismatched quantization scale ", raw_row, ": actual = ", imp_scale, ", expected: ", ref_scale); | ||
| 217 | ✗ | } | |
| 218 | ✗ | } | |
| 219 | |||
| 220 | ✗ | imp_ptr += row_block_scales_bytes; | |
| 221 | ✗ | ref_ptr += row_block_scales_bytes; | |
| 222 | } | ||
| 223 | 8343 | } | |
| 224 | |||
| 225 | 4800 | return handler.success(rect.height() * full_width); | |
| 226 | 2400 | } | |
| 227 | |||
| 228 | } // namespace | ||
| 229 | |||
| 230 | 60270 | bool compare( | |
| 231 | const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, | ||
| 232 | const Rect& rect, MismatchHandler& handler) { | ||
| 233 | 60270 | const auto data_type = format.data_type(); | |
| 234 | 60270 | const auto scale_dt = format.scale_data_type(); | |
| 235 | 60270 | const auto offset_dt = format.zero_point_data_type(); | |
| 236 | |||
| 237 |
3/4✗ Branch 0 not taken.
✓ Branch 1 taken 39212 times.
✓ Branch 2 taken 1752 times.
✓ Branch 3 taken 19306 times.
|
60270 | switch (format.pack_format()) { |
| 238 | case DataFormat::PackFormat::NONE: | ||
| 239 |
3/4✓ Branch 0 taken 26544 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 26430 times.
✓ Branch 3 taken 4896 times.
|
57870 | switch (data_type) { |
| 240 | case DataType::FP32: | ||
| 241 | 26430 | return compare_raw<float>(imp_data, ref_data, format, full_height, full_width, rect, handler); | |
| 242 | |||
| 243 | case DataType::FP16: | ||
| 244 | 26544 | return compare_raw<Float16>(imp_data, ref_data, format, full_height, full_width, rect, handler); | |
| 245 | |||
| 246 | case DataType::BF16: | ||
| 247 | 4896 | return compare_raw<BFloat16<>>(imp_data, ref_data, format, full_height, full_width, rect, handler); | |
| 248 | |||
| 249 | default: | ||
| 250 | ✗ | break; | |
| 251 | } | ||
| 252 | |||
| 253 | ✗ | break; | |
| 254 | |||
| 255 | case DataFormat::PackFormat::BIAS_PER_ROW: | ||
| 256 |
3/4✓ Branch 0 taken 1068 times.
✓ Branch 1 taken 1332 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1068 times.
|
2400 | if (data_type == DataType::FP16 && offset_dt == DataType::FP16) { |
| 257 | 1068 | return compare_per_row<Float16, std::nullptr_t, Float16>( | |
| 258 | 1068 | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
| 259 |
2/4✓ Branch 0 taken 1332 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1332 times.
|
1332 | } else if (data_type == DataType::FP32 && offset_dt == DataType::FP32) { |
| 260 | 1332 | return compare_per_row<float, std::nullptr_t, float>( | |
| 261 | 1332 | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
| 262 | ✗ | } else if (data_type == DataType::BF16 && offset_dt == DataType::FP32) { | |
| 263 | ✗ | return compare_per_row<BFloat16<>, std::nullptr_t, float>( | |
| 264 | ✗ | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
| 265 | } | ||
| 266 | |||
| 267 | ✗ | break; | |
| 268 | |||
| 269 | case DataFormat::PackFormat::QUANTIZE_PER_ROW: | ||
| 270 | ✗ | if (data_type_is_quantized_int8(data_type) && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { | |
| 271 | ✗ | return compare_per_row<int8_t, float, int32_t>( | |
| 272 | ✗ | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
| 273 | } else if ( | ||
| 274 | ✗ | data_type_is_quantized_int4(data_type) && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { | |
| 275 | ✗ | return compare_per_row<Int4, float, int32_t>( | |
| 276 | ✗ | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
| 277 | } | ||
| 278 | |||
| 279 | ✗ | break; | |
| 280 | |||
| 281 | default: | ||
| 282 | ✗ | break; | |
| 283 | } | ||
| 284 | |||
| 285 | − | KAI_ERROR("Unsupported format!"); | |
| 286 | 60270 | } | |
| 287 | |||
| 288 | // ===================================================================================================================== | ||
| 289 | |||
| 290 | 104186 | DefaultMismatchHandler::DefaultMismatchHandler( | |
| 291 | float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, | ||
| 292 | float rel_mismatched_threshold) : | ||
| 293 | 62070 | _abs_error_threshold(abs_error_threshold), | |
| 294 | 62070 | _rel_error_threshold(rel_error_threshold), | |
| 295 | 62070 | _abs_mismatched_threshold(abs_mismatched_threshold), | |
| 296 | 62070 | _rel_mismatched_threshold(rel_mismatched_threshold), | |
| 297 | 62070 | _num_mismatches(0), | |
| 298 | 166256 | _failed(false) { | |
| 299 | 104186 | } | |
| 300 | |||
| 301 | ✗ | DefaultMismatchHandler::DefaultMismatchHandler(const DefaultMismatchHandler& rhs) : | |
| 302 | ✗ | _abs_error_threshold(rhs._abs_error_threshold), | |
| 303 | ✗ | _rel_error_threshold(rhs._rel_error_threshold), | |
| 304 | ✗ | _abs_mismatched_threshold(rhs._abs_mismatched_threshold), | |
| 305 | ✗ | _rel_mismatched_threshold(rhs._rel_mismatched_threshold), | |
| 306 | ✗ | _num_mismatches(0), | |
| 307 | ✗ | _failed(false) { | |
| 308 | // Cannot copy mismatch handler that is already in use. | ||
| 309 | − | KAI_ASSUME_ALWAYS(rhs._num_mismatches == 0); | |
| 310 | − | KAI_ASSUME_ALWAYS(!rhs._failed); | |
| 311 | ✗ | } | |
| 312 | |||
| 313 | ✗ | DefaultMismatchHandler& DefaultMismatchHandler::operator=(const DefaultMismatchHandler& rhs) { | |
| 314 | ✗ | if (this != &rhs) { | |
| 315 | // Cannot copy mismatch handler that is already in use. | ||
| 316 | − | KAI_ASSUME_ALWAYS(rhs._num_mismatches == 0); | |
| 317 | − | KAI_ASSUME_ALWAYS(!rhs._failed); | |
| 318 | |||
| 319 | ✗ | _abs_error_threshold = rhs._abs_error_threshold; | |
| 320 | ✗ | _rel_error_threshold = rhs._rel_error_threshold; | |
| 321 | ✗ | _abs_mismatched_threshold = rhs._abs_mismatched_threshold; | |
| 322 | ✗ | _rel_mismatched_threshold = rhs._rel_mismatched_threshold; | |
| 323 | ✗ | } | |
| 324 | |||
| 325 | ✗ | return *this; | |
| 326 | } | ||
| 327 | |||
| 328 | 48073517 | bool DefaultMismatchHandler::handle_data(float absolute_error, float relative_error) { | |
| 329 |
2/2✓ Branch 0 taken 1028615 times.
✓ Branch 1 taken 47044902 times.
|
48073517 | const auto mismatched = absolute_error > _abs_error_threshold && relative_error > _rel_error_threshold; |
| 330 | |||
| 331 |
2/2✓ Branch 0 taken 48072522 times.
✓ Branch 1 taken 995 times.
|
48073517 | if (mismatched) { |
| 332 | 995 | ++_num_mismatches; | |
| 333 | 995 | } | |
| 334 | |||
| 335 | 96147034 | return mismatched; | |
| 336 | 48073517 | } | |
| 337 | |||
| 338 | ✗ | void DefaultMismatchHandler::mark_as_failed() { | |
| 339 | ✗ | _failed = true; | |
| 340 | ✗ | } | |
| 341 | |||
| 342 | 62070 | bool DefaultMismatchHandler::success(size_t num_checks) const { | |
| 343 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 62070 times.
|
62070 | if (_failed) { |
| 344 | ✗ | return false; | |
| 345 | } | ||
| 346 | |||
| 347 | 62070 | const auto mismatched_rate = static_cast<float>(_num_mismatches) / static_cast<float>(num_checks); | |
| 348 |
2/2✓ Branch 0 taken 61898 times.
✓ Branch 1 taken 172 times.
|
62070 | return _num_mismatches <= _abs_mismatched_threshold || mismatched_rate <= _rel_mismatched_threshold; |
| 349 | 62070 | } | |
| 350 | |||
| 351 | } // namespace kai::test | ||
| 352 |