KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 93.9% 31 / 1 / 34
Functions: 93.8% 15 / 0 / 16
Branches: -% 0 / 2 / 2

kai/kai_common.h
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #pragma once
7
8 #if defined(__ARM_NEON)
9 #include <arm_neon.h>
10 #endif // defined(__ARM_NEON)
11
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17
18 #ifdef __cplusplus
19 extern "C" {
20 #endif
21
22 // NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c)
23 //
24 // * cppcoreguidelines-avoid-do-while: do-while is necessary for macros.
25 // * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected.
26 // * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting.
27
28 #ifndef KLEIDIAI_ERROR_TRAP
29 #define KLEIDIAI_ERROR_TRAP 0
30 #endif
31
32 #ifndef KLEIDIAI_HAS_BUILTIN_UNREACHABLE
33 #define KLEIDIAI_HAS_BUILTIN_UNREACHABLE 0
34 #endif
35
36 #ifndef KLEIDIAI_HAS_BUILTIN_ASSUME0
37 #define KLEIDIAI_HAS_BUILTIN_ASSUME0 0
38 #endif
39
40 #if KLEIDIAI_ERROR_TRAP
41 #define KAI_ABORT() __builtin_trap()
42 #else
43 #define KAI_ABORT() abort()
44 #endif
45
46 #if KLEIDIAI_HAS_BUILTIN_UNREACHABLE
47 #define KAI_UNREACHABLE() __builtin_unreachable()
48 #elif KLEIDIAI_HAS_BUILTIN_ASSUME0
49 #define KAI_UNREACHABLE() __assume(0);
50 #else
51 #define KAI_UNREACHABLE()
52 #endif
53
54 #ifdef NDEBUG
55 #define KAI_ERROR(msg) \
56 do { \
57 KAI_UNUSED(msg); \
58 KAI_ABORT(); \
59 } while (0)
60
61 #define KAI_ASSERT_MSG(cond, msg) \
62 do { \
63 KAI_UNUSED(msg); \
64 if (!(cond)) { \
65 KAI_UNREACHABLE(); \
66 } \
67 } while (0)
68 #else
69 #define KAI_ERROR(msg) \
70 do { \
71 fflush(stdout); \
72 fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \
73 KAI_ABORT(); \
74 } while (0)
75
76 #define KAI_ASSERT_MSG(cond, msg) KAI_ASSERT_ALWAYS_MSG(cond, msg)
77 #endif // NDEBUG
78
79 #define KAI_ASSERT_ALWAYS_MSG(cond, msg) \
80 do { \
81 if (!(cond)) { \
82 KAI_ERROR(msg); \
83 } \
84 } while (0)
85
86 // NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c)
87
88 /// KAI_ASSERT* is used for logic sanity checking in the program
89 /// flow. Checks are optimized away in release builds same as
90 /// `assert`
91 #define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond)
92 #define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg)
93 #define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond)
94
95 /// `KAI_ASSERT_ALWAYS*` is same as `KAI_ASSERT*`, but doesn't get removed by `NDEBUG`
96 #define KAI_ASSERT_ALWAYS(cond) KAI_ASSERT_ALWAYS_MSG(cond, #cond)
97 #define KAI_ASSERT_ALWAYS_IF_MSG(precond, cond, msg) KAI_ASSERT_ALWAYS_MSG(!(precond) || (cond), msg)
98 #define KAI_ASSERT_ALWAYS_IF(precond, cond) KAI_ASSERT_ALWAYS_IF_MSG(precond, cond, #precond " |-> " #cond)
99
100 /// KAI_ASSUME* is used for function pre-condition checking, similar to `[[assume]]` in C++23.
101 /// So KAI_ASSUME should be used directly on the function parameters, rather than inside
102 /// function logic.
103 #define KAI_ASSUME_MSG KAI_ASSERT_MSG
104 #define KAI_ASSUME KAI_ASSERT
105 #define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG
106 #define KAI_ASSUME_IF KAI_ASSERT_IF
107
108 /// `KAI_ASSUME_ALWAYS*` is same as `KAI_ASSUME*`, but doesn't get removed by `NDEBUG`
109 #define KAI_ASSUME_ALWAYS_MSG KAI_ASSERT_ALWAYS_MSG
110 #define KAI_ASSUME_ALWAYS KAI_ASSERT_ALWAYS
111 #define KAI_ASSUME_ALWAYS_IF_MSG KAI_ASSERT_ALWAYS_IF_MSG
112 #define KAI_ASSUME_ALWAYS_IF KAI_ASSERT_ALWAYS_IF
113
114 /// Indicate that result of `x` is unused
115 #define KAI_UNUSED(x) (void)(x)
116
117 /// Return minimum or maximum of `a` and `b`
118 #define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b))
119 #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b))
120
121 /// Largest supported SME vector length in bytes
122 #define KAI_SME_VEC_LENGTH_MAX_BYTES 256 // NOLINT(cppcoreguidelines-macro-to-enum,modernize-macro-to-enum)
123
124 /// Gets the version of the project in the Major.Minor.Patch semantic versioning format.
125 ///
126 /// @return Project version as a string literal.
127 36 inline const char* kai_get_version(void) {
128 36 return "1.20.0";
129 }
130
131 /// KleidiAI data types
132 /// Format: <byte 3>(reserved)|<byte 2>(num-bytes)|<byte 1>(type)|<byte 0>(variant-type)
133 enum kai_datatype {
134 kai_dt_unknown = 0x0000,
135 kai_dt_f32 = 0x0411,
136 kai_dt_f16 = 0x0212,
137 kai_dt_bf16 = 0x0213,
138 kai_dt_int32 = 0x0421,
139 kai_dt_int16 = 0x0222,
140 kai_dt_int8 = 0x0124,
141 kai_dt_uint32 = 0x0431,
142 kai_dt_uint16 = 0x0232,
143 kai_dt_uint8 = 0x0134,
144 kai_dt_bool = 0x0441
145 };
146
147 /// Gets number of bytes for a given data type
148 /// @param[in] dt KleidiAI data type
149 ///
150 /// @return the numbers of bytes for the data type
151 20756 inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) {
152 20756 return (size_t)(dt >> 8);
153 }
154
155 /// Converts a scalar f16 value to f32
156 /// @param[in] f16 The f16 value
157 ///
158 /// @return the f32 value
159 #if defined(__ARM_NEON)
160 inline static float kai_cast_f32_f16(uint16_t f16) {
161 float16_t f32 = 0;
162 memcpy(&f32, &f16, sizeof(uint16_t));
163 return (float)f32;
164 }
165 #endif
166
167 /// Converts a scalar bf16 value to f32
168 /// @param[in] bf16 The f16 value
169 ///
170 /// @return the f32 value
171 5538698370 inline static float kai_cast_f32_bf16(uint16_t bf16) {
172 5538698370 const uint32_t i32 = (bf16 << 16);
173 5538698370 float f32 = 0;
174 5538698370 memcpy(&f32, &i32, sizeof(i32));
175 46273448 return f32;
176 5538698370 }
177
178 /// Converts a f32 value to bf16
179 /// @param[in] f32 The f32 value
180 ///
181 /// @return the bf16 value
182 816 inline static uint16_t kai_cast_bf16_f32(float f32) {
183 816 uint16_t bf16 = 0;
184 #ifdef __ARM_FEATURE_BF16
185 816 __asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32));
186 #else
187 const uint32_t* i32 = (uint32_t*)(&f32);
188 bf16 = (*i32 >> 16);
189 #endif
190 1632 return bf16;
191 816 }
192
193 /// Converts a scalar f32 value to f16
194 /// @param[in] f32 The f32 value
195 ///
196 /// @return the f16 value
197 #if defined(__ARM_NEON)
198 24716 inline static uint16_t kai_cast_f16_f32(float f32) {
199 24716 uint16_t f16 = 0;
200 24716 float16_t tmp = (float16_t)f32;
201 24716 memcpy(&f16, &tmp, sizeof(uint16_t));
202 49432 return f16;
203 24716 }
204 #endif
205
206 42620722 inline static size_t kai_roundup(size_t a, size_t b) {
207 42620722 return ((a + b - 1) / b) * b;
208 }
209
210 #if defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64)
211 /// Gets the SME vector length for 8-bit elements.
212 uint64_t kai_get_sme_vector_length_u8(void);
213
214 /// Gets the SME vector length for 16-bit elements.
215 105528 inline static uint64_t kai_get_sme_vector_length_u16(void) {
216 105528 return kai_get_sme_vector_length_u8() / 2;
217 }
218
219 /// Gets the SME vector length for 32-bit elements.
220 275409 inline static uint64_t kai_get_sme_vector_length_u32(void) {
221 275409 return kai_get_sme_vector_length_u8() / 4;
222 }
223
224 /// Commit ZA to lazy save buffer
225 void kai_commit_za(void);
226 #endif // defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64)
227
228 /// Gets the SVE vector length for 8-bit elements.
229 uint64_t kai_get_sve_vector_length_u8(void);
230
231 /// Gets the SVE vector length for 16-bit elements.
232 inline static uint64_t kai_get_sve_vector_length_u16(void) {
233 return kai_get_sve_vector_length_u8() / 2;
234 }
235
236 /// Gets the SVE vector length for 32-bit elements.
237 1838 inline static uint64_t kai_get_sve_vector_length_u32(void) {
238 1838 return kai_get_sve_vector_length_u8() / 4;
239 }
240
241 /// Extends the sign bit of int 4-bit value (stored in int8_t variable)
242 /// @param[in] value The 4-bit int value
243 ///
244 /// @return the int8_t value with sign extended
245 34499574 inline static int8_t kai_ext_sign_i8_i4(int8_t value) {
246 // Make sure value holds correct int4 value
247 KAI_ASSERT(value <= 0xF);
248
249 34499574 return (value ^ 0x8) - 8; // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
250 }
251
252 /// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 8-bit with per-channel quantization)
253 struct kai_rhs_pack_qsi8cx_params {
254 int32_t lhs_zero_point; ///< LHS Matrix quantization zero-point
255 float scale_multiplier; ///< Product of input (refers to lhs and rhs) and output quantization scales.
256 };
257
258 /// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 4-bit with per-block quantizatio and s1s0
259 /// nibble ordering)
260 struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params {
261 int8_t lhs_zero_point;
262 uint8_t rhs_zero_point;
263 enum kai_datatype scale_dt;
264 };
265
266 /// Parameter struct for RHS matrix packing (KxN variant for int4 qsi4c32p_qsu4c32s1s0)
267 struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params {
268 int8_t lhs_zero_point;
269 uint8_t rhs_zero_point;
270 enum kai_datatype scale_dt;
271 };
272
273 /// Parameter struct for RHS matrix packing
274 struct kai_rhs_pack_qs4cxs1s0_param {
275 int8_t lhs_zero_point; ///< LHS Matrix quantization zero-point
276 uint8_t rhs_zero_point; ///< RHS Matrix quantization zero-point
277 };
278
279 /// Requantization and clamp parameters for GEMM/GEMV output stage.
280 struct kai_matmul_requantize32_params {
281 int32_t min_value; ///< Minimum output value.
282 int32_t max_value; ///< Maximum output value.
283 int32_t output_zero_point; ///< Output quantization zero point.
284 };
285
286 #ifdef __cplusplus
287 }
288 #endif
289