test/nextgen/reference/unary_elementwise.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/nextgen/reference/unary_elementwise.hpp" | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | #include <cstdint> | ||
| 11 | #include <functional> | ||
| 12 | #include <numeric> | ||
| 13 | |||
| 14 | #include "test/common/assert.hpp" | ||
| 15 | #include "test/common/buffer.hpp" | ||
| 16 | #include "test/common/data_type.hpp" | ||
| 17 | #include "test/common/int4.hpp" | ||
| 18 | #include "test/common/memory.hpp" | ||
| 19 | #include "test/common/round.hpp" | ||
| 20 | #include "test/common/span.hpp" | ||
| 21 | #include "test/common/type_traits.hpp" | ||
| 22 | |||
| 23 | namespace kai::test { | ||
| 24 | |||
| 25 | namespace { | ||
| 26 | |||
| 27 | template <typename Op> | ||
| 28 | 400 | [[nodiscard]] Buffer unary_elementwise(Span<const size_t> shape, Span<const std::byte> data) { | |
| 29 | using Type = typename Op::Type; | ||
| 30 | |||
| 31 | 400 | const size_t width = shape.at(shape.size() - 1); | |
| 32 | 400 | const size_t row_size = round_up_division(width * size_in_bits<Type>, 8); | |
| 33 | 400 | const size_t num_rows = std::accumulate(shape.begin(), shape.end() - 1, 1, std::multiplies<>()); | |
| 34 | 400 | const size_t size = num_rows * row_size; | |
| 35 | |||
| 36 | 400 | Buffer output(size, 0); | |
| 37 | |||
| 38 |
4/4✓ Branch 0 taken 15452 times.
✓ Branch 1 taken 200 times.
✓ Branch 2 taken 15923 times.
✓ Branch 3 taken 200 times.
|
31775 | for (size_t row = 0; row < num_rows; ++row) { |
| 39 |
2/4✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15923 times.
✗ Branch 3 not taken.
|
31375 | const Span<const std::byte> src_row_data = data.subspan(row * row_size, row_size); |
| 40 |
4/8✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15452 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15923 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 15923 times.
✗ Branch 7 not taken.
|
31375 | const Span<std::byte> dst_row_data = Span<std::byte>(output).subspan(row * row_size, row_size); |
| 41 | |||
| 42 |
4/4✓ Branch 0 taken 15452 times.
✓ Branch 1 taken 15452 times.
✓ Branch 2 taken 1272221 times.
✓ Branch 3 taken 15923 times.
|
1319048 | for (size_t col = 0; col < width; ++col) { |
| 43 |
2/4✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1272221 times.
✗ Branch 3 not taken.
|
1287673 | const Type src_value = read_array<Type>(src_row_data, col); |
| 44 |
2/4✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1272221 times.
✗ Branch 3 not taken.
|
1287673 | const Type dst_value = Op::compute(src_value); |
| 45 |
2/4✓ Branch 0 taken 15452 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1272221 times.
✗ Branch 3 not taken.
|
1287673 | write_array<Type>(dst_row_data, col, dst_value); |
| 46 | 1287673 | } | |
| 47 | 31375 | } | |
| 48 | |||
| 49 | 400 | return output; | |
| 50 | 400 | } | |
| 51 | |||
| 52 | template <typename T> | ||
| 53 | struct NegateOp { | ||
| 54 | using Type = T; | ||
| 55 | |||
| 56 | 15452 | [[nodiscard]] static T compute(T value) { | |
| 57 | 15452 | return -value; | |
| 58 | } | ||
| 59 | }; | ||
| 60 | |||
| 61 | template <typename T> | ||
| 62 | struct ChangeSignednessOp { | ||
| 63 | using Type = T; | ||
| 64 | |||
| 65 | 1272221 | [[nodiscard]] static T compute(T value) { | |
| 66 | static_assert(is_integral<T>); | ||
| 67 | static_assert(sizeof(T) < sizeof(uint64_t)); | ||
| 68 | |||
| 69 | 1272221 | constexpr T mid_point = static_cast<T>(static_cast<uint64_t>(1) << (size_in_bits<T> - 1)); | |
| 70 | |||
| 71 | 1272221 | return value + mid_point; | |
| 72 | 1272221 | } | |
| 73 | }; | ||
| 74 | |||
| 75 | } // namespace | ||
| 76 | |||
| 77 | 200 | UnaryElementwiseFn make_negate(DataType dtype) { | |
| 78 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | switch (dtype) { |
| 79 | case DataType::I32: | ||
| 80 | 200 | return unary_elementwise<NegateOp<int32_t>>; | |
| 81 | |||
| 82 | default: | ||
| 83 | ✗ | KAI_TEST_ERROR("Not supported."); | |
| 84 | } | ||
| 85 | ✗ | } | |
| 86 | |||
| 87 | 200 | UnaryElementwiseFn make_change_signedness(DataType dtype) { | |
| 88 |
1/2✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
|
200 | switch (dtype) { |
| 89 | case DataType::U4: | ||
| 90 | case DataType::I4: | ||
| 91 | 200 | return unary_elementwise<ChangeSignednessOp<UInt4>>; | |
| 92 | |||
| 93 | default: | ||
| 94 | ✗ | KAI_TEST_ERROR("Not supported."); | |
| 95 | } | ||
| 96 | ✗ | } | |
| 97 | |||
| 98 | } // namespace kai::test | ||
| 99 |