Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) | ||
8 | #error This file must be compiled for AArch64, FEAT_SVE2. | ||
9 | #else // Architectural features check. | ||
10 | |||
11 | #include "kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h" | ||
12 | |||
13 | #include <stddef.h> | ||
14 | #include <stdint.h> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | |||
18 | enum { | ||
19 | NR = 2, | ||
20 | KR = 2, | ||
21 | SR = 1, | ||
22 | NUM_BYTES_DATA = 2, | ||
23 | NUM_BYTES_BIAS = 2, | ||
24 | MAX_BLOCK_HEIGHT = (NR * (KAI_SME_VEC_LENGTH_MAX_BYTES / NUM_BYTES_DATA) / KR), | ||
25 | }; | ||
26 | |||
27 | void kai_kernel_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme( | ||
28 | size_t height, size_t width, const void* in, void* out, const void* bias); | ||
29 | |||
30 | 206 | static size_t get_block_height(void) { | |
31 | 206 | const size_t block_height = NR * kai_get_sme_vector_length_u16() / KR; | |
32 | 412 | return block_height; | |
33 | 206 | } | |
34 | |||
35 | 36 | size_t kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(void) { | |
36 | 36 | return get_block_height(); | |
37 | } | ||
38 | |||
39 | 34 | size_t kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t n_idx, size_t rhs_stride) { | |
40 | − | KAI_ASSUME(n_idx % get_block_height() == 0); | |
41 | |||
42 | 34 | return n_idx * rhs_stride; | |
43 | } | ||
44 | |||
45 | ✗ | size_t kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t n_idx) { | |
46 | − | KAI_ASSUME(n_idx % get_block_height() == 0); | |
47 | |||
48 | ✗ | return n_idx * NUM_BYTES_BIAS; | |
49 | } | ||
50 | |||
51 | 68 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t k) { | |
52 | 68 | return NUM_BYTES_BIAS + kai_roundup(k, KR) * NUM_BYTES_DATA; | |
53 | } | ||
54 | |||
55 | 34 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t n_idx, size_t k) { | |
56 | − | KAI_ASSUME(n_idx % get_block_height() == 0); | |
57 | |||
58 | 34 | return n_idx * kai_get_rhs_packed_stride_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(k); | |
59 | } | ||
60 | |||
61 | 34 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t n, size_t k) { | |
62 | 34 | return kai_roundup(n, get_block_height()) * kai_get_rhs_packed_stride_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(k); | |
63 | } | ||
64 | |||
65 | 34 | void kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme( | |
66 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||
67 | const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { | ||
68 | − | KAI_ASSUME(num_groups == 1); | |
69 | − | KAI_ASSUME(nr == get_block_height()); | |
70 | − | KAI_ASSUME(kr == KR); | |
71 | − | KAI_ASSUME(sr == SR); | |
72 | − | KAI_ASSUME(rhs != NULL); | |
73 | − | KAI_ASSUME(bias != NULL); | |
74 | − | KAI_ASSUME(scale == NULL); | |
75 | − | KAI_ASSUME(rhs_packed != NULL); | |
76 | − | KAI_ASSUME(extra_bytes == 0); | |
77 | − | KAI_ASSUME(params == NULL); | |
78 | |||
79 | 34 | const size_t block_height = get_block_height(); | |
80 | 34 | const size_t width = k; | |
81 | |||
82 | 34 | const uint8_t* in[MAX_BLOCK_HEIGHT]; | |
83 | 34 | uint8_t* rhs_packed_ptr = rhs_packed; | |
84 | 34 | const uint8_t* rhs_ptr = rhs; | |
85 | 34 | const uint8_t* bias_ptr = bias; | |
86 | |||
87 |
2/2✓ Branch 0 taken 34 times.
✓ Branch 1 taken 42 times.
|
76 | for (size_t block_y = 0; block_y < n; block_y += block_height) { |
88 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 18 times.
|
42 | const size_t height = KAI_MIN(n - block_y, block_height); |
89 | 42 | uint8_t* out = rhs_packed_ptr + block_y * (NUM_BYTES_BIAS + kai_roundup(k, KR) * NUM_BYTES_DATA); | |
90 | |||
91 |
2/2✓ Branch 0 taken 892 times.
✓ Branch 1 taken 42 times.
|
934 | for (size_t y = 0; y < height; y++) { |
92 | 892 | in[y] = rhs_ptr + (block_y + y) * rhs_stride; | |
93 | 892 | } | |
94 | |||
95 | 42 | kai_kernel_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme( | |
96 | 42 | height, width, in, out, bias); // NOLINT(bugprone-multi-level-implicit-pointer-conversion) | |
97 | |||
98 | 42 | bias_ptr += height * NUM_BYTES_BIAS; | |
99 | 42 | bias = bias_ptr; | |
100 | 42 | } | |
101 | 34 | } | |
102 | |||
103 | #endif // Architectural features check. | ||
104 |