Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/reduce.hpp" | ||
8 | |||
9 | #include <algorithm> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | |||
13 | #include "kai/kai_common.h" | ||
14 | #include "test/common/buffer.hpp" | ||
15 | #include "test/common/data_format.hpp" | ||
16 | #include "test/common/data_type.hpp" | ||
17 | #include "test/common/int4.hpp" | ||
18 | #include "test/common/memory.hpp" | ||
19 | #include "test/common/round.hpp" | ||
20 | |||
21 | namespace kai::test { | ||
22 | |||
23 | namespace { | ||
24 | |||
25 | template <const ReductionOperator op, typename T> | ||
26 | ✗ | T scalar_reduce(T curr_value, T new_value) { | |
27 | if constexpr (op == ReductionOperator::ADD) { | ||
28 | ✗ | return curr_value + new_value; | |
29 | } | ||
30 | } | ||
31 | |||
32 | template <const ReductionOperator op, typename Input, typename Output> | ||
33 | ✗ | Buffer reduce_any_op_type(const void* src, size_t height, size_t width, size_t dimension) { | |
34 | ✗ | switch (dimension) { | |
35 | case 0: { | ||
36 | ✗ | Buffer dst(height * size_in_bits<Output> / 8); | |
37 | − | KAI_ASSUME(height * size_in_bits<Output> % 8 == 0); | |
38 | |||
39 | ✗ | for (size_t y = 0; y < height; ++y) { | |
40 | ✗ | Output acc = read_array<Input>(src, y * width); | |
41 | |||
42 | ✗ | for (size_t x = 1; x < width; ++x) { | |
43 | ✗ | Output value = read_array<Input>(src, y * width + x); | |
44 | ✗ | acc = scalar_reduce<op, Output>(acc, value); | |
45 | ✗ | } | |
46 | |||
47 | ✗ | write_array<Output>(dst.data(), y, acc); | |
48 | ✗ | } | |
49 | |||
50 | ✗ | return dst; | |
51 | ✗ | } | |
52 | |||
53 | case 1: { | ||
54 | ✗ | Buffer dst(width * size_in_bits<Output> / 8); | |
55 | − | KAI_ASSUME(width * size_in_bits<Output> % 8 == 0); | |
56 | |||
57 | ✗ | for (size_t x = 0; x < width; ++x) { | |
58 | ✗ | Output acc = read_array<Input>(src, x); | |
59 | |||
60 | ✗ | for (size_t y = 1; y < height; ++y) { | |
61 | ✗ | Output value = read_array<Input>(src, y * width + x); | |
62 | ✗ | acc = scalar_reduce<op, Output>(acc, value); | |
63 | ✗ | } | |
64 | |||
65 | ✗ | write_array<Output>(dst.data(), x, acc); | |
66 | ✗ | } | |
67 | |||
68 | ✗ | return dst; | |
69 | ✗ | } | |
70 | |||
71 | default: | ||
72 | − | KAI_ERROR("Only 2D data is supported!"); | |
73 | ✗ | } | |
74 | ✗ | } | |
75 | |||
76 | template <const ReductionOperator op> | ||
77 | ✗ | Buffer reduce_any_op( | |
78 | const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, | ||
79 | size_t dimension) { | ||
80 | − | KAI_ASSUME(src_format.is_raw()); | |
81 | − | KAI_ASSUME(dst_format.is_raw()); | |
82 | − | KAI_ASSUME(dimension < 2); | |
83 | − | KAI_ASSUME(height > 0); | |
84 | − | KAI_ASSUME(width > 0); | |
85 | |||
86 | ✗ | const auto src_dt = src_format.data_type(); | |
87 | ✗ | const auto dst_dt = dst_format.data_type(); | |
88 | |||
89 | ✗ | switch (src_dt) { | |
90 | case DataType::QSU4: | ||
91 | ✗ | switch (dst_dt) { | |
92 | case DataType::I32: | ||
93 | ✗ | return reduce_any_op_type<op, UInt4, int32_t>(src, height, width, dimension); | |
94 | break; | ||
95 | |||
96 | default: | ||
97 | − | KAI_ERROR("Unsupported data type!"); | |
98 | ✗ | } | |
99 | |||
100 | default: | ||
101 | − | KAI_ERROR("Unsupported data type!"); | |
102 | ✗ | } | |
103 | ✗ | } | |
104 | |||
105 | } // namespace | ||
106 | |||
107 | ✗ | Buffer reduce_add( | |
108 | const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, | ||
109 | size_t dimension) { | ||
110 | ✗ | return reduce_any_op<ReductionOperator::ADD>(src, src_format, height, width, dst_format, dimension); | |
111 | } | ||
112 | |||
113 | template <typename Value, typename Accumulator> | ||
114 | 927 | Buffer reduce_add_x(const void* src, size_t height, size_t width) { | |
115 | 927 | Buffer dst(round_up_division(height * size_in_bits<Accumulator>, 8)); | |
116 | |||
117 |
2/2✓ Branch 0 taken 927 times.
✓ Branch 1 taken 105378 times.
|
106305 | for (size_t y = 0; y < height; ++y) { |
118 | 105378 | Accumulator acc = 0; | |
119 | |||
120 |
2/2✓ Branch 0 taken 16220061 times.
✓ Branch 1 taken 105378 times.
|
16325439 | for (size_t x = 0; x < width; ++x) { |
121 |
1/2✓ Branch 0 taken 16220061 times.
✗ Branch 1 not taken.
|
16220061 | acc += static_cast<Accumulator>(read_array<Value>(src, y * width + x)); |
122 | 16220061 | } | |
123 | |||
124 |
2/4✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 105378 times.
✗ Branch 3 not taken.
|
105378 | write_array<Accumulator>(dst.data(), y, acc); |
125 | 105378 | } | |
126 | |||
127 | 927 | return dst; | |
128 | 927 | } | |
129 | |||
130 | template Buffer reduce_add_x<int8_t, int32_t>(const void* src, size_t height, size_t width); | ||
131 | |||
132 | template <typename T> | ||
133 | 927 | T reduce_min(const void* src, size_t len) { | |
134 | − | KAI_ASSUME(len > 0); | |
135 | |||
136 | 927 | T min = read_array<T>(src, 0); | |
137 | |||
138 |
2/2✓ Branch 0 taken 927 times.
✓ Branch 1 taken 1708899 times.
|
1709826 | for (size_t i = 1; i < len; ++i) { |
139 | 1708899 | min = std::min(min, read_array<T>(src, i)); | |
140 | 1708899 | } | |
141 | |||
142 | 1854 | return min; | |
143 | 927 | } | |
144 | |||
145 | template float reduce_min(const void* src, size_t len); | ||
146 | |||
147 | template <typename T> | ||
148 | 927 | T reduce_max(const void* src, size_t len) { | |
149 | − | KAI_ASSUME(len > 0); | |
150 | |||
151 | 927 | T max = read_array<T>(src, 0); | |
152 | |||
153 |
2/2✓ Branch 0 taken 927 times.
✓ Branch 1 taken 1708899 times.
|
1709826 | for (size_t i = 1; i < len; ++i) { |
154 | 1708899 | max = std::max(max, read_array<T>(src, i)); | |
155 | 1708899 | } | |
156 | |||
157 | 1854 | return max; | |
158 | 927 | } | |
159 | |||
160 | template float reduce_max(const void* src, size_t len); | ||
161 | |||
162 | } // namespace kai::test | ||
163 |