KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 0.0% 0 / 3 / 110
Functions: 0.0% 0 / 0 / 8
Branches: 0.0% 0 / 2 / 55

test/common/printer.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/common/printer.hpp"
8
9 #include <cstddef>
10 #include <cstdint>
11 #include <ostream>
12 #include <string_view>
13
14 #include "kai/kai_common.h"
15 #include "test/common/bfloat16.hpp"
16 #include "test/common/data_format.hpp"
17 #include "test/common/data_type.hpp"
18 #include "test/common/float16.hpp"
19 #include "test/common/int4.hpp"
20
21 namespace kai::test {
22
23 namespace {
24
25 inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataType data_type) {
26 if (data_type == DataType::QSU4) {
27 for (size_t i = 0; i < len / 2; ++i) {
28 const auto [low, high] = UInt4::unpack_u8(data[i]);
29 os << static_cast<int32_t>(low) << ", " << static_cast<int32_t>(high) << ", ";
30 }
31 } else if (data_type == DataType::QSI4 || data_type == DataType::QAI4) {
32 for (size_t i = 0; i < len / 2; ++i) {
33 const auto [low, high] = Int4::unpack_u8(data[i]);
34 os << static_cast<int32_t>(low) << ", " << static_cast<int32_t>(high) << ", ";
35 }
36 } else {
37 for (size_t i = 0; i < len; ++i) {
38 switch (data_type) {
39 case DataType::FP32:
40 os << reinterpret_cast<const float*>(data)[i];
41 break;
42
43 case DataType::FP16:
44 os << reinterpret_cast<const Float16*>(data)[i];
45 break;
46
47 case DataType::BF16:
48 os << reinterpret_cast<const BFloat16<>*>(data)[i];
49 break;
50
51 case DataType::I32:
52 os << reinterpret_cast<const int32_t*>(data)[i];
53 break;
54
55 case DataType::QAI8:
56 case DataType::QSI8:
57 os << static_cast<int32_t>(reinterpret_cast<const int8_t*>(data)[i]);
58 break;
59
60 default:
61 KAI_ERROR("Unsupported data type!");
62 }
63
64 os << ", ";
65 }
66 }
67 }
68
69 void print_matrix_raw(std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) {
70 const auto data_type = format.data_type();
71 const auto esize_bits = data_type_size_in_bits(data_type);
72 const auto block_height = format.actual_block_height(height);
73 const auto block_width = format.actual_block_width(width);
74 const auto subblock_height = format.actual_subblock_height(height);
75 const auto subblock_width = format.actual_subblock_width(width);
76
77 os << "[\n";
78 for (size_t y_block = 0; y_block < height; y_block += block_height) {
79 if (block_height != height) {
80 os << " [\n";
81 }
82
83 for (size_t x_block = 0; x_block < width; x_block += block_width) {
84 if (block_width != width) {
85 os << " [\n";
86 }
87
88 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
89 if (subblock_height != block_height) {
90 os << " [\n";
91 }
92
93 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
94 if (subblock_width != block_width) {
95 os << " [\n";
96 }
97
98 for (size_t y = 0; y < subblock_height; ++y) {
99 os << " [";
100 print_data(os, data, subblock_width, data_type);
101 data += subblock_width * esize_bits / 8;
102 os << "],\n";
103 }
104
105 if (subblock_width != block_width) {
106 os << " ]\n";
107 }
108 }
109
110 if (subblock_height != block_height) {
111 os << " ]\n";
112 }
113 }
114
115 if (block_width != width) {
116 os << " ],\n";
117 }
118 }
119
120 if (block_height != height) {
121 os << " ],\n";
122 }
123 }
124 os << "]\n";
125 }
126
127 void print_matrix_per_row(
128 std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) {
129 const auto has_scale = format.pack_format() == DataFormat::PackFormat::QUANTIZE_PER_ROW;
130
131 const auto block_height = format.actual_block_height(height);
132
133 const auto num_blocks = (height + block_height - 1) / block_height;
134
135 KAI_ASSUME_ALWAYS(format.default_size_in_bytes(height, width) % num_blocks == 0);
136 const auto block_data_bytes = block_height * width * data_type_size_in_bits(format.data_type()) / 8;
137 const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.zero_point_data_type()) / 8;
138 const auto block_scales_bytes = has_scale ? block_height * data_type_size_in_bits(format.scale_data_type()) / 8 : 0;
139
140 os << "[\n";
141 for (size_t y = 0; y < num_blocks; ++y) {
142 os << " {\"offsets\": [";
143 print_data(os, data, block_height, format.zero_point_data_type());
144 os << "], \"data\": [";
145 print_data(os, data + block_offsets_bytes, block_height * width, format.data_type());
146
147 if (has_scale) {
148 os << "], \"scales\": [";
149 print_data(os, data + block_offsets_bytes + block_data_bytes, block_height, format.scale_data_type());
150 }
151
152 os << "]},\n";
153
154 data += block_offsets_bytes + block_data_bytes + block_scales_bytes;
155 }
156 os << "]\n";
157 }
158
159 } // namespace
160
161 void print_matrix(
162 std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width) {
163 os << name << " = ";
164
165 switch (format.pack_format()) {
166 case DataFormat::PackFormat::NONE:
167 print_matrix_raw(os, reinterpret_cast<const uint8_t*>(data), format, height, width);
168 break;
169
170 case DataFormat::PackFormat::BIAS_PER_ROW:
171 case DataFormat::PackFormat::QUANTIZE_PER_ROW:
172 print_matrix_per_row(os, reinterpret_cast<const uint8_t*>(data), format, height, width);
173 break;
174
175 default:
176 KAI_ERROR("Unsupported quantization packing format!");
177 }
178 }
179
180 } // namespace kai::test
181