Line |
Branch |
Exec |
Source |
1 |
|
|
// |
2 |
|
|
// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> |
3 |
|
|
// |
4 |
|
|
// SPDX-License-Identifier: Apache-2.0 |
5 |
|
|
// |
6 |
|
|
|
7 |
|
|
#pragma once |
8 |
|
|
|
9 |
|
|
#include <cstdint> |
10 |
|
|
#include <iosfwd> |
11 |
|
|
|
12 |
|
|
#include "test/common/type_traits.hpp" |
13 |
|
|
|
14 |
|
|
extern "C" { |
15 |
|
|
|
16 |
|
|
/// Converts single-precision floating-point to half-precision floating-point. |
17 |
|
|
/// |
18 |
|
|
/// @params[in] value The single-precision floating-point value. |
19 |
|
|
/// |
20 |
|
|
/// @return The half-precision floating-point value reinterpreted as 16-bit unsigned integer. |
21 |
|
|
uint16_t kai_test_float16_from_float(float value); |
22 |
|
|
|
23 |
|
|
/// Converts half-precision floating-point to single-precision floating-point. |
24 |
|
|
/// |
25 |
|
|
/// @params[in] The half-precision floating-point value reinterpreted as 16-bit unsigned integer. |
26 |
|
|
/// |
27 |
|
|
/// @return The single-precision floating-point value. |
28 |
|
|
float kai_test_float_from_float16(uint16_t value); |
29 |
|
|
|
30 |
|
|
/// Adds two half-precision floating-point numbers. |
31 |
|
|
/// |
32 |
|
|
/// All half-precision floating-point values are reinterpreted as 16-bit unsigned integer. |
33 |
|
|
/// |
34 |
|
|
/// @param[in] lhs The LHS operand. |
35 |
|
|
/// @param[in] rhs The RHS operand. |
36 |
|
|
/// |
37 |
|
|
/// @return The result of the addition. |
38 |
|
|
uint16_t kai_test_float16_add(uint16_t lhs, uint16_t rhs); |
39 |
|
|
|
40 |
|
|
/// Subtracts two half-precision floating-point numbers. |
41 |
|
|
/// |
42 |
|
|
/// All half-precision floating-point values are reinterpreted as 16-bit unsigned integer. |
43 |
|
|
/// |
44 |
|
|
/// @param[in] lhs The LHS operand. |
45 |
|
|
/// @param[in] rhs The RHS operand. |
46 |
|
|
/// |
47 |
|
|
/// @return The result of the subtraction. |
48 |
|
|
uint16_t kai_test_float16_sub(uint16_t lhs, uint16_t rhs); |
49 |
|
|
|
50 |
|
|
/// Multiplies two half-precision floating-point numbers. |
51 |
|
|
/// |
52 |
|
|
/// All half-precision floating-point values are reinterpreted as 16-bit unsigned integer. |
53 |
|
|
/// |
54 |
|
|
/// @param[in] lhs The LHS operand. |
55 |
|
|
/// @param[in] rhs The RHS operand. |
56 |
|
|
/// |
57 |
|
|
/// @return The result of the multiplication. |
58 |
|
|
uint16_t kai_test_float16_mul(uint16_t lhs, uint16_t rhs); |
59 |
|
|
|
60 |
|
|
/// Divides two half-precision floating-point numbers. |
61 |
|
|
/// |
62 |
|
|
/// All half-precision floating-point values are reinterpreted as 16-bit unsigned integer. |
63 |
|
|
/// |
64 |
|
|
/// @param[in] lhs The LHS operand. |
65 |
|
|
/// @param[in] rhs The RHS operand. |
66 |
|
|
/// |
67 |
|
|
/// @return The result of the division. |
68 |
|
|
uint16_t kai_test_float16_div(uint16_t lhs, uint16_t rhs); |
69 |
|
|
|
70 |
|
|
/// Determines whether the first operand is less than the second operand. |
71 |
|
|
/// |
72 |
|
|
/// Both operands are half-precision floating-point reinterpreted as 16-bit unsigned integers. |
73 |
|
|
/// |
74 |
|
|
/// @param[in] lhs The LHS operand. |
75 |
|
|
/// @param[in] rhs The RHS operand. |
76 |
|
|
/// |
77 |
|
|
/// @return `true` if the first operand is less than the second operand, otherwise `true`. |
78 |
|
|
bool kai_test_float16_lt(uint16_t lhs, uint16_t rhs); |
79 |
|
|
|
80 |
|
|
/// Determines whether the first operand is greater than the second operand. |
81 |
|
|
/// |
82 |
|
|
/// Both operands are half-precision floating-point reinterpreted as 16-bit unsigned integers. |
83 |
|
|
/// |
84 |
|
|
/// @param[in] lhs The LHS operand. |
85 |
|
|
/// @param[in] rhs The RHS operand. |
86 |
|
|
/// |
87 |
|
|
/// @return `true` if the first operand is greater than the second operand, otherwise `true`. |
88 |
|
|
bool kai_test_float16_gt(uint16_t lhs, uint16_t rhs); |
89 |
|
|
|
90 |
|
|
} // extern "C" |
91 |
|
|
|
92 |
|
|
namespace kai::test { |
93 |
|
|
|
94 |
|
|
/// Half-precision floating-point. |
95 |
|
|
class Float16 { |
96 |
|
|
public: |
97 |
|
|
/// Constructor. |
98 |
|
1810869321 |
constexpr Float16() = default; |
99 |
|
|
|
100 |
|
|
/// Creates a new half-precision floating-point value from the specified |
101 |
|
|
/// single-precision floating-point value. |
102 |
|
|
/// |
103 |
|
|
/// @param[in] value The single-precision floating-point value. |
104 |
|
103717472 |
explicit Float16(float value) : m_data(kai_test_float16_from_float(value)) { |
105 |
|
103717472 |
} |
106 |
|
|
|
107 |
|
|
/// Creates a new half-precision floating-point value from the raw data. |
108 |
|
|
/// |
109 |
|
|
/// @param[in] data The binary representation of the floating-point value. |
110 |
|
|
/// |
111 |
|
|
/// @return The half-precision floating-point value. |
112 |
|
|
static constexpr Float16 from_binary(uint16_t data) { |
113 |
|
|
Float16 value{}; |
114 |
|
|
value.m_data = data; |
115 |
|
|
return value; |
116 |
|
|
} |
117 |
|
|
|
118 |
|
|
/// Assigns to the specified numeric value. |
119 |
|
|
template <typename T, std::enable_if_t<is_arithmetic<T>, bool> = true> |
120 |
|
|
Float16& operator=(T value) { |
121 |
|
|
const auto value_f32 = static_cast<float>(value); |
122 |
|
|
m_data = kai_test_float16_from_float(value_f32); |
123 |
|
|
return *this; |
124 |
|
|
} |
125 |
|
|
|
126 |
|
|
/// Converts to single-precision floating-point. |
127 |
|
114399181 |
explicit operator float() const { |
128 |
|
114399181 |
return kai_test_float_from_float16(m_data); |
129 |
|
|
} |
130 |
|
|
|
131 |
|
|
/// Addition and assignment operator. |
132 |
|
600720322 |
Float16& operator+=(Float16 rhs) { |
133 |
|
600720322 |
m_data = kai_test_float16_add(m_data, rhs.m_data); |
134 |
|
600720322 |
return *this; |
135 |
|
|
} |
136 |
|
|
|
137 |
|
|
/// Subtraction and assignment operator. |
138 |
|
1 |
Float16& operator-=(Float16 rhs) { |
139 |
|
1 |
m_data = kai_test_float16_sub(m_data, rhs.m_data); |
140 |
|
1 |
return *this; |
141 |
|
|
} |
142 |
|
|
|
143 |
|
|
/// Multiplication and assignment operator. |
144 |
|
1 |
Float16& operator*=(Float16 rhs) { |
145 |
|
1 |
m_data = kai_test_float16_mul(m_data, rhs.m_data); |
146 |
|
1 |
return *this; |
147 |
|
|
} |
148 |
|
|
|
149 |
|
|
/// Division and assignment operator. |
150 |
|
1 |
Float16& operator/=(Float16 rhs) { |
151 |
|
1 |
m_data = kai_test_float16_div(m_data, rhs.m_data); |
152 |
|
1 |
return *this; |
153 |
|
|
} |
154 |
|
|
|
155 |
|
|
private: |
156 |
|
|
/// Addition operator. |
157 |
|
2902782 |
[[nodiscard]] friend Float16 operator+(Float16 lhs, Float16 rhs) { |
158 |
|
2902782 |
Float16 value; |
159 |
|
2902782 |
value.m_data = kai_test_float16_add(lhs.m_data, rhs.m_data); |
160 |
|
2902782 |
return value; |
161 |
|
|
} |
162 |
|
|
|
163 |
|
|
/// Subtraction operator. |
164 |
|
1 |
[[nodiscard]] friend Float16 operator-(Float16 lhs, Float16 rhs) { |
165 |
|
1 |
Float16 value; |
166 |
|
1 |
value.m_data = kai_test_float16_sub(lhs.m_data, rhs.m_data); |
167 |
|
1 |
return value; |
168 |
|
|
} |
169 |
|
|
|
170 |
|
|
/// Multiplication operator. |
171 |
|
600720322 |
[[nodiscard]] friend Float16 operator*(Float16 lhs, Float16 rhs) { |
172 |
|
600720322 |
Float16 value; |
173 |
|
600720322 |
value.m_data = kai_test_float16_mul(lhs.m_data, rhs.m_data); |
174 |
|
600720322 |
return value; |
175 |
|
|
} |
176 |
|
|
|
177 |
|
|
/// Division operator. |
178 |
|
1 |
[[nodiscard]] friend Float16 operator/(Float16 lhs, Float16 rhs) { |
179 |
|
1 |
Float16 value; |
180 |
|
1 |
value.m_data = kai_test_float16_div(lhs.m_data, rhs.m_data); |
181 |
|
1 |
return value; |
182 |
|
|
} |
183 |
|
|
|
184 |
|
|
/// Equality operator. |
185 |
|
11 |
[[nodiscard]] friend bool operator==(Float16 lhs, Float16 rhs) { |
186 |
|
11 |
return lhs.m_data == rhs.m_data; |
187 |
|
|
} |
188 |
|
|
|
189 |
|
|
/// Unequality operator. |
190 |
|
3 |
[[nodiscard]] friend bool operator!=(Float16 lhs, Float16 rhs) { |
191 |
|
3 |
return lhs.m_data != rhs.m_data; |
192 |
|
|
} |
193 |
|
|
|
194 |
|
|
/// Less operator. |
195 |
|
6 |
[[nodiscard]] friend bool operator<(Float16 lhs, Float16 rhs) { |
196 |
|
6 |
return kai_test_float16_lt(lhs.m_data, rhs.m_data); |
197 |
|
|
} |
198 |
|
|
|
199 |
|
|
/// Greater operator. |
200 |
|
6 |
[[nodiscard]] friend bool operator>(Float16 lhs, Float16 rhs) { |
201 |
|
6 |
return kai_test_float16_gt(lhs.m_data, rhs.m_data); |
202 |
|
|
} |
203 |
|
|
|
204 |
|
|
/// Less-or-equal operator. |
205 |
|
3 |
[[nodiscard]] friend bool operator<=(Float16 lhs, Float16 rhs) { |
206 |
|
3 |
return !(lhs > rhs); |
207 |
|
|
} |
208 |
|
|
|
209 |
|
|
/// Greater-or-equal operator. |
210 |
|
3 |
[[nodiscard]] friend bool operator>=(Float16 lhs, Float16 rhs) { |
211 |
|
3 |
return !(lhs < rhs); |
212 |
|
|
} |
213 |
|
|
|
214 |
|
603623107 |
uint16_t m_data{0}; |
215 |
|
|
}; |
216 |
|
|
|
217 |
|
|
/// Writes the value to the output stream. |
218 |
|
|
/// |
219 |
|
|
/// @param[in] os Output stream to be written to. |
220 |
|
|
/// @param[in] value Value to be written. |
221 |
|
|
/// |
222 |
|
|
/// @return The output stream. |
223 |
|
|
std::ostream& operator<<(std::ostream& os, Float16 value); |
224 |
|
|
|
225 |
|
|
} // namespace kai::test |
226 |
|
|
|