KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 50 / 0 / 50
Functions: 97.4% 37 / 0 / 38
Branches: 36.4% 16 / 0 / 44

test/common/memory.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #pragma once
8
9 #include <cstddef>
10 #include <cstring>
11 #include <type_traits>
12
13 #include "kai/kai_common.h"
14 #include "test/common/assert.hpp"
15 #include "test/common/bfloat16.hpp"
16 #include "test/common/data_type.hpp"
17 #include "test/common/int4.hpp"
18 #include "test/common/round.hpp"
19 #include "test/common/span.hpp"
20
21 namespace kai::test {
22
23 /// The size in bits of type `T`.
24 template <typename T>
25 inline constexpr size_t size_in_bits = sizeof(T) * 8;
26
27 /// The size in bits of type `T`.
28 template <>
29 inline constexpr size_t size_in_bits<UInt4> = 4;
30
31 /// The size in bits of type `T`.
32 template <>
33 inline constexpr size_t size_in_bits<Int4> = 4;
34
35 /// Reads the array at the specified index.
36 ///
37 /// @param[in] array Data buffer.
38 /// @param[in] index Array index.
39 ///
40 /// @return The array value at the specified index.
41 template <typename T>
42 2202669645 T read_array(const void* array, size_t index) {
43 if constexpr (std::is_same_v<T, UInt4>) {
44 69119703 const auto [lo, hi] = UInt4::unpack_u8(reinterpret_cast<const uint8_t*>(array)[index / 2]);
45
2/2
✓ Branch 0 taken 34547130 times.
✓ Branch 1 taken 34572573 times.
69119703 return index % 2 == 0 ? lo : hi;
46 69119703 } else if constexpr (std::is_same_v<T, Int4>) {
47 7174989285 const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast<const uint8_t*>(array)[index / 2]);
48
2/2
✓ Branch 0 taken 3587471950 times.
✓ Branch 1 taken 3587517335 times.
7174989285 return index % 2 == 0 ? lo : hi;
49 7174989285 } else if constexpr (std::is_same_v<T, BFloat16<false>>) {
50 3969618908 uint16_t raw_value = reinterpret_cast<const uint16_t*>(array)[index];
51 3969618908 return BFloat16<false>(kai_cast_f32_bf16(raw_value));
52 3969618908 } else if constexpr (std::is_same_v<T, BFloat16<true>>) {
53 1563441534 uint16_t raw_value = reinterpret_cast<const uint16_t*>(array)[index];
54 1563441534 return BFloat16<true>(kai_cast_f32_bf16(raw_value));
55 1563441534 } else {
56 1816196033 return reinterpret_cast<const T*>(array)[index];
57 }
58 386473612 }
59
60 /// Reads the array at the specified index.
61 ///
62 /// @param[in] array Data buffer.
63 /// @param[in] index Array index.
64 ///
65 /// @return The array value at the specified index.
66 template <typename T>
67 246617207 T read_array(Span<const std::byte> array, size_t index) {
68 246617207 const size_t min_size = round_up_division((index + 1) * size_in_bits<T>, 8);
69
4/20
✓ Branch 0 taken 220430629 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4015676 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 9921401 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 12249501 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
246617207 KAI_TEST_ASSERT_MSG(array.size() >= min_size, "The read access is out-of-bound!");
70 477168250 return read_array<T>(array.data(), index);
71 246617207 }
72
73 /// Reads the 2D array at the specified coordinates.
74 ///
75 /// @param[in] data The data buffer.
76 /// @param[in] width The array width.
77 /// @param[in] row The row index.
78 /// @param[in] col The column index.
79 ///
80 /// @return The array value at the specified coordinates.
81 template <typename T>
82 215156436 T read_2d(Span<const std::byte> data, size_t width, size_t row, size_t col) {
83 215156436 const size_t stride = round_up_division(width * size_in_bits<T>, 8);
84 429040651 return read_array<T>(data.subspan(row * stride, stride), col);
85 215156436 }
86
87 /// Reads the array at the specified index
88 ///
89 /// @param[in] type Array element data type
90 /// @param[in] array Data buffer.
91 /// @param[in] index Array index.
92 ///
93 /// @return Value at specified index
94 double read_array(DataType type, const void* array, size_t index);
95
96 /// Writes the specified value to the array.
97 ///
98 /// @param[in] array Data buffer.
99 /// @param[in] index Array index.
100 /// @param[in] value Value to be stored.
101 template <typename T>
102 1382361094 void write_array(void* array, size_t index, T value) {
103 if constexpr (std::is_same_v<T, UInt4>) {
104 276255178 auto* arr_value = reinterpret_cast<uint8_t*>(array) + index / 2;
105 276255178 const auto [lo, hi] = UInt4::unpack_u8(*arr_value);
106
107
2/2
✓ Branch 0 taken 138037548 times.
✓ Branch 1 taken 138217630 times.
276255178 if (index % 2 == 0) {
108 276435260 *arr_value = UInt4::pack_u8(value, hi);
109 138217630 } else {
110 218801460 *arr_value = UInt4::pack_u8(lo, value);
111 }
112 276255178 } else if constexpr (std::is_same_v<T, Int4>) {
113 252637969 auto* arr_value = reinterpret_cast<uint8_t*>(array) + index / 2;
114 252637969 const auto [lo, hi] = Int4::unpack_u8(*arr_value);
115
116
2/2
✓ Branch 0 taken 126142010 times.
✓ Branch 1 taken 126495959 times.
252637969 if (index % 2 == 0) {
117 252991918 *arr_value = Int4::pack_u8(value, hi);
118 126495959 } else {
119 200913354 *arr_value = Int4::pack_u8(lo, value);
120 }
121 252637969 } else {
122 853467947 reinterpret_cast<T*>(array)[index] = value;
123 }
124 1382361094 }
125
126 /// Writes the specified value to the array.
127 ///
128 /// @param[in] array Data buffer.
129 /// @param[in] index Array index.
130 /// @param[in] value Value to be stored.
131 template <typename T>
132 12624192 void write_array(Span<std::byte> array, size_t index, T value) {
133 12624192 const size_t min_size = round_up_division((index + 1) * size_in_bits<T>, 8);
134
4/16
✓ Branch 0 taken 7517873 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2559894 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 1274204 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1272221 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
12624192 KAI_TEST_ASSERT_MSG(array.size() >= min_size, "The write access is out-of-bound!");
135 12624192 write_array<T>(array.data(), index, value);
136 12624192 }
137
138 /// Writes the specified value to the 2D array at the specified coordinates.
139 ///
140 /// @param[out] data The data buffer.
141 /// @param[in] width The array width.
142 /// @param[in] row The row index.
143 /// @param[in] col The column index.
144 /// @param[in] value The value to be stored.
145 template <typename T>
146 7586756 void write_2d(Span<std::byte> data, size_t width, size_t row, size_t col, T value) {
147 7586756 const size_t stride = round_up_division(width * size_in_bits<T>, 8);
148 7586756 write_array<T>(data.subspan(row * stride, stride), col, value);
149 7586756 }
150
151 /// Writes the specified value to the array.
152 ///
153 /// @param[in] type Array element type.
154 /// @param[in] array Data buffer.
155 /// @param[in] index Array index.
156 /// @param[in] value Value to be stored.
157 void write_array(DataType type, void* array, size_t index, double value);
158
159 } // namespace kai::test
160