Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | // Do not flag up inline assembly blocks | ||
8 | #pragma GCC diagnostic ignored "-Woverlength-strings" | ||
9 | |||
10 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || \ | ||
11 | !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) | ||
12 | #error This file must be compiled for AArch64, FEAT_BF16, FEAT_FP16. | ||
13 | #else // Architectural features check. | ||
14 | |||
15 | #include "kai_lhs_pack_bf16p8x4_f16_neon.h" | ||
16 | |||
17 | #include <stddef.h> | ||
18 | #include <stdint.h> | ||
19 | |||
20 | #include "kai/kai_common.h" | ||
21 | |||
22 | static const size_t kai_mr = 8; | ||
23 | static const size_t kai_kr = 4; | ||
24 | static const size_t kai_sr = 1; | ||
25 | |||
26 | ✗ | size_t kai_get_m_step_lhs_pack_bf16p8x4_f16_neon(size_t mr) { | |
27 | − | KAI_ASSUME(mr == kai_mr); | |
28 | |||
29 | ✗ | return kai_mr; | |
30 | } | ||
31 | |||
32 | 176 | size_t kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon(size_t m_idx, size_t lhs_stride) { | |
33 | − | KAI_ASSUME(m_idx % (kai_mr) == 0); | |
34 | |||
35 | 176 | return m_idx * lhs_stride; | |
36 | } | ||
37 | |||
38 | ✗ | size_t kai_get_lhs_packed_offset_lhs_pack_bf16p8x4_f16_neon(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | |
39 | − | KAI_ASSUME(m_idx % kai_mr == 0); | |
40 | − | KAI_ASSUME(mr == kai_mr); | |
41 | − | KAI_ASSUME(kr == kai_kr); | |
42 | − | KAI_ASSUME(sr == kai_sr); | |
43 | |||
44 | ✗ | return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t); | |
45 | } | ||
46 | |||
47 | 176 | size_t kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
48 | − | KAI_ASSUME(mr == kai_mr); | |
49 | − | KAI_ASSUME(kr == kai_kr); | |
50 | − | KAI_ASSUME(sr == kai_sr); | |
51 | |||
52 | 176 | return kai_roundup(m, kai_mr) * kai_roundup(k, kai_kr) * sizeof(uint16_t); | |
53 | } | ||
54 | |||
55 | 176 | void kai_run_lhs_pack_bf16p8x4_f16_neon( | |
56 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
57 | void* lhs_packed) { | ||
58 | − | KAI_ASSUME(mr == kai_mr); | |
59 | − | KAI_ASSUME(kr == kai_kr); | |
60 | − | KAI_ASSUME(sr == kai_sr); | |
61 | − | KAI_ASSUME(lhs != NULL); | |
62 | − | KAI_ASSUME(lhs_packed != NULL); | |
63 | |||
64 | − | KAI_ASSUME(m_idx_start == 0); | |
65 | |||
66 | 176 | const size_t block_height = kai_mr; | |
67 | 176 | const size_t row_offset = 0; | |
68 | |||
69 | 176 | const void* in[kai_mr]; | |
70 | |||
71 |
2/2✓ Branch 0 taken 176 times.
✓ Branch 1 taken 1204 times.
|
1380 | for (size_t block_y = 0; block_y < m; block_y += block_height) { |
72 |
2/2✓ Branch 0 taken 130 times.
✓ Branch 1 taken 1074 times.
|
1204 | const size_t height = KAI_MIN(m - block_y, block_height); |
73 | 1204 | void* out = (char*)lhs_packed + block_y * kai_roundup(k, kai_kr) * sizeof(uint16_t); | |
74 | 1204 | size_t width = k; | |
75 | |||
76 |
2/2✓ Branch 0 taken 8966 times.
✓ Branch 1 taken 1204 times.
|
10170 | for (size_t y = 0; y < height; y++) { |
77 | 8966 | in[y] = (const char*)lhs + (block_y + y) * lhs_stride; | |
78 | 8966 | } | |
79 | |||
80 | 2408 | __asm__ __volatile__( | |
81 | "ldr x28, [%x[in], #0x0]\n" | ||
82 | "ldr x27, [%x[in], #0x8]\n" | ||
83 | "cmp %x[height], #0x8\n" | ||
84 | "ldr x26, [%x[in], #0x10]\n" | ||
85 | "ldr x25, [%x[in], #0x18]\n" | ||
86 | "ldr x24, [%x[in], #0x20]\n" | ||
87 | "ldr x23, [%x[in], #0x28]\n" | ||
88 | "ldr x22, [%x[in], #0x30]\n" | ||
89 | "ldr x21, [%x[in], #0x38]\n" | ||
90 | "add x28, x28, %x[row_offset], LSL #1\n" | ||
91 | "add x27, x27, %x[row_offset], LSL #1\n" | ||
92 | "add x26, x26, %x[row_offset], LSL #1\n" | ||
93 | "add x25, x25, %x[row_offset], LSL #1\n" | ||
94 | "add x24, x24, %x[row_offset], LSL #1\n" | ||
95 | "add x23, x23, %x[row_offset], LSL #1\n" | ||
96 | "add x22, x22, %x[row_offset], LSL #1\n" | ||
97 | "add x21, x21, %x[row_offset], LSL #1\n" | ||
98 | "beq 1f\n" | ||
99 | "cmp %x[height], #0x2\n" | ||
100 | "mov x21, x28\n" | ||
101 | "csel x27, x27, x28, GE\n" | ||
102 | "csel x26, x26, x28, GT\n" | ||
103 | "cmp %x[height], #0x4\n" | ||
104 | "csel x25, x25, x28, GE\n" | ||
105 | "csel x24, x24, x28, GT\n" | ||
106 | "cmp %x[height], #0x6\n" | ||
107 | "csel x23, x23, x28, GE\n" | ||
108 | "csel x22, x22, x28, GT\n" | ||
109 | "1:" // no_pointer_adj | ||
110 | "cmp %x[width], #0x8\n" | ||
111 | "prfm pldl1keep, [x28, #0x0]\n" | ||
112 | "prfm pldl1keep, [x27, #0x0]\n" | ||
113 | "prfm pldl1keep, [x26, #0x0]\n" | ||
114 | "prfm pldl1keep, [x25, #0x0]\n" | ||
115 | "prfm pldl1keep, [x24, #0x0]\n" | ||
116 | "prfm pldl1keep, [x23, #0x0]\n" | ||
117 | "prfm pldl1keep, [x22, #0x0]\n" | ||
118 | "prfm pldl1keep, [x21, #0x0]\n" | ||
119 | "prfm pldl1keep, [x28, #0x40]\n" | ||
120 | "prfm pldl1keep, [x27, #0x40]\n" | ||
121 | "prfm pldl1keep, [x26, #0x40]\n" | ||
122 | "prfm pldl1keep, [x25, #0x40]\n" | ||
123 | "prfm pldl1keep, [x24, #0x40]\n" | ||
124 | "prfm pldl1keep, [x23, #0x40]\n" | ||
125 | "prfm pldl1keep, [x22, #0x40]\n" | ||
126 | "prfm pldl1keep, [x21, #0x40]\n" | ||
127 | "blt 3f\n" | ||
128 | "2:" // Main loop head | ||
129 | "ldr q19, [x28], #0x10\n" | ||
130 | "ldr q18, [x26], #0x10\n" | ||
131 | "subs %x[width], %x[width], #0x8\n" | ||
132 | "ldr q17, [x24], #0x10\n" | ||
133 | "ldr q16, [x22], #0x10\n" | ||
134 | "cmp %x[width], #0x8\n" | ||
135 | "ldr q25, [x27], #0x10\n" | ||
136 | "ldr q24, [x25], #0x10\n" | ||
137 | "ldr q1, [x23], #0x10\n" | ||
138 | "ldr q0, [x21], #0x10\n" | ||
139 | "fcvtl v23.4s, v19.4h\n" | ||
140 | "fcvtl2 v22.4s, v19.8h\n" | ||
141 | "fcvtl v21.4s, v18.4h\n" | ||
142 | "fcvtl2 v20.4s, v18.8h\n" | ||
143 | "prfm pldl1keep, [x28, #0x70]\n" | ||
144 | "fcvtl v19.4s, v17.4h\n" | ||
145 | "fcvtl2 v18.4s, v17.8h\n" | ||
146 | "prfm pldl1keep, [x27, #0x70]\n" | ||
147 | "prfm pldl1keep, [x26, #0x70]\n" | ||
148 | "fcvtl v17.4s, v16.4h\n" | ||
149 | "fcvtl2 v16.4s, v16.8h\n" | ||
150 | "prfm pldl1keep, [x25, #0x70]\n" | ||
151 | "prfm pldl1keep, [x24, #0x70]\n" | ||
152 | ".inst 0x0ea16aff // bfcvtn v31.4h, v23.4s\n" | ||
153 | ".inst 0x0ea16ade // bfcvtn v30.4h, v22.4s\n" | ||
154 | "prfm pldl1keep, [x23, #0x70]\n" | ||
155 | "prfm pldl1keep, [x22, #0x70]\n" | ||
156 | "fcvtl v29.4s, v25.4h\n" | ||
157 | "fcvtl2 v28.4s, v25.8h\n" | ||
158 | "prfm pldl1keep, [x21, #0x70]\n" | ||
159 | ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" | ||
160 | ".inst 0x0ea16a9a // bfcvtn v26.4h, v20.4s\n" | ||
161 | "fcvtl v25.4s, v24.4h\n" | ||
162 | "fcvtl2 v24.4s, v24.8h\n" | ||
163 | ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" | ||
164 | ".inst 0x0ea16a56 // bfcvtn v22.4h, v18.4s\n" | ||
165 | "fcvtl v21.4s, v1.4h\n" | ||
166 | "fcvtl2 v20.4s, v1.8h\n" | ||
167 | ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" | ||
168 | ".inst 0x0ea16a12 // bfcvtn v18.4h, v16.4s\n" | ||
169 | "fcvtl v17.4s, v0.4h\n" | ||
170 | "fcvtl2 v16.4s, v0.8h\n" | ||
171 | ".inst 0x4ea16bbf // bfcvtn2 v31.8h, v29.4s\n" | ||
172 | ".inst 0x4ea16b9e // bfcvtn2 v30.8h, v28.4s\n" | ||
173 | ".inst 0x4ea16b3b // bfcvtn2 v27.8h, v25.4s\n" | ||
174 | ".inst 0x4ea16b1a // bfcvtn2 v26.8h, v24.4s\n" | ||
175 | ".inst 0x4ea16ab7 // bfcvtn2 v23.8h, v21.4s\n" | ||
176 | ".inst 0x4ea16a96 // bfcvtn2 v22.8h, v20.4s\n" | ||
177 | ".inst 0x4ea16a33 // bfcvtn2 v19.8h, v17.4s\n" | ||
178 | ".inst 0x4ea16a12 // bfcvtn2 v18.8h, v16.4s\n" | ||
179 | "str q31, [%x[out_ptr], #0x0]\n" | ||
180 | "str q27, [%x[out_ptr], #0x10]\n" | ||
181 | "str q23, [%x[out_ptr], #0x20]\n" | ||
182 | "str q19, [%x[out_ptr], #0x30]\n" | ||
183 | "str q30, [%x[out_ptr], #0x40]\n" | ||
184 | "str q26, [%x[out_ptr], #0x50]\n" | ||
185 | "str q22, [%x[out_ptr], #0x60]\n" | ||
186 | "str q18, [%x[out_ptr], #0x70]\n" | ||
187 | "add %x[out_ptr], %x[out_ptr], #0x80\n" | ||
188 | "bge 2b\n" | ||
189 | "3:" // Main loop skip | ||
190 | "cbz %x[width], 8f\n" | ||
191 | "tbz %x[width], #2, 5f\n" | ||
192 | "ldr d19, [x28], #0x8\n" | ||
193 | "ldr d25, [x27], #0x8\n" | ||
194 | "ldr d18, [x26], #0x8\n" | ||
195 | "ldr d24, [x25], #0x8\n" | ||
196 | "ldr d17, [x24], #0x8\n" | ||
197 | "ldr d1, [x23], #0x8\n" | ||
198 | "ldr d16, [x22], #0x8\n" | ||
199 | "ldr d0, [x21], #0x8\n" | ||
200 | "tbz %x[width], #1, 4f\n" | ||
201 | "ld1 { v19.s }[2], [x28], #0x4\n" | ||
202 | "ld1 { v25.s }[2], [x27], #0x4\n" | ||
203 | "mov x20, #0x2\n" | ||
204 | "ld1 { v18.s }[2], [x26], #0x4\n" | ||
205 | "ld1 { v24.s }[2], [x25], #0x4\n" | ||
206 | "ld1 { v17.s }[2], [x24], #0x4\n" | ||
207 | "ld1 { v1.s }[2], [x23], #0x4\n" | ||
208 | "ld1 { v16.s }[2], [x22], #0x4\n" | ||
209 | "ld1 { v0.s }[2], [x21], #0x4\n" | ||
210 | "tbz %x[width], #0, 7f\n" | ||
211 | "ld1 { v19.h }[6], [x28]\n" | ||
212 | "ld1 { v25.h }[6], [x27]\n" | ||
213 | "ld1 { v18.h }[6], [x26]\n" | ||
214 | "ld1 { v24.h }[6], [x25]\n" | ||
215 | "ld1 { v17.h }[6], [x24]\n" | ||
216 | "ld1 { v1.h }[6], [x23]\n" | ||
217 | "ld1 { v16.h }[6], [x22]\n" | ||
218 | "ld1 { v0.h }[6], [x21]\n" | ||
219 | "b 7f\n" | ||
220 | "4:" // odd_loads_1_4 | ||
221 | "mov x20, #0x1\n" | ||
222 | "tbz %x[width], #0, 7f\n" | ||
223 | "ld1 { v19.h }[4], [x28]\n" | ||
224 | "ld1 { v25.h }[4], [x27]\n" | ||
225 | "mov x20, #0x2\n" | ||
226 | "ld1 { v18.h }[4], [x26]\n" | ||
227 | "ld1 { v24.h }[4], [x25]\n" | ||
228 | "ld1 { v17.h }[4], [x24]\n" | ||
229 | "ld1 { v1.h }[4], [x23]\n" | ||
230 | "ld1 { v16.h }[4], [x22]\n" | ||
231 | "ld1 { v0.h }[4], [x21]\n" | ||
232 | "b 7f\n" | ||
233 | "5:" // odd_loads_2_0 | ||
234 | "tbz %x[width], #1, 6f\n" | ||
235 | "ldr s19, [x28], #0x4\n" | ||
236 | "ldr s25, [x27], #0x4\n" | ||
237 | "mov x20, #0x1\n" | ||
238 | "ldr s18, [x26], #0x4\n" | ||
239 | "ldr s24, [x25], #0x4\n" | ||
240 | "ldr s17, [x24], #0x4\n" | ||
241 | "ldr s1, [x23], #0x4\n" | ||
242 | "ldr s16, [x22], #0x4\n" | ||
243 | "ldr s0, [x21], #0x4\n" | ||
244 | "tbz %x[width], #0, 7f\n" | ||
245 | "ld1 { v19.h }[2], [x28]\n" | ||
246 | "ld1 { v25.h }[2], [x27]\n" | ||
247 | "ld1 { v18.h }[2], [x26]\n" | ||
248 | "ld1 { v24.h }[2], [x25]\n" | ||
249 | "ld1 { v17.h }[2], [x24]\n" | ||
250 | "ld1 { v1.h }[2], [x23]\n" | ||
251 | "ld1 { v16.h }[2], [x22]\n" | ||
252 | "ld1 { v0.h }[2], [x21]\n" | ||
253 | "b 7f\n" | ||
254 | "6:" // odd_loads_1_0 | ||
255 | "ldr h19, [x28, #0x0]\n" | ||
256 | "ldr h25, [x27, #0x0]\n" | ||
257 | "mov x20, #0x1\n" | ||
258 | "ldr h18, [x26, #0x0]\n" | ||
259 | "ldr h24, [x25, #0x0]\n" | ||
260 | "ldr h17, [x24, #0x0]\n" | ||
261 | "ldr h1, [x23, #0x0]\n" | ||
262 | "ldr h16, [x22, #0x0]\n" | ||
263 | "ldr h0, [x21, #0x0]\n" | ||
264 | "7:" // Odd load end | ||
265 | "fcvtl v23.4s, v19.4h\n" | ||
266 | "fcvtl2 v22.4s, v19.8h\n" | ||
267 | "subs x20, x20, #0x1\n" | ||
268 | "fcvtl v21.4s, v18.4h\n" | ||
269 | "fcvtl2 v20.4s, v18.8h\n" | ||
270 | "fcvtl v19.4s, v17.4h\n" | ||
271 | "fcvtl2 v18.4s, v17.8h\n" | ||
272 | "fcvtl v17.4s, v16.4h\n" | ||
273 | "fcvtl2 v16.4s, v16.8h\n" | ||
274 | ".inst 0x0ea16aff // bfcvtn v31.4h, v23.4s\n" | ||
275 | ".inst 0x0ea16ade // bfcvtn v30.4h, v22.4s\n" | ||
276 | "fcvtl v29.4s, v25.4h\n" | ||
277 | "fcvtl2 v28.4s, v25.8h\n" | ||
278 | ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" | ||
279 | ".inst 0x0ea16a9a // bfcvtn v26.4h, v20.4s\n" | ||
280 | "fcvtl v25.4s, v24.4h\n" | ||
281 | "fcvtl2 v24.4s, v24.8h\n" | ||
282 | ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" | ||
283 | ".inst 0x0ea16a56 // bfcvtn v22.4h, v18.4s\n" | ||
284 | "fcvtl v21.4s, v1.4h\n" | ||
285 | "fcvtl2 v20.4s, v1.8h\n" | ||
286 | ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" | ||
287 | ".inst 0x0ea16a12 // bfcvtn v18.4h, v16.4s\n" | ||
288 | "fcvtl v17.4s, v0.4h\n" | ||
289 | "fcvtl2 v16.4s, v0.8h\n" | ||
290 | ".inst 0x4ea16bbf // bfcvtn2 v31.8h, v29.4s\n" | ||
291 | ".inst 0x4ea16b9e // bfcvtn2 v30.8h, v28.4s\n" | ||
292 | ".inst 0x4ea16b3b // bfcvtn2 v27.8h, v25.4s\n" | ||
293 | ".inst 0x4ea16b1a // bfcvtn2 v26.8h, v24.4s\n" | ||
294 | ".inst 0x4ea16ab7 // bfcvtn2 v23.8h, v21.4s\n" | ||
295 | ".inst 0x4ea16a96 // bfcvtn2 v22.8h, v20.4s\n" | ||
296 | ".inst 0x4ea16a33 // bfcvtn2 v19.8h, v17.4s\n" | ||
297 | ".inst 0x4ea16a12 // bfcvtn2 v18.8h, v16.4s\n" | ||
298 | "str q31, [%x[out_ptr], #0x0]\n" | ||
299 | "str q27, [%x[out_ptr], #0x10]\n" | ||
300 | "str q23, [%x[out_ptr], #0x20]\n" | ||
301 | "str q19, [%x[out_ptr], #0x30]\n" | ||
302 | "add %x[out_ptr], %x[out_ptr], #0x40\n" | ||
303 | "beq 8f\n" | ||
304 | "str q30, [%x[out_ptr], #0x0]\n" | ||
305 | "str q26, [%x[out_ptr], #0x10]\n" | ||
306 | "str q22, [%x[out_ptr], #0x20]\n" | ||
307 | "str q18, [%x[out_ptr], #0x30]\n" | ||
308 | "add %x[out_ptr], %x[out_ptr], #0x40\n" | ||
309 | "8:" // Odds skip | ||
310 | : [out_ptr] "+&r"(out), [width] "+&r"(width) | ||
311 | 1204 | : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset) | |
312 | : "cc", "memory", "v0", "v1", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", | ||
313 | "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); | ||
314 | 1204 | } | |
315 | 176 | } | |
316 | |||
317 | #endif // Architectural features check. | ||
318 |