KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 82.6% 19 15 38
Functions: 60.0% 3 0 5
Branches: 100.0% 6 30 36

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // Do not flag up inline assembly blocks
8 #pragma GCC diagnostic ignored "-Woverlength-strings"
9
10 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || \
11 !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
12 #error This file must be compiled for AArch64, FEAT_BF16, FEAT_FP16.
13 #else // Architectural features check.
14
15 #include "kai_lhs_pack_bf16p8x4_f16_neon.h"
16
17 #include <stddef.h>
18 #include <stdint.h>
19
20 #include "kai/kai_common.h"
21
22 static const size_t kai_mr = 8;
23 static const size_t kai_kr = 4;
24 static const size_t kai_sr = 1;
25
26 size_t kai_get_m_step_lhs_pack_bf16p8x4_f16_neon(size_t mr) {
27 KAI_ASSUME(mr == kai_mr);
28
29 return kai_mr;
30 }
31
32 176 size_t kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon(size_t m_idx, size_t lhs_stride) {
33 KAI_ASSUME(m_idx % (kai_mr) == 0);
34
35 176 return m_idx * lhs_stride;
36 }
37
38 size_t kai_get_lhs_packed_offset_lhs_pack_bf16p8x4_f16_neon(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
39 KAI_ASSUME(m_idx % kai_mr == 0);
40 KAI_ASSUME(mr == kai_mr);
41 KAI_ASSUME(kr == kai_kr);
42 KAI_ASSUME(sr == kai_sr);
43
44 return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t);
45 }
46
47 176 size_t kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
48 KAI_ASSUME(mr == kai_mr);
49 KAI_ASSUME(kr == kai_kr);
50 KAI_ASSUME(sr == kai_sr);
51
52 176 return kai_roundup(m, kai_mr) * kai_roundup(k, kai_kr) * sizeof(uint16_t);
53 }
54
55 176 void kai_run_lhs_pack_bf16p8x4_f16_neon(
56 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
57 void* lhs_packed) {
58 KAI_ASSUME(mr == kai_mr);
59 KAI_ASSUME(kr == kai_kr);
60 KAI_ASSUME(sr == kai_sr);
61 KAI_ASSUME(lhs != NULL);
62 KAI_ASSUME(lhs_packed != NULL);
63
64 KAI_ASSUME(m_idx_start == 0);
65
66 176 const size_t block_height = kai_mr;
67 176 const size_t row_offset = 0;
68
69 176 const void* in[kai_mr];
70
71
2/2
✓ Branch 0 taken 176 times.
✓ Branch 1 taken 1204 times.
1380 for (size_t block_y = 0; block_y < m; block_y += block_height) {
72
2/2
✓ Branch 0 taken 130 times.
✓ Branch 1 taken 1074 times.
1204 const size_t height = KAI_MIN(m - block_y, block_height);
73 1204 void* out = (char*)lhs_packed + block_y * kai_roundup(k, kai_kr) * sizeof(uint16_t);
74 1204 size_t width = k;
75
76
2/2
✓ Branch 0 taken 8966 times.
✓ Branch 1 taken 1204 times.
10170 for (size_t y = 0; y < height; y++) {
77 8966 in[y] = (const char*)lhs + (block_y + y) * lhs_stride;
78 8966 }
79
80 2408 __asm__ __volatile__(
81 "ldr x28, [%x[in], #0x0]\n"
82 "ldr x27, [%x[in], #0x8]\n"
83 "cmp %x[height], #0x8\n"
84 "ldr x26, [%x[in], #0x10]\n"
85 "ldr x25, [%x[in], #0x18]\n"
86 "ldr x24, [%x[in], #0x20]\n"
87 "ldr x23, [%x[in], #0x28]\n"
88 "ldr x22, [%x[in], #0x30]\n"
89 "ldr x21, [%x[in], #0x38]\n"
90 "add x28, x28, %x[row_offset], LSL #1\n"
91 "add x27, x27, %x[row_offset], LSL #1\n"
92 "add x26, x26, %x[row_offset], LSL #1\n"
93 "add x25, x25, %x[row_offset], LSL #1\n"
94 "add x24, x24, %x[row_offset], LSL #1\n"
95 "add x23, x23, %x[row_offset], LSL #1\n"
96 "add x22, x22, %x[row_offset], LSL #1\n"
97 "add x21, x21, %x[row_offset], LSL #1\n"
98 "beq 1f\n"
99 "cmp %x[height], #0x2\n"
100 "mov x21, x28\n"
101 "csel x27, x27, x28, GE\n"
102 "csel x26, x26, x28, GT\n"
103 "cmp %x[height], #0x4\n"
104 "csel x25, x25, x28, GE\n"
105 "csel x24, x24, x28, GT\n"
106 "cmp %x[height], #0x6\n"
107 "csel x23, x23, x28, GE\n"
108 "csel x22, x22, x28, GT\n"
109 "1:" // no_pointer_adj
110 "cmp %x[width], #0x8\n"
111 "prfm pldl1keep, [x28, #0x0]\n"
112 "prfm pldl1keep, [x27, #0x0]\n"
113 "prfm pldl1keep, [x26, #0x0]\n"
114 "prfm pldl1keep, [x25, #0x0]\n"
115 "prfm pldl1keep, [x24, #0x0]\n"
116 "prfm pldl1keep, [x23, #0x0]\n"
117 "prfm pldl1keep, [x22, #0x0]\n"
118 "prfm pldl1keep, [x21, #0x0]\n"
119 "prfm pldl1keep, [x28, #0x40]\n"
120 "prfm pldl1keep, [x27, #0x40]\n"
121 "prfm pldl1keep, [x26, #0x40]\n"
122 "prfm pldl1keep, [x25, #0x40]\n"
123 "prfm pldl1keep, [x24, #0x40]\n"
124 "prfm pldl1keep, [x23, #0x40]\n"
125 "prfm pldl1keep, [x22, #0x40]\n"
126 "prfm pldl1keep, [x21, #0x40]\n"
127 "blt 3f\n"
128 "2:" // Main loop head
129 "ldr q19, [x28], #0x10\n"
130 "ldr q18, [x26], #0x10\n"
131 "subs %x[width], %x[width], #0x8\n"
132 "ldr q17, [x24], #0x10\n"
133 "ldr q16, [x22], #0x10\n"
134 "cmp %x[width], #0x8\n"
135 "ldr q25, [x27], #0x10\n"
136 "ldr q24, [x25], #0x10\n"
137 "ldr q1, [x23], #0x10\n"
138 "ldr q0, [x21], #0x10\n"
139 "fcvtl v23.4s, v19.4h\n"
140 "fcvtl2 v22.4s, v19.8h\n"
141 "fcvtl v21.4s, v18.4h\n"
142 "fcvtl2 v20.4s, v18.8h\n"
143 "prfm pldl1keep, [x28, #0x70]\n"
144 "fcvtl v19.4s, v17.4h\n"
145 "fcvtl2 v18.4s, v17.8h\n"
146 "prfm pldl1keep, [x27, #0x70]\n"
147 "prfm pldl1keep, [x26, #0x70]\n"
148 "fcvtl v17.4s, v16.4h\n"
149 "fcvtl2 v16.4s, v16.8h\n"
150 "prfm pldl1keep, [x25, #0x70]\n"
151 "prfm pldl1keep, [x24, #0x70]\n"
152 ".inst 0x0ea16aff // bfcvtn v31.4h, v23.4s\n"
153 ".inst 0x0ea16ade // bfcvtn v30.4h, v22.4s\n"
154 "prfm pldl1keep, [x23, #0x70]\n"
155 "prfm pldl1keep, [x22, #0x70]\n"
156 "fcvtl v29.4s, v25.4h\n"
157 "fcvtl2 v28.4s, v25.8h\n"
158 "prfm pldl1keep, [x21, #0x70]\n"
159 ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n"
160 ".inst 0x0ea16a9a // bfcvtn v26.4h, v20.4s\n"
161 "fcvtl v25.4s, v24.4h\n"
162 "fcvtl2 v24.4s, v24.8h\n"
163 ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n"
164 ".inst 0x0ea16a56 // bfcvtn v22.4h, v18.4s\n"
165 "fcvtl v21.4s, v1.4h\n"
166 "fcvtl2 v20.4s, v1.8h\n"
167 ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n"
168 ".inst 0x0ea16a12 // bfcvtn v18.4h, v16.4s\n"
169 "fcvtl v17.4s, v0.4h\n"
170 "fcvtl2 v16.4s, v0.8h\n"
171 ".inst 0x4ea16bbf // bfcvtn2 v31.8h, v29.4s\n"
172 ".inst 0x4ea16b9e // bfcvtn2 v30.8h, v28.4s\n"
173 ".inst 0x4ea16b3b // bfcvtn2 v27.8h, v25.4s\n"
174 ".inst 0x4ea16b1a // bfcvtn2 v26.8h, v24.4s\n"
175 ".inst 0x4ea16ab7 // bfcvtn2 v23.8h, v21.4s\n"
176 ".inst 0x4ea16a96 // bfcvtn2 v22.8h, v20.4s\n"
177 ".inst 0x4ea16a33 // bfcvtn2 v19.8h, v17.4s\n"
178 ".inst 0x4ea16a12 // bfcvtn2 v18.8h, v16.4s\n"
179 "str q31, [%x[out_ptr], #0x0]\n"
180 "str q27, [%x[out_ptr], #0x10]\n"
181 "str q23, [%x[out_ptr], #0x20]\n"
182 "str q19, [%x[out_ptr], #0x30]\n"
183 "str q30, [%x[out_ptr], #0x40]\n"
184 "str q26, [%x[out_ptr], #0x50]\n"
185 "str q22, [%x[out_ptr], #0x60]\n"
186 "str q18, [%x[out_ptr], #0x70]\n"
187 "add %x[out_ptr], %x[out_ptr], #0x80\n"
188 "bge 2b\n"
189 "3:" // Main loop skip
190 "cbz %x[width], 8f\n"
191 "tbz %x[width], #2, 5f\n"
192 "ldr d19, [x28], #0x8\n"
193 "ldr d25, [x27], #0x8\n"
194 "ldr d18, [x26], #0x8\n"
195 "ldr d24, [x25], #0x8\n"
196 "ldr d17, [x24], #0x8\n"
197 "ldr d1, [x23], #0x8\n"
198 "ldr d16, [x22], #0x8\n"
199 "ldr d0, [x21], #0x8\n"
200 "tbz %x[width], #1, 4f\n"
201 "ld1 { v19.s }[2], [x28], #0x4\n"
202 "ld1 { v25.s }[2], [x27], #0x4\n"
203 "mov x20, #0x2\n"
204 "ld1 { v18.s }[2], [x26], #0x4\n"
205 "ld1 { v24.s }[2], [x25], #0x4\n"
206 "ld1 { v17.s }[2], [x24], #0x4\n"
207 "ld1 { v1.s }[2], [x23], #0x4\n"
208 "ld1 { v16.s }[2], [x22], #0x4\n"
209 "ld1 { v0.s }[2], [x21], #0x4\n"
210 "tbz %x[width], #0, 7f\n"
211 "ld1 { v19.h }[6], [x28]\n"
212 "ld1 { v25.h }[6], [x27]\n"
213 "ld1 { v18.h }[6], [x26]\n"
214 "ld1 { v24.h }[6], [x25]\n"
215 "ld1 { v17.h }[6], [x24]\n"
216 "ld1 { v1.h }[6], [x23]\n"
217 "ld1 { v16.h }[6], [x22]\n"
218 "ld1 { v0.h }[6], [x21]\n"
219 "b 7f\n"
220 "4:" // odd_loads_1_4
221 "mov x20, #0x1\n"
222 "tbz %x[width], #0, 7f\n"
223 "ld1 { v19.h }[4], [x28]\n"
224 "ld1 { v25.h }[4], [x27]\n"
225 "mov x20, #0x2\n"
226 "ld1 { v18.h }[4], [x26]\n"
227 "ld1 { v24.h }[4], [x25]\n"
228 "ld1 { v17.h }[4], [x24]\n"
229 "ld1 { v1.h }[4], [x23]\n"
230 "ld1 { v16.h }[4], [x22]\n"
231 "ld1 { v0.h }[4], [x21]\n"
232 "b 7f\n"
233 "5:" // odd_loads_2_0
234 "tbz %x[width], #1, 6f\n"
235 "ldr s19, [x28], #0x4\n"
236 "ldr s25, [x27], #0x4\n"
237 "mov x20, #0x1\n"
238 "ldr s18, [x26], #0x4\n"
239 "ldr s24, [x25], #0x4\n"
240 "ldr s17, [x24], #0x4\n"
241 "ldr s1, [x23], #0x4\n"
242 "ldr s16, [x22], #0x4\n"
243 "ldr s0, [x21], #0x4\n"
244 "tbz %x[width], #0, 7f\n"
245 "ld1 { v19.h }[2], [x28]\n"
246 "ld1 { v25.h }[2], [x27]\n"
247 "ld1 { v18.h }[2], [x26]\n"
248 "ld1 { v24.h }[2], [x25]\n"
249 "ld1 { v17.h }[2], [x24]\n"
250 "ld1 { v1.h }[2], [x23]\n"
251 "ld1 { v16.h }[2], [x22]\n"
252 "ld1 { v0.h }[2], [x21]\n"
253 "b 7f\n"
254 "6:" // odd_loads_1_0
255 "ldr h19, [x28, #0x0]\n"
256 "ldr h25, [x27, #0x0]\n"
257 "mov x20, #0x1\n"
258 "ldr h18, [x26, #0x0]\n"
259 "ldr h24, [x25, #0x0]\n"
260 "ldr h17, [x24, #0x0]\n"
261 "ldr h1, [x23, #0x0]\n"
262 "ldr h16, [x22, #0x0]\n"
263 "ldr h0, [x21, #0x0]\n"
264 "7:" // Odd load end
265 "fcvtl v23.4s, v19.4h\n"
266 "fcvtl2 v22.4s, v19.8h\n"
267 "subs x20, x20, #0x1\n"
268 "fcvtl v21.4s, v18.4h\n"
269 "fcvtl2 v20.4s, v18.8h\n"
270 "fcvtl v19.4s, v17.4h\n"
271 "fcvtl2 v18.4s, v17.8h\n"
272 "fcvtl v17.4s, v16.4h\n"
273 "fcvtl2 v16.4s, v16.8h\n"
274 ".inst 0x0ea16aff // bfcvtn v31.4h, v23.4s\n"
275 ".inst 0x0ea16ade // bfcvtn v30.4h, v22.4s\n"
276 "fcvtl v29.4s, v25.4h\n"
277 "fcvtl2 v28.4s, v25.8h\n"
278 ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n"
279 ".inst 0x0ea16a9a // bfcvtn v26.4h, v20.4s\n"
280 "fcvtl v25.4s, v24.4h\n"
281 "fcvtl2 v24.4s, v24.8h\n"
282 ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n"
283 ".inst 0x0ea16a56 // bfcvtn v22.4h, v18.4s\n"
284 "fcvtl v21.4s, v1.4h\n"
285 "fcvtl2 v20.4s, v1.8h\n"
286 ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n"
287 ".inst 0x0ea16a12 // bfcvtn v18.4h, v16.4s\n"
288 "fcvtl v17.4s, v0.4h\n"
289 "fcvtl2 v16.4s, v0.8h\n"
290 ".inst 0x4ea16bbf // bfcvtn2 v31.8h, v29.4s\n"
291 ".inst 0x4ea16b9e // bfcvtn2 v30.8h, v28.4s\n"
292 ".inst 0x4ea16b3b // bfcvtn2 v27.8h, v25.4s\n"
293 ".inst 0x4ea16b1a // bfcvtn2 v26.8h, v24.4s\n"
294 ".inst 0x4ea16ab7 // bfcvtn2 v23.8h, v21.4s\n"
295 ".inst 0x4ea16a96 // bfcvtn2 v22.8h, v20.4s\n"
296 ".inst 0x4ea16a33 // bfcvtn2 v19.8h, v17.4s\n"
297 ".inst 0x4ea16a12 // bfcvtn2 v18.8h, v16.4s\n"
298 "str q31, [%x[out_ptr], #0x0]\n"
299 "str q27, [%x[out_ptr], #0x10]\n"
300 "str q23, [%x[out_ptr], #0x20]\n"
301 "str q19, [%x[out_ptr], #0x30]\n"
302 "add %x[out_ptr], %x[out_ptr], #0x40\n"
303 "beq 8f\n"
304 "str q30, [%x[out_ptr], #0x0]\n"
305 "str q26, [%x[out_ptr], #0x10]\n"
306 "str q22, [%x[out_ptr], #0x20]\n"
307 "str q18, [%x[out_ptr], #0x30]\n"
308 "add %x[out_ptr], %x[out_ptr], #0x40\n"
309 "8:" // Odds skip
310 : [out_ptr] "+&r"(out), [width] "+&r"(width)
311 1204 : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset)
312 : "cc", "memory", "v0", "v1", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
313 "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28");
314 1204 }
315 176 }
316
317 #endif // Architectural features check.
318