KleidiAI Coverage Report


Directory: ./
File: kai/kai_common.h
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 100.0% 41 1 42
Functions: 100.0% 13 0 13
Branches: -% 0 2 2

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #pragma once
7
8 #if defined(__ARM_NEON)
9 #include <arm_neon.h>
10 #endif // defined(__ARM_NEON)
11
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17
18 #ifdef __cplusplus
19 extern "C" {
20 #endif
21
22 // NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c)
23 //
24 // * cppcoreguidelines-avoid-do-while: do-while is necessary for macros.
25 // * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected.
26 // * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting.
27 #define KAI_ERROR(msg) \
28 do { \
29 fflush(stdout); \
30 fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \
31 abort(); \
32 } while (0)
33
34 #define KAI_ASSERT_MSG(cond, msg) \
35 do { \
36 if (!(cond)) { \
37 KAI_ERROR(msg); \
38 } \
39 } while (0)
40
41 // NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c)
42
43 #define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond)
44
45 #define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg)
46 #define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond)
47
48 #define KAI_ASSUME_MSG KAI_ASSERT_MSG
49 #define KAI_ASSUME KAI_ASSERT
50 #define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG
51 #define KAI_ASSUME_IF KAI_ASSERT_IF
52
53 #define KAI_UNUSED(x) (void)(x)
54 #define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b))
55 #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b))
56
57 /// Largest supported SME vector length in bytes
58 #define KAI_SME_VEC_LENGTH_MAX_BYTES 256 // NOLINT(cppcoreguidelines-macro-to-enum,modernize-macro-to-enum)
59
60 /// Gets the version of the project in the Major.Minor.Patch semantic versioning format.
61 ///
62 /// @return Project version as a string literal.
63 1 inline const char* kai_get_version(void) {
64 1 return "1.15.1";
65 }
66
67 /// KleidiAI data types
68 /// Format: <byte 3>(reserved)|<byte 2>(num-bytes)|<byte 1>(type)|<byte 0>(variant-type)
69 enum kai_datatype {
70 kai_dt_unknown = 0x0000,
71 kai_dt_f32 = 0x0411,
72 kai_dt_f16 = 0x0212,
73 kai_dt_bf16 = 0x0213,
74 kai_dt_int32 = 0x0421,
75 kai_dt_int16 = 0x0222,
76 kai_dt_int8 = 0x0124,
77 kai_dt_uint32 = 0x0431,
78 kai_dt_uint16 = 0x0232,
79 kai_dt_uint8 = 0x0134,
80 kai_dt_bool = 0x0441
81 };
82
83 /// Gets number of bytes for a given data type
84 /// @param[in] dt KleidiAI data type
85 ///
86 /// @return the numbers of bytes for the data type
87 9126 inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) {
88 9126 return (size_t)(dt >> 8);
89 }
90
91 /// Converts a scalar f16 value to f32
92 /// @param[in] f16 The f16 value
93 ///
94 /// @return the f32 value
95 #if defined(__ARM_NEON)
96 inline static float kai_cast_f32_f16(uint16_t f16) {
97 float16_t f32 = 0;
98 memcpy(&f32, &f16, sizeof(uint16_t));
99 return (float)f32;
100 }
101 #endif
102
103 /// Converts a scalar bf16 value to f32
104 /// @param[in] bf16 The f16 value
105 ///
106 /// @return the f32 value
107 1685863882 inline static float kai_cast_f32_bf16(uint16_t bf16) {
108 1685863882 const uint32_t i32 = (bf16 << 16);
109 1685863882 float f32 = 0;
110 1685863882 memcpy(&f32, &i32, sizeof(i32));
111 3371727764 return f32;
112 1685863882 }
113
114 /// Converts a f32 value to bf16
115 /// @param[in] f32 The f32 value
116 ///
117 /// @return the bf16 value
118 128 inline static uint16_t kai_cast_bf16_f32(float f32) {
119 128 uint16_t bf16 = 0;
120 #ifdef __ARM_FEATURE_BF16
121 128 __asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32));
122 #else
123 const uint32_t* i32 = (uint32_t*)(&f32);
124 bf16 = (*i32 >> 16);
125 #endif
126 256 return bf16;
127 128 }
128
129 /// Converts a scalar f32 value to f16
130 /// @param[in] f32 The f32 value
131 ///
132 /// @return the f16 value
133 #if defined(__ARM_NEON)
134 4540 inline static uint16_t kai_cast_f16_f32(float f32) {
135 4540 uint16_t f16 = 0;
136 4540 float16_t tmp = (float16_t)f32;
137 4540 memcpy(&f16, &tmp, sizeof(uint16_t));
138 9080 return f16;
139 4540 }
140 #endif
141
142 10736077 inline static size_t kai_roundup(size_t a, size_t b) {
143 10736077 return ((a + b - 1) / b) * b;
144 }
145
146 #if defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64)
147 /// Gets the SME vector length for 8-bit elements.
148 uint64_t kai_get_sme_vector_length_u8(void);
149
150 /// Gets the SME vector length for 16-bit elements.
151 230550 inline static uint64_t kai_get_sme_vector_length_u16(void) {
152 230550 return kai_get_sme_vector_length_u8() / 2;
153 }
154
155 /// Gets the SME vector length for 32-bit elements.
156 273336 inline static uint64_t kai_get_sme_vector_length_u32(void) {
157 273336 return kai_get_sme_vector_length_u8() / 4;
158 }
159 #endif // defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64)
160
161 /// Extends the sign bit of int 4-bit value (stored in int8_t variable)
162 /// @param[in] value The 4-bit int value
163 ///
164 /// @return the int8_t value with sign extended
165 5788242 inline static int8_t kai_ext_sign_i8_i4(int8_t value) {
166 // Make sure value holds correct int4 value
167 KAI_ASSERT(value <= 0xF);
168
169 5788242 return (value ^ 0x8) - 8; // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
170 }
171
172 /// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 8-bit with per-channel quantization)
173 struct kai_rhs_pack_qsi8cx_params {
174 int32_t lhs_zero_point; ///< LHS Matrix quantization zero-point
175 float scale_multiplier; ///< Product of input (refers to lhs and rhs) and output quantization scales.
176 };
177
178 /// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 4-bit with per-block quantizatio and s1s0
179 /// nibble ordering)
180 struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params {
181 int8_t lhs_zero_point;
182 uint8_t rhs_zero_point;
183 enum kai_datatype scale_dt;
184 };
185
186 /// Parameter struct for RHS matrix packing
187 struct kai_rhs_pack_qs4cxs1s0_param {
188 int8_t lhs_zero_point; ///< LHS Matrix quantization zero-point
189 uint8_t rhs_zero_point; ///< RHS Matrix quantization zero-point
190 };
191
192 /// Requantization and clamp parameters for GEMM/GEMV output stage.
193 struct kai_matmul_requantize32_params {
194 int32_t min_value; ///< Minimum output value.
195 int32_t max_value; ///< Maximum output value.
196 int32_t output_zero_point; ///< Output quantization zero point.
197 };
198
199 #ifdef __cplusplus
200 }
201 #endif
202