Line |
Branch |
Exec |
Source |
1 |
|
|
// |
2 |
|
|
// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> |
3 |
|
|
// |
4 |
|
|
// SPDX-License-Identifier: Apache-2.0 |
5 |
|
|
// |
6 |
|
|
|
7 |
|
|
#if !defined(__aarch64__) |
8 |
|
|
#error This file must be compiled for AArch64. |
9 |
|
|
#else // Architectural features check. |
10 |
|
|
|
11 |
|
|
#include "kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" |
12 |
|
|
|
13 |
|
|
#include <stddef.h> |
14 |
|
|
#include <stdint.h> |
15 |
|
|
|
16 |
|
|
#include "kai/kai_common.h" |
17 |
|
|
|
18 |
|
|
static const size_t kai_nr = 8; |
19 |
|
|
static const size_t kai_kr = 1; |
20 |
|
|
|
21 |
|
17 |
size_t kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(void) { |
22 |
|
17 |
return kai_nr; |
23 |
|
|
} |
24 |
|
|
|
25 |
|
17 |
size_t kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx) { |
26 |
|
− |
KAI_ASSUME(n_idx % kai_nr == 0); |
27 |
|
|
|
28 |
|
17 |
return n_idx * sizeof(uint32_t); |
29 |
|
|
} |
30 |
|
|
|
31 |
|
17 |
size_t kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx) { |
32 |
|
17 |
return n_idx * sizeof(uint32_t); |
33 |
|
|
} |
34 |
|
|
|
35 |
|
34 |
size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx, size_t k) { |
36 |
|
− |
KAI_ASSUME(n_idx % kai_nr == 0); |
37 |
|
|
|
38 |
|
34 |
return n_idx * (sizeof(uint32_t) + k * sizeof(uint32_t)); |
39 |
|
|
} |
40 |
|
|
|
41 |
|
17 |
size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n, size_t k) { |
42 |
|
17 |
return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(kai_roundup(n, kai_nr), k); |
43 |
|
|
} |
44 |
|
|
|
45 |
|
17 |
void kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon( |
46 |
|
|
size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, |
47 |
|
|
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { |
48 |
|
− |
KAI_ASSUME(num_groups == 1); |
49 |
|
− |
KAI_ASSUME(nr == kai_nr); |
50 |
|
− |
KAI_ASSUME(kr == kai_kr); |
51 |
|
− |
KAI_ASSUME(sr == 1); |
52 |
|
− |
KAI_ASSUME(rhs != NULL); |
53 |
|
− |
KAI_ASSUME(bias != NULL); |
54 |
|
− |
KAI_ASSUME(scale == NULL); |
55 |
|
− |
KAI_ASSUME(rhs_packed != NULL); |
56 |
|
− |
KAI_ASSUME(extra_bytes == 0); |
57 |
|
− |
KAI_ASSUME(params == NULL); |
58 |
|
|
|
59 |
|
17 |
size_t height = k; |
60 |
|
17 |
const size_t width = n; |
61 |
|
17 |
const void* in = rhs; |
62 |
|
17 |
void* out = rhs_packed; |
63 |
|
17 |
const size_t in_stride = rhs_stride; |
64 |
|
17 |
size_t out_stride = kai_nr * height * sizeof(uint32_t) + kai_nr * sizeof(uint32_t); |
65 |
|
|
|
66 |
|
34 |
__asm__ __volatile__( |
67 |
|
|
"mov x22, %x[width]\n" |
68 |
|
|
"mov x21, %x[out]\n" |
69 |
|
|
"cmp x22, #0x8\n" |
70 |
|
|
"blt 2f\n" |
71 |
|
|
"1:" // Bias: Full loop |
72 |
|
|
"ldr q17, [%x[bias], #0x0]\n" |
73 |
|
|
"ldr q16, [%x[bias], #0x10]\n" |
74 |
|
|
"sub x22, x22, #0x8\n" |
75 |
|
|
"add %x[bias], %x[bias], #0x20\n" |
76 |
|
|
"cmp x22, #0x8\n" |
77 |
|
|
"str q17, [x21, #0x0]\n" |
78 |
|
|
"str q16, [x21, #0x10]\n" |
79 |
|
|
"add x21, x21, %x[out_stride]\n" |
80 |
|
|
"bge 1b\n" |
81 |
|
|
"cbz x22, 3f\n" |
82 |
|
|
"2:" // Bias: Tail loop |
83 |
|
|
"ldr w20, [%x[bias], #0x0]\n" |
84 |
|
|
"sub x22, x22, #0x1\n" |
85 |
|
|
"add %x[bias], %x[bias], #0x4\n" |
86 |
|
|
"cmp x22, #0x0\n" |
87 |
|
|
"str x20, [x21]\n" |
88 |
|
|
"add x21, x21, #0x4\n" |
89 |
|
|
"bgt 2b\n" |
90 |
|
|
"3:" // Bias: Done |
91 |
|
|
"cmp %x[height], #0x4\n" |
92 |
|
|
"add %x[out], %x[out], #0x20\n" |
93 |
|
|
"blt 12f\n" |
94 |
|
|
"4:" // Main row loop: Head |
95 |
|
|
"mov x25, %x[in]\n" |
96 |
|
|
"mov x24, %x[width]\n" |
97 |
|
|
"mov x23, %x[out]\n" |
98 |
|
|
"sub %x[height], %x[height], #0x4\n" |
99 |
|
|
"add x22, x25, %x[in_stride]\n" |
100 |
|
|
"add x21, x22, %x[in_stride]\n" |
101 |
|
|
"add x20, x21, %x[in_stride]\n" |
102 |
|
|
"cmp x24, #0x8\n" |
103 |
|
|
"add %x[in], x20, %x[in_stride]\n" |
104 |
|
|
"blt 6f\n" |
105 |
|
|
"5:" // Main row loop: Column loop |
106 |
|
|
"ldr q23, [x25], #0x10\n" |
107 |
|
|
"ldr q22, [x22], #0x10\n" |
108 |
|
|
"sub x24, x24, #0x8\n" |
109 |
|
|
"ldr q21, [x21], #0x10\n" |
110 |
|
|
"ldr q20, [x20], #0x10\n" |
111 |
|
|
"cmp x24, #0x8\n" |
112 |
|
|
"ldr q19, [x25], #0x10\n" |
113 |
|
|
"ldr q18, [x22], #0x10\n" |
114 |
|
|
"ldr q17, [x21], #0x10\n" |
115 |
|
|
"ldr q16, [x20], #0x10\n" |
116 |
|
|
"str q23, [x23, #0x0]\n" |
117 |
|
|
"str q19, [x23, #0x10]\n" |
118 |
|
|
"str q22, [x23, #0x20]\n" |
119 |
|
|
"str q18, [x23, #0x30]\n" |
120 |
|
|
"str q21, [x23, #0x40]\n" |
121 |
|
|
"str q17, [x23, #0x50]\n" |
122 |
|
|
"str q20, [x23, #0x60]\n" |
123 |
|
|
"str q16, [x23, #0x70]\n" |
124 |
|
|
"add x23, x23, %x[out_stride]\n" |
125 |
|
|
"bge 5b\n" |
126 |
|
|
"6:" // Main row loop: Column loop skip |
127 |
|
|
"cbz x24, 11f\n" |
128 |
|
|
"cmp x24, #0x4\n" |
129 |
|
|
"movi v16.4s, #0x0\n" |
130 |
|
|
"str q16, [x23, #0x0]\n" |
131 |
|
|
"str q16, [x23, #0x10]\n" |
132 |
|
|
"str q16, [x23, #0x20]\n" |
133 |
|
|
"str q16, [x23, #0x30]\n" |
134 |
|
|
"str q16, [x23, #0x40]\n" |
135 |
|
|
"str q16, [x23, #0x50]\n" |
136 |
|
|
"str q16, [x23, #0x60]\n" |
137 |
|
|
"str q16, [x23, #0x70]\n" |
138 |
|
|
"blt 8f\n" |
139 |
|
|
"7:" // Main row loop: width 4 loop: loop |
140 |
|
|
"ldr q19, [x25], #0x10\n" |
141 |
|
|
"ldr q18, [x22], #0x10\n" |
142 |
|
|
"sub x24, x24, #0x4\n" |
143 |
|
|
"ldr q17, [x21], #0x10\n" |
144 |
|
|
"ldr q16, [x20], #0x10\n" |
145 |
|
|
"cmp x24, #0x4\n" |
146 |
|
|
"str q19, [x23, #0x0]\n" |
147 |
|
|
"str q18, [x23, #0x20]\n" |
148 |
|
|
"str q17, [x23, #0x40]\n" |
149 |
|
|
"str q16, [x23, #0x60]\n" |
150 |
|
|
"add x23, x23, #0x10\n" |
151 |
|
|
"bge 7b\n" |
152 |
|
|
"8:" // Main row loop: width 4 loop: skip |
153 |
|
|
"cmp x24, #0x1\n" |
154 |
|
|
"blt 10f\n" |
155 |
|
|
"9:" // Main row loop: width 1 loop: loop |
156 |
|
|
"ldr s19, [x25], #0x4\n" |
157 |
|
|
"ldr s18, [x22], #0x4\n" |
158 |
|
|
"sub x24, x24, #0x1\n" |
159 |
|
|
"ldr s17, [x21], #0x4\n" |
160 |
|
|
"ldr s16, [x20], #0x4\n" |
161 |
|
|
"cmp x24, #0x1\n" |
162 |
|
|
"str s19, [x23, #0x0]\n" |
163 |
|
|
"str s18, [x23, #0x20]\n" |
164 |
|
|
"str s17, [x23, #0x40]\n" |
165 |
|
|
"str s16, [x23, #0x60]\n" |
166 |
|
|
"add x23, x23, #0x4\n" |
167 |
|
|
"bge 9b\n" |
168 |
|
|
"10:" // Main row loop: width 1 loop: skip |
169 |
|
|
"11:" // Main row loop: odd col skip |
170 |
|
|
"cmp %x[height], #0x4\n" |
171 |
|
|
"add %x[out], %x[out], #0x80\n" |
172 |
|
|
"bge 4b\n" |
173 |
|
|
"cbz %x[height], 21f\n" |
174 |
|
|
"12:" // Main loop skip |
175 |
|
|
"13:" // Tail row loop: Head |
176 |
|
|
"mov x20, %x[width]\n" |
177 |
|
|
"mov x25, %x[in]\n" |
178 |
|
|
"mov x23, %x[out]\n" |
179 |
|
|
"sub %x[height], %x[height], #0x1\n" |
180 |
|
|
"cmp x20, #0x8\n" |
181 |
|
|
"add %x[in], x25, %x[in_stride]\n" |
182 |
|
|
"blt 15f\n" |
183 |
|
|
"14:" // Tail row loop: Column loop |
184 |
|
|
"ldr q17, [x25], #0x10\n" |
185 |
|
|
"sub x20, x20, #0x8\n" |
186 |
|
|
"ldr q16, [x25], #0x10\n" |
187 |
|
|
"cmp x20, #0x8\n" |
188 |
|
|
"str q17, [x23, #0x0]\n" |
189 |
|
|
"str q16, [x23, #0x10]\n" |
190 |
|
|
"add x23, x23, %x[out_stride]\n" |
191 |
|
|
"bge 14b\n" |
192 |
|
|
"15:" // Tail row loop: Column loop skip |
193 |
|
|
"cbz x20, 20f\n" |
194 |
|
|
"cmp x20, #0x4\n" |
195 |
|
|
"movi v16.4s, #0x0\n" |
196 |
|
|
"str q16, [x23, #0x0]\n" |
197 |
|
|
"str q16, [x23, #0x10]\n" |
198 |
|
|
"blt 17f\n" |
199 |
|
|
"16:" // Tail row loop: width 4 loop: loop |
200 |
|
|
"ldr q16, [x25], #0x10\n" |
201 |
|
|
"sub x20, x20, #0x4\n" |
202 |
|
|
"cmp x20, #0x4\n" |
203 |
|
|
"str q16, [x23, #0x0]\n" |
204 |
|
|
"add x23, x23, #0x10\n" |
205 |
|
|
"bge 16b\n" |
206 |
|
|
"17:" // Tail row loop: width 4 loop: skip |
207 |
|
|
"cmp x20, #0x1\n" |
208 |
|
|
"blt 19f\n" |
209 |
|
|
"18:" // Tail row loop: width 1 loop: loop |
210 |
|
|
"ldr s16, [x25], #0x4\n" |
211 |
|
|
"sub x20, x20, #0x1\n" |
212 |
|
|
"cmp x20, #0x1\n" |
213 |
|
|
"str s16, [x23, #0x0]\n" |
214 |
|
|
"add x23, x23, #0x4\n" |
215 |
|
|
"bge 18b\n" |
216 |
|
|
"19:" // Tail row loop: width 1 loop: skip |
217 |
|
|
"20:" // Tail row loop: odd col skip |
218 |
|
|
"cmp %x[height], #0x1\n" |
219 |
|
|
"add %x[out], %x[out], #0x20\n" |
220 |
|
|
"bge 13b\n" |
221 |
|
|
"21:" // Done |
222 |
|
|
: [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) |
223 |
|
17 |
: [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [width] "r"(width) |
224 |
|
|
: "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", |
225 |
|
|
"x25"); |
226 |
|
17 |
} |
227 |
|
|
|
228 |
|
|
#endif // Architectural features check. |
229 |
|
|
|