Line |
Branch |
Exec |
Source |
1 |
|
|
// |
2 |
|
|
// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> |
3 |
|
|
// |
4 |
|
|
// SPDX-License-Identifier: Apache-2.0 |
5 |
|
|
// |
6 |
|
|
#pragma once |
7 |
|
|
|
8 |
|
|
#if defined(__ARM_NEON) |
9 |
|
|
#include <arm_neon.h> |
10 |
|
|
#endif // defined(__ARM_NEON) |
11 |
|
|
|
12 |
|
|
#include <stddef.h> |
13 |
|
|
#include <stdint.h> |
14 |
|
|
#include <stdio.h> |
15 |
|
|
#include <stdlib.h> |
16 |
|
|
#include <string.h> |
17 |
|
|
|
18 |
|
|
#ifdef __cplusplus |
19 |
|
|
extern "C" { |
20 |
|
|
#endif |
21 |
|
|
|
22 |
|
|
// NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) |
23 |
|
|
// |
24 |
|
|
// * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. |
25 |
|
|
// * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected. |
26 |
|
|
// * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting. |
27 |
|
|
#define KAI_ERROR(msg) \ |
28 |
|
|
do { \ |
29 |
|
|
fflush(stdout); \ |
30 |
|
|
fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \ |
31 |
|
|
abort(); \ |
32 |
|
|
} while (0) |
33 |
|
|
|
34 |
|
|
#define KAI_ASSERT_MSG(cond, msg) \ |
35 |
|
|
do { \ |
36 |
|
|
if (!(cond)) { \ |
37 |
|
|
KAI_ERROR(msg); \ |
38 |
|
|
} \ |
39 |
|
|
} while (0) |
40 |
|
|
|
41 |
|
|
// NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) |
42 |
|
|
|
43 |
|
|
#define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) |
44 |
|
|
|
45 |
|
|
#define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg) |
46 |
|
|
#define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond) |
47 |
|
|
|
48 |
|
|
#define KAI_ASSUME_MSG KAI_ASSERT_MSG |
49 |
|
|
#define KAI_ASSUME KAI_ASSERT |
50 |
|
|
#define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG |
51 |
|
|
#define KAI_ASSUME_IF KAI_ASSERT_IF |
52 |
|
|
|
53 |
|
|
#define KAI_UNUSED(x) (void)(x) |
54 |
|
|
#define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) |
55 |
|
|
#define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) |
56 |
|
|
|
57 |
|
|
/// Largest supported SME vector length in bytes |
58 |
|
|
#define KAI_SME_VEC_LENGTH_MAX_BYTES 256 // NOLINT(cppcoreguidelines-macro-to-enum,modernize-macro-to-enum) |
59 |
|
|
|
60 |
|
|
/// Gets the version of the project in the Major.Minor.Patch semantic versioning format. |
61 |
|
|
/// |
62 |
|
|
/// @return Project version as a string literal. |
63 |
|
1 |
inline const char* kai_get_version(void) { |
64 |
|
1 |
return "1.15.1"; |
65 |
|
|
} |
66 |
|
|
|
67 |
|
|
/// KleidiAI data types |
68 |
|
|
/// Format: <byte 3>(reserved)|<byte 2>(num-bytes)|<byte 1>(type)|<byte 0>(variant-type) |
69 |
|
|
enum kai_datatype { |
70 |
|
|
kai_dt_unknown = 0x0000, |
71 |
|
|
kai_dt_f32 = 0x0411, |
72 |
|
|
kai_dt_f16 = 0x0212, |
73 |
|
|
kai_dt_bf16 = 0x0213, |
74 |
|
|
kai_dt_int32 = 0x0421, |
75 |
|
|
kai_dt_int16 = 0x0222, |
76 |
|
|
kai_dt_int8 = 0x0124, |
77 |
|
|
kai_dt_uint32 = 0x0431, |
78 |
|
|
kai_dt_uint16 = 0x0232, |
79 |
|
|
kai_dt_uint8 = 0x0134, |
80 |
|
|
kai_dt_bool = 0x0441 |
81 |
|
|
}; |
82 |
|
|
|
83 |
|
|
/// Gets number of bytes for a given data type |
84 |
|
|
/// @param[in] dt KleidiAI data type |
85 |
|
|
/// |
86 |
|
|
/// @return the numbers of bytes for the data type |
87 |
|
9126 |
inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) { |
88 |
|
9126 |
return (size_t)(dt >> 8); |
89 |
|
|
} |
90 |
|
|
|
91 |
|
|
/// Converts a scalar f16 value to f32 |
92 |
|
|
/// @param[in] f16 The f16 value |
93 |
|
|
/// |
94 |
|
|
/// @return the f32 value |
95 |
|
|
#if defined(__ARM_NEON) |
96 |
|
|
inline static float kai_cast_f32_f16(uint16_t f16) { |
97 |
|
|
float16_t f32 = 0; |
98 |
|
|
memcpy(&f32, &f16, sizeof(uint16_t)); |
99 |
|
|
return (float)f32; |
100 |
|
|
} |
101 |
|
|
#endif |
102 |
|
|
|
103 |
|
|
/// Converts a scalar bf16 value to f32 |
104 |
|
|
/// @param[in] bf16 The f16 value |
105 |
|
|
/// |
106 |
|
|
/// @return the f32 value |
107 |
|
1685863882 |
inline static float kai_cast_f32_bf16(uint16_t bf16) { |
108 |
|
1685863882 |
const uint32_t i32 = (bf16 << 16); |
109 |
|
1685863882 |
float f32 = 0; |
110 |
|
1685863882 |
memcpy(&f32, &i32, sizeof(i32)); |
111 |
|
3371727764 |
return f32; |
112 |
|
1685863882 |
} |
113 |
|
|
|
114 |
|
|
/// Converts a f32 value to bf16 |
115 |
|
|
/// @param[in] f32 The f32 value |
116 |
|
|
/// |
117 |
|
|
/// @return the bf16 value |
118 |
|
128 |
inline static uint16_t kai_cast_bf16_f32(float f32) { |
119 |
|
128 |
uint16_t bf16 = 0; |
120 |
|
|
#ifdef __ARM_FEATURE_BF16 |
121 |
|
128 |
__asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32)); |
122 |
|
|
#else |
123 |
|
|
const uint32_t* i32 = (uint32_t*)(&f32); |
124 |
|
|
bf16 = (*i32 >> 16); |
125 |
|
|
#endif |
126 |
|
256 |
return bf16; |
127 |
|
128 |
} |
128 |
|
|
|
129 |
|
|
/// Converts a scalar f32 value to f16 |
130 |
|
|
/// @param[in] f32 The f32 value |
131 |
|
|
/// |
132 |
|
|
/// @return the f16 value |
133 |
|
|
#if defined(__ARM_NEON) |
134 |
|
4540 |
inline static uint16_t kai_cast_f16_f32(float f32) { |
135 |
|
4540 |
uint16_t f16 = 0; |
136 |
|
4540 |
float16_t tmp = (float16_t)f32; |
137 |
|
4540 |
memcpy(&f16, &tmp, sizeof(uint16_t)); |
138 |
|
9080 |
return f16; |
139 |
|
4540 |
} |
140 |
|
|
#endif |
141 |
|
|
|
142 |
|
10736077 |
inline static size_t kai_roundup(size_t a, size_t b) { |
143 |
|
10736077 |
return ((a + b - 1) / b) * b; |
144 |
|
|
} |
145 |
|
|
|
146 |
|
|
#if defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64) |
147 |
|
|
/// Gets the SME vector length for 8-bit elements. |
148 |
|
|
uint64_t kai_get_sme_vector_length_u8(void); |
149 |
|
|
|
150 |
|
|
/// Gets the SME vector length for 16-bit elements. |
151 |
|
230550 |
inline static uint64_t kai_get_sme_vector_length_u16(void) { |
152 |
|
230550 |
return kai_get_sme_vector_length_u8() / 2; |
153 |
|
|
} |
154 |
|
|
|
155 |
|
|
/// Gets the SME vector length for 32-bit elements. |
156 |
|
273336 |
inline static uint64_t kai_get_sme_vector_length_u32(void) { |
157 |
|
273336 |
return kai_get_sme_vector_length_u8() / 4; |
158 |
|
|
} |
159 |
|
|
#endif // defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64) |
160 |
|
|
|
161 |
|
|
/// Extends the sign bit of int 4-bit value (stored in int8_t variable) |
162 |
|
|
/// @param[in] value The 4-bit int value |
163 |
|
|
/// |
164 |
|
|
/// @return the int8_t value with sign extended |
165 |
|
5788242 |
inline static int8_t kai_ext_sign_i8_i4(int8_t value) { |
166 |
|
|
// Make sure value holds correct int4 value |
167 |
|
− |
KAI_ASSERT(value <= 0xF); |
168 |
|
|
|
169 |
|
5788242 |
return (value ^ 0x8) - 8; // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
170 |
|
|
} |
171 |
|
|
|
172 |
|
|
/// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 8-bit with per-channel quantization) |
173 |
|
|
struct kai_rhs_pack_qsi8cx_params { |
174 |
|
|
int32_t lhs_zero_point; ///< LHS Matrix quantization zero-point |
175 |
|
|
float scale_multiplier; ///< Product of input (refers to lhs and rhs) and output quantization scales. |
176 |
|
|
}; |
177 |
|
|
|
178 |
|
|
/// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 4-bit with per-block quantizatio and s1s0 |
179 |
|
|
/// nibble ordering) |
180 |
|
|
struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params { |
181 |
|
|
int8_t lhs_zero_point; |
182 |
|
|
uint8_t rhs_zero_point; |
183 |
|
|
enum kai_datatype scale_dt; |
184 |
|
|
}; |
185 |
|
|
|
186 |
|
|
/// Parameter struct for RHS matrix packing |
187 |
|
|
struct kai_rhs_pack_qs4cxs1s0_param { |
188 |
|
|
int8_t lhs_zero_point; ///< LHS Matrix quantization zero-point |
189 |
|
|
uint8_t rhs_zero_point; ///< RHS Matrix quantization zero-point |
190 |
|
|
}; |
191 |
|
|
|
192 |
|
|
/// Requantization and clamp parameters for GEMM/GEMV output stage. |
193 |
|
|
struct kai_matmul_requantize32_params { |
194 |
|
|
int32_t min_value; ///< Minimum output value. |
195 |
|
|
int32_t max_value; ///< Maximum output value. |
196 |
|
|
int32_t output_zero_point; ///< Output quantization zero point. |
197 |
|
|
}; |
198 |
|
|
|
199 |
|
|
#ifdef __cplusplus |
200 |
|
|
} |
201 |
|
|
#endif |
202 |
|
|
|