Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/fill.hpp" | ||
8 | |||
9 | #include <cstddef> | ||
10 | #include <cstdint> | ||
11 | #include <functional> | ||
12 | #include <random> | ||
13 | #include <type_traits> | ||
14 | |||
15 | #include "kai/kai_common.h" | ||
16 | #include "test/common/bfloat16.hpp" | ||
17 | #include "test/common/buffer.hpp" | ||
18 | #include "test/common/data_format.hpp" | ||
19 | #include "test/common/data_type.hpp" | ||
20 | #include "test/common/float16.hpp" | ||
21 | #include "test/common/int4.hpp" | ||
22 | #include "test/common/memory.hpp" | ||
23 | |||
24 | namespace kai::test { | ||
25 | |||
26 | namespace { | ||
27 | |||
28 | template <typename T> | ||
29 | 35739 | Buffer fill_matrix_random_raw(size_t height, size_t width, uint32_t seed) { | |
30 | using TDist = std::conditional_t< | ||
31 | std::is_floating_point_v<T>, std::uniform_real_distribution<float>, std::uniform_int_distribution<T>>; | ||
32 | |||
33 | 35739 | std::mt19937 rnd(seed); | |
34 | 35739 | TDist dist; | |
35 | |||
36 |
1/2✓ Branch 0 taken 35739 times.
✗ Branch 1 not taken.
|
133053365 | return fill_matrix_raw<T>(height, width, [&](size_t, size_t) { return dist(rnd); }); |
37 | 35739 | } | |
38 | |||
39 | template <> | ||
40 | 7355 | Buffer fill_matrix_random_raw<Float16>(size_t height, size_t width, uint32_t seed) { | |
41 | 7355 | std::mt19937 rnd(seed); | |
42 | 7355 | std::uniform_real_distribution<float> dist; | |
43 | |||
44 |
1/2✓ Branch 0 taken 7355 times.
✗ Branch 1 not taken.
|
28436245 | return fill_matrix_raw<Float16>(height, width, [&](size_t, size_t) { return static_cast<Float16>(dist(rnd)); }); |
45 | 7355 | } | |
46 | |||
47 | template <> | ||
48 | ✗ | Buffer fill_matrix_random_raw<BFloat16<>>(size_t height, size_t width, uint32_t seed) { | |
49 | ✗ | std::mt19937 rnd(seed); | |
50 | ✗ | std::uniform_real_distribution<float> dist; | |
51 | |||
52 | ✗ | return fill_matrix_raw<BFloat16<>>( | |
53 | ✗ | height, width, [&](size_t, size_t) { return static_cast<BFloat16<>>(dist(rnd)); }); | |
54 | ✗ | } | |
55 | |||
56 | template <> | ||
57 | 980 | Buffer fill_matrix_random_raw<BFloat16<false>>(size_t height, size_t width, uint32_t seed) { | |
58 | 980 | std::mt19937 rnd(seed); | |
59 | 980 | std::uniform_real_distribution<float> dist; | |
60 | |||
61 |
1/2✓ Branch 0 taken 980 times.
✗ Branch 1 not taken.
|
980 | return fill_matrix_raw<BFloat16<false>>( |
62 | 1943596 | height, width, [&](size_t, size_t) { return static_cast<BFloat16<false>>(dist(rnd)); }); | |
63 | 980 | } | |
64 | |||
65 | template <> | ||
66 | ✗ | Buffer fill_matrix_random_raw<Int4>(size_t height, size_t width, uint32_t seed) { | |
67 | ✗ | std::mt19937 rnd(seed); | |
68 | ✗ | std::uniform_int_distribution<int16_t> dist(-8, 7); | |
69 | |||
70 | ✗ | return fill_matrix_raw<Int4>(height, width, [&](size_t, size_t) { return Int4(static_cast<int8_t>(dist(rnd))); }); | |
71 | ✗ | } | |
72 | |||
73 | template <> | ||
74 | ✗ | Buffer fill_matrix_random_raw<UInt4>(size_t height, size_t width, uint32_t seed) { | |
75 | ✗ | std::mt19937 rnd(seed); | |
76 | ✗ | std::uniform_int_distribution<int16_t> dist(0, 15); | |
77 | |||
78 | ✗ | return fill_matrix_raw<UInt4>(height, width, [&](size_t, size_t) { return UInt4(static_cast<int8_t>(dist(rnd))); }); | |
79 | ✗ | } | |
80 | |||
81 | } // namespace | ||
82 | |||
83 | template <typename T> | ||
84 | 45674 | Buffer fill_matrix_raw(size_t height, size_t width, std::function<T(size_t, size_t)> gen) { | |
85 | 45674 | const auto size = height * width * size_in_bits<T> / 8; | |
86 | − | KAI_ASSUME(width * size_in_bits<T> % 8 == 0); | |
87 | |||
88 | 45674 | Buffer data(size); | |
89 |
1/2✓ Branch 0 taken 37339 times.
✗ Branch 1 not taken.
|
45674 | auto ptr = reinterpret_cast<T*>(data.data()); |
90 | |||
91 |
6/12✓ Branch 0 taken 37339 times.
✓ Branch 1 taken 569796 times.
✓ Branch 2 taken 7355 times.
✓ Branch 3 taken 524591 times.
✓ Branch 4 taken 980 times.
✓ Branch 5 taken 980 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
1141041 | for (size_t y = 0; y < height; ++y) { |
92 |
6/12✓ Branch 0 taken 569796 times.
✓ Branch 1 taken 139891226 times.
✓ Branch 2 taken 524591 times.
✓ Branch 3 taken 28428890 times.
✓ Branch 4 taken 980 times.
✓ Branch 5 taken 1942616 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
171358099 | for (size_t x = 0; x < width; ++x) { |
93 |
6/24✓ Branch 0 taken 139891226 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 139891226 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 28428890 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 28428890 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1942616 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1942616 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
|
170262732 | write_array<T>(ptr, y * width + x, gen(y, x)); |
94 | 170262732 | } | |
95 | 1095367 | } | |
96 | |||
97 | 45674 | return data; | |
98 | 45674 | } | |
99 | |||
100 | 7555 | Buffer fill_matrix_random(size_t height, size_t width, const DataFormat& format, uint32_t seed) { | |
101 |
1/2✓ Branch 0 taken 7555 times.
✗ Branch 1 not taken.
|
7555 | switch (format.pack_format()) { |
102 | case DataFormat::PackFormat::NONE: | ||
103 |
2/6✗ Branch 0 not taken.
✓ Branch 1 taken 3936 times.
✓ Branch 2 taken 3619 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
7555 | switch (format.data_type()) { |
104 | case DataType::FP32: | ||
105 | 3936 | return fill_matrix_random_raw<float>(height, width, seed); | |
106 | |||
107 | case DataType::FP16: | ||
108 | 3619 | return fill_matrix_random_raw<Float16>(height, width, seed); | |
109 | |||
110 | case DataType::BF16: | ||
111 | ✗ | return fill_matrix_random_raw<BFloat16<>>(height, width, seed); | |
112 | |||
113 | case DataType::QSU4: | ||
114 | ✗ | return fill_matrix_random_raw<UInt4>(height, width, seed); | |
115 | |||
116 | case DataType::QSI4: | ||
117 | ✗ | return fill_matrix_random_raw<Int4>(height, width, seed); | |
118 | |||
119 | default: | ||
120 | − | KAI_ERROR("Unsupported data type!"); | |
121 | ✗ | } | |
122 | |||
123 | ✗ | break; | |
124 | |||
125 | default: | ||
126 | − | KAI_ERROR("Unsupported data format!"); | |
127 | ✗ | } | |
128 | 7555 | } | |
129 | |||
130 | template <typename Value> | ||
131 | 36519 | Buffer fill_random(size_t length, uint32_t seed) { | |
132 | 36519 | return fill_matrix_random_raw<Value>(1, length, seed); | |
133 | } | ||
134 | |||
135 | template Buffer fill_random<float>(size_t length, uint32_t seed); | ||
136 | template Buffer fill_random<Float16>(size_t length, uint32_t seed); | ||
137 | template Buffer fill_matrix_raw<float>(size_t height, size_t width, std::function<float(size_t, size_t)> gen); | ||
138 | template Buffer fill_random<BFloat16<false>>(size_t length, uint32_t seed); | ||
139 | |||
140 | } // namespace kai::test | ||
141 |