Line |
Branch |
Exec |
Source |
1 |
|
|
// |
2 |
|
|
// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> |
3 |
|
|
// |
4 |
|
|
// SPDX-License-Identifier: Apache-2.0 |
5 |
|
|
// |
6 |
|
|
|
7 |
|
|
#include "test/common/printer.hpp" |
8 |
|
|
|
9 |
|
|
#include <cstddef> |
10 |
|
|
#include <cstdint> |
11 |
|
|
#include <ostream> |
12 |
|
|
#include <string_view> |
13 |
|
|
|
14 |
|
|
#include "kai/kai_common.h" |
15 |
|
|
#include "test/common/bfloat16.hpp" |
16 |
|
|
#include "test/common/data_format.hpp" |
17 |
|
|
#include "test/common/data_type.hpp" |
18 |
|
|
#include "test/common/float16.hpp" |
19 |
|
|
#include "test/common/int4.hpp" |
20 |
|
|
|
21 |
|
|
namespace kai::test { |
22 |
|
|
|
23 |
|
|
namespace { |
24 |
|
|
|
25 |
|
✗ |
inline void print_data(std::ostream& os, const uint8_t* data, size_t len, DataType data_type) { |
26 |
|
✗ |
if (data_type == DataType::QSU4) { |
27 |
|
✗ |
for (size_t i = 0; i < len / 2; ++i) { |
28 |
|
✗ |
const auto [low, high] = UInt4::unpack_u8(data[i]); |
29 |
|
✗ |
os << static_cast<int32_t>(low) << ", " << static_cast<int32_t>(high) << ", "; |
30 |
|
✗ |
} |
31 |
|
✗ |
} else if (data_type == DataType::QSI4) { |
32 |
|
✗ |
for (size_t i = 0; i < len / 2; ++i) { |
33 |
|
✗ |
const auto [low, high] = Int4::unpack_u8(data[i]); |
34 |
|
✗ |
os << static_cast<int32_t>(low) << ", " << static_cast<int32_t>(high) << ", "; |
35 |
|
✗ |
} |
36 |
|
✗ |
} else { |
37 |
|
✗ |
for (size_t i = 0; i < len; ++i) { |
38 |
|
✗ |
switch (data_type) { |
39 |
|
|
case DataType::FP32: |
40 |
|
✗ |
os << reinterpret_cast<const float*>(data)[i]; |
41 |
|
✗ |
break; |
42 |
|
|
|
43 |
|
|
case DataType::FP16: |
44 |
|
✗ |
os << reinterpret_cast<const Float16*>(data)[i]; |
45 |
|
✗ |
break; |
46 |
|
|
|
47 |
|
|
case DataType::BF16: |
48 |
|
✗ |
os << reinterpret_cast<const BFloat16<>*>(data)[i]; |
49 |
|
✗ |
break; |
50 |
|
|
|
51 |
|
|
case DataType::I32: |
52 |
|
✗ |
os << reinterpret_cast<const int32_t*>(data)[i]; |
53 |
|
✗ |
break; |
54 |
|
|
|
55 |
|
|
case DataType::QAI8: |
56 |
|
✗ |
os << static_cast<int32_t>(reinterpret_cast<const int8_t*>(data)[i]); |
57 |
|
✗ |
break; |
58 |
|
|
|
59 |
|
|
default: |
60 |
|
− |
KAI_ERROR("Unsupported data type!"); |
61 |
|
✗ |
} |
62 |
|
|
|
63 |
|
✗ |
os << ", "; |
64 |
|
✗ |
} |
65 |
|
|
} |
66 |
|
✗ |
} |
67 |
|
|
|
68 |
|
✗ |
void print_matrix_raw(std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) { |
69 |
|
✗ |
const auto data_type = format.data_type(); |
70 |
|
✗ |
const auto esize_bits = data_type_size_in_bits(data_type); |
71 |
|
✗ |
const auto block_height = format.actual_block_height(height); |
72 |
|
✗ |
const auto block_width = format.actual_block_width(width); |
73 |
|
✗ |
const auto subblock_height = format.actual_subblock_height(height); |
74 |
|
✗ |
const auto subblock_width = format.actual_subblock_width(width); |
75 |
|
|
|
76 |
|
✗ |
os << "[\n"; |
77 |
|
✗ |
for (size_t y_block = 0; y_block < height; y_block += block_height) { |
78 |
|
✗ |
if (block_height != height) { |
79 |
|
✗ |
os << " [\n"; |
80 |
|
✗ |
} |
81 |
|
|
|
82 |
|
✗ |
for (size_t x_block = 0; x_block < width; x_block += block_width) { |
83 |
|
✗ |
if (block_width != width) { |
84 |
|
✗ |
os << " [\n"; |
85 |
|
✗ |
} |
86 |
|
|
|
87 |
|
✗ |
for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) { |
88 |
|
✗ |
if (subblock_height != block_height) { |
89 |
|
✗ |
os << " [\n"; |
90 |
|
✗ |
} |
91 |
|
|
|
92 |
|
✗ |
for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) { |
93 |
|
✗ |
if (subblock_width != block_width) { |
94 |
|
✗ |
os << " [\n"; |
95 |
|
✗ |
} |
96 |
|
|
|
97 |
|
✗ |
for (size_t y = 0; y < subblock_height; ++y) { |
98 |
|
✗ |
os << " ["; |
99 |
|
✗ |
print_data(os, data, subblock_width, data_type); |
100 |
|
✗ |
data += subblock_width * esize_bits / 8; |
101 |
|
✗ |
os << "],\n"; |
102 |
|
✗ |
} |
103 |
|
|
|
104 |
|
✗ |
if (subblock_width != block_width) { |
105 |
|
✗ |
os << " ]\n"; |
106 |
|
✗ |
} |
107 |
|
✗ |
} |
108 |
|
|
|
109 |
|
✗ |
if (subblock_height != block_height) { |
110 |
|
✗ |
os << " ]\n"; |
111 |
|
✗ |
} |
112 |
|
✗ |
} |
113 |
|
|
|
114 |
|
✗ |
if (block_width != width) { |
115 |
|
✗ |
os << " ],\n"; |
116 |
|
✗ |
} |
117 |
|
✗ |
} |
118 |
|
|
|
119 |
|
✗ |
if (block_height != height) { |
120 |
|
✗ |
os << " ],\n"; |
121 |
|
✗ |
} |
122 |
|
✗ |
} |
123 |
|
✗ |
os << "]\n"; |
124 |
|
✗ |
} |
125 |
|
|
|
126 |
|
✗ |
void print_matrix_per_row( |
127 |
|
|
std::ostream& os, const uint8_t* data, const DataFormat& format, size_t height, size_t width) { |
128 |
|
✗ |
const auto has_scale = format.pack_format() == DataFormat::PackFormat::QUANTIZE_PER_ROW; |
129 |
|
|
|
130 |
|
✗ |
const auto block_height = format.actual_block_height(height); |
131 |
|
|
|
132 |
|
✗ |
const auto num_blocks = (height + block_height - 1) / block_height; |
133 |
|
|
|
134 |
|
− |
KAI_ASSUME(format.default_size_in_bytes(height, width) % num_blocks == 0); |
135 |
|
✗ |
const auto block_data_bytes = block_height * width * data_type_size_in_bits(format.data_type()) / 8; |
136 |
|
✗ |
const auto block_offsets_bytes = block_height * data_type_size_in_bits(format.zero_point_data_type()) / 8; |
137 |
|
✗ |
const auto block_scales_bytes = has_scale ? block_height * data_type_size_in_bits(format.scale_data_type()) / 8 : 0; |
138 |
|
|
|
139 |
|
✗ |
os << "[\n"; |
140 |
|
✗ |
for (size_t y = 0; y < num_blocks; ++y) { |
141 |
|
✗ |
os << " {\"offsets\": ["; |
142 |
|
✗ |
print_data(os, data, block_height, format.zero_point_data_type()); |
143 |
|
✗ |
os << "], \"data\": ["; |
144 |
|
✗ |
print_data(os, data + block_offsets_bytes, block_height * width, format.data_type()); |
145 |
|
|
|
146 |
|
✗ |
if (has_scale) { |
147 |
|
✗ |
os << "], \"scales\": ["; |
148 |
|
✗ |
print_data(os, data + block_offsets_bytes + block_data_bytes, block_height, format.scale_data_type()); |
149 |
|
✗ |
} |
150 |
|
|
|
151 |
|
✗ |
os << "]},\n"; |
152 |
|
|
|
153 |
|
✗ |
data += block_offsets_bytes + block_data_bytes + block_scales_bytes; |
154 |
|
✗ |
} |
155 |
|
✗ |
os << "]\n"; |
156 |
|
✗ |
} |
157 |
|
|
|
158 |
|
|
} // namespace |
159 |
|
|
|
160 |
|
✗ |
void print_matrix( |
161 |
|
|
std::ostream& os, std::string_view name, const void* data, const DataFormat& format, size_t height, size_t width) { |
162 |
|
✗ |
os << name << " = "; |
163 |
|
|
|
164 |
|
✗ |
switch (format.pack_format()) { |
165 |
|
|
case DataFormat::PackFormat::NONE: |
166 |
|
✗ |
print_matrix_raw(os, reinterpret_cast<const uint8_t*>(data), format, height, width); |
167 |
|
✗ |
break; |
168 |
|
|
|
169 |
|
|
case DataFormat::PackFormat::BIAS_PER_ROW: |
170 |
|
|
case DataFormat::PackFormat::QUANTIZE_PER_ROW: |
171 |
|
✗ |
print_matrix_per_row(os, reinterpret_cast<const uint8_t*>(data), format, height, width); |
172 |
|
✗ |
break; |
173 |
|
|
|
174 |
|
|
default: |
175 |
|
− |
KAI_ERROR("Unsupported quantization packing format!"); |
176 |
|
✗ |
} |
177 |
|
✗ |
} |
178 |
|
|
|
179 |
|
|
} // namespace kai::test |
180 |
|
|
|