KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 22 / 2 / 24
Functions: 89.6% 129 / 0 / 144
Branches: -% 0 / 28 / 28

test/common/abi_checker.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #pragma once
8
9 #include <cstdint>
10 #include <functional>
11 #include <optional>
12 #include <type_traits>
13 #include <utility>
14
15 #include "kai/kai_common.h"
16
17 namespace kai::test {
18
19 #if defined(__ARM_FEATURE_SME) && !_MSC_VER
20
21 /// Checker for FP ABI compliance
22 template <typename Func, typename... Args>
23 170062 inline auto abi_check_fp(Func&& func, Args&&... args)
24 -> decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...)) {
25 using ResultType = std::invoke_result_t<Func, Args...>;
26 using StorageType = std::conditional_t<std::is_void_v<ResultType>, int, ResultType>;
27
28 170062 std::optional<StorageType> result;
29 static constexpr const uint64_t canary = 0xAAAABBBBCCCCDDDDULL;
30
31 /* The block below will attempt to verify that FP registers are preserved
32 * as expected. GP registers are not really easily possible to verify
33 * using this method, as this function itself might change them */
34 170062 __asm__ __volatile__(
35 // Fill callee saved registers with canaries
36 "ldr x9, [%x[canary]]\n\t"
37
38 // FP registers
39 "fmov d8, x9\n\t"
40 "fmov d9, x9\n\t"
41 "fmov d10, x9\n\t"
42 "fmov d11, x9\n\t"
43 "fmov d12, x9\n\t"
44 "fmov d13, x9\n\t"
45 "fmov d14, x9\n\t"
46 "fmov d15, x9\n\t"
47 :
48 : [canary] "r"(&canary)
49 : "x9", // Canary storage
50 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15");
51
52 if constexpr (std::is_void_v<ResultType>) {
53 170062 std::invoke(std::forward<Func>(func), std::forward<Args>(args)...);
54 } else {
55 result = std::invoke(std::forward<Func>(func), std::forward<Args>(args)...);
56 }
57
58 170062 uint64_t first_mismatch = canary;
59 170062 __asm__ __volatile__(
60 // Check that canary is still present in all callee saved registers
61 "ldr x9, [%x[canary]]\n\t"
62
63 // Check FP registers
64 "11: fmov x10, d8\n\t"
65 "cmp x10, x9\n\t"
66 "b.eq 12f\n\t"
67 "mov %x[first_mismatch], #8\n\t"
68 "b 20f\n\t"
69
70 "12: fmov x10, d9\n\t"
71 "cmp x10, x9\n\t"
72 "b.eq 13f\n\t"
73 "mov %x[first_mismatch], #9\n\t"
74 "b 20f\n\t"
75
76 "13: fmov x10, d10\n\t"
77 "cmp x10, x9\n\t"
78 "b.eq 14f\n\t"
79 "mov %x[first_mismatch], #10\n\t"
80 "b 20f\n\t"
81
82 "14: fmov x10, d11\n\t"
83 "cmp x10, x9\n\t"
84 "b.eq 15f\n\t"
85 "mov %x[first_mismatch], #11\n\t"
86 "b 20f\n\t"
87
88 "15: fmov x10, d12\n\t"
89 "cmp x10, x9\n\t"
90 "b.eq 16f\n\t"
91 "mov %x[first_mismatch], #12\n\t"
92 "b 20f\n\t"
93
94 "16: fmov x10, d13\n\t"
95 "cmp x10, x9\n\t"
96 "b.eq 17f\n\t"
97 "mov %x[first_mismatch], #13\n\t"
98 "b 20f\n\t"
99
100 "17: fmov x10, d14\n\t"
101 "cmp x10, x9\n\t"
102 "b.eq 18f\n\t"
103 "mov %x[first_mismatch], #14\n\t"
104 "b 20f\n\t"
105
106 "18: fmov x10, d15\n\t"
107 "cmp x10, x9\n\t"
108 "b.eq 20f\n\t"
109 "mov %x[first_mismatch], #15\n\t"
110
111 "20:\n\t"
112 : [first_mismatch] "+r"(first_mismatch)
113 : [canary] "r"(&canary)
114 : "cc", "x9", "x10");
115
116 KAI_ASSERT_MSG(first_mismatch == canary, "FP register corruption detected");
117 if constexpr (!std::is_void_v<ResultType>) {
118 return *result;
119 }
120 170062 }
121
122 /// Checker for SME ABI compliance
123 template <typename Func, typename... Args>
124 170062 __arm_new("za") __arm_locally_streaming inline auto abi_check_za(Func&& func, Args&&... args)
125 -> decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...)) {
126 using ResultType = std::invoke_result_t<Func, Args...>;
127 using StorageType = std::conditional_t<std::is_void_v<ResultType>, int, ResultType>;
128
129 170062 std::optional<StorageType> result;
130 static constexpr const uint64_t canary = 0xAAAABBBBCCCCDDDDULL;
131
132 /* This block attempts to check if ZA register is correctly preserved
133 * by filling with a known pattern, and then checking pattern after
134 * returning from function call */
135 170062 __asm__ __volatile__(
136 "ldr x9, [%x[canary]]\n\t"
137
138 // Fill ZA with canary pattern
139 "dup z16.d, x9\n\t" // Broadcast canary to vector
140 "rdsvl x9, #1\n\t" // Read number of ZA rows
141 "ptrue p0.b\n\t" // Make p0.b fully enabled
142 "mov w12, wzr\n\t" // Set row index to 0
143 "1: mova za0h.b[w12, #0], p0/m, z16.b\n\t" // copy vector tor row
144 "add w12, w12, #1\n\t" // Increment row index
145 "cmp x12, x9\n\t" // Repeat until all rows are filled
146 "blt 1b\n\t"
147 :
148 : [canary] "r"(&canary)
149 : "cc",
150 "x9", // Canary storage, and then ZA row count
151 "x12", // ZA row index
152 "z16", // canary vector
153 "p0", // predicate
154 "za");
155
156 if constexpr (std::is_void_v<ResultType>) {
157 170062 abi_check_fp(std::forward<Func>(func), std::forward<Args>(args)...);
158 } else {
159 result = abi_check_fp<Func>(std::forward<Func>(func), std::forward<Args>(args)...);
160 }
161
162 170062 uint64_t first_mismatch = canary;
163 170062 __asm__ __volatile__(
164 // Check that canary is still present in ZA
165 "ldr x9, [%x[canary]]\n\t"
166
167 "dup z16.d, x9\n\t" // Broadcast canary to vector
168 "rdsvl x9, #1\n\t" // get rows of ZA
169 "ptrue p0.b\n\t" // Make p0.b fully enabled
170 "mov w12, wzr\n\t" // Clear w12
171 "20: mova z17.b, p0/m, za0h.b[w12, #0]\n\t" // Read row w12 from ZA
172 "cmpne p1.b, p0/z, z16.b, z17.b\n\t" // p1 = true for any mismatch
173 "cntp x10, p0, p1.b\n\t" // x10 = number of mismatches
174 "cmp x10, xzr\n\t" // if (mismatches == 0)
175 "b.eq 21f\n\t" // proceed
176 "mov %x[first_mismatch], x12\n\t" // else, store mismatching row
177 "b 30f\n\t" // and leave checker
178 "21: add w12, w12, #1\n\t" // w12 += 1
179 "cmp x12, x9\n\t" // if (w12 < SVL_b)
180 "blt 20b\n\t" // check next row
181 "30:\n\t"
182 : [first_mismatch] "+r"(first_mismatch)
183 : [canary] "r"(&canary)
184 : "cc",
185 "x9", // Canary storage, then row ZA row count
186 "x10", // Row mismatch counter
187 "x12", // Row index
188 "z16", // Canary vector
189 "z17", // Current ZA row
190 "p0");
191 KAI_ASSERT_MSG(first_mismatch == canary, "ZA register corruption detected");
192
193 if constexpr (!std::is_void_v<ResultType>) {
194 return *result;
195 }
196 170062 }
197
198 /// Wrapper for checking ABI compliance
199 template <typename Func, typename... Args>
200 170062 inline auto abi_check(Func&& func, Args&&... args)
201 -> decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...)) {
202 using ResultType = std::invoke_result_t<Func, Args...>;
203 using StorageType = std::conditional_t<std::is_void_v<ResultType>, int, ResultType>;
204
205 170062 std::optional<StorageType> result;
206
207 if constexpr (std::is_void_v<ResultType>) {
208 170062 abi_check_za<Func>(std::forward<Func>(func), std::forward<Args>(args)...);
209 } else {
210 result = abi_check_za<Func>(std::forward<Func>(func), std::forward<Args>(args)...);
211 }
212
213 if constexpr (!std::is_void_v<ResultType>) {
214 return *result;
215 }
216 170062 }
217
218 #else
219
220 /// Call wrapped function, without any checking
221 template <typename Func, typename... Args>
222 70576 inline auto abi_check(Func&& func, Args&&... args)
223 -> decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...)) {
224 using ResultType = std::invoke_result_t<Func, Args...>;
225 using StorageType = std::conditional_t<std::is_void_v<ResultType>, int, ResultType>;
226
227 70576 std::optional<StorageType> result;
228
229 if constexpr (std::is_void_v<ResultType>) {
230 70576 std::invoke(std::forward<Func>(func), std::forward<Args>(args)...);
231 } else {
232 result = std::invoke(std::forward<Func>(func), std::forward<Args>(args)...);
233 }
234
235 if constexpr (!std::is_void_v<ResultType>) {
236 return *result;
237 }
238 70576 }
239
240 #endif
241
242 } // namespace kai::test
243