Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "kai_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme.h" | ||
8 | |||
9 | #include <stdint.h> | ||
10 | #include <string.h> | ||
11 | |||
12 | #include "kai/kai_common.h" | ||
13 | |||
14 | 60 | size_t kai_rhs_get_dst_size_dwconv_pack_x32p1vlx1b_x32_x32_sme( | |
15 | size_t filter_height, size_t filter_width, size_t num_channels) { | ||
16 | 60 | const size_t depth_elements = kai_roundup(num_channels, kai_get_sme_vector_length_u32()); | |
17 | 120 | return depth_elements * (filter_height * filter_width + 1) * sizeof(float); | |
18 | 60 | } | |
19 | |||
20 | 60 | void kai_run_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme( | |
21 | size_t filter_height, size_t filter_width, size_t height, size_t width, size_t num_channels, const void* rhs, | ||
22 | const void* bias, void* rhs_packed) { | ||
23 | − | KAI_ASSUME(rhs != NULL); | |
24 | − | KAI_ASSUME(rhs_packed != NULL); | |
25 | − | KAI_ASSUME(bias != NULL); | |
26 | 60 | KAI_UNUSED(height); | |
27 | 60 | KAI_UNUSED(width); | |
28 | |||
29 | // Cast the pointers to byte sizes | ||
30 | 60 | const uint8_t* src = (const uint8_t*)(rhs); | |
31 | 60 | const uint8_t* bias_ptr = (const uint8_t*)(bias); | |
32 | 60 | uint8_t* dst = (uint8_t*)(rhs_packed); | |
33 | |||
34 | 60 | const size_t vl = kai_get_sme_vector_length_u32(); | |
35 | 60 | const size_t element_size = sizeof(float); | |
36 | |||
37 |
2/2✓ Branch 0 taken 60 times.
✓ Branch 1 taken 204 times.
|
264 | for (size_t n = 0; n < num_channels; n += vl) { |
38 |
2/2✓ Branch 0 taken 144 times.
✓ Branch 1 taken 60 times.
|
204 | const size_t count = (vl < (num_channels - n)) ? vl : (num_channels - n); |
39 | 204 | memcpy(dst, bias_ptr, count * element_size); | |
40 | 204 | dst += (vl * element_size); | |
41 | 204 | bias_ptr += (count * element_size); | |
42 | |||
43 |
2/2✓ Branch 0 taken 1836 times.
✓ Branch 1 taken 204 times.
|
2040 | for (size_t idx = 0; idx < filter_height * filter_width; idx++) { |
44 | 1836 | const uint8_t* src_ptr = src + ((idx * num_channels + n) * element_size); | |
45 | 1836 | memcpy(dst, src_ptr, count * element_size); | |
46 | 1836 | dst += (vl * element_size); // move ptr. | |
47 | 1836 | } | |
48 | 204 | } | |
49 | 60 | } | |
50 |