kai/kai_common.h
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | #pragma once | ||
| 7 | |||
| 8 | #if defined(__ARM_NEON) | ||
| 9 | #include <arm_neon.h> | ||
| 10 | #endif // defined(__ARM_NEON) | ||
| 11 | |||
| 12 | #include <stddef.h> | ||
| 13 | #include <stdint.h> | ||
| 14 | #include <stdio.h> | ||
| 15 | #include <stdlib.h> | ||
| 16 | #include <string.h> | ||
| 17 | |||
| 18 | #ifdef __cplusplus | ||
| 19 | extern "C" { | ||
| 20 | #endif | ||
| 21 | |||
| 22 | // NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) | ||
| 23 | // | ||
| 24 | // * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. | ||
| 25 | // * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected. | ||
| 26 | // * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting. | ||
| 27 | |||
| 28 | #ifndef KLEIDIAI_ERROR_TRAP | ||
| 29 | #define KLEIDIAI_ERROR_TRAP 0 | ||
| 30 | #endif | ||
| 31 | |||
| 32 | #ifndef KLEIDIAI_HAS_BUILTIN_UNREACHABLE | ||
| 33 | #define KLEIDIAI_HAS_BUILTIN_UNREACHABLE 0 | ||
| 34 | #endif | ||
| 35 | |||
| 36 | #ifndef KLEIDIAI_HAS_BUILTIN_ASSUME0 | ||
| 37 | #define KLEIDIAI_HAS_BUILTIN_ASSUME0 0 | ||
| 38 | #endif | ||
| 39 | |||
| 40 | #if KLEIDIAI_ERROR_TRAP | ||
| 41 | #define KAI_ABORT() __builtin_trap() | ||
| 42 | #else | ||
| 43 | #define KAI_ABORT() abort() | ||
| 44 | #endif | ||
| 45 | |||
| 46 | #if KLEIDIAI_HAS_BUILTIN_UNREACHABLE | ||
| 47 | #define KAI_UNREACHABLE() __builtin_unreachable() | ||
| 48 | #elif KLEIDIAI_HAS_BUILTIN_ASSUME0 | ||
| 49 | #define KAI_UNREACHABLE() __assume(0); | ||
| 50 | #else | ||
| 51 | #define KAI_UNREACHABLE() | ||
| 52 | #endif | ||
| 53 | |||
| 54 | #ifdef NDEBUG | ||
| 55 | #define KAI_ERROR(msg) \ | ||
| 56 | do { \ | ||
| 57 | KAI_UNUSED(msg); \ | ||
| 58 | KAI_ABORT(); \ | ||
| 59 | } while (0) | ||
| 60 | |||
| 61 | #define KAI_ASSERT_MSG(cond, msg) \ | ||
| 62 | do { \ | ||
| 63 | KAI_UNUSED(msg); \ | ||
| 64 | if (!(cond)) { \ | ||
| 65 | KAI_UNREACHABLE(); \ | ||
| 66 | } \ | ||
| 67 | } while (0) | ||
| 68 | #else | ||
| 69 | #define KAI_ERROR(msg) \ | ||
| 70 | do { \ | ||
| 71 | fflush(stdout); \ | ||
| 72 | fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \ | ||
| 73 | KAI_ABORT(); \ | ||
| 74 | } while (0) | ||
| 75 | |||
| 76 | #define KAI_ASSERT_MSG(cond, msg) KAI_ASSERT_ALWAYS_MSG(cond, msg) | ||
| 77 | #endif // NDEBUG | ||
| 78 | |||
| 79 | #define KAI_ASSERT_ALWAYS_MSG(cond, msg) \ | ||
| 80 | do { \ | ||
| 81 | if (!(cond)) { \ | ||
| 82 | KAI_ERROR(msg); \ | ||
| 83 | } \ | ||
| 84 | } while (0) | ||
| 85 | |||
| 86 | // NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) | ||
| 87 | |||
| 88 | /// KAI_ASSERT* is used for logic sanity checking in the program | ||
| 89 | /// flow. Checks are optimized away in release builds same as | ||
| 90 | /// `assert` | ||
| 91 | #define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) | ||
| 92 | #define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg) | ||
| 93 | #define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond) | ||
| 94 | |||
| 95 | /// `KAI_ASSERT_ALWAYS*` is same as `KAI_ASSERT*`, but doesn't get removed by `NDEBUG` | ||
| 96 | #define KAI_ASSERT_ALWAYS(cond) KAI_ASSERT_ALWAYS_MSG(cond, #cond) | ||
| 97 | #define KAI_ASSERT_ALWAYS_IF_MSG(precond, cond, msg) KAI_ASSERT_ALWAYS_MSG(!(precond) || (cond), msg) | ||
| 98 | #define KAI_ASSERT_ALWAYS_IF(precond, cond) KAI_ASSERT_ALWAYS_IF_MSG(precond, cond, #precond " |-> " #cond) | ||
| 99 | |||
| 100 | /// KAI_ASSUME* is used for function pre-condition checking, similar to `[[assume]]` in C++23. | ||
| 101 | /// So KAI_ASSUME should be used directly on the function parameters, rather than inside | ||
| 102 | /// function logic. | ||
| 103 | #define KAI_ASSUME_MSG KAI_ASSERT_MSG | ||
| 104 | #define KAI_ASSUME KAI_ASSERT | ||
| 105 | #define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG | ||
| 106 | #define KAI_ASSUME_IF KAI_ASSERT_IF | ||
| 107 | |||
| 108 | /// `KAI_ASSUME_ALWAYS*` is same as `KAI_ASSUME*`, but doesn't get removed by `NDEBUG` | ||
| 109 | #define KAI_ASSUME_ALWAYS_MSG KAI_ASSERT_ALWAYS_MSG | ||
| 110 | #define KAI_ASSUME_ALWAYS KAI_ASSERT_ALWAYS | ||
| 111 | #define KAI_ASSUME_ALWAYS_IF_MSG KAI_ASSERT_ALWAYS_IF_MSG | ||
| 112 | #define KAI_ASSUME_ALWAYS_IF KAI_ASSERT_ALWAYS_IF | ||
| 113 | |||
| 114 | /// Indicate that result of `x` is unused | ||
| 115 | #define KAI_UNUSED(x) (void)(x) | ||
| 116 | |||
| 117 | /// Return minimum or maximum of `a` and `b` | ||
| 118 | #define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) | ||
| 119 | #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) | ||
| 120 | |||
| 121 | /// Largest supported SME vector length in bytes | ||
| 122 | #define KAI_SME_VEC_LENGTH_MAX_BYTES 256 // NOLINT(cppcoreguidelines-macro-to-enum,modernize-macro-to-enum) | ||
| 123 | |||
| 124 | /// Gets the version of the project in the Major.Minor.Patch semantic versioning format. | ||
| 125 | /// | ||
| 126 | /// @return Project version as a string literal. | ||
| 127 | 36 | inline const char* kai_get_version(void) { | |
| 128 | 36 | return "1.20.0"; | |
| 129 | } | ||
| 130 | |||
| 131 | /// KleidiAI data types | ||
| 132 | /// Format: <byte 3>(reserved)|<byte 2>(num-bytes)|<byte 1>(type)|<byte 0>(variant-type) | ||
| 133 | enum kai_datatype { | ||
| 134 | kai_dt_unknown = 0x0000, | ||
| 135 | kai_dt_f32 = 0x0411, | ||
| 136 | kai_dt_f16 = 0x0212, | ||
| 137 | kai_dt_bf16 = 0x0213, | ||
| 138 | kai_dt_int32 = 0x0421, | ||
| 139 | kai_dt_int16 = 0x0222, | ||
| 140 | kai_dt_int8 = 0x0124, | ||
| 141 | kai_dt_uint32 = 0x0431, | ||
| 142 | kai_dt_uint16 = 0x0232, | ||
| 143 | kai_dt_uint8 = 0x0134, | ||
| 144 | kai_dt_bool = 0x0441 | ||
| 145 | }; | ||
| 146 | |||
| 147 | /// Gets number of bytes for a given data type | ||
| 148 | /// @param[in] dt KleidiAI data type | ||
| 149 | /// | ||
| 150 | /// @return the numbers of bytes for the data type | ||
| 151 | 20756 | inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) { | |
| 152 | 20756 | return (size_t)(dt >> 8); | |
| 153 | } | ||
| 154 | |||
| 155 | /// Converts a scalar f16 value to f32 | ||
| 156 | /// @param[in] f16 The f16 value | ||
| 157 | /// | ||
| 158 | /// @return the f32 value | ||
| 159 | #if defined(__ARM_NEON) | ||
| 160 | inline static float kai_cast_f32_f16(uint16_t f16) { | ||
| 161 | float16_t f32 = 0; | ||
| 162 | memcpy(&f32, &f16, sizeof(uint16_t)); | ||
| 163 | return (float)f32; | ||
| 164 | } | ||
| 165 | #endif | ||
| 166 | |||
| 167 | /// Converts a scalar bf16 value to f32 | ||
| 168 | /// @param[in] bf16 The f16 value | ||
| 169 | /// | ||
| 170 | /// @return the f32 value | ||
| 171 | 5538698370 | inline static float kai_cast_f32_bf16(uint16_t bf16) { | |
| 172 | 5538698370 | const uint32_t i32 = (bf16 << 16); | |
| 173 | 5538698370 | float f32 = 0; | |
| 174 | 5538698370 | memcpy(&f32, &i32, sizeof(i32)); | |
| 175 | 46273448 | return f32; | |
| 176 | 5538698370 | } | |
| 177 | |||
| 178 | /// Converts a f32 value to bf16 | ||
| 179 | /// @param[in] f32 The f32 value | ||
| 180 | /// | ||
| 181 | /// @return the bf16 value | ||
| 182 | 816 | inline static uint16_t kai_cast_bf16_f32(float f32) { | |
| 183 | 816 | uint16_t bf16 = 0; | |
| 184 | #ifdef __ARM_FEATURE_BF16 | ||
| 185 | 816 | __asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32)); | |
| 186 | #else | ||
| 187 | const uint32_t* i32 = (uint32_t*)(&f32); | ||
| 188 | bf16 = (*i32 >> 16); | ||
| 189 | #endif | ||
| 190 | 1632 | return bf16; | |
| 191 | 816 | } | |
| 192 | |||
| 193 | /// Converts a scalar f32 value to f16 | ||
| 194 | /// @param[in] f32 The f32 value | ||
| 195 | /// | ||
| 196 | /// @return the f16 value | ||
| 197 | #if defined(__ARM_NEON) | ||
| 198 | 24716 | inline static uint16_t kai_cast_f16_f32(float f32) { | |
| 199 | 24716 | uint16_t f16 = 0; | |
| 200 | 24716 | float16_t tmp = (float16_t)f32; | |
| 201 | 24716 | memcpy(&f16, &tmp, sizeof(uint16_t)); | |
| 202 | 49432 | return f16; | |
| 203 | 24716 | } | |
| 204 | #endif | ||
| 205 | |||
| 206 | 42620722 | inline static size_t kai_roundup(size_t a, size_t b) { | |
| 207 | 42620722 | return ((a + b - 1) / b) * b; | |
| 208 | } | ||
| 209 | |||
| 210 | #if defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64) | ||
| 211 | /// Gets the SME vector length for 8-bit elements. | ||
| 212 | uint64_t kai_get_sme_vector_length_u8(void); | ||
| 213 | |||
| 214 | /// Gets the SME vector length for 16-bit elements. | ||
| 215 | 105528 | inline static uint64_t kai_get_sme_vector_length_u16(void) { | |
| 216 | 105528 | return kai_get_sme_vector_length_u8() / 2; | |
| 217 | } | ||
| 218 | |||
| 219 | /// Gets the SME vector length for 32-bit elements. | ||
| 220 | 275409 | inline static uint64_t kai_get_sme_vector_length_u32(void) { | |
| 221 | 275409 | return kai_get_sme_vector_length_u8() / 4; | |
| 222 | } | ||
| 223 | |||
| 224 | /// Commit ZA to lazy save buffer | ||
| 225 | void kai_commit_za(void); | ||
| 226 | #endif // defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64) | ||
| 227 | |||
| 228 | /// Gets the SVE vector length for 8-bit elements. | ||
| 229 | uint64_t kai_get_sve_vector_length_u8(void); | ||
| 230 | |||
| 231 | /// Gets the SVE vector length for 16-bit elements. | ||
| 232 | ✗ | inline static uint64_t kai_get_sve_vector_length_u16(void) { | |
| 233 | ✗ | return kai_get_sve_vector_length_u8() / 2; | |
| 234 | } | ||
| 235 | |||
| 236 | /// Gets the SVE vector length for 32-bit elements. | ||
| 237 | 1838 | inline static uint64_t kai_get_sve_vector_length_u32(void) { | |
| 238 | 1838 | return kai_get_sve_vector_length_u8() / 4; | |
| 239 | } | ||
| 240 | |||
| 241 | /// Extends the sign bit of int 4-bit value (stored in int8_t variable) | ||
| 242 | /// @param[in] value The 4-bit int value | ||
| 243 | /// | ||
| 244 | /// @return the int8_t value with sign extended | ||
| 245 | 34499574 | inline static int8_t kai_ext_sign_i8_i4(int8_t value) { | |
| 246 | // Make sure value holds correct int4 value | ||
| 247 | − | KAI_ASSERT(value <= 0xF); | |
| 248 | |||
| 249 | 34499574 | return (value ^ 0x8) - 8; // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | |
| 250 | } | ||
| 251 | |||
| 252 | /// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 8-bit with per-channel quantization) | ||
| 253 | struct kai_rhs_pack_qsi8cx_params { | ||
| 254 | int32_t lhs_zero_point; ///< LHS Matrix quantization zero-point | ||
| 255 | float scale_multiplier; ///< Product of input (refers to lhs and rhs) and output quantization scales. | ||
| 256 | }; | ||
| 257 | |||
| 258 | /// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 4-bit with per-block quantizatio and s1s0 | ||
| 259 | /// nibble ordering) | ||
| 260 | struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params { | ||
| 261 | int8_t lhs_zero_point; | ||
| 262 | uint8_t rhs_zero_point; | ||
| 263 | enum kai_datatype scale_dt; | ||
| 264 | }; | ||
| 265 | |||
| 266 | /// Parameter struct for RHS matrix packing (KxN variant for int4 qsi4c32p_qsu4c32s1s0) | ||
| 267 | struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params { | ||
| 268 | int8_t lhs_zero_point; | ||
| 269 | uint8_t rhs_zero_point; | ||
| 270 | enum kai_datatype scale_dt; | ||
| 271 | }; | ||
| 272 | |||
| 273 | /// Parameter struct for RHS matrix packing | ||
| 274 | struct kai_rhs_pack_qs4cxs1s0_param { | ||
| 275 | int8_t lhs_zero_point; ///< LHS Matrix quantization zero-point | ||
| 276 | uint8_t rhs_zero_point; ///< RHS Matrix quantization zero-point | ||
| 277 | }; | ||
| 278 | |||
| 279 | /// Requantization and clamp parameters for GEMM/GEMV output stage. | ||
| 280 | struct kai_matmul_requantize32_params { | ||
| 281 | int32_t min_value; ///< Minimum output value. | ||
| 282 | int32_t max_value; ///< Maximum output value. | ||
| 283 | int32_t output_zero_point; ///< Output quantization zero point. | ||
| 284 | }; | ||
| 285 | |||
| 286 | #ifdef __cplusplus | ||
| 287 | } | ||
| 288 | #endif | ||
| 289 |