Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) | ||
8 | #error This file must be compiled for AArch64, FEAT_BF16. | ||
9 | #else // Architectural features check. | ||
10 | |||
11 | #define MAX_MR 8 | ||
12 | |||
13 | #include "kai_lhs_quant_pack_bf16p8x4_f32_neon.h" | ||
14 | |||
15 | #include <arm_neon.h> | ||
16 | #include <stddef.h> | ||
17 | #include <stdint.h> | ||
18 | |||
19 | #include "kai/kai_common.h" | ||
20 | |||
21 | static const size_t kai_mr = 8; | ||
22 | static const size_t kai_kr = 4; | ||
23 | static const size_t kai_sr = 1; | ||
24 | |||
25 | ✗ | size_t kai_get_m_step_lhs_quant_pack_bf16p8x4_f32_neon(size_t mr) { | |
26 | − | KAI_ASSUME(mr == kai_mr); | |
27 | ✗ | return kai_mr; | |
28 | } | ||
29 | |||
30 | 92 | size_t kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon(size_t m_idx, size_t lhs_stride) { | |
31 | 92 | return m_idx * lhs_stride; | |
32 | } | ||
33 | |||
34 | ✗ | size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p8x4_f32_neon( | |
35 | size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | ||
36 | ✗ | KAI_UNUSED(sr); | |
37 | − | KAI_ASSUME(mr == kai_mr); | |
38 | − | KAI_ASSUME(kr == kai_kr); | |
39 | − | KAI_ASSUME(sr == kai_sr); | |
40 | − | KAI_ASSUME(m_idx % mr == 0); | |
41 | |||
42 | ✗ | return m_idx * kai_roundup(k, kr) * sizeof(uint16_t); | |
43 | } | ||
44 | |||
45 | 92 | size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
46 | 92 | KAI_UNUSED(sr); | |
47 | − | KAI_ASSUME(mr == kai_mr); | |
48 | − | KAI_ASSUME(kr == kai_kr); | |
49 | − | KAI_ASSUME(sr == kai_sr); | |
50 | |||
51 | 92 | return kai_roundup(m, mr) * kai_roundup(k, kr) * sizeof(uint16_t); | |
52 | } | ||
53 | |||
54 | 92 | void kai_run_lhs_quant_pack_bf16p8x4_f32_neon( | |
55 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
56 | void* lhs_packed) { | ||
57 | − | KAI_ASSUME(mr == kai_mr); | |
58 | − | KAI_ASSUME(sr == kai_sr); | |
59 | − | KAI_ASSUME(kr == kai_kr); | |
60 | 92 | KAI_UNUSED(sr); | |
61 | − | KAI_ASSUME(lhs != NULL); | |
62 | − | KAI_ASSUME(lhs_packed != NULL); | |
63 | |||
64 | − | KAI_ASSUME(m_idx_start == 0); | |
65 | − | KAI_ASSUME(mr <= MAX_MR); | |
66 | |||
67 | 92 | const size_t block_height = mr; | |
68 | 92 | const size_t row_offset = 0; | |
69 | |||
70 | 92 | const void* in[MAX_MR]; | |
71 | |||
72 |
2/2✓ Branch 0 taken 92 times.
✓ Branch 1 taken 962 times.
|
1054 | for (size_t block_y = 0; block_y < m; block_y += block_height) { |
73 |
2/2✓ Branch 0 taken 64 times.
✓ Branch 1 taken 898 times.
|
962 | const size_t height = KAI_MIN(m - block_y, block_height); |
74 | 962 | void* out = (char*)lhs_packed + block_y * kai_roundup(k, kr) * sizeof(uint16_t); | |
75 | 962 | size_t width = k; | |
76 | |||
77 |
2/2✓ Branch 0 taken 7366 times.
✓ Branch 1 taken 962 times.
|
8328 | for (size_t y = 0; y < height; y++) { |
78 | 7366 | in[y] = (const char*)lhs + (block_y + y) * lhs_stride; | |
79 | 7366 | } | |
80 | |||
81 | 1924 | __asm__ __volatile__( | |
82 | "ldr x28, [%x[in], #0x0]\n" | ||
83 | "ldr x27, [%x[in], #0x8]\n" | ||
84 | "cmp %x[height], #0x8\n" | ||
85 | "ldr x26, [%x[in], #0x10]\n" | ||
86 | "ldr x25, [%x[in], #0x18]\n" | ||
87 | "ldr x24, [%x[in], #0x20]\n" | ||
88 | "ldr x23, [%x[in], #0x28]\n" | ||
89 | "ldr x22, [%x[in], #0x30]\n" | ||
90 | "ldr x21, [%x[in], #0x38]\n" | ||
91 | "add x28, x28, %x[row_offset], LSL #2\n" | ||
92 | "add x27, x27, %x[row_offset], LSL #2\n" | ||
93 | "add x26, x26, %x[row_offset], LSL #2\n" | ||
94 | "add x25, x25, %x[row_offset], LSL #2\n" | ||
95 | "add x24, x24, %x[row_offset], LSL #2\n" | ||
96 | "add x23, x23, %x[row_offset], LSL #2\n" | ||
97 | "add x22, x22, %x[row_offset], LSL #2\n" | ||
98 | "add x21, x21, %x[row_offset], LSL #2\n" | ||
99 | "beq 1f\n" | ||
100 | "cmp %x[height], #0x2\n" | ||
101 | "mov x21, x28\n" | ||
102 | "csel x27, x27, x28, GE\n" | ||
103 | "csel x26, x26, x28, GT\n" | ||
104 | "cmp %x[height], #0x4\n" | ||
105 | "csel x25, x25, x28, GE\n" | ||
106 | "csel x24, x24, x28, GT\n" | ||
107 | "cmp %x[height], #0x6\n" | ||
108 | "csel x23, x23, x28, GE\n" | ||
109 | "csel x22, x22, x28, GT\n" | ||
110 | "1:" // no_pointer_adj | ||
111 | "cmp %x[width], #0x4\n" | ||
112 | "prfm pldl1keep, [x28, #0x0]\n" | ||
113 | "prfm pldl1keep, [x27, #0x0]\n" | ||
114 | "prfm pldl1keep, [x26, #0x0]\n" | ||
115 | "prfm pldl1keep, [x25, #0x0]\n" | ||
116 | "prfm pldl1keep, [x24, #0x0]\n" | ||
117 | "prfm pldl1keep, [x23, #0x0]\n" | ||
118 | "prfm pldl1keep, [x22, #0x0]\n" | ||
119 | "prfm pldl1keep, [x21, #0x0]\n" | ||
120 | "prfm pldl1keep, [x28, #0x40]\n" | ||
121 | "prfm pldl1keep, [x27, #0x40]\n" | ||
122 | "prfm pldl1keep, [x26, #0x40]\n" | ||
123 | "prfm pldl1keep, [x25, #0x40]\n" | ||
124 | "prfm pldl1keep, [x24, #0x40]\n" | ||
125 | "prfm pldl1keep, [x23, #0x40]\n" | ||
126 | "prfm pldl1keep, [x22, #0x40]\n" | ||
127 | "prfm pldl1keep, [x21, #0x40]\n" | ||
128 | "blt 3f\n" | ||
129 | "2:" // Main loop head | ||
130 | "ldr q19, [x28], #0x10\n" | ||
131 | "ldr q18, [x26], #0x10\n" | ||
132 | "subs %x[width], %x[width], #0x4\n" | ||
133 | "ldr q17, [x24], #0x10\n" | ||
134 | "ldr q16, [x22], #0x10\n" | ||
135 | "cmp %x[width], #0x4\n" | ||
136 | "ldr q23, [x27], #0x10\n" | ||
137 | "ldr q22, [x25], #0x10\n" | ||
138 | "ldr q21, [x23], #0x10\n" | ||
139 | "ldr q20, [x21], #0x10\n" | ||
140 | ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" | ||
141 | ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" | ||
142 | ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" | ||
143 | ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" | ||
144 | "prfm pldl1keep, [x28, #0x70]\n" | ||
145 | "prfm pldl1keep, [x27, #0x70]\n" | ||
146 | "prfm pldl1keep, [x26, #0x70]\n" | ||
147 | "prfm pldl1keep, [x25, #0x70]\n" | ||
148 | "prfm pldl1keep, [x24, #0x70]\n" | ||
149 | "prfm pldl1keep, [x23, #0x70]\n" | ||
150 | ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" | ||
151 | ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" | ||
152 | "prfm pldl1keep, [x22, #0x70]\n" | ||
153 | "prfm pldl1keep, [x21, #0x70]\n" | ||
154 | ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" | ||
155 | ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" | ||
156 | "str q19, [%x[out_ptr], #0x0]\n" | ||
157 | "str q18, [%x[out_ptr], #0x10]\n" | ||
158 | "str q17, [%x[out_ptr], #0x20]\n" | ||
159 | "str q16, [%x[out_ptr], #0x30]\n" | ||
160 | "add %x[out_ptr], %x[out_ptr], #0x40\n" | ||
161 | "bge 2b\n" | ||
162 | "3:" // Main loop skip | ||
163 | "cbz %x[width], 6f\n" | ||
164 | "tbz %x[width], #1, 4f\n" | ||
165 | "ldr d19, [x28], #0x8\n" | ||
166 | "ldr d23, [x27], #0x8\n" | ||
167 | "mov x20, #0x1\n" | ||
168 | "ldr d18, [x26], #0x8\n" | ||
169 | "ldr d22, [x25], #0x8\n" | ||
170 | "ldr d17, [x24], #0x8\n" | ||
171 | "ldr d21, [x23], #0x8\n" | ||
172 | "ldr d16, [x22], #0x8\n" | ||
173 | "ldr d20, [x21], #0x8\n" | ||
174 | "tbz %x[width], #0, 5f\n" | ||
175 | "ld1 { v19.s }[2], [x28]\n" | ||
176 | "ld1 { v23.s }[2], [x27]\n" | ||
177 | "ld1 { v18.s }[2], [x26]\n" | ||
178 | "ld1 { v22.s }[2], [x25]\n" | ||
179 | "ld1 { v17.s }[2], [x24]\n" | ||
180 | "ld1 { v21.s }[2], [x23]\n" | ||
181 | "ld1 { v16.s }[2], [x22]\n" | ||
182 | "ld1 { v20.s }[2], [x21]\n" | ||
183 | "b 5f\n" | ||
184 | "4:" // odd_loads_1_0 | ||
185 | "ldr s19, [x28, #0x0]\n" | ||
186 | "ldr s23, [x27, #0x0]\n" | ||
187 | "mov x20, #0x1\n" | ||
188 | "ldr s18, [x26, #0x0]\n" | ||
189 | "ldr s22, [x25, #0x0]\n" | ||
190 | "ldr s17, [x24, #0x0]\n" | ||
191 | "ldr s21, [x23, #0x0]\n" | ||
192 | "ldr s16, [x22, #0x0]\n" | ||
193 | "ldr s20, [x21, #0x0]\n" | ||
194 | "5:" // Odd load end | ||
195 | ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" | ||
196 | ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" | ||
197 | ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" | ||
198 | ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" | ||
199 | ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n" | ||
200 | ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n" | ||
201 | ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n" | ||
202 | ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n" | ||
203 | "str q19, [%x[out_ptr], #0x0]\n" | ||
204 | "str q18, [%x[out_ptr], #0x10]\n" | ||
205 | "str q17, [%x[out_ptr], #0x20]\n" | ||
206 | "str q16, [%x[out_ptr], #0x30]\n" | ||
207 | "add %x[out_ptr], %x[out_ptr], #0x40\n" | ||
208 | "6:" // Odds skip | ||
209 | : [out_ptr] "+&r"(out), [width] "+&r"(width) | ||
210 | 962 | : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset) | |
211 | : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", | ||
212 | "x25", "x26", "x27", "x28"); | ||
213 | 962 | } | |
214 | 92 | } | |
215 | |||
216 | #endif // Architectural features check. | ||
217 |