KleidiAI Coverage Report


Directory: ./
File: test/common/printer.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 0.0% 0 3 110
Functions: 0.0% 0 0 4
Branches: 0.0% 0 2 53

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/common/printer.hpp"
8
9 #include <cstddef>
10 #include <cstdint>
11 #include <ostream>
12 #include <string_view>
13
14 #include "kai/kai_common.h"
15 #include "test/common/bfloat16.hpp"
16 #include "test/common/data_format.hpp"
17 #include "test/common/data_type.hpp"
18 #include "test/common/float16.hpp"
19 #include "test/common/int4.hpp"
20
21 namespace kai::test {
22
23 namespace {
24
25 inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataType data_type) {
26 if (data_type == DataType::QSU4) {
27 for (size_t i = 0; i < len / 2; ++i) {
28 const auto [low, high] = UInt4::unpack_u8(data[i]);
29 os << static_cast<int32_t>(low) << ", " << static_cast<int32_t>(high) << ", ";
30 }
31 } else if (data_type == DataType::QSI4) {
32 for (size_t i = 0; i < len / 2; ++i) {
33 const auto [low, high] = Int4::unpack_u8(data[i]);
34 os << static_cast<int32_t>(low) << ", " << static_cast<int32_t>(high) << ", ";
35 }
36 } else {
37 for (size_t i = 0; i < len; ++i) {
38 switch (data_type) {
39 case DataType::FP32:
40 os << reinterpret_cast<const float*>(data)[i];
41 break;
42
43 case DataType::FP16:
44 os << reinterpret_cast<const Float16*>(data)[i];
45 break;
46
47 case DataType::BF16:
48 os << reinterpret_cast<const BFloat16<>*>(data)[i];
49 break;
50
51 case DataType::I32:
52 os << reinterpret_cast<const int32_t*>(data)[i];
53 break;
54
55 case DataType::QAI8:
56 os << static_cast<int32_t>(reinterpret_cast<const int8_t*>(data)[i]);
57 break;
58
59 default:
60 KAI_ERROR("Unsupported data type!");
61 }
62
63 os << ", ";
64 }
65 }
66 }
67
68 void print_matrix_raw(std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) {
69 const auto data_type = format.data_type();
70 const auto esize_bits = data_type_size_in_bits(data_type);
71 const auto block_height = format.actual_block_height(height);
72 const auto block_width = format.actual_block_width(width);
73 const auto subblock_height = format.actual_subblock_height(height);
74 const auto subblock_width = format.actual_subblock_width(width);
75
76 os << "[\n";
77 for (size_t y_block = 0; y_block < height; y_block += block_height) {
78 if (block_height != height) {
79 os << " [\n";
80 }
81
82 for (size_t x_block = 0; x_block < width; x_block += block_width) {
83 if (block_width != width) {
84 os << " [\n";
85 }
86
87 for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
88 if (subblock_height != block_height) {
89 os << " [\n";
90 }
91
92 for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
93 if (subblock_width != block_width) {
94 os << " [\n";
95 }
96
97 for (size_t y = 0; y < subblock_height; ++y) {
98 os << " [";
99 print_data(os, data, subblock_width, data_type);
100 data += subblock_width * esize_bits / 8;
101 os << "],\n";
102 }
103
104 if (subblock_width != block_width) {
105 os << " ]\n";
106 }
107 }
108
109 if (subblock_height != block_height) {
110 os << " ]\n";
111 }
112 }
113
114 if (block_width != width) {
115 os << " ],\n";
116 }
117 }
118
119 if (block_height != height) {
120 os << " ],\n";
121 }
122 }
123 os << "]\n";
124 }
125
126 void print_matrix_per_row(
127 std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) {
128 const auto has_scale = format.pack_format() == DataFormat::PackFormat::QUANTIZE_PER_ROW;
129
130 const auto block_height = format.actual_block_height(height);
131
132 const auto num_blocks = (height + block_height - 1) / block_height;
133
134 KAI_ASSUME(format.default_size_in_bytes(height, width) % num_blocks == 0);
135 const auto block_data_bytes = block_height * width * data_type_size_in_bits(format.data_type()) / 8;
136 const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.zero_point_data_type()) / 8;
137 const auto block_scales_bytes = has_scale ? block_height * data_type_size_in_bits(format.scale_data_type()) / 8 : 0;
138
139 os << "[\n";
140 for (size_t y = 0; y < num_blocks; ++y) {
141 os << " {\"offsets\": [";
142 print_data(os, data, block_height, format.zero_point_data_type());
143 os << "], \"data\": [";
144 print_data(os, data + block_offsets_bytes, block_height * width, format.data_type());
145
146 if (has_scale) {
147 os << "], \"scales\": [";
148 print_data(os, data + block_offsets_bytes + block_data_bytes, block_height, format.scale_data_type());
149 }
150
151 os << "]},\n";
152
153 data += block_offsets_bytes + block_data_bytes + block_scales_bytes;
154 }
155 os << "]\n";
156 }
157
158 } // namespace
159
160 void print_matrix(
161 std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width) {
162 os << name << " = ";
163
164 switch (format.pack_format()) {
165 case DataFormat::PackFormat::NONE:
166 print_matrix_raw(os, reinterpret_cast<const uint8_t*>(data), format, height, width);
167 break;
168
169 case DataFormat::PackFormat::BIAS_PER_ROW:
170 case DataFormat::PackFormat::QUANTIZE_PER_ROW:
171 print_matrix_per_row(os, reinterpret_cast<const uint8_t*>(data), format, height, width);
172 break;
173
174 default:
175 KAI_ERROR("Unsupported quantization packing format!");
176 }
177 }
178
179 } // namespace kai::test
180