Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/common/compare.hpp" | ||
8 | |||
9 | #include <cstddef> | ||
10 | #include <cstdint> | ||
11 | #include <cstdlib> | ||
12 | #include <sstream> | ||
13 | #include <tuple> | ||
14 | #include <type_traits> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | #include "test/common/bfloat16.hpp" | ||
18 | #include "test/common/data_format.hpp" | ||
19 | #include "test/common/data_type.hpp" | ||
20 | #include "test/common/float16.hpp" | ||
21 | #include "test/common/int4.hpp" | ||
22 | #include "test/common/logging.hpp" | ||
23 | #include "test/common/memory.hpp" | ||
24 | #include "test/common/printer.hpp" | ||
25 | #include "test/common/rect.hpp" | ||
26 | #include "test/common/round.hpp" | ||
27 | |||
28 | namespace kai::test { | ||
29 | |||
30 | namespace { | ||
31 | |||
32 | /// Calculates the absolute and relative errors. | ||
33 | /// | ||
34 | /// @param[in] imp Value under test. | ||
35 | /// @param[in] ref Reference value. | ||
36 | /// | ||
37 | /// @return The absolute error and relative error. | ||
38 | template <typename T> | ||
39 | 120119698 | std::tuple<float, float> calculate_error(T imp, T ref) { | |
40 | 120119698 | const auto imp_f = static_cast<float>(imp); | |
41 | 120119698 | const auto ref_f = static_cast<float>(ref); | |
42 | |||
43 | 120119698 | const auto abs_error = std::abs(imp_f - ref_f); | |
44 |
6/12✓ Branch 0 taken 58355864 times.
✓ Branch 1 taken 20016612 times.
✓ Branch 2 taken 24431130 times.
✓ Branch 3 taken 14540032 times.
✓ Branch 4 taken 1241360 times.
✓ Branch 5 taken 1534700 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
120119698 | const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0F; |
45 | |||
46 | 120119698 | return {abs_error, rel_error}; | |
47 | 120119698 | } | |
48 | |||
49 | /// Compares matrices with per-row quantization. | ||
50 | template <typename Data> | ||
51 | 33190 | bool compare_raw( | |
52 | const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, | ||
53 | const Rect& rect, MismatchHandler& handler) { | ||
54 | 33190 | const size_t block_height = format.actual_block_height(full_height); | |
55 | 33190 | const size_t block_width = format.actual_block_width(full_width); | |
56 | 33190 | const size_t subblock_height = format.actual_subblock_height(full_height); | |
57 | 33190 | const size_t subblock_width = format.actual_subblock_width(full_width); | |
58 | |||
59 | 33190 | size_t idx = 0; | |
60 | |||
61 | 33190 | bool block_heading_written = false; | |
62 | 33190 | bool subblock_heading_written = false; | |
63 | 33190 | bool row_heading_written = false; | |
64 | 33190 | std::ostringstream sstream; | |
65 |
6/6✓ Branch 0 taken 16756 times.
✓ Branch 1 taken 16744 times.
✓ Branch 2 taken 15390 times.
✓ Branch 3 taken 15378 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 1068 times.
|
66404 | for (size_t y_block = 0; y_block < full_height; y_block += block_height) { |
66 |
6/6✓ Branch 0 taken 18334 times.
✓ Branch 1 taken 16756 times.
✓ Branch 2 taken 15390 times.
✓ Branch 3 taken 16162 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 1068 times.
|
68778 | for (size_t x_block = 0; x_block < full_width; x_block += block_width) { |
67 |
6/6✓ Branch 0 taken 18334 times.
✓ Branch 1 taken 18334 times.
✓ Branch 2 taken 16162 times.
✓ Branch 3 taken 16162 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 1068 times.
|
71128 | for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { |
68 |
6/6✓ Branch 0 taken 18334 times.
✓ Branch 1 taken 18334 times.
✓ Branch 2 taken 16162 times.
✓ Branch 3 taken 16162 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 1068 times.
|
71128 | for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { |
69 |
6/6✓ Branch 0 taken 18334 times.
✓ Branch 1 taken 596681 times.
✓ Branch 2 taken 16162 times.
✓ Branch 3 taken 495092 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 25388 times.
|
1152725 | for (size_t y_element = 0; y_element < subblock_height; ++y_element) { |
70 |
6/6✓ Branch 0 taken 75881580 times.
✓ Branch 1 taken 596681 times.
✓ Branch 2 taken 37450874 times.
✓ Branch 3 taken 495092 times.
✓ Branch 4 taken 2776060 times.
✓ Branch 5 taken 25388 times.
|
117225675 | for (size_t x_element = 0; x_element < subblock_width; ++x_element) { |
71 | 116108514 | const auto y = y_block + y_subblock + y_element; | |
72 | 116108514 | const auto x = x_block + x_subblock + x_element; | |
73 | |||
74 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 75881580 times.
✓ Branch 2 taken 37450874 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2776060 times.
✗ Branch 5 not taken.
|
116108514 | const auto in_roi = rect.contains(y, x); |
75 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 75881580 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 37450874 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2776060 times.
|
116108514 | const auto imp_value = read_array<Data>(imp_data, idx); |
76 |
11/16✓ Branch 0 taken 56326618 times.
✓ Branch 1 taken 19554962 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 56326618 times.
✓ Branch 4 taken 14411534 times.
✓ Branch 5 taken 23039340 times.
✓ Branch 6 taken 23039340 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 14411534 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1534700 times.
✓ Branch 11 taken 1241360 times.
✓ Branch 12 taken 1241360 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1534700 times.
✗ Branch 15 not taken.
|
116108514 | const auto ref_value = in_roi ? read_array<Data>(ref_data, idx) : static_cast<Data>(0); |
77 | |||
78 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 75881580 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 37450874 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2776060 times.
|
191848864 | const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); |
79 | |||
80 |
9/12✓ Branch 0 taken 50878608 times.
✓ Branch 1 taken 25002972 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 50878608 times.
✓ Branch 4 taken 25202102 times.
✓ Branch 5 taken 12248772 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 25202102 times.
✓ Branch 8 taken 2157629 times.
✓ Branch 9 taken 618431 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2157629 times.
|
116108514 | if (abs_err != 0 || rel_err != 0) { |
81 |
3/6✓ Branch 0 taken 25002972 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12248772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 618431 times.
✗ Branch 5 not taken.
|
37870175 | if (!in_roi) { |
82 | ✗ | handler.mark_as_failed(); | |
83 | ✗ | } | |
84 | |||
85 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 25002972 times.
✓ Branch 2 taken 25002972 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25002972 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 25002972 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 12248772 times.
✓ Branch 10 taken 12248772 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 12248772 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 12248772 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 618431 times.
✓ Branch 18 taken 618431 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 618431 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 618431 times.
✗ Branch 23 not taken.
|
37870175 | const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); |
86 | |||
87 |
5/6✓ Branch 0 taken 13406 times.
✓ Branch 1 taken 24989566 times.
✓ Branch 2 taken 13406 times.
✓ Branch 3 taken 12235366 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 618431 times.
|
37870175 | if (notifying) { |
88 |
4/6✓ Branch 0 taken 360 times.
✓ Branch 1 taken 13046 times.
✓ Branch 2 taken 360 times.
✓ Branch 3 taken 13046 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
26812 | if (!block_heading_written) { |
89 |
10/30✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 360 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 360 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 360 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 360 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 360 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 360 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
|
720 | sstream << "block @ (" << y_block << ", " << x_block << "):\n"; |
90 | 720 | block_heading_written = true; | |
91 | 720 | } | |
92 |
4/6✓ Branch 0 taken 360 times.
✓ Branch 1 taken 13046 times.
✓ Branch 2 taken 360 times.
✓ Branch 3 taken 13046 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
26812 | if (!subblock_heading_written) { |
93 |
10/30✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 360 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 360 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 360 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 360 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 360 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 360 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
|
720 | sstream << " sub-block @ (" << y_subblock << ", " << x_subblock << "):\n"; |
94 | 720 | subblock_heading_written = true; | |
95 | 720 | } | |
96 |
4/6✓ Branch 0 taken 5010 times.
✓ Branch 1 taken 8396 times.
✓ Branch 2 taken 5010 times.
✓ Branch 3 taken 8396 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
26812 | if (!row_heading_written) { |
97 |
6/18✓ Branch 0 taken 5010 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5010 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5010 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 5010 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5010 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5010 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
|
10020 | sstream << " row=" << y_element << ": "; |
98 | 10020 | row_heading_written = true; | |
99 | 10020 | } | |
100 |
4/12✓ Branch 0 taken 13406 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13406 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13406 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 13406 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
26812 | sstream << x_element << ", "; |
101 | 26812 | } | |
102 | 37870175 | } | |
103 | |||
104 | 116108514 | ++idx; | |
105 | 116108514 | } | |
106 |
5/6✓ Branch 0 taken 591671 times.
✓ Branch 1 taken 5010 times.
✓ Branch 2 taken 490082 times.
✓ Branch 3 taken 5010 times.
✓ Branch 4 taken 25388 times.
✗ Branch 5 not taken.
|
1117161 | if (row_heading_written) { |
107 |
2/6✓ Branch 0 taken 5010 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5010 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
10020 | sstream << "\n"; |
108 | 10020 | } | |
109 | 1117161 | row_heading_written = false; | |
110 | 1117161 | } | |
111 | 35564 | subblock_heading_written = false; | |
112 | 35564 | } | |
113 | 35564 | } | |
114 | 35564 | block_heading_written = false; | |
115 | 35564 | } | |
116 | 33214 | } | |
117 | |||
118 |
3/6✓ Branch 0 taken 16744 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15378 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1068 times.
✗ Branch 5 not taken.
|
33190 | const bool success = handler.success(full_height * full_width); |
119 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 16744 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 15378 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1068 times.
|
33190 | if (!success) { |
120 | ✗ | KAI_LOGE("mismatches:\n", sstream.str()); | |
121 | ✗ | } | |
122 | 33190 | return success; | |
123 | 33190 | } | |
124 | |||
125 | /// Compares matrices with per-row bias or per-row quantization. | ||
126 | template <typename Data, typename Scale, typename Offset> | ||
127 | 468 | bool compare_per_row( | |
128 | const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, | ||
129 | const Rect& rect, MismatchHandler& handler) { | ||
130 | static constexpr auto has_scale = !std::is_null_pointer_v<Scale>; | ||
131 | |||
132 | 468 | const auto block_height = format.actual_block_height(full_height); | |
133 | 468 | const auto block_width = format.actual_block_width(full_width); | |
134 | 468 | const auto subblock_height = format.actual_subblock_height(full_height); | |
135 | 468 | const auto subblock_width = format.actual_subblock_width(full_width); | |
136 | |||
137 | − | KAI_ASSUME(format.scheduler_block_height(full_height) == block_height); | |
138 | − | KAI_ASSUME(format.scheduler_block_width(full_width) == full_width); | |
139 | − | KAI_ASSUME(rect.start_col() == 0); | |
140 | − | KAI_ASSUME(rect.width() == full_width); | |
141 | |||
142 | 468 | const size_t row_block_zero_points_bytes = block_height * sizeof(Offset); | |
143 | 468 | const size_t row_block_scales_bytes = has_scale ? block_height * sizeof(Scale) : 0; | |
144 | 468 | const size_t row_block_data_bytes = block_height * round_up_multiple(full_width, block_width) * sizeof(Data); | |
145 | |||
146 | 468 | const auto* imp_ptr = reinterpret_cast<const uint8_t*>(imp_data); | |
147 | 468 | const auto* ref_ptr = reinterpret_cast<const uint8_t*>(ref_data); | |
148 | |||
149 |
4/10✓ Branch 0 taken 211 times.
✓ Branch 1 taken 993 times.
✓ Branch 2 taken 257 times.
✓ Branch 3 taken 1141 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
2602 | for (size_t y_block = 0; y_block < full_height; y_block += block_height) { |
150 |
4/10✓ Branch 0 taken 18 times.
✓ Branch 1 taken 975 times.
✓ Branch 2 taken 36 times.
✓ Branch 3 taken 1105 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
2134 | const auto in_roi = y_block >= rect.start_row() && y_block < rect.end_row(); |
151 | |||
152 | // Checks the zero points. | ||
153 |
4/10✓ Branch 0 taken 31120 times.
✓ Branch 1 taken 993 times.
✓ Branch 2 taken 51720 times.
✓ Branch 3 taken 1141 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
84974 | for (size_t i = 0; i < block_height; ++i) { |
154 | 82840 | const auto imp_zero_point = reinterpret_cast<const Offset*>(imp_ptr)[i]; | |
155 |
4/10✓ Branch 0 taken 29952 times.
✓ Branch 1 taken 1168 times.
✓ Branch 2 taken 50392 times.
✓ Branch 3 taken 1328 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
82840 | const Offset ref_zero_point = in_roi ? reinterpret_cast<const Offset*>(ref_ptr)[i] : static_cast<Offset>(0); |
156 | 82840 | const auto [abs_err, rel_err] = calculate_error(imp_zero_point, ref_zero_point); | |
157 | |||
158 |
4/20✓ Branch 0 taken 31120 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 31120 times.
✓ Branch 4 taken 51720 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 51720 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
|
82840 | if (abs_err != 0 || rel_err != 0) { |
159 | ✗ | handler.mark_as_failed(); | |
160 | |||
161 | ✗ | const auto raw_row = y_block + i; | |
162 | ✗ | KAI_LOGE( | |
163 | "Mismatched zero point ", raw_row, ": actual = ", imp_zero_point, ", expected: ", ref_zero_point); | ||
164 | ✗ | } | |
165 | 82840 | } | |
166 | |||
167 | 2134 | imp_ptr += row_block_zero_points_bytes; | |
168 | 2134 | ref_ptr += row_block_zero_points_bytes; | |
169 | |||
170 | // Checks the data. | ||
171 |
4/10✓ Branch 0 taken 22109 times.
✓ Branch 1 taken 993 times.
✓ Branch 2 taken 44477 times.
✓ Branch 3 taken 1141 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
68720 | for (size_t x_block = 0; x_block < full_width; x_block += block_width) { |
172 |
4/10✓ Branch 0 taken 22109 times.
✓ Branch 1 taken 22109 times.
✓ Branch 2 taken 44477 times.
✓ Branch 3 taken 44477 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
133172 | for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { |
173 |
4/10✓ Branch 0 taken 25271 times.
✓ Branch 1 taken 22109 times.
✓ Branch 2 taken 53617 times.
✓ Branch 3 taken 44477 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
145474 | for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { |
174 |
4/10✓ Branch 0 taken 784656 times.
✓ Branch 1 taken 25271 times.
✓ Branch 2 taken 2439176 times.
✓ Branch 3 taken 53617 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
3302720 | for (size_t y = 0; y < subblock_height; ++y) { |
175 |
4/10✓ Branch 0 taken 1489168 times.
✓ Branch 1 taken 784656 times.
✓ Branch 2 taken 2439176 times.
✓ Branch 3 taken 2439176 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
7152176 | for (size_t x = 0; x < subblock_width; ++x) { |
176 | 3928344 | const auto offset = (y_subblock + y) * full_width + x_block + x_subblock + x; | |
177 | 3928344 | const auto imp_value = read_array<Data>(imp_ptr, offset); | |
178 |
4/10✓ Branch 0 taken 1432800 times.
✓ Branch 1 taken 56368 times.
✓ Branch 2 taken 2379440 times.
✓ Branch 3 taken 59736 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
3928344 | const Data ref_value = in_roi ? read_array<Data>(ref_ptr, offset) : static_cast<Data>(0); |
179 | 3928344 | const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value); | |
180 | |||
181 |
4/20✓ Branch 0 taken 1489168 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1489168 times.
✓ Branch 4 taken 2439176 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2439176 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
|
3928344 | if (abs_err != 0 || rel_err != 0) { |
182 | ✗ | if (!in_roi) { | |
183 | ✗ | handler.mark_as_failed(); | |
184 | ✗ | } | |
185 | |||
186 | ✗ | const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err); | |
187 | |||
188 | ✗ | if (notifying) { | |
189 | ✗ | const auto raw_index = y_block * block_height * block_width + offset; | |
190 | ✗ | KAI_LOGE( | |
191 | "Mismatched data ", raw_index, ": actual = ", imp_value, | ||
192 | ", expected: ", ref_value); | ||
193 | ✗ | } | |
194 | ✗ | } | |
195 | 3928344 | } | |
196 | 3223832 | } | |
197 | 78888 | } | |
198 | 66586 | } | |
199 | 66586 | } | |
200 | |||
201 | 2134 | imp_ptr += row_block_data_bytes; | |
202 | 2134 | ref_ptr += row_block_data_bytes; | |
203 | |||
204 | // Checks the scales (if exists). | ||
205 | if constexpr (has_scale) { | ||
206 | ✗ | for (size_t i = 0; i < block_height; ++i) { | |
207 | ✗ | const auto imp_scale = reinterpret_cast<const Scale*>(imp_ptr)[i]; | |
208 | ✗ | const Scale ref_scale = in_roi ? reinterpret_cast<const Scale*>(ref_ptr)[i] : 0; | |
209 | ✗ | const auto [abs_err, rel_err] = calculate_error(imp_scale, ref_scale); | |
210 | |||
211 | ✗ | if (abs_err != 0 || rel_err != 0) { | |
212 | ✗ | handler.mark_as_failed(); | |
213 | |||
214 | ✗ | const auto raw_row = y_block + i; | |
215 | ✗ | KAI_LOGE( | |
216 | "Mismatched quantization scale ", raw_row, ": actual = ", imp_scale, ", expected: ", ref_scale); | ||
217 | ✗ | } | |
218 | ✗ | } | |
219 | |||
220 | ✗ | imp_ptr += row_block_scales_bytes; | |
221 | ✗ | ref_ptr += row_block_scales_bytes; | |
222 | } | ||
223 | 2134 | } | |
224 | |||
225 | 936 | return handler.success(rect.height() * full_width); | |
226 | 468 | } | |
227 | |||
228 | } // namespace | ||
229 | |||
230 | 33658 | bool compare( | |
231 | const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width, | ||
232 | const Rect& rect, MismatchHandler& handler) { | ||
233 | 33658 | const auto data_type = format.data_type(); | |
234 | 33658 | const auto scale_dt = format.scale_data_type(); | |
235 | 33658 | const auto offset_dt = format.zero_point_data_type(); | |
236 | |||
237 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 33190 times.
✓ Branch 2 taken 468 times.
✗ Branch 3 not taken.
|
33658 | switch (format.pack_format()) { |
238 | case DataFormat::PackFormat::NONE: | ||
239 |
3/4✓ Branch 0 taken 15378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 16744 times.
✓ Branch 3 taken 1068 times.
|
33190 | switch (data_type) { |
240 | case DataType::FP32: | ||
241 | 16744 | return compare_raw<float>(imp_data, ref_data, format, full_height, full_width, rect, handler); | |
242 | |||
243 | case DataType::FP16: | ||
244 | 15378 | return compare_raw<Float16>(imp_data, ref_data, format, full_height, full_width, rect, handler); | |
245 | |||
246 | case DataType::BF16: | ||
247 | 1068 | return compare_raw<BFloat16<>>(imp_data, ref_data, format, full_height, full_width, rect, handler); | |
248 | |||
249 | default: | ||
250 | ✗ | break; | |
251 | } | ||
252 | |||
253 | ✗ | break; | |
254 | |||
255 | case DataFormat::PackFormat::BIAS_PER_ROW: | ||
256 |
3/4✓ Branch 0 taken 211 times.
✓ Branch 1 taken 257 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 211 times.
|
468 | if (data_type == DataType::FP16 && offset_dt == DataType::FP16) { |
257 | 211 | return compare_per_row<Float16, std::nullptr_t, Float16>( | |
258 | 211 | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
259 |
2/4✓ Branch 0 taken 257 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 257 times.
|
257 | } else if (data_type == DataType::FP32 && offset_dt == DataType::FP32) { |
260 | 257 | return compare_per_row<float, std::nullptr_t, float>( | |
261 | 257 | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
262 | ✗ | } else if (data_type == DataType::BF16 && offset_dt == DataType::FP32) { | |
263 | ✗ | return compare_per_row<BFloat16<>, std::nullptr_t, float>( | |
264 | ✗ | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
265 | } | ||
266 | |||
267 | ✗ | break; | |
268 | |||
269 | case DataFormat::PackFormat::QUANTIZE_PER_ROW: | ||
270 | ✗ | if (data_type == DataType::QAI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { | |
271 | ✗ | return compare_per_row<int8_t, float, int32_t>( | |
272 | ✗ | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
273 | ✗ | } else if (data_type == DataType::QSI4 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) { | |
274 | ✗ | return compare_per_row<Int4, float, int32_t>( | |
275 | ✗ | imp_data, ref_data, format, full_height, full_width, rect, handler); | |
276 | } | ||
277 | |||
278 | ✗ | break; | |
279 | |||
280 | default: | ||
281 | ✗ | break; | |
282 | } | ||
283 | |||
284 | − | KAI_ERROR("Unsupported format!"); | |
285 | 33658 | } | |
286 | |||
287 | // ===================================================================================================================== | ||
288 | |||
289 | 65912 | DefaultMismatchHandler::DefaultMismatchHandler( | |
290 | float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold, | ||
291 | float rel_mismatched_threshold) : | ||
292 | 32956 | _abs_error_threshold(abs_error_threshold), | |
293 | 32956 | _rel_error_threshold(rel_error_threshold), | |
294 | 32956 | _abs_mismatched_threshold(abs_mismatched_threshold), | |
295 | 32956 | _rel_mismatched_threshold(rel_mismatched_threshold), | |
296 | 32956 | _num_mismatches(0), | |
297 | 98868 | _failed(false) { | |
298 | 65912 | } | |
299 | |||
300 | ✗ | DefaultMismatchHandler::DefaultMismatchHandler(const DefaultMismatchHandler& rhs) : | |
301 | ✗ | _abs_error_threshold(rhs._abs_error_threshold), | |
302 | ✗ | _rel_error_threshold(rhs._rel_error_threshold), | |
303 | ✗ | _abs_mismatched_threshold(rhs._abs_mismatched_threshold), | |
304 | ✗ | _rel_mismatched_threshold(rhs._rel_mismatched_threshold), | |
305 | ✗ | _num_mismatches(0), | |
306 | ✗ | _failed(false) { | |
307 | // Cannot copy mismatch handler that is already in use. | ||
308 | − | KAI_ASSUME(rhs._num_mismatches == 0); | |
309 | − | KAI_ASSUME(!rhs._failed); | |
310 | ✗ | } | |
311 | |||
312 | ✗ | DefaultMismatchHandler& DefaultMismatchHandler::operator=(const DefaultMismatchHandler& rhs) { | |
313 | ✗ | if (this != &rhs) { | |
314 | // Cannot copy mismatch handler that is already in use. | ||
315 | − | KAI_ASSUME(rhs._num_mismatches == 0); | |
316 | − | KAI_ASSUME(!rhs._failed); | |
317 | |||
318 | ✗ | _abs_error_threshold = rhs._abs_error_threshold; | |
319 | ✗ | _rel_error_threshold = rhs._rel_error_threshold; | |
320 | ✗ | _abs_mismatched_threshold = rhs._abs_mismatched_threshold; | |
321 | ✗ | _rel_mismatched_threshold = rhs._rel_mismatched_threshold; | |
322 | ✗ | } | |
323 | |||
324 | ✗ | return *this; | |
325 | } | ||
326 | |||
327 | 37870175 | bool DefaultMismatchHandler::handle_data(float absolute_error, float relative_error) { | |
328 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 37870175 times.
|
37870175 | const auto mismatched = absolute_error > _abs_error_threshold && relative_error > _rel_error_threshold; |
329 | |||
330 |
2/2✓ Branch 0 taken 37843363 times.
✓ Branch 1 taken 26812 times.
|
37870175 | if (mismatched) { |
331 | 26812 | ++_num_mismatches; | |
332 | 26812 | } | |
333 | |||
334 | 75740350 | return mismatched; | |
335 | 37870175 | } | |
336 | |||
337 | ✗ | void DefaultMismatchHandler::mark_as_failed() { | |
338 | ✗ | _failed = true; | |
339 | ✗ | } | |
340 | |||
341 | 33658 | bool DefaultMismatchHandler::success(size_t num_checks) const { | |
342 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 33658 times.
|
33658 | if (_failed) { |
343 | ✗ | return false; | |
344 | } | ||
345 | |||
346 | 33658 | const auto mismatched_rate = static_cast<float>(_num_mismatches) / static_cast<float>(num_checks); | |
347 |
2/2✓ Branch 0 taken 32938 times.
✓ Branch 1 taken 720 times.
|
33658 | return _num_mismatches <= _abs_mismatched_threshold || mismatched_rate <= _rel_mismatched_threshold; |
348 | 33658 | } | |
349 | |||
350 | } // namespace kai::test | ||
351 |