test/common/abi_checker.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <cstdint> | ||
| 10 | #include <functional> | ||
| 11 | #include <optional> | ||
| 12 | #include <type_traits> | ||
| 13 | #include <utility> | ||
| 14 | |||
| 15 | #include "kai/kai_common.h" | ||
| 16 | |||
| 17 | namespace kai::test { | ||
| 18 | |||
| 19 | #if defined(__ARM_FEATURE_SME) && !_MSC_VER | ||
| 20 | |||
| 21 | /// Checker for FP ABI compliance | ||
| 22 | template <typename Func, typename... Args> | ||
| 23 | 170062 | inline auto abi_check_fp(Func&& func, Args&&... args) | |
| 24 | -> decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...)) { | ||
| 25 | using ResultType = std::invoke_result_t<Func, Args...>; | ||
| 26 | using StorageType = std::conditional_t<std::is_void_v<ResultType>, int, ResultType>; | ||
| 27 | |||
| 28 | 170062 | std::optional<StorageType> result; | |
| 29 | static constexpr const uint64_t canary = 0xAAAABBBBCCCCDDDDULL; | ||
| 30 | |||
| 31 | /* The block below will attempt to verify that FP registers are preserved | ||
| 32 | * as expected. GP registers are not really easily possible to verify | ||
| 33 | * using this method, as this function itself might change them */ | ||
| 34 | 170062 | __asm__ __volatile__( | |
| 35 | // Fill callee saved registers with canaries | ||
| 36 | "ldr x9, [%x[canary]]\n\t" | ||
| 37 | |||
| 38 | // FP registers | ||
| 39 | "fmov d8, x9\n\t" | ||
| 40 | "fmov d9, x9\n\t" | ||
| 41 | "fmov d10, x9\n\t" | ||
| 42 | "fmov d11, x9\n\t" | ||
| 43 | "fmov d12, x9\n\t" | ||
| 44 | "fmov d13, x9\n\t" | ||
| 45 | "fmov d14, x9\n\t" | ||
| 46 | "fmov d15, x9\n\t" | ||
| 47 | : | ||
| 48 | : [canary] "r"(&canary) | ||
| 49 | : "x9", // Canary storage | ||
| 50 | "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15"); | ||
| 51 | |||
| 52 | if constexpr (std::is_void_v<ResultType>) { | ||
| 53 | 170062 | std::invoke(std::forward<Func>(func), std::forward<Args>(args)...); | |
| 54 | } else { | ||
| 55 | result = std::invoke(std::forward<Func>(func), std::forward<Args>(args)...); | ||
| 56 | } | ||
| 57 | |||
| 58 | 170062 | uint64_t first_mismatch = canary; | |
| 59 | 170062 | __asm__ __volatile__( | |
| 60 | // Check that canary is still present in all callee saved registers | ||
| 61 | "ldr x9, [%x[canary]]\n\t" | ||
| 62 | |||
| 63 | // Check FP registers | ||
| 64 | "11: fmov x10, d8\n\t" | ||
| 65 | "cmp x10, x9\n\t" | ||
| 66 | "b.eq 12f\n\t" | ||
| 67 | "mov %x[first_mismatch], #8\n\t" | ||
| 68 | "b 20f\n\t" | ||
| 69 | |||
| 70 | "12: fmov x10, d9\n\t" | ||
| 71 | "cmp x10, x9\n\t" | ||
| 72 | "b.eq 13f\n\t" | ||
| 73 | "mov %x[first_mismatch], #9\n\t" | ||
| 74 | "b 20f\n\t" | ||
| 75 | |||
| 76 | "13: fmov x10, d10\n\t" | ||
| 77 | "cmp x10, x9\n\t" | ||
| 78 | "b.eq 14f\n\t" | ||
| 79 | "mov %x[first_mismatch], #10\n\t" | ||
| 80 | "b 20f\n\t" | ||
| 81 | |||
| 82 | "14: fmov x10, d11\n\t" | ||
| 83 | "cmp x10, x9\n\t" | ||
| 84 | "b.eq 15f\n\t" | ||
| 85 | "mov %x[first_mismatch], #11\n\t" | ||
| 86 | "b 20f\n\t" | ||
| 87 | |||
| 88 | "15: fmov x10, d12\n\t" | ||
| 89 | "cmp x10, x9\n\t" | ||
| 90 | "b.eq 16f\n\t" | ||
| 91 | "mov %x[first_mismatch], #12\n\t" | ||
| 92 | "b 20f\n\t" | ||
| 93 | |||
| 94 | "16: fmov x10, d13\n\t" | ||
| 95 | "cmp x10, x9\n\t" | ||
| 96 | "b.eq 17f\n\t" | ||
| 97 | "mov %x[first_mismatch], #13\n\t" | ||
| 98 | "b 20f\n\t" | ||
| 99 | |||
| 100 | "17: fmov x10, d14\n\t" | ||
| 101 | "cmp x10, x9\n\t" | ||
| 102 | "b.eq 18f\n\t" | ||
| 103 | "mov %x[first_mismatch], #14\n\t" | ||
| 104 | "b 20f\n\t" | ||
| 105 | |||
| 106 | "18: fmov x10, d15\n\t" | ||
| 107 | "cmp x10, x9\n\t" | ||
| 108 | "b.eq 20f\n\t" | ||
| 109 | "mov %x[first_mismatch], #15\n\t" | ||
| 110 | |||
| 111 | "20:\n\t" | ||
| 112 | : [first_mismatch] "+r"(first_mismatch) | ||
| 113 | : [canary] "r"(&canary) | ||
| 114 | : "cc", "x9", "x10"); | ||
| 115 | |||
| 116 | − | KAI_ASSERT_MSG(first_mismatch == canary, "FP register corruption detected"); | |
| 117 | if constexpr (!std::is_void_v<ResultType>) { | ||
| 118 | return *result; | ||
| 119 | } | ||
| 120 | 170062 | } | |
| 121 | |||
| 122 | /// Checker for SME ABI compliance | ||
| 123 | template <typename Func, typename... Args> | ||
| 124 | 170062 | __arm_new("za") __arm_locally_streaming inline auto abi_check_za(Func&& func, Args&&... args) | |
| 125 | -> decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...)) { | ||
| 126 | using ResultType = std::invoke_result_t<Func, Args...>; | ||
| 127 | using StorageType = std::conditional_t<std::is_void_v<ResultType>, int, ResultType>; | ||
| 128 | |||
| 129 | 170062 | std::optional<StorageType> result; | |
| 130 | static constexpr const uint64_t canary = 0xAAAABBBBCCCCDDDDULL; | ||
| 131 | |||
| 132 | /* This block attempts to check if ZA register is correctly preserved | ||
| 133 | * by filling with a known pattern, and then checking pattern after | ||
| 134 | * returning from function call */ | ||
| 135 | 170062 | __asm__ __volatile__( | |
| 136 | "ldr x9, [%x[canary]]\n\t" | ||
| 137 | |||
| 138 | // Fill ZA with canary pattern | ||
| 139 | "dup z16.d, x9\n\t" // Broadcast canary to vector | ||
| 140 | "rdsvl x9, #1\n\t" // Read number of ZA rows | ||
| 141 | "ptrue p0.b\n\t" // Make p0.b fully enabled | ||
| 142 | "mov w12, wzr\n\t" // Set row index to 0 | ||
| 143 | "1: mova za0h.b[w12, #0], p0/m, z16.b\n\t" // copy vector tor row | ||
| 144 | "add w12, w12, #1\n\t" // Increment row index | ||
| 145 | "cmp x12, x9\n\t" // Repeat until all rows are filled | ||
| 146 | "blt 1b\n\t" | ||
| 147 | : | ||
| 148 | : [canary] "r"(&canary) | ||
| 149 | : "cc", | ||
| 150 | "x9", // Canary storage, and then ZA row count | ||
| 151 | "x12", // ZA row index | ||
| 152 | "z16", // canary vector | ||
| 153 | "p0", // predicate | ||
| 154 | "za"); | ||
| 155 | |||
| 156 | if constexpr (std::is_void_v<ResultType>) { | ||
| 157 | 170062 | abi_check_fp(std::forward<Func>(func), std::forward<Args>(args)...); | |
| 158 | } else { | ||
| 159 | result = abi_check_fp<Func>(std::forward<Func>(func), std::forward<Args>(args)...); | ||
| 160 | } | ||
| 161 | |||
| 162 | 170062 | uint64_t first_mismatch = canary; | |
| 163 | 170062 | __asm__ __volatile__( | |
| 164 | // Check that canary is still present in ZA | ||
| 165 | "ldr x9, [%x[canary]]\n\t" | ||
| 166 | |||
| 167 | "dup z16.d, x9\n\t" // Broadcast canary to vector | ||
| 168 | "rdsvl x9, #1\n\t" // get rows of ZA | ||
| 169 | "ptrue p0.b\n\t" // Make p0.b fully enabled | ||
| 170 | "mov w12, wzr\n\t" // Clear w12 | ||
| 171 | "20: mova z17.b, p0/m, za0h.b[w12, #0]\n\t" // Read row w12 from ZA | ||
| 172 | "cmpne p1.b, p0/z, z16.b, z17.b\n\t" // p1 = true for any mismatch | ||
| 173 | "cntp x10, p0, p1.b\n\t" // x10 = number of mismatches | ||
| 174 | "cmp x10, xzr\n\t" // if (mismatches == 0) | ||
| 175 | "b.eq 21f\n\t" // proceed | ||
| 176 | "mov %x[first_mismatch], x12\n\t" // else, store mismatching row | ||
| 177 | "b 30f\n\t" // and leave checker | ||
| 178 | "21: add w12, w12, #1\n\t" // w12 += 1 | ||
| 179 | "cmp x12, x9\n\t" // if (w12 < SVL_b) | ||
| 180 | "blt 20b\n\t" // check next row | ||
| 181 | "30:\n\t" | ||
| 182 | : [first_mismatch] "+r"(first_mismatch) | ||
| 183 | : [canary] "r"(&canary) | ||
| 184 | : "cc", | ||
| 185 | "x9", // Canary storage, then row ZA row count | ||
| 186 | "x10", // Row mismatch counter | ||
| 187 | "x12", // Row index | ||
| 188 | "z16", // Canary vector | ||
| 189 | "z17", // Current ZA row | ||
| 190 | "p0"); | ||
| 191 | − | KAI_ASSERT_MSG(first_mismatch == canary, "ZA register corruption detected"); | |
| 192 | |||
| 193 | if constexpr (!std::is_void_v<ResultType>) { | ||
| 194 | return *result; | ||
| 195 | } | ||
| 196 | 170062 | } | |
| 197 | |||
| 198 | /// Wrapper for checking ABI compliance | ||
| 199 | template <typename Func, typename... Args> | ||
| 200 | 170062 | inline auto abi_check(Func&& func, Args&&... args) | |
| 201 | -> decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...)) { | ||
| 202 | using ResultType = std::invoke_result_t<Func, Args...>; | ||
| 203 | using StorageType = std::conditional_t<std::is_void_v<ResultType>, int, ResultType>; | ||
| 204 | |||
| 205 | 170062 | std::optional<StorageType> result; | |
| 206 | |||
| 207 | if constexpr (std::is_void_v<ResultType>) { | ||
| 208 | 170062 | abi_check_za<Func>(std::forward<Func>(func), std::forward<Args>(args)...); | |
| 209 | } else { | ||
| 210 | result = abi_check_za<Func>(std::forward<Func>(func), std::forward<Args>(args)...); | ||
| 211 | } | ||
| 212 | |||
| 213 | if constexpr (!std::is_void_v<ResultType>) { | ||
| 214 | return *result; | ||
| 215 | } | ||
| 216 | 170062 | } | |
| 217 | |||
| 218 | #else | ||
| 219 | |||
| 220 | /// Call wrapped function, without any checking | ||
| 221 | template <typename Func, typename... Args> | ||
| 222 | 70576 | inline auto abi_check(Func&& func, Args&&... args) | |
| 223 | -> decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...)) { | ||
| 224 | using ResultType = std::invoke_result_t<Func, Args...>; | ||
| 225 | using StorageType = std::conditional_t<std::is_void_v<ResultType>, int, ResultType>; | ||
| 226 | |||
| 227 | 70576 | std::optional<StorageType> result; | |
| 228 | |||
| 229 | if constexpr (std::is_void_v<ResultType>) { | ||
| 230 | 70576 | std::invoke(std::forward<Func>(func), std::forward<Args>(args)...); | |
| 231 | } else { | ||
| 232 | result = std::invoke(std::forward<Func>(func), std::forward<Args>(args)...); | ||
| 233 | } | ||
| 234 | |||
| 235 | if constexpr (!std::is_void_v<ResultType>) { | ||
| 236 | return *result; | ||
| 237 | } | ||
| 238 | 70576 | } | |
| 239 | |||
| 240 | #endif | ||
| 241 | |||
| 242 | } // namespace kai::test | ||
| 243 |