Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #pragma once | ||
8 | |||
9 | #include <cstddef> | ||
10 | #include <cstring> | ||
11 | #include <type_traits> | ||
12 | |||
13 | #include "kai/kai_common.h" | ||
14 | #include "test/common/bfloat16.hpp" | ||
15 | #include "test/common/data_type.hpp" | ||
16 | #include "test/common/int4.hpp" | ||
17 | |||
18 | namespace kai::test { | ||
19 | |||
20 | /// The size in bits of type `T`. | ||
21 | template <typename T> | ||
22 | inline constexpr size_t size_in_bits = sizeof(T) * 8; | ||
23 | |||
24 | /// The size in bits of type `T`. | ||
25 | template <> | ||
26 | inline constexpr size_t size_in_bits<UInt4> = 4; | ||
27 | |||
28 | /// The size in bits of type `T`. | ||
29 | template <> | ||
30 | inline constexpr size_t size_in_bits<Int4> = 4; | ||
31 | |||
32 | /// Reads the array at the specified index. | ||
33 | /// | ||
34 | /// @param[in] array Data buffer. | ||
35 | /// @param[in] index Array index. | ||
36 | /// | ||
37 | /// @return The array value at the specified index. | ||
38 | template <typename T> | ||
39 | 1145064238 | T read_array(const void* array, size_t index) { | |
40 | if constexpr (std::is_same_v<T, UInt4>) { | ||
41 | 28125312 | const auto [lo, hi] = UInt4::unpack_u8(reinterpret_cast<const uint8_t*>(array)[index / 2]); | |
42 |
4/6int kai::test::read_array<int>(void const*, unsigned long):
✓ Branch 0 taken 13906208 times.
✓ Branch 1 taken 13906528 times.
kai::test::Int4 kai::test::read_array<kai::test::Int4>(void const*, unsigned long):
✗ Branch 0 not taken.
✗ Branch 1 not taken.
kai::test::UInt4 kai::test::read_array<kai::test::UInt4>(void const*, unsigned long):
✓ Branch 0 taken 156288 times.
✓ Branch 1 taken 156288 times.
|
28125312 | return index % 2 == 0 ? lo : hi; |
43 | 28125312 | } else if constexpr (std::is_same_v<T, Int4>) { | |
44 | 2119608532 | const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast<const uint8_t*>(array)[index / 2]); | |
45 |
8/8float kai::test::read_array<float>(void const*, unsigned long):
✓ Branch 0 taken 5972876 times.
✓ Branch 1 taken 5973192 times.
int kai::test::read_array<int>(void const*, unsigned long):
✓ Branch 0 taken 2883196 times.
✓ Branch 1 taken 2883752 times.
kai::test::Float16 kai::test::read_array<kai::test::Float16>(void const*, unsigned long):
✓ Branch 0 taken 1026087420 times.
✓ Branch 1 taken 1026097928 times.
kai::test::Int4 kai::test::read_array<kai::test::Int4>(void const*, unsigned long):
✓ Branch 0 taken 24854996 times.
✓ Branch 1 taken 24855172 times.
|
2119608532 | return index % 2 == 0 ? lo : hi; |
46 | 2119608532 | } else if constexpr (std::is_same_v<T, BFloat16<false>>) { | |
47 | 1394600792 | uint16_t raw_value = reinterpret_cast<const uint16_t*>(array)[index]; | |
48 | 1394600792 | return BFloat16<false>(kai_cast_f32_bf16(raw_value)); | |
49 | 1394600792 | } else if constexpr (std::is_same_v<T, BFloat16<true>>) { | |
50 | 290269930 | uint16_t raw_value = reinterpret_cast<const uint16_t*>(array)[index]; | |
51 | 290269930 | return BFloat16<true>(kai_cast_f32_bf16(raw_value)); | |
52 | 290269930 | } else { | |
53 | 1043555706 | return reinterpret_cast<const T*>(array)[index]; | |
54 | } | ||
55 | 3832604566 | } | |
56 | |||
57 | /// Reads the array at the specified index | ||
58 | /// | ||
59 | /// @param[in] type Array element data type | ||
60 | /// @param[in] array Data buffer. | ||
61 | /// @param[in] index Array index. | ||
62 | /// | ||
63 | /// @return Value at specified index | ||
64 | double read_array(DataType type, const void* array, size_t index); | ||
65 | |||
66 | /// Writes the specified value to the array. | ||
67 | /// | ||
68 | /// @param[in] array Data buffer. | ||
69 | /// @param[in] index Array index. | ||
70 | /// @param[in] value Value to be stored. | ||
71 | template <typename T> | ||
72 | 558650255 | void write_array(void* array, size_t index, T value) { | |
73 | if constexpr (std::is_same_v<T, UInt4>) { | ||
74 | 77835480 | auto* arr_value = reinterpret_cast<uint8_t*>(array) + index / 2; | |
75 | 77835480 | const auto [lo, hi] = UInt4::unpack_u8(*arr_value); | |
76 | |||
77 |
6/8void kai::test::write_array<int>(void*, unsigned long, int):
✓ Branch 0 taken 13865936 times.
✓ Branch 1 taken 13946800 times.
void kai::test::write_array<kai::test::Int4>(void*, unsigned long, kai::test::Int4):
✗ Branch 0 not taken.
✗ Branch 1 not taken.
void kai::test::write_array<kai::test::UInt4>(void*, unsigned long, kai::test::UInt4):
✓ Branch 0 taken 24854996 times.
✓ Branch 1 taken 24855172 times.
void kai::test::write_array<std::nullptr_t>(void*, unsigned long, std::nullptr_t):
✓ Branch 0 taken 156288 times.
✓ Branch 1 taken 156288 times.
|
77835480 | if (index % 2 == 0) { |
78 | 77916520 | *arr_value = UInt4::pack_u8(value, hi); | |
79 | 38958260 | } else { | |
80 | 77754440 | *arr_value = UInt4::pack_u8(lo, value); | |
81 | } | ||
82 | 77835480 | } else if constexpr (std::is_same_v<T, Int4>) { | |
83 | 65000348 | auto* arr_value = reinterpret_cast<uint8_t*>(array) + index / 2; | |
84 | 65000348 | const auto [lo, hi] = Int4::unpack_u8(*arr_value); | |
85 | |||
86 |
6/8void kai::test::write_array<float>(void*, unsigned long, float):
✓ Branch 0 taken 5931588 times.
✓ Branch 1 taken 6014480 times.
void kai::test::write_array<int>(void*, unsigned long, int):
✓ Branch 0 taken 2864904 times.
✓ Branch 1 taken 2902044 times.
void kai::test::write_array<kai::test::Int4>(void*, unsigned long, kai::test::Int4):
✗ Branch 0 not taken.
✗ Branch 1 not taken.
void kai::test::write_array<signed char>(void*, unsigned long, signed char):
✓ Branch 0 taken 23643228 times.
✓ Branch 1 taken 23644104 times.
|
65000348 | if (index % 2 == 0) { |
87 | 65121256 | *arr_value = Int4::pack_u8(value, hi); | |
88 | 32560628 | } else { | |
89 | 64879440 | *arr_value = Int4::pack_u8(lo, value); | |
90 | } | ||
91 | 65000348 | } else { | |
92 | 415814427 | reinterpret_cast<T*>(array)[index] = value; | |
93 | } | ||
94 | 558650255 | } | |
95 | |||
96 | /// Writes the specified value to the array. | ||
97 | /// | ||
98 | /// @param[in] type Array element type. | ||
99 | /// @param[in] array Data buffer. | ||
100 | /// @param[in] index Array index. | ||
101 | /// @param[in] value Value to be stored. | ||
102 | void write_array(DataType type, void* array, size_t index, double value); | ||
103 | |||
104 | } // namespace kai::test | ||
105 |