test/reference/reduce.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/reference/reduce.hpp" | ||
| 8 | |||
| 9 | #include <algorithm> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | |||
| 13 | #include "kai/kai_common.h" | ||
| 14 | #include "test/common/buffer.hpp" | ||
| 15 | #include "test/common/data_format.hpp" | ||
| 16 | #include "test/common/data_type.hpp" | ||
| 17 | #include "test/common/int4.hpp" | ||
| 18 | #include "test/common/memory.hpp" | ||
| 19 | #include "test/common/round.hpp" | ||
| 20 | |||
| 21 | namespace kai::test { | ||
| 22 | |||
| 23 | namespace { | ||
| 24 | |||
| 25 | template <const ReductionOperator op, typename T> | ||
| 26 | ✗ | T scalar_reduce(T curr_value, T new_value) { | |
| 27 | if constexpr (op == ReductionOperator::ADD) { | ||
| 28 | ✗ | return curr_value + new_value; | |
| 29 | } | ||
| 30 | } | ||
| 31 | |||
| 32 | template <const ReductionOperator op, typename Input, typename Output> | ||
| 33 | ✗ | Buffer reduce_any_op_type(const void* src, size_t height, size_t width, size_t dimension) { | |
| 34 | ✗ | switch (dimension) { | |
| 35 | case 0: { | ||
| 36 | ✗ | Buffer dst(height * size_in_bits<Output> / 8); | |
| 37 | − | KAI_ASSUME_ALWAYS(height * size_in_bits<Output> % 8 == 0); | |
| 38 | |||
| 39 | ✗ | for (size_t y = 0; y < height; ++y) { | |
| 40 | ✗ | Output acc = read_array<Input>(src, y * width); | |
| 41 | |||
| 42 | ✗ | for (size_t x = 1; x < width; ++x) { | |
| 43 | ✗ | Output value = read_array<Input>(src, y * width + x); | |
| 44 | ✗ | acc = scalar_reduce<op, Output>(acc, value); | |
| 45 | ✗ | } | |
| 46 | |||
| 47 | ✗ | write_array<Output>(dst.data(), y, acc); | |
| 48 | ✗ | } | |
| 49 | |||
| 50 | ✗ | return dst; | |
| 51 | ✗ | } | |
| 52 | |||
| 53 | case 1: { | ||
| 54 | ✗ | Buffer dst(width * size_in_bits<Output> / 8); | |
| 55 | − | KAI_ASSUME_ALWAYS(width * size_in_bits<Output> % 8 == 0); | |
| 56 | |||
| 57 | ✗ | for (size_t x = 0; x < width; ++x) { | |
| 58 | ✗ | Output acc = read_array<Input>(src, x); | |
| 59 | |||
| 60 | ✗ | for (size_t y = 1; y < height; ++y) { | |
| 61 | ✗ | Output value = read_array<Input>(src, y * width + x); | |
| 62 | ✗ | acc = scalar_reduce<op, Output>(acc, value); | |
| 63 | ✗ | } | |
| 64 | |||
| 65 | ✗ | write_array<Output>(dst.data(), x, acc); | |
| 66 | ✗ | } | |
| 67 | |||
| 68 | ✗ | return dst; | |
| 69 | ✗ | } | |
| 70 | |||
| 71 | default: | ||
| 72 | − | KAI_ERROR("Only 2D data is supported!"); | |
| 73 | ✗ | } | |
| 74 | ✗ | } | |
| 75 | |||
| 76 | template <const ReductionOperator op> | ||
| 77 | ✗ | Buffer reduce_any_op( | |
| 78 | const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, | ||
| 79 | size_t dimension) { | ||
| 80 | − | KAI_ASSUME_ALWAYS(src_format.is_raw()); | |
| 81 | − | KAI_ASSUME_ALWAYS(dst_format.is_raw()); | |
| 82 | − | KAI_ASSUME_ALWAYS(dimension < 2); | |
| 83 | − | KAI_ASSUME_ALWAYS(height > 0); | |
| 84 | − | KAI_ASSUME_ALWAYS(width > 0); | |
| 85 | |||
| 86 | ✗ | const auto src_dt = src_format.data_type(); | |
| 87 | ✗ | const auto dst_dt = dst_format.data_type(); | |
| 88 | |||
| 89 | ✗ | switch (src_dt) { | |
| 90 | case DataType::QSU4: | ||
| 91 | ✗ | switch (dst_dt) { | |
| 92 | case DataType::I32: | ||
| 93 | ✗ | return reduce_any_op_type<op, UInt4, int32_t>(src, height, width, dimension); | |
| 94 | break; | ||
| 95 | |||
| 96 | default: | ||
| 97 | − | KAI_ERROR("Unsupported data type!"); | |
| 98 | ✗ | } | |
| 99 | |||
| 100 | default: | ||
| 101 | − | KAI_ERROR("Unsupported data type!"); | |
| 102 | ✗ | } | |
| 103 | ✗ | } | |
| 104 | |||
| 105 | } // namespace | ||
| 106 | |||
| 107 | ✗ | Buffer reduce_add( | |
| 108 | const void* src, const DataFormat& src_format, size_t height, size_t width, const DataFormat& dst_format, | ||
| 109 | size_t dimension) { | ||
| 110 | ✗ | return reduce_any_op<ReductionOperator::ADD>(src, src_format, height, width, dst_format, dimension); | |
| 111 | } | ||
| 112 | |||
| 113 | template <typename Value, typename Accumulator> | ||
| 114 | 521 | Buffer reduce_add_x(const void* src, size_t height, size_t width) { | |
| 115 | 521 | Buffer dst(round_up_division(height * size_in_bits<Accumulator>, 8)); | |
| 116 | |||
| 117 |
2/2✓ Branch 0 taken 521 times.
✓ Branch 1 taken 60763 times.
|
61284 | for (size_t y = 0; y < height; ++y) { |
| 118 | 60763 | Accumulator acc = 0; | |
| 119 | |||
| 120 |
2/2✓ Branch 0 taken 5901096 times.
✓ Branch 1 taken 60763 times.
|
5961859 | for (size_t x = 0; x < width; ++x) { |
| 121 |
1/2✓ Branch 0 taken 5901096 times.
✗ Branch 1 not taken.
|
5901096 | acc += static_cast<Accumulator>(read_array<Value>(src, y * width + x)); |
| 122 | 5901096 | } | |
| 123 | |||
| 124 |
2/4✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60763 times.
✗ Branch 3 not taken.
|
60763 | write_array<Accumulator>(dst.data(), y, acc); |
| 125 | 60763 | } | |
| 126 | |||
| 127 | 521 | return dst; | |
| 128 | 521 | } | |
| 129 | |||
| 130 | template Buffer reduce_add_x<int8_t, int32_t>(const void* src, size_t height, size_t width); | ||
| 131 | |||
| 132 | template <typename T> | ||
| 133 | 521 | T reduce_min(const void* src, size_t len) { | |
| 134 | − | KAI_ASSUME_ALWAYS(len > 0); | |
| 135 | |||
| 136 | 521 | T min = read_array<T>(src, 0); | |
| 137 | |||
| 138 |
2/2✓ Branch 0 taken 521 times.
✓ Branch 1 taken 929318 times.
|
929839 | for (size_t i = 1; i < len; ++i) { |
| 139 | 929318 | min = std::min(min, read_array<T>(src, i)); | |
| 140 | 929318 | } | |
| 141 | |||
| 142 | 1042 | return min; | |
| 143 | 521 | } | |
| 144 | |||
| 145 | template float reduce_min(const void* src, size_t len); | ||
| 146 | |||
| 147 | template <typename T> | ||
| 148 | 521 | T reduce_max(const void* src, size_t len) { | |
| 149 | − | KAI_ASSUME_ALWAYS(len > 0); | |
| 150 | |||
| 151 | 521 | T max = read_array<T>(src, 0); | |
| 152 | |||
| 153 |
2/2✓ Branch 0 taken 521 times.
✓ Branch 1 taken 929318 times.
|
929839 | for (size_t i = 1; i < len; ++i) { |
| 154 | 929318 | max = std::max(max, read_array<T>(src, i)); | |
| 155 | 929318 | } | |
| 156 | |||
| 157 | 1042 | return max; | |
| 158 | 521 | } | |
| 159 | |||
| 160 | template float reduce_max(const void* src, size_t len); | ||
| 161 | |||
| 162 | } // namespace kai::test | ||
| 163 |