Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "buffer.hpp" | ||
8 | |||
9 | #if defined(__linux__) || defined(__APPLE__) | ||
10 | #include <sys/mman.h> | ||
11 | #include <unistd.h> | ||
12 | #endif // defined(__linux__) || defined(__APPLE__) | ||
13 | |||
14 | #include <algorithm> | ||
15 | #include <cstddef> | ||
16 | #include <functional> | ||
17 | #include <sstream> | ||
18 | #include <string> | ||
19 | |||
20 | #include "kai/kai_common.h" | ||
21 | |||
22 | namespace kai::test { | ||
23 | |||
24 | 321254 | Buffer::Buffer(const size_t size) : Buffer(size, 0) { | |
25 | 321254 | } | |
26 | |||
27 | 733698 | Buffer::Buffer(const size_t size, const uint8_t init_value = 0) : m_user_buffer_size(size) { | |
28 | KAI_ASSUME_MSG(size > 0, "Buffers must be of non-zero size"); | ||
29 | |||
30 | const char* val = getenv(buffer_policy_env_name); | ||
31 | const std::string buffer_policy = (val != nullptr) ? std::string(val) : std::string("NONE"); | ||
32 | |||
33 | std::ostringstream oss; | ||
34 | |||
35 | if (buffer_policy == "PROTECT_UNDERFLOW" || buffer_policy == "PROTECT_OVERFLOW") { | ||
36 | #if defined(__linux__) || defined(__APPLE__) | ||
37 | m_protection_policy = (buffer_policy == "PROTECT_UNDERFLOW") ? BufferProtectionPolicy::ProtectUnderflow | ||
38 | : BufferProtectionPolicy::ProtectOverflow; | ||
39 | #else // defined(__linux__) || defined(__APPLE__) | ||
40 | oss << buffer_policy << " buffer protection policy is not supported on target platform"; | ||
41 | #endif // defined(__linux__) || defined(__APPLE__) | ||
42 | } else if (buffer_policy == "NONE") { | ||
43 | m_protection_policy = BufferProtectionPolicy::None; | ||
44 | } else { | ||
45 | oss << "Unrecognized buffer protection policy provided by " << buffer_policy_env_name << ": "; | ||
46 | oss << buffer_policy; | ||
47 | } | ||
48 | |||
49 | if (!oss.str().empty()) { | ||
50 | KAI_ERROR(oss.str().c_str()); | ||
51 | } | ||
52 | |||
53 | switch (m_protection_policy) { | ||
54 | #if defined(__linux__) || defined(__APPLE__) | ||
55 | case BufferProtectionPolicy::ProtectUnderflow: | ||
56 | case BufferProtectionPolicy::ProtectOverflow: | ||
57 | allocate_with_guard_pages(); | ||
58 | break; | ||
59 | #endif // defined(__linux__) || defined(__APPLE__) | ||
60 | default: | ||
61 | allocate(); | ||
62 | } | ||
63 | |||
64 | memset(data(), init_value, size); | ||
65 | 366849 | } | |
66 | |||
67 | 366649 | void Buffer::allocate() { | |
68 | 366649 | m_buffer = handle(std::malloc(m_user_buffer_size), &std::free); | |
69 | − | KAI_ASSUME_MSG(m_buffer.get() != nullptr, "Failure allocating memory"); | |
70 | − | KAI_ASSUME_MSG(m_user_buffer_offset == 0, "Buffer offset must be zero for naive allocation"); | |
71 | 366649 | } | |
72 | |||
73 | #if defined(__linux__) || defined(__APPLE__) | ||
74 | 200 | void Buffer::allocate_with_guard_pages() { | |
75 | 200 | const auto sc_pagesize_res = sysconf(_SC_PAGESIZE); | |
76 | − | KAI_ASSUME_MSG(sc_pagesize_res != -1, "Error finding page size"); | |
77 | |||
78 | 200 | const auto page_size = static_cast<size_t>(sc_pagesize_res); | |
79 | |||
80 | // Offset the user buffer by the size of the first guard page | ||
81 | 200 | m_user_buffer_offset = page_size; | |
82 | |||
83 | // The user buffer is rounded to the size of the nearest whole page. | ||
84 | // This forms the valid region between the two guard pages | ||
85 | 200 | const size_t valid_region_size = kai_roundup(m_user_buffer_size, page_size); | |
86 | 200 | const size_t protected_region_size = 2 * page_size; | |
87 | 200 | const size_t total_memory_size = valid_region_size + protected_region_size; | |
88 | |||
89 |
2/2✓ Branch 0 taken 100 times.
✓ Branch 1 taken 100 times.
|
200 | if (m_protection_policy == BufferProtectionPolicy::ProtectOverflow) { |
90 | // To detect overflows we offset the user buffer so that edge of the buffer is aligned to the start of the | ||
91 | // higher guard page thus detecting whenever a buffer overflow occurs. | ||
92 | 100 | m_user_buffer_offset += valid_region_size - m_user_buffer_size; | |
93 | 100 | } | |
94 | |||
95 | 400 | auto mmap_deleter = [total_memory_size](void* ptr) { | |
96 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | if (munmap(ptr, total_memory_size) != 0) { |
97 | − | KAI_ERROR("Failure deleting memory mappings"); | |
98 | ✗ | } | |
99 | 200 | }; | |
100 | |||
101 | 200 | m_buffer = | |
102 | 200 | handle(mmap(nullptr, total_memory_size, PROT_READ | PROT_WRITE, MAP_ANON | MAP_PRIVATE, -1, 0), mmap_deleter); | |
103 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | if (m_buffer.get() == MAP_FAILED) { |
104 | − | KAI_ERROR("Failure mapping memory"); | |
105 | ✗ | } | |
106 | |||
107 | 200 | void* head_guard_page = m_buffer.get(); | |
108 | 200 | void* tail_guard_page = static_cast<std::byte*>(m_buffer.get()) + (total_memory_size - page_size); | |
109 | |||
110 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | if (mprotect(head_guard_page, std::max(static_cast<size_t>(0), page_size), PROT_NONE) != 0) { |
111 | − | KAI_ERROR("Failure protecting page immediately preceding buffer"); | |
112 | ✗ | } | |
113 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | if (mprotect(tail_guard_page, std::max(static_cast<size_t>(0), page_size), PROT_NONE) != 0) { |
114 | − | KAI_ERROR("Failure protecting page immediately following buffer"); | |
115 | ✗ | } | |
116 | 200 | } | |
117 | #endif // defined(__linux__) || defined(__APPLE__) | ||
118 | |||
119 | } // namespace kai::test | ||
120 |