KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 69.1% 141 / 9 / 213
Functions: 55.6% 15 / 0 / 27
Branches: 37.2% 191 / 48 / 562

test/common/compare.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/common/compare.hpp"
8
9 #include <cstddef>
10 #include <cstdint>
11 #include <cstdlib>
12 #include <sstream>
13 #include <tuple>
14 #include <type_traits>
15
16 #include "kai/kai_common.h"
17 #include "test/common/bfloat16.hpp"
18 #include "test/common/data_format.hpp"
19 #include "test/common/data_type.hpp"
20 #include "test/common/float16.hpp"
21 #include "test/common/int4.hpp"
22 #include "test/common/logging.hpp"
23 #include "test/common/memory.hpp"
24 #include "test/common/printer.hpp"
25 #include "test/common/rect.hpp"
26 #include "test/common/round.hpp"
27
28 namespace kai::test {
29
30 namespace {
31
32 /// Calculates the absolute and relative errors.
33 ///
34 /// @param[in] imp Value under test.
35 /// @param[in] ref Reference value.
36 ///
37 /// @return The absolute error and relative error.
38 template <typename T>
39 156982263 std::tuple<float, float> calculate_error(T imp, T ref) {
40 156982263 const auto imp_f = static_cast<float>(imp);
41 156982263 const auto ref_f = static_cast<float>(ref);
42
43 156982263 const auto abs_error = std::abs(imp_f - ref_f);
44
6/12
✓ Branch 0 taken 78018717 times.
✓ Branch 1 taken 34277676 times.
✓ Branch 2 taken 20604111 times.
✓ Branch 3 taken 14129727 times.
✓ Branch 4 taken 4633584 times.
✓ Branch 5 taken 5318448 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
156982263 const auto rel_error = ref_f != 0 ? abs_error / std::abs(ref_f) : 0.0F;
45
46 156982263 return {abs_error, rel_error};
47 156982263 }
48
49 /// Compares matrices with per-row quantization.
50 template <typename Data>
51 57870 bool compare_raw(
52 const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width,
53 const Rect& rect, MismatchHandler& handler) {
54 57870 const size_t block_height = format.actual_block_height(full_height);
55 57870 const size_t block_width = format.actual_block_width(full_width);
56 57870 const size_t subblock_height = format.actual_subblock_height(full_height);
57 57870 const size_t subblock_width = format.actual_subblock_width(full_width);
58
59 57870 size_t idx = 0;
60
61 57870 bool block_heading_written = false;
62 57870 bool subblock_heading_written = false;
63 57870 bool row_heading_written = false;
64 57870 std::ostringstream sstream;
65
6/6
✓ Branch 0 taken 26466 times.
✓ Branch 1 taken 26430 times.
✓ Branch 2 taken 26580 times.
✓ Branch 3 taken 26544 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 4896 times.
115812 for (size_t y_block = 0; y_block < full_height; y_block += block_height) {
66
6/6
✓ Branch 0 taken 31290 times.
✓ Branch 1 taken 26466 times.
✓ Branch 2 taken 26580 times.
✓ Branch 3 taken 28938 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 4896 times.
123066 for (size_t x_block = 0; x_block < full_width; x_block += block_width) {
67
6/6
✓ Branch 0 taken 31290 times.
✓ Branch 1 taken 31290 times.
✓ Branch 2 taken 28938 times.
✓ Branch 3 taken 28938 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 4896 times.
130248 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
68
6/6
✓ Branch 0 taken 31290 times.
✓ Branch 1 taken 31290 times.
✓ Branch 2 taken 28938 times.
✓ Branch 3 taken 28938 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 4896 times.
130248 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
69
6/6
✓ Branch 0 taken 31290 times.
✓ Branch 1 taken 989456 times.
✓ Branch 2 taken 28938 times.
✓ Branch 3 taken 471438 times.
✓ Branch 4 taken 4896 times.
✓ Branch 5 taken 119952 times.
1645970 for (size_t y_element = 0; y_element < subblock_height; ++y_element) {
70
6/6
✓ Branch 0 taken 104115945 times.
✓ Branch 1 taken 989456 times.
✓ Branch 2 taken 29549166 times.
✓ Branch 3 taken 471438 times.
✓ Branch 4 taken 9952032 times.
✓ Branch 5 taken 119952 times.
145197989 for (size_t x_element = 0; x_element < subblock_width; ++x_element) {
71 143617143 const auto y = y_block + y_subblock + y_element;
72 143617143 const auto x = x_block + x_subblock + x_element;
73
74
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 104115945 times.
✓ Branch 2 taken 29549166 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9952032 times.
✗ Branch 5 not taken.
143617143 const auto in_roi = rect.contains(y, x);
75
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 104115945 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 29549166 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 9952032 times.
143617143 const auto imp_value = read_array<Data>(imp_data, idx);
76
11/16
✓ Branch 0 taken 71557437 times.
✓ Branch 1 taken 32558508 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 71557437 times.
✓ Branch 4 taken 13434080 times.
✓ Branch 5 taken 16115086 times.
✓ Branch 6 taken 16115086 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 13434080 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5318448 times.
✓ Branch 11 taken 4633584 times.
✓ Branch 12 taken 4633584 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 5318448 times.
✗ Branch 15 not taken.
143617143 const auto ref_value = in_roi ? read_array<Data>(ref_data, idx) : static_cast<Data>(0);
77
78
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 104115945 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 29549166 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 9952032 times.
237706947 const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value);
79
80
9/12
✓ Branch 0 taken 63584591 times.
✓ Branch 1 taken 40531354 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 63584591 times.
✓ Branch 4 taken 25347998 times.
✓ Branch 5 taken 4201168 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 25347998 times.
✓ Branch 8 taken 7639652 times.
✓ Branch 9 taken 2312380 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7639652 times.
143617143 if (abs_err != 0 || rel_err != 0) {
81
3/6
✓ Branch 0 taken 40531354 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4201168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2312380 times.
✗ Branch 5 not taken.
47044902 if (!in_roi) {
82 handler.mark_as_failed();
83 }
84
85
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 40531354 times.
✓ Branch 2 taken 40531354 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40531354 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 40531354 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 4201168 times.
✓ Branch 10 taken 4201168 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4201168 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4201168 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2312380 times.
✓ Branch 18 taken 2312380 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2312380 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2312380 times.
✗ Branch 23 not taken.
47044902 const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err);
86
87
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 40531354 times.
✓ Branch 2 taken 995 times.
✓ Branch 3 taken 4200173 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2312380 times.
47044902 if (notifying) {
88
2/6
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 172 times.
✓ Branch 3 taken 823 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
995 if (!block_heading_written) {
89
5/30
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 172 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 172 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 172 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 172 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 172 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
172 sstream << "block @ (" << y_block << ", " << x_block << "):\n";
90 172 block_heading_written = true;
91 172 }
92
2/6
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 172 times.
✓ Branch 3 taken 823 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
995 if (!subblock_heading_written) {
93
5/30
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 172 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 172 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 172 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 172 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 172 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
172 sstream << " sub-block @ (" << y_subblock << ", " << x_subblock << "):\n";
94 172 subblock_heading_written = true;
95 172 }
96
2/6
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 877 times.
✓ Branch 3 taken 118 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
995 if (!row_heading_written) {
97
3/18
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 877 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 877 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 877 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
877 sstream << " row=" << y_element << ": ";
98 877 row_heading_written = true;
99 877 }
100
2/12
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 995 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 995 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
995 sstream << x_element << ", ";
101 995 }
102 47044902 }
103
104 143617143 ++idx;
105 143617143 }
106
4/6
✓ Branch 0 taken 989456 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 470561 times.
✓ Branch 3 taken 877 times.
✓ Branch 4 taken 119952 times.
✗ Branch 5 not taken.
1580846 if (row_heading_written) {
107
1/6
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 877 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
877 sstream << "\n";
108 877 }
109 1580846 row_heading_written = false;
110 1580846 }
111 65124 subblock_heading_written = false;
112 65124 }
113 65124 }
114 65124 block_heading_written = false;
115 65124 }
116 57942 }
117
118
3/6
✓ Branch 0 taken 26430 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 26544 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4896 times.
✗ Branch 5 not taken.
57870 const bool success = handler.success(full_height * full_width);
119
6/6
✓ Branch 0 taken 8338 times.
✓ Branch 1 taken 18092 times.
✓ Branch 2 taken 8520 times.
✓ Branch 3 taken 18024 times.
✓ Branch 4 taken 2448 times.
✓ Branch 5 taken 2448 times.
57870 if (!success) {
120 KAI_LOGE("mismatches:\n", sstream.str());
121 }
122 57870 return success;
123 57870 }
124
125 /// Compares matrices with per-row bias or per-row quantization.
126 template <typename Data, typename Scale, typename Offset>
127 2400 bool compare_per_row(
128 const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width,
129 const Rect& rect, MismatchHandler& handler) {
130 static constexpr auto has_scale = !std::is_null_pointer_v<Scale>;
131
132 2400 const auto block_height = format.actual_block_height(full_height);
133 2400 const auto block_width = format.actual_block_width(full_width);
134 2400 const auto subblock_height = format.actual_subblock_height(full_height);
135 2400 const auto subblock_width = format.actual_subblock_width(full_width);
136
137 KAI_ASSUME_ALWAYS(format.scheduler_block_height(full_height) == block_height);
138 KAI_ASSUME_ALWAYS(format.scheduler_block_width(full_width) == full_width);
139 KAI_ASSUME_ALWAYS(rect.start_col() == 0);
140 KAI_ASSUME_ALWAYS(rect.width() == full_width);
141
142 2400 const size_t row_block_zero_points_bytes = block_height * sizeof(Offset);
143 2400 const size_t row_block_scales_bytes = has_scale ? block_height * sizeof(Scale) : 0;
144 2400 const size_t row_block_data_bytes = block_height * round_up_multiple(full_width, block_width) * sizeof(Data);
145
146 2400 const auto* imp_ptr = reinterpret_cast<const uint8_t*>(imp_data);
147 2400 const auto* ref_ptr = reinterpret_cast<const uint8_t*>(ref_data);
148
149
4/10
✓ Branch 0 taken 1068 times.
✓ Branch 1 taken 3648 times.
✓ Branch 2 taken 1332 times.
✓ Branch 3 taken 4695 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
10743 for (size_t y_block = 0; y_block < full_height; y_block += block_height) {
150
4/10
✓ Branch 0 taken 108 times.
✓ Branch 1 taken 3540 times.
✓ Branch 2 taken 288 times.
✓ Branch 3 taken 4407 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
8343 const auto in_roi = y_block >= rect.start_row() && y_block < rect.end_row();
151
152 // Checks the zero points.
153
4/10
✓ Branch 0 taken 112704 times.
✓ Branch 1 taken 3648 times.
✓ Branch 2 taken 180000 times.
✓ Branch 3 taken 4695 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
301047 for (size_t i = 0; i < block_height; ++i) {
154 292704 const auto imp_zero_point = reinterpret_cast<const Offset*>(imp_ptr)[i];
155
4/10
✓ Branch 0 taken 105696 times.
✓ Branch 1 taken 7008 times.
✓ Branch 2 taken 170400 times.
✓ Branch 3 taken 9600 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
292704 const Offset ref_zero_point = in_roi ? reinterpret_cast<const Offset*>(ref_ptr)[i] : static_cast<Offset>(0);
156 292704 const auto [abs_err, rel_err] = calculate_error(imp_zero_point, ref_zero_point);
157
158
4/20
✓ Branch 0 taken 112704 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 112704 times.
✓ Branch 4 taken 180000 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 180000 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
292704 if (abs_err != 0 || rel_err != 0) {
159 handler.mark_as_failed();
160
161 const auto raw_row = y_block + i;
162 KAI_LOGE(
163 "Mismatched zero point ", raw_row, ": actual = ", imp_zero_point, ", expected: ", ref_zero_point);
164 }
165 292704 }
166
167 8343 imp_ptr += row_block_zero_points_bytes;
168 8343 ref_ptr += row_block_zero_points_bytes;
169
170 // Checks the data.
171
4/10
✓ Branch 0 taken 67104 times.
✓ Branch 1 taken 3648 times.
✓ Branch 2 taken 134703 times.
✓ Branch 3 taken 4695 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
210150 for (size_t x_block = 0; x_block < full_width; x_block += block_width) {
172
4/10
✓ Branch 0 taken 67104 times.
✓ Branch 1 taken 67104 times.
✓ Branch 2 taken 134703 times.
✓ Branch 3 taken 134703 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
403614 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
173
4/10
✓ Branch 0 taken 96858 times.
✓ Branch 1 taken 67104 times.
✓ Branch 2 taken 204696 times.
✓ Branch 3 taken 134703 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
503361 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
174
4/10
✓ Branch 0 taken 2953440 times.
✓ Branch 1 taken 96858 times.
✓ Branch 2 taken 8000448 times.
✓ Branch 3 taken 204696 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
11255442 for (size_t y = 0; y < subblock_height; ++y) {
175
4/10
✓ Branch 0 taken 5071968 times.
✓ Branch 1 taken 2953440 times.
✓ Branch 2 taken 8000448 times.
✓ Branch 3 taken 8000448 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
24026304 for (size_t x = 0; x < subblock_width; ++x) {
176 13072416 const auto offset = (y_subblock + y) * full_width + x_block + x_subblock + x;
177 13072416 const auto imp_value = read_array<Data>(imp_ptr, offset);
178
4/10
✓ Branch 0 taken 4734144 times.
✓ Branch 1 taken 337824 times.
✓ Branch 2 taken 7582032 times.
✓ Branch 3 taken 418416 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
13072416 const Data ref_value = in_roi ? read_array<Data>(ref_ptr, offset) : static_cast<Data>(0);
179 13072416 const auto [abs_err, rel_err] = calculate_error(imp_value, ref_value);
180
181
4/20
✓ Branch 0 taken 5071968 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5071968 times.
✓ Branch 4 taken 8000448 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 8000448 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
13072416 if (abs_err != 0 || rel_err != 0) {
182 if (!in_roi) {
183 handler.mark_as_failed();
184 }
185
186 const auto notifying = !in_roi || handler.handle_data(abs_err, rel_err);
187
188 if (notifying) {
189 const auto raw_index = y_block * block_height * block_width + offset;
190 KAI_LOGE(
191 "Mismatched data ", raw_index, ": actual = ", imp_value,
192 ", expected: ", ref_value);
193 }
194 }
195 13072416 }
196 10953888 }
197 301554 }
198 201807 }
199 201807 }
200
201 8343 imp_ptr += row_block_data_bytes;
202 8343 ref_ptr += row_block_data_bytes;
203
204 // Checks the scales (if exists).
205 if constexpr (has_scale) {
206 for (size_t i = 0; i < block_height; ++i) {
207 const auto imp_scale = reinterpret_cast<const Scale*>(imp_ptr)[i];
208 const Scale ref_scale = in_roi ? reinterpret_cast<const Scale*>(ref_ptr)[i] : 0;
209 const auto [abs_err, rel_err] = calculate_error(imp_scale, ref_scale);
210
211 if (abs_err != 0 || rel_err != 0) {
212 handler.mark_as_failed();
213
214 const auto raw_row = y_block + i;
215 KAI_LOGE(
216 "Mismatched quantization scale ", raw_row, ": actual = ", imp_scale, ", expected: ", ref_scale);
217 }
218 }
219
220 imp_ptr += row_block_scales_bytes;
221 ref_ptr += row_block_scales_bytes;
222 }
223 8343 }
224
225 4800 return handler.success(rect.height() * full_width);
226 2400 }
227
228 } // namespace
229
230 60270 bool compare(
231 const void* imp_data, const void* ref_data, const DataFormat& format, size_t full_height, size_t full_width,
232 const Rect& rect, MismatchHandler& handler) {
233 60270 const auto data_type = format.data_type();
234 60270 const auto scale_dt = format.scale_data_type();
235 60270 const auto offset_dt = format.zero_point_data_type();
236
237
3/4
✗ Branch 0 not taken.
✓ Branch 1 taken 39212 times.
✓ Branch 2 taken 1752 times.
✓ Branch 3 taken 19306 times.
60270 switch (format.pack_format()) {
238 case DataFormat::PackFormat::NONE:
239
3/4
✓ Branch 0 taken 26544 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 26430 times.
✓ Branch 3 taken 4896 times.
57870 switch (data_type) {
240 case DataType::FP32:
241 26430 return compare_raw<float>(imp_data, ref_data, format, full_height, full_width, rect, handler);
242
243 case DataType::FP16:
244 26544 return compare_raw<Float16>(imp_data, ref_data, format, full_height, full_width, rect, handler);
245
246 case DataType::BF16:
247 4896 return compare_raw<BFloat16<>>(imp_data, ref_data, format, full_height, full_width, rect, handler);
248
249 default:
250 break;
251 }
252
253 break;
254
255 case DataFormat::PackFormat::BIAS_PER_ROW:
256
3/4
✓ Branch 0 taken 1068 times.
✓ Branch 1 taken 1332 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1068 times.
2400 if (data_type == DataType::FP16 && offset_dt == DataType::FP16) {
257 1068 return compare_per_row<Float16, std::nullptr_t, Float16>(
258 1068 imp_data, ref_data, format, full_height, full_width, rect, handler);
259
2/4
✓ Branch 0 taken 1332 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1332 times.
1332 } else if (data_type == DataType::FP32 && offset_dt == DataType::FP32) {
260 1332 return compare_per_row<float, std::nullptr_t, float>(
261 1332 imp_data, ref_data, format, full_height, full_width, rect, handler);
262 } else if (data_type == DataType::BF16 && offset_dt == DataType::FP32) {
263 return compare_per_row<BFloat16<>, std::nullptr_t, float>(
264 imp_data, ref_data, format, full_height, full_width, rect, handler);
265 }
266
267 break;
268
269 case DataFormat::PackFormat::QUANTIZE_PER_ROW:
270 if (data_type_is_quantized_int8(data_type) && scale_dt == DataType::FP32 && offset_dt == DataType::I32) {
271 return compare_per_row<int8_t, float, int32_t>(
272 imp_data, ref_data, format, full_height, full_width, rect, handler);
273 } else if (
274 data_type_is_quantized_int4(data_type) && scale_dt == DataType::FP32 && offset_dt == DataType::I32) {
275 return compare_per_row<Int4, float, int32_t>(
276 imp_data, ref_data, format, full_height, full_width, rect, handler);
277 }
278
279 break;
280
281 default:
282 break;
283 }
284
285 KAI_ERROR("Unsupported format!");
286 60270 }
287
288 // =====================================================================================================================
289
290 104186 DefaultMismatchHandler::DefaultMismatchHandler(
291 float abs_error_threshold, float rel_error_threshold, size_t abs_mismatched_threshold,
292 float rel_mismatched_threshold) :
293 62070 _abs_error_threshold(abs_error_threshold),
294 62070 _rel_error_threshold(rel_error_threshold),
295 62070 _abs_mismatched_threshold(abs_mismatched_threshold),
296 62070 _rel_mismatched_threshold(rel_mismatched_threshold),
297 62070 _num_mismatches(0),
298 166256 _failed(false) {
299 104186 }
300
301 DefaultMismatchHandler::DefaultMismatchHandler(const DefaultMismatchHandler& rhs) :
302 _abs_error_threshold(rhs._abs_error_threshold),
303 _rel_error_threshold(rhs._rel_error_threshold),
304 _abs_mismatched_threshold(rhs._abs_mismatched_threshold),
305 _rel_mismatched_threshold(rhs._rel_mismatched_threshold),
306 _num_mismatches(0),
307 _failed(false) {
308 // Cannot copy mismatch handler that is already in use.
309 KAI_ASSUME_ALWAYS(rhs._num_mismatches == 0);
310 KAI_ASSUME_ALWAYS(!rhs._failed);
311 }
312
313 DefaultMismatchHandler& DefaultMismatchHandler::operator=(const DefaultMismatchHandler& rhs) {
314 if (this != &rhs) {
315 // Cannot copy mismatch handler that is already in use.
316 KAI_ASSUME_ALWAYS(rhs._num_mismatches == 0);
317 KAI_ASSUME_ALWAYS(!rhs._failed);
318
319 _abs_error_threshold = rhs._abs_error_threshold;
320 _rel_error_threshold = rhs._rel_error_threshold;
321 _abs_mismatched_threshold = rhs._abs_mismatched_threshold;
322 _rel_mismatched_threshold = rhs._rel_mismatched_threshold;
323 }
324
325 return *this;
326 }
327
328 48073517 bool DefaultMismatchHandler::handle_data(float absolute_error, float relative_error) {
329
2/2
✓ Branch 0 taken 1028615 times.
✓ Branch 1 taken 47044902 times.
48073517 const auto mismatched = absolute_error > _abs_error_threshold && relative_error > _rel_error_threshold;
330
331
2/2
✓ Branch 0 taken 48072522 times.
✓ Branch 1 taken 995 times.
48073517 if (mismatched) {
332 995 ++_num_mismatches;
333 995 }
334
335 96147034 return mismatched;
336 48073517 }
337
338 void DefaultMismatchHandler::mark_as_failed() {
339 _failed = true;
340 }
341
342 62070 bool DefaultMismatchHandler::success(size_t num_checks) const {
343
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 62070 times.
62070 if (_failed) {
344 return false;
345 }
346
347 62070 const auto mismatched_rate = static_cast<float>(_num_mismatches) / static_cast<float>(num_checks);
348
2/2
✓ Branch 0 taken 61898 times.
✓ Branch 1 taken 172 times.
62070 return _num_mismatches <= _abs_mismatched_threshold || mismatched_rate <= _rel_mismatched_threshold;
349 62070 }
350
351 } // namespace kai::test
352