KleidiAI Coverage Report


Directory: ./
File: test/common/compare.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 69.1% 141 9 213
Functions: 57.1% 12 0 21
Branches: 41.3% 210 52 560

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/common/compare.hpp"
8
9 #include <cstddef>
10 #include <cstdint>
11 #include <cstdlib>
12 #include <sstream>
13 #include <tuple>
14 #include <type_traits>
15
16 #include "kai/kai_common.h"
17 #include "test/common/bfloat16.hpp"
18 #include "test/common/data_format.hpp"
19 #include "test/common/data_type.hpp"
20 #include "test/common/float16.hpp"
21 #include "test/common/int4.hpp"
22 #include "test/common/logging.hpp"
23 #include "test/common/memory.hpp"
24 #include "test/common/printer.hpp"
25 #include "test/common/rect.hpp"
26 #include "test/common/round.hpp"
27
28 namespace kai::test {
29
30 namespace {
31
32 /// Calculates the absolute and relative errors.
33 ///
34 /// @param[in] imp Value under test.
35 /// @param[in] ref Reference value.
36 ///
37 /// @return The absolute error and relative error.
38 template <typename T>
39 120119698 std::tuple<float, float> calculate_error(T imp, T ref) {
40 120119698 const auto imp_f = static_cast<float>(imp);
41 120119698 const auto ref_f = static_cast<float>(ref);
42
43 120119698 const auto abs_error = std::abs(imp_f - ref_f);
44
6/12
✓ Branch 0 taken 58355864 times.
✓ Branch 1 taken 20016612 times.
✓ Branch 2 taken 24431130 times.
✓ Branch 3 taken 14540032 times.
✓ Branch 4 taken 1241360 times.
✓ Branch 5 taken 1534700 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
120119698 const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0F;
45
46 120119698 return {abs_error, rel_error};
47 120119698 }
48
49 /// Compares matrices with per-row quantization.
50 template <typename Data>
51 33190 bool compare_raw(
52 const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width,
53 const Rect& rect, MismatchHandler& handler) {
54 33190 const size_t block_height = format.actual_block_height(full_height);
55 33190 const size_t block_width = format.actual_block_width(full_width);
56 33190 const size_t subblock_height = format.actual_subblock_height(full_height);
57 33190 const size_t subblock_width = format.actual_subblock_width(full_width);
58
59 33190 size_t idx = 0;
60
61 33190 bool block_heading_written = false;
62 33190 bool subblock_heading_written = false;
63 33190 bool row_heading_written = false;
64 33190 std::ostringstream sstream;
65
6/6
✓ Branch 0 taken 16756 times.
✓ Branch 1 taken 16744 times.
✓ Branch 2 taken 15390 times.
✓ Branch 3 taken 15378 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 1068 times.
66404 for (size_t y_block = 0; y_block < full_height; y_block += block_height) {
66
6/6
✓ Branch 0 taken 18334 times.
✓ Branch 1 taken 16756 times.
✓ Branch 2 taken 15390 times.
✓ Branch 3 taken 16162 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 1068 times.
68778 for (size_t x_block = 0; x_block < full_width; x_block += block_width) {
67
6/6
✓ Branch 0 taken 18334 times.
✓ Branch 1 taken 18334 times.
✓ Branch 2 taken 16162 times.
✓ Branch 3 taken 16162 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 1068 times.
71128 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
68
6/6
✓ Branch 0 taken 18334 times.
✓ Branch 1 taken 18334 times.
✓ Branch 2 taken 16162 times.
✓ Branch 3 taken 16162 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 1068 times.
71128 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
69
6/6
✓ Branch 0 taken 18334 times.
✓ Branch 1 taken 596681 times.
✓ Branch 2 taken 16162 times.
✓ Branch 3 taken 495092 times.
✓ Branch 4 taken 1068 times.
✓ Branch 5 taken 25388 times.
1152725 for (size_t y_element = 0; y_element < subblock_height; ++y_element) {
70
6/6
✓ Branch 0 taken 75881580 times.
✓ Branch 1 taken 596681 times.
✓ Branch 2 taken 37450874 times.
✓ Branch 3 taken 495092 times.
✓ Branch 4 taken 2776060 times.
✓ Branch 5 taken 25388 times.
117225675 for (size_t x_element = 0; x_element < subblock_width; ++x_element) {
71 116108514 const auto y = y_block + y_subblock + y_element;
72 116108514 const auto x = x_block + x_subblock + x_element;
73
74
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 75881580 times.
✓ Branch 2 taken 37450874 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2776060 times.
✗ Branch 5 not taken.
116108514 const auto in_roi = rect.contains(y, x);
75
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 75881580 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 37450874 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2776060 times.
116108514 const auto imp_value = read_array<Data>(imp_data, idx);
76
11/16
✓ Branch 0 taken 56326618 times.
✓ Branch 1 taken 19554962 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 56326618 times.
✓ Branch 4 taken 14411534 times.
✓ Branch 5 taken 23039340 times.
✓ Branch 6 taken 23039340 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 14411534 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1534700 times.
✓ Branch 11 taken 1241360 times.
✓ Branch 12 taken 1241360 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1534700 times.
✗ Branch 15 not taken.
116108514 const auto ref_value = in_roi ? read_array<Data>(ref_data, idx) : static_cast<Data>(0);
77
78
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 75881580 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 37450874 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2776060 times.
191848864 const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value);
79
80
9/12
✓ Branch 0 taken 50878608 times.
✓ Branch 1 taken 25002972 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 50878608 times.
✓ Branch 4 taken 25202102 times.
✓ Branch 5 taken 12248772 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 25202102 times.
✓ Branch 8 taken 2157629 times.
✓ Branch 9 taken 618431 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2157629 times.
116108514 if (abs_err != 0 || rel_err != 0) {
81
3/6
✓ Branch 0 taken 25002972 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12248772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 618431 times.
✗ Branch 5 not taken.
37870175 if (!in_roi) {
82 handler.mark_as_failed();
83 }
84
85
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 25002972 times.
✓ Branch 2 taken 25002972 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25002972 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 25002972 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 12248772 times.
✓ Branch 10 taken 12248772 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 12248772 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 12248772 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 618431 times.
✓ Branch 18 taken 618431 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 618431 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 618431 times.
✗ Branch 23 not taken.
37870175 const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err);
86
87
5/6
✓ Branch 0 taken 13406 times.
✓ Branch 1 taken 24989566 times.
✓ Branch 2 taken 13406 times.
✓ Branch 3 taken 12235366 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 618431 times.
37870175 if (notifying) {
88
4/6
✓ Branch 0 taken 360 times.
✓ Branch 1 taken 13046 times.
✓ Branch 2 taken 360 times.
✓ Branch 3 taken 13046 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
26812 if (!block_heading_written) {
89
10/30
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 360 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 360 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 360 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 360 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 360 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 360 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
720 sstream << "block @ (" << y_block << ", " << x_block << "):\n";
90 720 block_heading_written = true;
91 720 }
92
4/6
✓ Branch 0 taken 360 times.
✓ Branch 1 taken 13046 times.
✓ Branch 2 taken 360 times.
✓ Branch 3 taken 13046 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
26812 if (!subblock_heading_written) {
93
10/30
✓ Branch 0 taken 360 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 360 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 360 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 360 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 360 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 360 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 360 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 360 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 360 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 360 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
720 sstream << " sub-block @ (" << y_subblock << ", " << x_subblock << "):\n";
94 720 subblock_heading_written = true;
95 720 }
96
4/6
✓ Branch 0 taken 5010 times.
✓ Branch 1 taken 8396 times.
✓ Branch 2 taken 5010 times.
✓ Branch 3 taken 8396 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
26812 if (!row_heading_written) {
97
6/18
✓ Branch 0 taken 5010 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5010 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5010 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 5010 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5010 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5010 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
10020 sstream << " row=" << y_element << ": ";
98 10020 row_heading_written = true;
99 10020 }
100
4/12
✓ Branch 0 taken 13406 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13406 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13406 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 13406 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
26812 sstream << x_element << ", ";
101 26812 }
102 37870175 }
103
104 116108514 ++idx;
105 116108514 }
106
5/6
✓ Branch 0 taken 591671 times.
✓ Branch 1 taken 5010 times.
✓ Branch 2 taken 490082 times.
✓ Branch 3 taken 5010 times.
✓ Branch 4 taken 25388 times.
✗ Branch 5 not taken.
1117161 if (row_heading_written) {
107
2/6
✓ Branch 0 taken 5010 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5010 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
10020 sstream << "\n";
108 10020 }
109 1117161 row_heading_written = false;
110 1117161 }
111 35564 subblock_heading_written = false;
112 35564 }
113 35564 }
114 35564 block_heading_written = false;
115 35564 }
116 33214 }
117
118
3/6
✓ Branch 0 taken 16744 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15378 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1068 times.
✗ Branch 5 not taken.
33190 const bool success = handler.success(full_height * full_width);
119
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 16744 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 15378 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1068 times.
33190 if (!success) {
120 KAI_LOGE("mismatches:\n", sstream.str());
121 }
122 33190 return success;
123 33190 }
124
125 /// Compares matrices with per-row bias or per-row quantization.
126 template <typename Data, typename Scale, typename Offset>
127 468 bool compare_per_row(
128 const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width,
129 const Rect& rect, MismatchHandler& handler) {
130 static constexpr auto has_scale = !std::is_null_pointer_v<Scale>;
131
132 468 const auto block_height = format.actual_block_height(full_height);
133 468 const auto block_width = format.actual_block_width(full_width);
134 468 const auto subblock_height = format.actual_subblock_height(full_height);
135 468 const auto subblock_width = format.actual_subblock_width(full_width);
136
137 KAI_ASSUME(format.scheduler_block_height(full_height) == block_height);
138 KAI_ASSUME(format.scheduler_block_width(full_width) == full_width);
139 KAI_ASSUME(rect.start_col() == 0);
140 KAI_ASSUME(rect.width() == full_width);
141
142 468 const size_t row_block_zero_points_bytes = block_height * sizeof(Offset);
143 468 const size_t row_block_scales_bytes = has_scale ? block_height * sizeof(Scale) : 0;
144 468 const size_t row_block_data_bytes = block_height * round_up_multiple(full_width, block_width) * sizeof(Data);
145
146 468 const auto* imp_ptr = reinterpret_cast<const uint8_t*>(imp_data);
147 468 const auto* ref_ptr = reinterpret_cast<const uint8_t*>(ref_data);
148
149
4/10
✓ Branch 0 taken 211 times.
✓ Branch 1 taken 993 times.
✓ Branch 2 taken 257 times.
✓ Branch 3 taken 1141 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
2602 for (size_t y_block = 0; y_block < full_height; y_block += block_height) {
150
4/10
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 975 times.
✓ Branch 2 taken 36 times.
✓ Branch 3 taken 1105 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
2134 const auto in_roi = y_block >= rect.start_row() && y_block < rect.end_row();
151
152 // Checks the zero points.
153
4/10
✓ Branch 0 taken 31120 times.
✓ Branch 1 taken 993 times.
✓ Branch 2 taken 51720 times.
✓ Branch 3 taken 1141 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
84974 for (size_t i = 0; i < block_height; ++i) {
154 82840 const auto imp_zero_point = reinterpret_cast<const Offset*>(imp_ptr)[i];
155
4/10
✓ Branch 0 taken 29952 times.
✓ Branch 1 taken 1168 times.
✓ Branch 2 taken 50392 times.
✓ Branch 3 taken 1328 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
82840 const Offset ref_zero_point = in_roi ? reinterpret_cast<const Offset*>(ref_ptr)[i] : static_cast<Offset>(0);
156 82840 const auto [abs_err, rel_err] = calculate_error(imp_zero_point, ref_zero_point);
157
158
4/20
✓ Branch 0 taken 31120 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 31120 times.
✓ Branch 4 taken 51720 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 51720 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
82840 if (abs_err != 0 || rel_err != 0) {
159 handler.mark_as_failed();
160
161 const auto raw_row = y_block + i;
162 KAI_LOGE(
163 "Mismatched zero point ", raw_row, ": actual = ", imp_zero_point, ", expected: ", ref_zero_point);
164 }
165 82840 }
166
167 2134 imp_ptr += row_block_zero_points_bytes;
168 2134 ref_ptr += row_block_zero_points_bytes;
169
170 // Checks the data.
171
4/10
✓ Branch 0 taken 22109 times.
✓ Branch 1 taken 993 times.
✓ Branch 2 taken 44477 times.
✓ Branch 3 taken 1141 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
68720 for (size_t x_block = 0; x_block < full_width; x_block += block_width) {
172
4/10
✓ Branch 0 taken 22109 times.
✓ Branch 1 taken 22109 times.
✓ Branch 2 taken 44477 times.
✓ Branch 3 taken 44477 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
133172 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
173
4/10
✓ Branch 0 taken 25271 times.
✓ Branch 1 taken 22109 times.
✓ Branch 2 taken 53617 times.
✓ Branch 3 taken 44477 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
145474 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
174
4/10
✓ Branch 0 taken 784656 times.
✓ Branch 1 taken 25271 times.
✓ Branch 2 taken 2439176 times.
✓ Branch 3 taken 53617 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
3302720 for (size_t y = 0; y < subblock_height; ++y) {
175
4/10
✓ Branch 0 taken 1489168 times.
✓ Branch 1 taken 784656 times.
✓ Branch 2 taken 2439176 times.
✓ Branch 3 taken 2439176 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
7152176 for (size_t x = 0; x < subblock_width; ++x) {
176 3928344 const auto offset = (y_subblock + y) * full_width + x_block + x_subblock + x;
177 3928344 const auto imp_value = read_array<Data>(imp_ptr, offset);
178
4/10
✓ Branch 0 taken 1432800 times.
✓ Branch 1 taken 56368 times.
✓ Branch 2 taken 2379440 times.
✓ Branch 3 taken 59736 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
3928344 const Data ref_value = in_roi ? read_array<Data>(ref_ptr, offset) : static_cast<Data>(0);
179 3928344 const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value);
180
181
4/20
✓ Branch 0 taken 1489168 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1489168 times.
✓ Branch 4 taken 2439176 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2439176 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
3928344 if (abs_err != 0 || rel_err != 0) {
182 if (!in_roi) {
183 handler.mark_as_failed();
184 }
185
186 const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err);
187
188 if (notifying) {
189 const auto raw_index = y_block * block_height * block_width + offset;
190 KAI_LOGE(
191 "Mismatched data ", raw_index, ": actual = ", imp_value,
192 ", expected: ", ref_value);
193 }
194 }
195 3928344 }
196 3223832 }
197 78888 }
198 66586 }
199 66586 }
200
201 2134 imp_ptr += row_block_data_bytes;
202 2134 ref_ptr += row_block_data_bytes;
203
204 // Checks the scales (if exists).
205 if constexpr (has_scale) {
206 for (size_t i = 0; i < block_height; ++i) {
207 const auto imp_scale = reinterpret_cast<const Scale*>(imp_ptr)[i];
208 const Scale ref_scale = in_roi ? reinterpret_cast<const Scale*>(ref_ptr)[i] : 0;
209 const auto [abs_err, rel_err] = calculate_error(imp_scale, ref_scale);
210
211 if (abs_err != 0 || rel_err != 0) {
212 handler.mark_as_failed();
213
214 const auto raw_row = y_block + i;
215 KAI_LOGE(
216 "Mismatched quantization scale ", raw_row, ": actual = ", imp_scale, ", expected: ", ref_scale);
217 }
218 }
219
220 imp_ptr += row_block_scales_bytes;
221 ref_ptr += row_block_scales_bytes;
222 }
223 2134 }
224
225 936 return handler.success(rect.height() * full_width);
226 468 }
227
228 } // namespace
229
230 33658 bool compare(
231 const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width,
232 const Rect& rect, MismatchHandler& handler) {
233 33658 const auto data_type = format.data_type();
234 33658 const auto scale_dt = format.scale_data_type();
235 33658 const auto offset_dt = format.zero_point_data_type();
236
237
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 33190 times.
✓ Branch 2 taken 468 times.
✗ Branch 3 not taken.
33658 switch (format.pack_format()) {
238 case DataFormat::PackFormat::NONE:
239
3/4
✓ Branch 0 taken 15378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 16744 times.
✓ Branch 3 taken 1068 times.
33190 switch (data_type) {
240 case DataType::FP32:
241 16744 return compare_raw<float>(imp_data, ref_data, format, full_height, full_width, rect, handler);
242
243 case DataType::FP16:
244 15378 return compare_raw<Float16>(imp_data, ref_data, format, full_height, full_width, rect, handler);
245
246 case DataType::BF16:
247 1068 return compare_raw<BFloat16<>>(imp_data, ref_data, format, full_height, full_width, rect, handler);
248
249 default:
250 break;
251 }
252
253 break;
254
255 case DataFormat::PackFormat::BIAS_PER_ROW:
256
3/4
✓ Branch 0 taken 211 times.
✓ Branch 1 taken 257 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 211 times.
468 if (data_type == DataType::FP16 && offset_dt == DataType::FP16) {
257 211 return compare_per_row<Float16, std::nullptr_t, Float16>(
258 211 imp_data, ref_data, format, full_height, full_width, rect, handler);
259
2/4
✓ Branch 0 taken 257 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 257 times.
257 } else if (data_type == DataType::FP32 && offset_dt == DataType::FP32) {
260 257 return compare_per_row<float, std::nullptr_t, float>(
261 257 imp_data, ref_data, format, full_height, full_width, rect, handler);
262 } else if (data_type == DataType::BF16 && offset_dt == DataType::FP32) {
263 return compare_per_row<BFloat16<>, std::nullptr_t, float>(
264 imp_data, ref_data, format, full_height, full_width, rect, handler);
265 }
266
267 break;
268
269 case DataFormat::PackFormat::QUANTIZE_PER_ROW:
270 if (data_type == DataType::QAI8 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) {
271 return compare_per_row<int8_t, float, int32_t>(
272 imp_data, ref_data, format, full_height, full_width, rect, handler);
273 } else if (data_type == DataType::QSI4 && scale_dt == DataType::FP32 && offset_dt == DataType::I32) {
274 return compare_per_row<Int4, float, int32_t>(
275 imp_data, ref_data, format, full_height, full_width, rect, handler);
276 }
277
278 break;
279
280 default:
281 break;
282 }
283
284 KAI_ERROR("Unsupported format!");
285 33658 }
286
287 // =====================================================================================================================
288
289 65912 DefaultMismatchHandler::DefaultMismatchHandler(
290 float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold,
291 float rel_mismatched_threshold) :
292 32956 _abs_error_threshold(abs_error_threshold),
293 32956 _rel_error_threshold(rel_error_threshold),
294 32956 _abs_mismatched_threshold(abs_mismatched_threshold),
295 32956 _rel_mismatched_threshold(rel_mismatched_threshold),
296 32956 _num_mismatches(0),
297 98868 _failed(false) {
298 65912 }
299
300 DefaultMismatchHandler::DefaultMismatchHandler(const DefaultMismatchHandler& rhs) :
301 _abs_error_threshold(rhs._abs_error_threshold),
302 _rel_error_threshold(rhs._rel_error_threshold),
303 _abs_mismatched_threshold(rhs._abs_mismatched_threshold),
304 _rel_mismatched_threshold(rhs._rel_mismatched_threshold),
305 _num_mismatches(0),
306 _failed(false) {
307 // Cannot copy mismatch handler that is already in use.
308 KAI_ASSUME(rhs._num_mismatches == 0);
309 KAI_ASSUME(!rhs._failed);
310 }
311
312 DefaultMismatchHandler& DefaultMismatchHandler::operator=(const DefaultMismatchHandler& rhs) {
313 if (this != &rhs) {
314 // Cannot copy mismatch handler that is already in use.
315 KAI_ASSUME(rhs._num_mismatches == 0);
316 KAI_ASSUME(!rhs._failed);
317
318 _abs_error_threshold = rhs._abs_error_threshold;
319 _rel_error_threshold = rhs._rel_error_threshold;
320 _abs_mismatched_threshold = rhs._abs_mismatched_threshold;
321 _rel_mismatched_threshold = rhs._rel_mismatched_threshold;
322 }
323
324 return *this;
325 }
326
327 37870175 bool DefaultMismatchHandler::handle_data(float absolute_error, float relative_error) {
328
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 37870175 times.
37870175 const auto mismatched = absolute_error > _abs_error_threshold && relative_error > _rel_error_threshold;
329
330
2/2
✓ Branch 0 taken 37843363 times.
✓ Branch 1 taken 26812 times.
37870175 if (mismatched) {
331 26812 ++_num_mismatches;
332 26812 }
333
334 75740350 return mismatched;
335 37870175 }
336
337 void DefaultMismatchHandler::mark_as_failed() {
338 _failed = true;
339 }
340
341 33658 bool DefaultMismatchHandler::success(size_t num_checks) const {
342
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 33658 times.
33658 if (_failed) {
343 return false;
344 }
345
346 33658 const auto mismatched_rate = static_cast<float>(_num_mismatches) / static_cast<float>(num_checks);
347
2/2
✓ Branch 0 taken 32938 times.
✓ Branch 1 taken 720 times.
33658 return _num_mismatches <= _abs_mismatched_threshold || mismatched_rate <= _rel_mismatched_threshold;
348 33658 }
349
350 } // namespace kai::test
351