Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/pad.hpp" | ||
8 | |||
9 | #include <cstddef> | ||
10 | #include <cstdint> | ||
11 | #include <cstring> | ||
12 | |||
13 | #include "kai/kai_common.h" | ||
14 | #include "test/common/buffer.hpp" | ||
15 | #include "test/common/data_type.hpp" | ||
16 | #include "test/common/memory.hpp" | ||
17 | #include "test/common/round.hpp" | ||
18 | |||
19 | namespace kai::test { | ||
20 | |||
21 | template <typename T> | ||
22 | 5010 | Buffer pad_row( | |
23 | const void* data, const size_t height, const size_t width, const size_t src_stride, const size_t dst_stride, | ||
24 | const size_t dst_size, const uint8_t val) { | ||
25 | 5010 | Buffer output(dst_size, val); | |
26 | |||
27 |
4/4✓ Branch 0 taken 106716 times.
✓ Branch 1 taken 2104 times.
✓ Branch 2 taken 218284 times.
✓ Branch 3 taken 2906 times.
|
330010 | for (size_t y = 0; y < height; ++y) { |
28 |
4/4✓ Branch 0 taken 5766948 times.
✓ Branch 1 taken 106716 times.
✓ Branch 2 taken 27812736 times.
✓ Branch 3 taken 218284 times.
|
33904684 | for (size_t x = 0; x < width; ++x) { |
29 |
2/4✓ Branch 0 taken 5766948 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 27812736 times.
✗ Branch 3 not taken.
|
33579684 | auto element = read_array<T>(data, (y * src_stride) + x); |
30 |
4/8✓ Branch 0 taken 5766948 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5766948 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 27812736 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 27812736 times.
✗ Branch 7 not taken.
|
33579684 | write_array<T>(output.data(), (y * dst_stride) + x, element); |
31 | 33579684 | } | |
32 | 325000 | } | |
33 | 5010 | return output; | |
34 | 5010 | } | |
35 | template Buffer pad_row<Int4>( | ||
36 | const void* data, const size_t height, const size_t width, const size_t src_stride, const size_t dst_stride, | ||
37 | const size_t dst_size, const uint8_t val); | ||
38 | |||
39 | template Buffer pad_row<UInt4>( | ||
40 | const void* data, const size_t height, const size_t width, const size_t src_stride, const size_t dst_stride, | ||
41 | const size_t dst_size, const uint8_t val); | ||
42 | |||
43 | template <typename T> | ||
44 | 1854 | Buffer pad_matrix( | |
45 | const void* data, size_t height, size_t width, size_t pad_left, size_t pad_top, size_t pad_right, size_t pad_bottom, | ||
46 | T pad_value) { | ||
47 | 1854 | const size_t dst_height = height + pad_top + pad_bottom; | |
48 | 1854 | const size_t dst_width = width + pad_left + pad_right; | |
49 | 1854 | const size_t dst_size = round_up_multiple(dst_height * dst_width * size_in_bits<T>, 8); | |
50 | |||
51 | 1854 | Buffer dst(dst_size); | |
52 | |||
53 |
4/4✓ Branch 0 taken 927 times.
✓ Branch 1 taken 927 times.
✓ Branch 2 taken 927 times.
✓ Branch 3 taken 927 times.
|
3708 | for (size_t row = 0; row < dst_height; ++row) { |
54 |
4/4✓ Branch 0 taken 927 times.
✓ Branch 1 taken 116832 times.
✓ Branch 2 taken 927 times.
✓ Branch 3 taken 116832 times.
|
235518 | for (size_t col = 0; col < dst_width; ++col) { |
55 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 116832 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 116832 times.
|
233664 | const bool valid_row = row >= pad_top && row < pad_top + height; |
56 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 116832 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 116832 times.
|
233664 | const bool valid_col = col >= pad_left && col < pad_left + width; |
57 |
6/8✓ Branch 0 taken 116832 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11454 times.
✓ Branch 3 taken 105378 times.
✓ Branch 4 taken 116832 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 11454 times.
✓ Branch 7 taken 105378 times.
|
233664 | if (valid_row && valid_col) { |
58 |
2/4✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 105378 times.
✗ Branch 3 not taken.
|
210756 | const T value = read_array<T>(data, (row - pad_top) * width + col - pad_left); |
59 |
4/8✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 105378 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 105378 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 105378 times.
✗ Branch 7 not taken.
|
210756 | write_array<T>(dst.data(), row * dst_width + col, value); |
60 | 210756 | } else { | |
61 |
4/8✓ Branch 0 taken 11454 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11454 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11454 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 11454 times.
✗ Branch 7 not taken.
|
22908 | write_array<T>(dst.data(), row * dst_width + col, pad_value); |
62 | } | ||
63 | 233664 | } | |
64 | 1854 | } | |
65 | |||
66 | 1854 | return dst; | |
67 | 1854 | } | |
68 | |||
69 | template Buffer pad_matrix( | ||
70 | const void* data, size_t height, size_t width, size_t pad_left, size_t pad_top, size_t pad_right, size_t pad_bottom, | ||
71 | float pad_value); | ||
72 | template Buffer pad_matrix( | ||
73 | const void* data, size_t height, size_t width, size_t pad_left, size_t pad_top, size_t pad_right, size_t pad_bottom, | ||
74 | int32_t pad_value); | ||
75 | |||
76 | } // namespace kai::test | ||
77 |