Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | // Do not flag up inline assembly blocks | ||
8 | #pragma GCC diagnostic ignored "-Woverlength-strings" | ||
9 | |||
10 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) | ||
11 | #error This file must be compiled for AArch64, FEAT_SVE2. | ||
12 | #else // Architectural features check. | ||
13 | |||
14 | #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" | ||
15 | |||
16 | #include <stddef.h> | ||
17 | #include <stdint.h> | ||
18 | |||
19 | #include "kai/kai_common.h" | ||
20 | |||
21 | static const size_t kai_mr = 2; | ||
22 | static const size_t kai_kr = 2; | ||
23 | static const size_t kai_sr = 1; | ||
24 | |||
25 | 184 | static size_t kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme(void) { | |
26 | 184 | return kai_mr * kai_get_sme_vector_length_u16() / kai_kr; | |
27 | } | ||
28 | |||
29 | ✗ | size_t kai_get_m_step_lhs_pack_bf16p2vlx2_f32_sme(size_t mr) { | |
30 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()); | |
31 | ✗ | KAI_UNUSED(mr); | |
32 | |||
33 | ✗ | return kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme(); | |
34 | } | ||
35 | |||
36 | 46 | size_t kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme(size_t m_idx, size_t lhs_stride) { | |
37 | − | KAI_ASSUME(m_idx % kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme() == 0); | |
38 | |||
39 | 46 | return m_idx * lhs_stride; | |
40 | } | ||
41 | |||
42 | ✗ | size_t kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | |
43 | − | KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_bf16p2vlx2_f32_sme(mr) == 0); | |
44 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()); | |
45 | − | KAI_ASSUME(kr == kai_kr); | |
46 | − | KAI_ASSUME(sr == kai_sr); | |
47 | |||
48 | ✗ | KAI_UNUSED(mr); | |
49 | ✗ | KAI_UNUSED(kr); | |
50 | ✗ | KAI_UNUSED(sr); | |
51 | |||
52 | ✗ | return m_idx * kai_roundup(k, kr) * sizeof(uint16_t); | |
53 | } | ||
54 | |||
55 | 46 | size_t kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
56 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()); | |
57 | − | KAI_ASSUME(kr == kai_kr); | |
58 | − | KAI_ASSUME(sr == kai_sr); | |
59 | |||
60 | 46 | KAI_UNUSED(mr); | |
61 | 46 | KAI_UNUSED(kr); | |
62 | 46 | KAI_UNUSED(sr); | |
63 | |||
64 | 46 | return kai_roundup(m, kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()) * kai_roundup(k, kai_kr) * sizeof(uint16_t); | |
65 | } | ||
66 | |||
67 | 46 | void kai_run_lhs_pack_bf16p2vlx2_f32_sme( | |
68 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
69 | void* lhs_packed) { | ||
70 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()); | |
71 | − | KAI_ASSUME(kr == kai_kr); | |
72 | − | KAI_ASSUME(sr == kai_sr); | |
73 | − | KAI_ASSUME(lhs != NULL); | |
74 | − | KAI_ASSUME(lhs_packed != NULL); | |
75 | |||
76 | − | KAI_ASSUME(m_idx_start == 0); | |
77 | |||
78 | 46 | const size_t block_height = mr; | |
79 | 46 | const size_t width = k; | |
80 | 46 | const size_t row_offset = 0; | |
81 | |||
82 | 46 | const void* in[block_height]; | |
83 | |||
84 |
2/2✓ Branch 0 taken 46 times.
✓ Branch 1 taken 145 times.
|
191 | for (size_t block_y = 0; block_y < m; block_y += block_height) { |
85 |
2/2✓ Branch 0 taken 37 times.
✓ Branch 1 taken 108 times.
|
145 | const size_t height = KAI_MIN(m - block_y, block_height); |
86 | 145 | void* out = (char*)lhs_packed + (block_y * kai_roundup(k, kai_kr) * sizeof(uint16_t)); | |
87 | |||
88 |
2/2✓ Branch 0 taken 3794 times.
✓ Branch 1 taken 145 times.
|
3939 | for (size_t y = 0; y < height; y++) { |
89 | 3794 | in[y] = (const void*)((const char*)lhs + (block_y + y) * lhs_stride); | |
90 | 3794 | } | |
91 | |||
92 | 290 | __asm__ __volatile__( | |
93 | ".inst 0xd503477f // SMSTART ZA\n" | ||
94 | "sub x10, %x[width], #0x1\n" | ||
95 | "mov x9, #0x0\n" | ||
96 | "cntw x22, ALL, MUL #2\n" | ||
97 | "cntw x28\n" | ||
98 | "cntw x21, ALL, MUL #2\n" | ||
99 | "sub x20, x22, #0x1\n" | ||
100 | ".inst 0x25207815 // ptrue pn13.b\n" | ||
101 | "whilelt p12.s, XZR, %x[height]\n" | ||
102 | "whilelt p11.s, x28, %x[height]\n" | ||
103 | "add x10, x10, x21\n" | ||
104 | "ands x27, %x[width], x20\n" | ||
105 | "udiv x10, x10, x21\n" | ||
106 | "csel x27, x27, x22, NE\n" | ||
107 | "and x26, x10, #0x1\n" | ||
108 | "sub x10, x10, #0x1\n" | ||
109 | "add x27, x27, #0x1\n" | ||
110 | "mov x20, %x[width]\n" | ||
111 | "mov x25, %x[in]\n" | ||
112 | "ptrue p0.b\n" | ||
113 | "mov x24, %x[outptr_raw]\n" | ||
114 | "mov x23, %x[row_offset]\n" | ||
115 | "lsr x10, x10, #0x1\n" | ||
116 | "lsr x27, x27, #0x1\n" | ||
117 | "mov x12, #0x0\n" | ||
118 | ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n" | ||
119 | "add x22, x25, x28, LSL #3\n" | ||
120 | "1:" // Width loop: Preamble: Loop | ||
121 | "ldr x21, [x25], #0x8\n" | ||
122 | ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n" | ||
123 | ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n" | ||
124 | "ldr x20, [x22], #0x8\n" | ||
125 | ".inst 0xa01746b4 // ld1w { z20.s-z21.s }, pn9.s/Z, [x21, x23, LSL #2]\n" | ||
126 | ".inst 0xa017428c // ld1w { z12.s-z13.s }, pn8.s/Z, [x20, x23, LSL #2]\n" | ||
127 | ".inst 0xc160e294 // bfcvt z20.h, { z20.s-z21.s }\n" | ||
128 | ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n" | ||
129 | ".inst 0xc0800280 // mova za0h.s[x12], p0/M, z20.s\n" | ||
130 | ".inst 0xc0800184 // mova za1h.s[x12], p0/M, z12.s\n" | ||
131 | "add x12, x12, #0x1\n" | ||
132 | "cmp x12, x28\n" | ||
133 | "blt 1b\n" | ||
134 | "incw x23, ALL, MUL #2\n" | ||
135 | "incw x9, ALL, MUL #2\n" | ||
136 | "cbz x10, 5f\n" | ||
137 | "2:" // Width loop | ||
138 | "mov x20, %x[width]\n" | ||
139 | "mov x25, %x[in]\n" | ||
140 | "mov x12, #0x0\n" | ||
141 | ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n" | ||
142 | "add x22, x25, x28, LSL #3\n" | ||
143 | "3:" // Width loop: Odd: Loop | ||
144 | "ldr x21, [x25], #0x8\n" | ||
145 | ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n" | ||
146 | ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n" | ||
147 | ".inst 0xc0828007 // mova z7.s, p0/M, za0v.s[x12]\n" | ||
148 | "ldr x20, [x22], #0x8\n" | ||
149 | ".inst 0xc082808f // mova z15.s, p0/M, za1v.s[x12]\n" | ||
150 | ".inst 0xa01746b6 // ld1w { z22.s-z23.s }, pn9.s/Z, [x21, x23, LSL #2]\n" | ||
151 | ".inst 0xa017429a // ld1w { z26.s-z27.s }, pn8.s/Z, [x20, x23, LSL #2]\n" | ||
152 | ".inst 0xa1605707 // st1w { z7.s, z15.s }, pn13.b, [x24]\n" | ||
153 | "addvl x24, x24, #2\n" | ||
154 | ".inst 0xc160e2d6 // bfcvt z22.h, { z22.s-z23.s }\n" | ||
155 | ".inst 0xc160e35a // bfcvt z26.h, { z26.s-z27.s }\n" | ||
156 | ".inst 0xc08002c8 // mova za2h.s[x12], p0/M, z22.s\n" | ||
157 | ".inst 0xc080034c // mova za3h.s[x12], p0/M, z26.s\n" | ||
158 | "add x12, x12, #0x1\n" | ||
159 | "cmp x12, x28\n" | ||
160 | "blt 3b\n" | ||
161 | "incw x9, ALL, MUL #2\n" | ||
162 | "mov x20, %x[width]\n" | ||
163 | "mov x25, %x[in]\n" | ||
164 | "incw x23, ALL, MUL #2\n" | ||
165 | "mov x12, #0x0\n" | ||
166 | ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n" | ||
167 | "add x22, x25, x28, LSL #3\n" | ||
168 | "4:" // Width loop: Even: Loop | ||
169 | "ldr x21, [x25], #0x8\n" | ||
170 | ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n" | ||
171 | ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n" | ||
172 | ".inst 0xc0828108 // mova z8.s, p0/M, za2v.s[x12]\n" | ||
173 | "ldr x20, [x22], #0x8\n" | ||
174 | ".inst 0xc0828189 // mova z9.s, p0/M, za3v.s[x12]\n" | ||
175 | ".inst 0xa01746ae // ld1w { z14.s-z15.s }, pn9.s/Z, [x21, x23, LSL #2]\n" | ||
176 | ".inst 0xa017428c // ld1w { z12.s-z13.s }, pn8.s/Z, [x20, x23, LSL #2]\n" | ||
177 | ".inst 0xa0605708 // st1w { z8.s-z9.s }, pn13.b, [x24]\n" | ||
178 | "addvl x24, x24, #2\n" | ||
179 | ".inst 0xc160e1ce // bfcvt z14.h, { z14.s-z15.s }\n" | ||
180 | ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n" | ||
181 | ".inst 0xc08001c0 // mova za0h.s[x12], p0/M, z14.s\n" | ||
182 | ".inst 0xc0800184 // mova za1h.s[x12], p0/M, z12.s\n" | ||
183 | "add x12, x12, #0x1\n" | ||
184 | "cmp x12, x28\n" | ||
185 | "blt 4b\n" | ||
186 | "subs x10, x10, #0x1\n" | ||
187 | "incw x23, ALL, MUL #2\n" | ||
188 | "incw x9, ALL, MUL #2\n" | ||
189 | "bgt 2b\n" | ||
190 | "5:" // Width loop: Tails | ||
191 | "cbnz x26, 8f\n" | ||
192 | "mov x20, %x[width]\n" | ||
193 | "mov x25, %x[in]\n" | ||
194 | "mov x12, #0x0\n" | ||
195 | ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n" | ||
196 | "add x22, x25, x28, LSL #3\n" | ||
197 | "6:" // Width loop: Tails: Even: Odd: Loop | ||
198 | "ldr x21, [x25], #0x8\n" | ||
199 | ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n" | ||
200 | ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n" | ||
201 | ".inst 0xc0828003 // mova z3.s, p0/M, za0v.s[x12]\n" | ||
202 | "ldr x20, [x22], #0x8\n" | ||
203 | ".inst 0xc082808b // mova z11.s, p0/M, za1v.s[x12]\n" | ||
204 | ".inst 0xa01746ac // ld1w { z12.s-z13.s }, pn9.s/Z, [x21, x23, LSL #2]\n" | ||
205 | ".inst 0xa017428e // ld1w { z14.s-z15.s }, pn8.s/Z, [x20, x23, LSL #2]\n" | ||
206 | ".inst 0xa1605703 // st1w { z3.s, z11.s }, pn13.b, [x24]\n" | ||
207 | "addvl x24, x24, #2\n" | ||
208 | ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n" | ||
209 | ".inst 0xc160e1ce // bfcvt z14.h, { z14.s-z15.s }\n" | ||
210 | ".inst 0xc0800188 // mova za2h.s[x12], p0/M, z12.s\n" | ||
211 | ".inst 0xc08001cc // mova za3h.s[x12], p0/M, z14.s\n" | ||
212 | "add x12, x12, #0x1\n" | ||
213 | "cmp x12, x28\n" | ||
214 | "blt 6b\n" | ||
215 | "mov x12, #0x0\n" | ||
216 | "7:" // Width loop: Tails: Even: Even: Loop | ||
217 | ".inst 0xc082810e // mova z14.s, p0/M, za2v.s[x12]\n" | ||
218 | ".inst 0xc082818f // mova z15.s, p0/M, za3v.s[x12]\n" | ||
219 | "add x12, x12, #0x1\n" | ||
220 | "cmp x12, x27\n" | ||
221 | ".inst 0xa060570e // st1w { z14.s-z15.s }, pn13.b, [x24]\n" | ||
222 | "addvl x24, x24, #2\n" | ||
223 | "blt 7b\n" | ||
224 | "b 10f\n" | ||
225 | "8:" // Width loop: Tails: Odd | ||
226 | "mov x12, #0x0\n" | ||
227 | "9:" // Width loop: Tails: Odd: Loop | ||
228 | ".inst 0xc0828014 // mova z20.s, p0/M, za0v.s[x12]\n" | ||
229 | ".inst 0xc0828095 // mova z21.s, p0/M, za1v.s[x12]\n" | ||
230 | "add x12, x12, #0x1\n" | ||
231 | "cmp x12, x27\n" | ||
232 | ".inst 0xa0605714 // st1w { z20.s-z21.s }, pn13.b, [x24]\n" | ||
233 | "addvl x24, x24, #2\n" | ||
234 | "blt 9b\n" | ||
235 | "10:" // End | ||
236 | "mov %x[outptr_raw], x24\n" | ||
237 | ".inst 0xd503467f // SMSTOP\n" | ||
238 | : [outptr_raw] "+&r"(out) | ||
239 | 145 | : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width) | |
240 | : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", | ||
241 | "p8", "p9", "x10", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", | ||
242 | "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", | ||
243 | "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); | ||
244 | 145 | } | |
245 | 46 | } | |
246 | |||
247 | #endif // Architectural features check. | ||
248 |