Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/clamp.hpp" | ||
8 | |||
9 | #include <algorithm> | ||
10 | #include <cstddef> | ||
11 | |||
12 | #include "kai/kai_common.h" | ||
13 | #include "test/common/buffer.hpp" | ||
14 | #include "test/common/float16.hpp" | ||
15 | #include "test/common/memory.hpp" | ||
16 | #include "test/common/numeric_limits.hpp" | ||
17 | #include "test/common/round.hpp" | ||
18 | |||
19 | namespace kai::test { | ||
20 | |||
21 | template <typename T> | ||
22 | 8520 | std::tuple<T, T> find_clamp_range(const void* src, size_t len, float ratio) { | |
23 | − | KAI_ASSUME(ratio > 0.0F); | |
24 | − | KAI_ASSUME(ratio <= 1.0F); | |
25 | |||
26 | 8520 | T min_value = numeric_highest<T>; | |
27 | 8520 | T max_value = numeric_lowest<T>; | |
28 | |||
29 |
2/4✓ Branch 0 taken 8520 times.
✓ Branch 1 taken 12332220 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
12340740 | for (size_t i = 0; i < len; ++i) { |
30 | 12332220 | const T value = read_array<T>(src, i); | |
31 | |||
32 | 12332220 | min_value = std::min(min_value, value); | |
33 | 12332220 | max_value = std::max(max_value, value); | |
34 | 12332220 | } | |
35 | |||
36 | 8520 | min_value = std::max(min_value, numeric_lowest<T>); | |
37 | 8520 | max_value = std::min(max_value, numeric_highest<T>); | |
38 | |||
39 | 8520 | const T range = max_value - min_value; | |
40 | 8520 | const T reduction = static_cast<T>(static_cast<float>(range) * (1.0F - ratio) / 2); | |
41 | |||
42 | 8520 | const T clamp_min_value = min_value + reduction; | |
43 | 8520 | const T clamp_max_value = max_value - reduction; | |
44 | |||
45 | 8520 | return {clamp_min_value, clamp_max_value}; | |
46 | 8520 | } | |
47 | |||
48 | template std::tuple<float, float> find_clamp_range(const void* src, size_t len, float ratio); | ||
49 | template std::tuple<Float16, Float16> find_clamp_range(const void* src, size_t len, float ratio); | ||
50 | |||
51 | 2429 | std::tuple<float, float> find_clamp_range(DataType type, const void* src, size_t len, float ratio) { | |
52 | 2429 | auto max = std::numeric_limits<float>::min(); | |
53 | 2429 | auto min = std::numeric_limits<float>::max(); | |
54 | |||
55 |
2/2✓ Branch 0 taken 2429 times.
✓ Branch 1 taken 34287499 times.
|
34289928 | for (size_t i = 0; i < len; i += 1) { |
56 | 34287499 | const float value = read_array(type, src, i); | |
57 | 34287499 | max = std::max(value, max); | |
58 | 34287499 | min = std::min(value, min); | |
59 | 34287499 | } | |
60 | |||
61 | 2429 | const float reduction = (max - min) * (1.0F - ratio) / 2.0F; | |
62 | 2429 | return {min + reduction, max - reduction}; | |
63 | 2429 | } | |
64 | |||
65 | template <typename T> | ||
66 | 9447 | Buffer clamp(const void* src, size_t len, T min_value, T max_value) { | |
67 | 9447 | Buffer dst(round_up_division(len * size_in_bits<T>, 8)); | |
68 | |||
69 |
2/4✓ Branch 0 taken 14042046 times.
✓ Branch 1 taken 9447 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
14051493 | for (size_t i = 0; i < len; ++i) { |
70 |
4/16✗ Branch 0 not taken.
✓ Branch 1 taken 14042046 times.
✓ Branch 2 taken 14042046 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14042046 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 14042046 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
|
14042046 | write_array<T>(dst.data(), i, std::clamp(read_array<T>(src, i), min_value, max_value)); |
71 | 14042046 | } | |
72 | |||
73 | 9447 | return dst; | |
74 | 9447 | } | |
75 | |||
76 | template Buffer clamp(const void* src, size_t len, float min_value, float max_value); | ||
77 | template Buffer clamp(const void* src, size_t len, Float16 min_value, Float16 max_value); | ||
78 | |||
79 | 2429 | Buffer clamp(DataType type, const void* src, size_t len, float min_value, float max_value) { | |
80 | 2429 | Buffer dst(round_up_division(len * data_type_size_in_bits(type), 8)); | |
81 | |||
82 |
2/2✓ Branch 0 taken 34287499 times.
✓ Branch 1 taken 2429 times.
|
34289928 | for (size_t i = 0; i < len; ++i) { |
83 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 34287499 times.
✓ Branch 2 taken 34287499 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 34287499 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 34287499 times.
✗ Branch 7 not taken.
|
34287499 | write_array(type, dst.data(), i, std::clamp<float>(read_array(type, src, i), min_value, max_value)); |
84 | 34287499 | } | |
85 | |||
86 | 2429 | return dst; | |
87 | 2429 | } | |
88 | |||
89 | } // namespace kai::test | ||
90 |