test/common/float16.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <cstdint> | ||
| 10 | #include <iosfwd> | ||
| 11 | |||
| 12 | #include "test/common/type_traits.hpp" | ||
| 13 | |||
| 14 | extern "C" { | ||
| 15 | |||
| 16 | /// Converts single-precision floating-point to half-precision floating-point. | ||
| 17 | /// | ||
| 18 | /// @params[in] value The single-precision floating-point value. | ||
| 19 | /// | ||
| 20 | /// @return The half-precision floating-point value reinterpreted as 16-bit unsigned integer. | ||
| 21 | uint16_t kai_test_float16_from_float(float value); | ||
| 22 | |||
| 23 | /// Converts half-precision floating-point to single-precision floating-point. | ||
| 24 | /// | ||
| 25 | /// @params[in] The half-precision floating-point value reinterpreted as 16-bit unsigned integer. | ||
| 26 | /// | ||
| 27 | /// @return The single-precision floating-point value. | ||
| 28 | float kai_test_float_from_float16(uint16_t value); | ||
| 29 | |||
| 30 | /// Adds two half-precision floating-point numbers. | ||
| 31 | /// | ||
| 32 | /// All half-precision floating-point values are reinterpreted as 16-bit unsigned integer. | ||
| 33 | /// | ||
| 34 | /// @param[in] lhs The LHS operand. | ||
| 35 | /// @param[in] rhs The RHS operand. | ||
| 36 | /// | ||
| 37 | /// @return The result of the addition. | ||
| 38 | uint16_t kai_test_float16_add(uint16_t lhs, uint16_t rhs); | ||
| 39 | |||
| 40 | /// Subtracts two half-precision floating-point numbers. | ||
| 41 | /// | ||
| 42 | /// All half-precision floating-point values are reinterpreted as 16-bit unsigned integer. | ||
| 43 | /// | ||
| 44 | /// @param[in] lhs The LHS operand. | ||
| 45 | /// @param[in] rhs The RHS operand. | ||
| 46 | /// | ||
| 47 | /// @return The result of the subtraction. | ||
| 48 | uint16_t kai_test_float16_sub(uint16_t lhs, uint16_t rhs); | ||
| 49 | |||
| 50 | /// Multiplies two half-precision floating-point numbers. | ||
| 51 | /// | ||
| 52 | /// All half-precision floating-point values are reinterpreted as 16-bit unsigned integer. | ||
| 53 | /// | ||
| 54 | /// @param[in] lhs The LHS operand. | ||
| 55 | /// @param[in] rhs The RHS operand. | ||
| 56 | /// | ||
| 57 | /// @return The result of the multiplication. | ||
| 58 | uint16_t kai_test_float16_mul(uint16_t lhs, uint16_t rhs); | ||
| 59 | |||
| 60 | /// Divides two half-precision floating-point numbers. | ||
| 61 | /// | ||
| 62 | /// All half-precision floating-point values are reinterpreted as 16-bit unsigned integer. | ||
| 63 | /// | ||
| 64 | /// @param[in] lhs The LHS operand. | ||
| 65 | /// @param[in] rhs The RHS operand. | ||
| 66 | /// | ||
| 67 | /// @return The result of the division. | ||
| 68 | uint16_t kai_test_float16_div(uint16_t lhs, uint16_t rhs); | ||
| 69 | |||
| 70 | /// Determines whether the first operand is less than the second operand. | ||
| 71 | /// | ||
| 72 | /// Both operands are half-precision floating-point reinterpreted as 16-bit unsigned integers. | ||
| 73 | /// | ||
| 74 | /// @param[in] lhs The LHS operand. | ||
| 75 | /// @param[in] rhs The RHS operand. | ||
| 76 | /// | ||
| 77 | /// @return `true` if the first operand is less than the second operand, otherwise `true`. | ||
| 78 | bool kai_test_float16_lt(uint16_t lhs, uint16_t rhs); | ||
| 79 | |||
| 80 | /// Determines whether the first operand is greater than the second operand. | ||
| 81 | /// | ||
| 82 | /// Both operands are half-precision floating-point reinterpreted as 16-bit unsigned integers. | ||
| 83 | /// | ||
| 84 | /// @param[in] lhs The LHS operand. | ||
| 85 | /// @param[in] rhs The RHS operand. | ||
| 86 | /// | ||
| 87 | /// @return `true` if the first operand is greater than the second operand, otherwise `true`. | ||
| 88 | bool kai_test_float16_gt(uint16_t lhs, uint16_t rhs); | ||
| 89 | |||
| 90 | } // extern "C" | ||
| 91 | |||
| 92 | namespace kai::test { | ||
| 93 | |||
| 94 | /// Half-precision floating-point. | ||
| 95 | class Float16 { | ||
| 96 | public: | ||
| 97 | /// Constructor. | ||
| 98 | 25 | constexpr Float16() = default; | |
| 99 | |||
| 100 | /// Creates a new half-precision floating-point value from the specified | ||
| 101 | /// single-precision floating-point value. | ||
| 102 | /// | ||
| 103 | /// @param[in] value The single-precision floating-point value. | ||
| 104 | 109855622 | explicit Float16(float value) : m_data(kai_test_float16_from_float(value)) { | |
| 105 | 109855622 | } | |
| 106 | |||
| 107 | /// Creates a new half-precision floating-point value from the raw data. | ||
| 108 | /// | ||
| 109 | /// @param[in] data The binary representation of the floating-point value. | ||
| 110 | /// | ||
| 111 | /// @return The half-precision floating-point value. | ||
| 112 | static constexpr Float16 from_binary(uint16_t data) { | ||
| 113 | Float16 value{}; | ||
| 114 | value.m_data = data; | ||
| 115 | return value; | ||
| 116 | } | ||
| 117 | |||
| 118 | /// Assigns to the specified numeric value. | ||
| 119 | template <typename T, std::enable_if_t<is_arithmetic<T>, bool> = true> | ||
| 120 | Float16& operator=(T value) { | ||
| 121 | const auto value_f32 = static_cast<float>(value); | ||
| 122 | m_data = kai_test_float16_from_float(value_f32); | ||
| 123 | return *this; | ||
| 124 | } | ||
| 125 | |||
| 126 | /// Converts to single-precision floating-point. | ||
| 127 | 712667014 | explicit operator float() const { | |
| 128 | 712667014 | return kai_test_float_from_float16(m_data); | |
| 129 | } | ||
| 130 | |||
| 131 | /// Addition and assignment operator. | ||
| 132 | 2294 | Float16& operator+=(Float16 rhs) { | |
| 133 | 2294 | m_data = kai_test_float16_add(m_data, rhs.m_data); | |
| 134 | 2294 | return *this; | |
| 135 | } | ||
| 136 | |||
| 137 | /// Subtraction and assignment operator. | ||
| 138 | 2 | Float16& operator-=(Float16 rhs) { | |
| 139 | 2 | m_data = kai_test_float16_sub(m_data, rhs.m_data); | |
| 140 | 2 | return *this; | |
| 141 | } | ||
| 142 | |||
| 143 | /// Multiplication and assignment operator. | ||
| 144 | 2 | Float16& operator*=(Float16 rhs) { | |
| 145 | 2 | m_data = kai_test_float16_mul(m_data, rhs.m_data); | |
| 146 | 2 | return *this; | |
| 147 | } | ||
| 148 | |||
| 149 | /// Division and assignment operator. | ||
| 150 | 2 | Float16& operator/=(Float16 rhs) { | |
| 151 | 2 | m_data = kai_test_float16_div(m_data, rhs.m_data); | |
| 152 | 2 | return *this; | |
| 153 | } | ||
| 154 | |||
| 155 | private: | ||
| 156 | /// Addition operator. | ||
| 157 | 2 | [[nodiscard]] friend Float16 operator+(Float16 lhs, Float16 rhs) { | |
| 158 | 2 | Float16 value; | |
| 159 | 2 | value.m_data = kai_test_float16_add(lhs.m_data, rhs.m_data); | |
| 160 | 2 | return value; | |
| 161 | } | ||
| 162 | |||
| 163 | /// Subtraction operator. | ||
| 164 | 2 | [[nodiscard]] friend Float16 operator-(Float16 lhs, Float16 rhs) { | |
| 165 | 2 | Float16 value; | |
| 166 | 2 | value.m_data = kai_test_float16_sub(lhs.m_data, rhs.m_data); | |
| 167 | 2 | return value; | |
| 168 | } | ||
| 169 | |||
| 170 | /// Multiplication operator. | ||
| 171 | 2 | [[nodiscard]] friend Float16 operator*(Float16 lhs, Float16 rhs) { | |
| 172 | 2 | Float16 value; | |
| 173 | 2 | value.m_data = kai_test_float16_mul(lhs.m_data, rhs.m_data); | |
| 174 | 2 | return value; | |
| 175 | } | ||
| 176 | |||
| 177 | /// Division operator. | ||
| 178 | 2 | [[nodiscard]] friend Float16 operator/(Float16 lhs, Float16 rhs) { | |
| 179 | 2 | Float16 value; | |
| 180 | 2 | value.m_data = kai_test_float16_div(lhs.m_data, rhs.m_data); | |
| 181 | 2 | return value; | |
| 182 | } | ||
| 183 | |||
| 184 | /// Equality operator. | ||
| 185 | 22 | [[nodiscard]] friend bool operator==(Float16 lhs, Float16 rhs) { | |
| 186 | 22 | return lhs.m_data == rhs.m_data; | |
| 187 | } | ||
| 188 | |||
| 189 | /// Unequality operator. | ||
| 190 | 6 | [[nodiscard]] friend bool operator!=(Float16 lhs, Float16 rhs) { | |
| 191 | 6 | return lhs.m_data != rhs.m_data; | |
| 192 | } | ||
| 193 | |||
| 194 | /// Less operator. | ||
| 195 | 12 | [[nodiscard]] friend bool operator<(Float16 lhs, Float16 rhs) { | |
| 196 | 12 | return kai_test_float16_lt(lhs.m_data, rhs.m_data); | |
| 197 | } | ||
| 198 | |||
| 199 | /// Greater operator. | ||
| 200 | 12 | [[nodiscard]] friend bool operator>(Float16 lhs, Float16 rhs) { | |
| 201 | 12 | return kai_test_float16_gt(lhs.m_data, rhs.m_data); | |
| 202 | } | ||
| 203 | |||
| 204 | /// Less-or-equal operator. | ||
| 205 | 6 | [[nodiscard]] friend bool operator<=(Float16 lhs, Float16 rhs) { | |
| 206 | 6 | return !(lhs > rhs); | |
| 207 | } | ||
| 208 | |||
| 209 | /// Greater-or-equal operator. | ||
| 210 | 6 | [[nodiscard]] friend bool operator>=(Float16 lhs, Float16 rhs) { | |
| 211 | 6 | return !(lhs < rhs); | |
| 212 | } | ||
| 213 | |||
| 214 | 10 | uint16_t m_data{0}; | |
| 215 | }; | ||
| 216 | |||
| 217 | /// Writes the value to the output stream. | ||
| 218 | /// | ||
| 219 | /// @param[in] os Output stream to be written to. | ||
| 220 | /// @param[in] value Value to be written. | ||
| 221 | /// | ||
| 222 | /// @return The output stream. | ||
| 223 | std::ostream& operator<<(std::ostream& os, Float16 value); | ||
| 224 | |||
| 225 | } // namespace kai::test | ||
| 226 |