Line |
Branch |
Exec |
Source |
1 |
|
|
// |
2 |
|
|
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> |
3 |
|
|
// |
4 |
|
|
// SPDX-License-Identifier: Apache-2.0 |
5 |
|
|
// |
6 |
|
|
|
7 |
|
|
#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) |
8 |
|
|
#error This file must be compiled for AArch64, FEAT_SVE2. |
9 |
|
|
#else // Architectural features check. |
10 |
|
|
|
11 |
|
|
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" |
12 |
|
|
|
13 |
|
|
#include <stddef.h> |
14 |
|
|
#include <stdint.h> |
15 |
|
|
#include <string.h> |
16 |
|
|
|
17 |
|
|
#include "kai/kai_common.h" |
18 |
|
|
|
19 |
|
|
static const size_t kai_nr = 2; |
20 |
|
|
static const size_t kai_kr = 2; |
21 |
|
|
static const size_t kai_num_bytes_input = 4; |
22 |
|
|
static const size_t kai_num_bytes_output = 2; |
23 |
|
|
static const size_t kai_num_bytes_bias = 4; |
24 |
|
|
|
25 |
|
382 |
size_t kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(void) { |
26 |
|
382 |
return kai_nr * kai_get_sme_vector_length_u16() / kai_kr; |
27 |
|
|
} |
28 |
|
|
|
29 |
|
46 |
size_t kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx) { |
30 |
|
− |
KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() == 0); |
31 |
|
|
|
32 |
|
46 |
return n_idx * kai_num_bytes_input; |
33 |
|
|
} |
34 |
|
|
|
35 |
|
✗ |
size_t kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx) { |
36 |
|
✗ |
return n_idx * kai_num_bytes_bias; |
37 |
|
|
} |
38 |
|
|
|
39 |
|
112 |
size_t kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t k) { |
40 |
|
224 |
return kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() * |
41 |
|
112 |
(kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output); |
42 |
|
|
} |
43 |
|
|
|
44 |
|
56 |
size_t kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx, size_t k) { |
45 |
|
− |
KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() == 0); |
46 |
|
|
|
47 |
|
56 |
const size_t block_idx = n_idx / kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(); |
48 |
|
112 |
return block_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(k); |
49 |
|
56 |
} |
50 |
|
|
|
51 |
|
56 |
size_t kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n, size_t k) { |
52 |
|
56 |
const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme()); |
53 |
|
112 |
return kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(n_rounded_up, k); |
54 |
|
56 |
} |
55 |
|
|
|
56 |
|
56 |
void kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme( |
57 |
|
|
size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, |
58 |
|
|
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { |
59 |
|
− |
KAI_ASSUME(num_groups == 1); |
60 |
|
− |
KAI_ASSUME(nr == kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme()); |
61 |
|
− |
KAI_ASSUME(kr == kai_kr); |
62 |
|
− |
KAI_ASSUME(sr == 1); |
63 |
|
− |
KAI_ASSUME(rhs != NULL); |
64 |
|
− |
KAI_ASSUME(bias != NULL); |
65 |
|
− |
KAI_ASSUME(scale == NULL); |
66 |
|
− |
KAI_ASSUME(rhs_packed != NULL); |
67 |
|
− |
KAI_ASSUME(extra_bytes == 0); |
68 |
|
− |
KAI_ASSUME(params == NULL); |
69 |
|
|
|
70 |
|
56 |
size_t height = k; |
71 |
|
56 |
const size_t width = n; |
72 |
|
56 |
const void* in = rhs; |
73 |
|
56 |
void* out = rhs_packed; |
74 |
|
56 |
const size_t in_stride = rhs_stride; |
75 |
|
56 |
const float* pad_row = rhs; |
76 |
|
|
|
77 |
|
56 |
size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(height); |
78 |
|
|
|
79 |
|
112 |
__asm__ __volatile__( |
80 |
|
|
".inst 0xd503477f // SMSTART ZA\n" |
81 |
|
|
"mov x22, %x[out]\n" |
82 |
|
|
"mov x21, %x[width]\n" |
83 |
|
|
"ptrue p2.b\n" |
84 |
|
|
"1:" // Bias: Full loop |
85 |
|
|
"mov x20, x21\n" |
86 |
|
|
"decw x21, ALL, MUL #2\n" |
87 |
|
|
"whilelt p1.s, XZR, x20\n" |
88 |
|
|
"decw x20\n" |
89 |
|
|
"whilelt p0.s, XZR, x20\n" |
90 |
|
|
"cmp x21, #0x0\n" |
91 |
|
|
"ld1w { z17.s }, p1/Z, [%x[bias]]\n" |
92 |
|
|
"ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" |
93 |
|
|
"incb %x[bias], ALL, MUL #2\n" |
94 |
|
|
"st1w { z17.s }, p2, [x22]\n" |
95 |
|
|
"st1w { z16.s }, p2, [x22, #1, MUL VL]\n" |
96 |
|
|
"add x22, x22, %x[out_stride]\n" |
97 |
|
|
"bgt 1b\n" |
98 |
|
|
"cmp %x[height], #0x8\n" |
99 |
|
|
"incb %x[out], ALL, MUL #2\n" |
100 |
|
|
"blt 5f\n" |
101 |
|
|
"2:" // Main row loop: Head |
102 |
|
|
"mov x10, %x[in]\n" |
103 |
|
|
"mov x9, %x[out]\n" |
104 |
|
|
"add x28, x10, %x[in_stride]\n" |
105 |
|
|
"sub %x[height], %x[height], #0x8\n" |
106 |
|
|
"add x27, x28, %x[in_stride]\n" |
107 |
|
|
"mov x26, %x[width]\n" |
108 |
|
|
"add x25, x27, %x[in_stride]\n" |
109 |
|
|
"add x24, x25, %x[in_stride]\n" |
110 |
|
|
"add x23, x24, %x[in_stride]\n" |
111 |
|
|
"add x22, x23, %x[in_stride]\n" |
112 |
|
|
"add x21, x22, %x[in_stride]\n" |
113 |
|
|
"add %x[in], x21, %x[in_stride]\n" |
114 |
|
|
"3:" // Main row loop: Column loop |
115 |
|
|
"mov x20, x26\n" |
116 |
|
|
"decw x26, ALL, MUL #2\n" |
117 |
|
|
"whilelt p1.s, XZR, x20\n" |
118 |
|
|
"decw x20\n" |
119 |
|
|
"whilelt p0.s, XZR, x20\n" |
120 |
|
|
"ld1w { z19.s }, p1/Z, [x10]\n" |
121 |
|
|
"cmp x26, #0x0\n" |
122 |
|
|
"ld1w { z18.s }, p0/Z, [x10, #1, MUL VL]\n" |
123 |
|
|
"addvl x10, x10, #2\n" |
124 |
|
|
"ld1w { z17.s }, p1/Z, [x27]\n" |
125 |
|
|
"ld1w { z16.s }, p0/Z, [x27, #1, MUL VL]\n" |
126 |
|
|
".inst 0x658aaa7b // bfcvt z27.h, p2/M, z19.s\n" |
127 |
|
|
"addvl x27, x27, #2\n" |
128 |
|
|
"ld1w { z19.s }, p1/Z, [x24]\n" |
129 |
|
|
".inst 0x658aaa5a // bfcvt z26.h, p2/M, z18.s\n" |
130 |
|
|
"ld1w { z18.s }, p0/Z, [x24, #1, MUL VL]\n" |
131 |
|
|
".inst 0x658aaa39 // bfcvt z25.h, p2/M, z17.s\n" |
132 |
|
|
"addvl x24, x24, #2\n" |
133 |
|
|
"ld1w { z17.s }, p1/Z, [x22]\n" |
134 |
|
|
".inst 0x658aaa18 // bfcvt z24.h, p2/M, z16.s\n" |
135 |
|
|
"ld1w { z16.s }, p0/Z, [x22, #1, MUL VL]\n" |
136 |
|
|
".inst 0x658aaa77 // bfcvt z23.h, p2/M, z19.s\n" |
137 |
|
|
"addvl x22, x22, #2\n" |
138 |
|
|
".inst 0x658aaa56 // bfcvt z22.h, p2/M, z18.s\n" |
139 |
|
|
"ld1w { z19.s }, p1/Z, [x28]\n" |
140 |
|
|
".inst 0x658aaa35 // bfcvt z21.h, p2/M, z17.s\n" |
141 |
|
|
"ld1w { z18.s }, p0/Z, [x28, #1, MUL VL]\n" |
142 |
|
|
"addvl x28, x28, #2\n" |
143 |
|
|
".inst 0x658aaa14 // bfcvt z20.h, p2/M, z16.s\n" |
144 |
|
|
"ld1w { z17.s }, p1/Z, [x25]\n" |
145 |
|
|
"ld1w { z16.s }, p0/Z, [x25, #1, MUL VL]\n" |
146 |
|
|
"addvl x25, x25, #2\n" |
147 |
|
|
".inst 0x648aaa7b // bfcvtnt z27.h, p2/M, z19.s\n" |
148 |
|
|
"ld1w { z19.s }, p1/Z, [x23]\n" |
149 |
|
|
".inst 0x648aaa5a // bfcvtnt z26.h, p2/M, z18.s\n" |
150 |
|
|
"ld1w { z18.s }, p0/Z, [x23, #1, MUL VL]\n" |
151 |
|
|
"addvl x23, x23, #2\n" |
152 |
|
|
".inst 0x648aaa39 // bfcvtnt z25.h, p2/M, z17.s\n" |
153 |
|
|
"ld1w { z17.s }, p1/Z, [x21]\n" |
154 |
|
|
".inst 0x648aaa18 // bfcvtnt z24.h, p2/M, z16.s\n" |
155 |
|
|
"ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n" |
156 |
|
|
"addvl x21, x21, #2\n" |
157 |
|
|
".inst 0x648aaa77 // bfcvtnt z23.h, p2/M, z19.s\n" |
158 |
|
|
"st1h { z27.h }, p2, [x9]\n" |
159 |
|
|
".inst 0x648aaa56 // bfcvtnt z22.h, p2/M, z18.s\n" |
160 |
|
|
"st1h { z26.h }, p2, [x9, #1, MUL VL]\n" |
161 |
|
|
".inst 0x648aaa35 // bfcvtnt z21.h, p2/M, z17.s\n" |
162 |
|
|
"st1h { z25.h }, p2, [x9, #2, MUL VL]\n" |
163 |
|
|
".inst 0x648aaa14 // bfcvtnt z20.h, p2/M, z16.s\n" |
164 |
|
|
"st1h { z24.h }, p2, [x9, #3, MUL VL]\n" |
165 |
|
|
"st1h { z23.h }, p2, [x9, #4, MUL VL]\n" |
166 |
|
|
"st1h { z22.h }, p2, [x9, #5, MUL VL]\n" |
167 |
|
|
"st1h { z21.h }, p2, [x9, #6, MUL VL]\n" |
168 |
|
|
"st1h { z20.h }, p2, [x9, #7, MUL VL]\n" |
169 |
|
|
"add x9, x9, %x[out_stride]\n" |
170 |
|
|
"bgt 3b\n" |
171 |
|
|
"cmp %x[height], #0x8\n" |
172 |
|
|
"addvl %x[out], %x[out], #8\n" |
173 |
|
|
"bge 2b\n" |
174 |
|
|
"cbz %x[height], 9f\n" |
175 |
|
|
"5:" // Main loop skip |
176 |
|
|
"6:" // Tail row loop: Head |
177 |
|
|
"mov x10, %x[in]\n" |
178 |
|
|
"cmp %x[height], #0x1\n" |
179 |
|
|
"add x28, x10, %x[in_stride]\n" |
180 |
|
|
"mov x9, %x[out]\n" |
181 |
|
|
"add %x[in], x28, %x[in_stride]\n" |
182 |
|
|
"csel x28, x28, %x[pad_row], GT\n" |
183 |
|
|
"sub %x[height], %x[height], #0x2\n" |
184 |
|
|
"mov x21, %x[width]\n" |
185 |
|
|
"7:" // Tail row loop: Column loop |
186 |
|
|
"mov x20, x21\n" |
187 |
|
|
"decw x21, ALL, MUL #2\n" |
188 |
|
|
"whilelt p1.s, XZR, x20\n" |
189 |
|
|
"decw x20\n" |
190 |
|
|
"whilelt p0.s, XZR, x20\n" |
191 |
|
|
"ld1w { z17.s }, p1/Z, [x10]\n" |
192 |
|
|
"cmp x21, #0x0\n" |
193 |
|
|
"ld1w { z16.s }, p0/Z, [x10, #1, MUL VL]\n" |
194 |
|
|
"addvl x10, x10, #2\n" |
195 |
|
|
"ld1w { z19.s }, p1/Z, [x28]\n" |
196 |
|
|
".inst 0x658aaa32 // bfcvt z18.h, p2/M, z17.s\n" |
197 |
|
|
"ld1w { z17.s }, p0/Z, [x28, #1, MUL VL]\n" |
198 |
|
|
"addvl x28, x28, #2\n" |
199 |
|
|
".inst 0x658aaa10 // bfcvt z16.h, p2/M, z16.s\n" |
200 |
|
|
".inst 0x648aaa72 // bfcvtnt z18.h, p2/M, z19.s\n" |
201 |
|
|
".inst 0x648aaa30 // bfcvtnt z16.h, p2/M, z17.s\n" |
202 |
|
|
"st1h { z18.h }, p2, [x9]\n" |
203 |
|
|
"st1h { z16.h }, p2, [x9, #1, MUL VL]\n" |
204 |
|
|
"add x9, x9, %x[out_stride]\n" |
205 |
|
|
"bgt 7b\n" |
206 |
|
|
"cmp %x[height], #0x1\n" |
207 |
|
|
"addvl %x[out], %x[out], #2\n" |
208 |
|
|
"bge 6b\n" |
209 |
|
|
"9:" // Done |
210 |
|
|
".inst 0xd503467f // SMSTOP\n" |
211 |
|
|
: [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) |
212 |
|
56 |
: [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) |
213 |
|
|
: "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", |
214 |
|
|
"p8", "p9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", "z10", |
215 |
|
|
"z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", |
216 |
|
|
"z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); |
217 |
|
56 |
} |
218 |
|
|
|
219 |
|
|
#endif // Architectural features check. |
220 |
|
|
|