KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 77.1% 27 / 16 / 51
Functions: 66.7% 4 / 0 / 6
Branches: 100.0% 6 / 32 / 38

kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // Do not flag up inline assembly blocks
8 #pragma GCC diagnostic ignored "-Woverlength-strings"
9
10 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)
11 #error This file must be compiled for AArch64, FEAT_SVE2.
12 #else // Architectural features check.
13 #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
14
15 #include <stddef.h>
16 #include <stdint.h>
17
18 #include "kai/kai_common.h"
19
20 enum {
21 MR = 2,
22 KR = 2,
23 MAX_M_STEP = MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR,
24 SR = 1,
25 };
26
27 750 static size_t kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme(void) {
28 750 return MR * kai_get_sme_vector_length_u16() / KR;
29 }
30
31 size_t kai_get_m_step_lhs_pack_bf16p2vlx2_f32_sme(size_t mr) {
32 KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme());
33 KAI_UNUSED(mr);
34 return kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme();
35 }
36
37 150 size_t kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme(size_t m_idx, size_t lhs_stride_row) {
38 KAI_ASSUME(m_idx % kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme() == 0);
39
40 150 return m_idx * lhs_stride_row;
41 }
42
43 size_t kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
44 KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_bf16p2vlx2_f32_sme(mr) == 0);
45 KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme());
46 KAI_ASSUME(kr == KR);
47 KAI_ASSUME(sr == SR);
48
49 KAI_UNUSED(mr);
50 KAI_UNUSED(kr);
51 KAI_UNUSED(sr);
52
53 return m_idx * kai_roundup(k, KR) * sizeof(uint16_t);
54 }
55
56 150 size_t kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
57 KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme());
58 KAI_ASSUME(kr == KR);
59 KAI_ASSUME(sr == SR);
60
61 150 KAI_UNUSED(mr);
62 150 KAI_UNUSED(kr);
63 150 KAI_UNUSED(sr);
64 150 return kai_roundup(m, kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()) * kai_roundup(k, KR) * sizeof(uint16_t);
65 }
66
67 150 void kai_run_lhs_pack_bf16p2vlx2_f32_sme(
68 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride_row,
69 void* lhs_packed) {
70 KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme());
71 KAI_ASSUME(kr == KR);
72 KAI_ASSUME(sr == SR);
73 KAI_ASSUME(m_idx_start == 0);
74 KAI_ASSUME(lhs != NULL);
75 KAI_ASSUME(lhs_packed != NULL);
76
77 150 const size_t m_step = kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme();
78 150 const size_t width = k;
79
80 KAI_ASSERT(m_step <= MAX_M_STEP);
81 150 const uint8_t* in[MAX_M_STEP];
82
83 150 uint8_t* out_base = lhs_packed;
84 150 const uint8_t* lhs_ptr = lhs;
85
86 150 kai_commit_za();
87
88
2/2
✓ Branch 0 taken 150 times.
✓ Branch 1 taken 471 times.
621 for (size_t i_m = 0; i_m < m; i_m += m_step) {
89
2/2
✓ Branch 0 taken 120 times.
✓ Branch 1 taken 351 times.
471 const size_t height = KAI_MIN(m - i_m, m_step);
90 471 void* out = out_base;
91 471 out_base += m_step * kai_roundup(k, KR) * sizeof(uint16_t);
92
93
2/2
✓ Branch 0 taken 12258 times.
✓ Branch 1 taken 471 times.
12729 for (size_t y = 0; y < height; y++) {
94 12258 in[y] = lhs_ptr + (i_m + y) * lhs_stride_row;
95 12258 }
96
97 942 __asm__ __volatile__(
98 ".inst 0xd503477f // SMSTART ZA\n"
99 "sub x10, %x[width], #0x1\n"
100 "mov x9, #0x0\n"
101 "cntw x22, ALL, MUL #2\n"
102 "cntw x28\n"
103 "cntw x21, ALL, MUL #2\n"
104 "sub x20, x22, #0x1\n"
105 ".inst 0x25207815 // ptrue pn13.b\n"
106 "whilelt p12.s, XZR, %x[height]\n"
107 "whilelt p11.s, x28, %x[height]\n"
108 "add x10, x10, x21\n"
109 "ands x27, %x[width], x20\n"
110 "udiv x10, x10, x21\n"
111 "csel x27, x27, x22, NE\n"
112 "and x26, x10, #0x1\n"
113 "sub x10, x10, #0x1\n"
114 "add x27, x27, #0x1\n"
115 "mov x20, %x[width]\n"
116 "mov x25, %x[in]\n"
117 "ptrue p0.b\n"
118 "mov x24, %x[outptr_raw]\n"
119 "mov x23, #0x0\n"
120 "lsr x10, x10, #0x1\n"
121 "lsr x27, x27, #0x1\n"
122 "mov x12, #0x0\n"
123 ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n"
124 "add x22, x25, x28, LSL #3\n"
125 "1:" // Width loop: Preamble: Loop
126 "ldr x21, [x25], #0x8\n"
127 ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n"
128 ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n"
129 "ldr x20, [x22], #0x8\n"
130 ".inst 0xa01746b4 // ld1w { z20.s-z21.s }, pn9.s/Z, [x21, x23, LSL #2]\n"
131 ".inst 0xa017428c // ld1w { z12.s-z13.s }, pn8.s/Z, [x20, x23, LSL #2]\n"
132 ".inst 0xc160e294 // bfcvt z20.h, { z20.s-z21.s }\n"
133 ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n"
134 ".inst 0xc0800280 // mova za0h.s[x12], p0/M, z20.s\n"
135 ".inst 0xc0800184 // mova za1h.s[x12], p0/M, z12.s\n"
136 "add x12, x12, #0x1\n"
137 "cmp x12, x28\n"
138 "blt 1b\n"
139 "incw x23, ALL, MUL #2\n"
140 "incw x9, ALL, MUL #2\n"
141 "cbz x10, 5f\n"
142 "2:" // Width loop
143 "mov x20, %x[width]\n"
144 "mov x25, %x[in]\n"
145 "mov x12, #0x0\n"
146 ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n"
147 "add x22, x25, x28, LSL #3\n"
148 "3:" // Width loop: Odd: Loop
149 "ldr x21, [x25], #0x8\n"
150 ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n"
151 ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n"
152 ".inst 0xc0828007 // mova z7.s, p0/M, za0v.s[x12]\n"
153 "ldr x20, [x22], #0x8\n"
154 ".inst 0xc082808f // mova z15.s, p0/M, za1v.s[x12]\n"
155 ".inst 0xa01746b6 // ld1w { z22.s-z23.s }, pn9.s/Z, [x21, x23, LSL #2]\n"
156 ".inst 0xa017429a // ld1w { z26.s-z27.s }, pn8.s/Z, [x20, x23, LSL #2]\n"
157 ".inst 0xa1605707 // st1w { z7.s, z15.s }, pn13.b, [x24]\n"
158 "addvl x24, x24, #2\n"
159 ".inst 0xc160e2d6 // bfcvt z22.h, { z22.s-z23.s }\n"
160 ".inst 0xc160e35a // bfcvt z26.h, { z26.s-z27.s }\n"
161 ".inst 0xc08002c8 // mova za2h.s[x12], p0/M, z22.s\n"
162 ".inst 0xc080034c // mova za3h.s[x12], p0/M, z26.s\n"
163 "add x12, x12, #0x1\n"
164 "cmp x12, x28\n"
165 "blt 3b\n"
166 "incw x9, ALL, MUL #2\n"
167 "mov x20, %x[width]\n"
168 "mov x25, %x[in]\n"
169 "incw x23, ALL, MUL #2\n"
170 "mov x12, #0x0\n"
171 ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n"
172 "add x22, x25, x28, LSL #3\n"
173 "4:" // Width loop: Even: Loop
174 "ldr x21, [x25], #0x8\n"
175 ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n"
176 ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n"
177 ".inst 0xc0828108 // mova z8.s, p0/M, za2v.s[x12]\n"
178 "ldr x20, [x22], #0x8\n"
179 ".inst 0xc0828189 // mova z9.s, p0/M, za3v.s[x12]\n"
180 ".inst 0xa01746ae // ld1w { z14.s-z15.s }, pn9.s/Z, [x21, x23, LSL #2]\n"
181 ".inst 0xa017428c // ld1w { z12.s-z13.s }, pn8.s/Z, [x20, x23, LSL #2]\n"
182 ".inst 0xa0605708 // st1w { z8.s-z9.s }, pn13.b, [x24]\n"
183 "addvl x24, x24, #2\n"
184 ".inst 0xc160e1ce // bfcvt z14.h, { z14.s-z15.s }\n"
185 ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n"
186 ".inst 0xc08001c0 // mova za0h.s[x12], p0/M, z14.s\n"
187 ".inst 0xc0800184 // mova za1h.s[x12], p0/M, z12.s\n"
188 "add x12, x12, #0x1\n"
189 "cmp x12, x28\n"
190 "blt 4b\n"
191 "subs x10, x10, #0x1\n"
192 "incw x23, ALL, MUL #2\n"
193 "incw x9, ALL, MUL #2\n"
194 "bgt 2b\n"
195 "5:" // Width loop: Tails
196 "cbnz x26, 8f\n"
197 "mov x20, %x[width]\n"
198 "mov x25, %x[in]\n"
199 "mov x12, #0x0\n"
200 ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n"
201 "add x22, x25, x28, LSL #3\n"
202 "6:" // Width loop: Tails: Even: Odd: Loop
203 "ldr x21, [x25], #0x8\n"
204 ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n"
205 ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n"
206 ".inst 0xc0828003 // mova z3.s, p0/M, za0v.s[x12]\n"
207 "ldr x20, [x22], #0x8\n"
208 ".inst 0xc082808b // mova z11.s, p0/M, za1v.s[x12]\n"
209 ".inst 0xa01746ac // ld1w { z12.s-z13.s }, pn9.s/Z, [x21, x23, LSL #2]\n"
210 ".inst 0xa017428e // ld1w { z14.s-z15.s }, pn8.s/Z, [x20, x23, LSL #2]\n"
211 ".inst 0xa1605703 // st1w { z3.s, z11.s }, pn13.b, [x24]\n"
212 "addvl x24, x24, #2\n"
213 ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n"
214 ".inst 0xc160e1ce // bfcvt z14.h, { z14.s-z15.s }\n"
215 ".inst 0xc0800188 // mova za2h.s[x12], p0/M, z12.s\n"
216 ".inst 0xc08001cc // mova za3h.s[x12], p0/M, z14.s\n"
217 "add x12, x12, #0x1\n"
218 "cmp x12, x28\n"
219 "blt 6b\n"
220 "mov x12, #0x0\n"
221 "7:" // Width loop: Tails: Even: Even: Loop
222 ".inst 0xc082810e // mova z14.s, p0/M, za2v.s[x12]\n"
223 ".inst 0xc082818f // mova z15.s, p0/M, za3v.s[x12]\n"
224 "add x12, x12, #0x1\n"
225 "cmp x12, x27\n"
226 ".inst 0xa060570e // st1w { z14.s-z15.s }, pn13.b, [x24]\n"
227 "addvl x24, x24, #2\n"
228 "blt 7b\n"
229 "b 10f\n"
230 "8:" // Width loop: Tails: Odd
231 "mov x12, #0x0\n"
232 "9:" // Width loop: Tails: Odd: Loop
233 ".inst 0xc0828014 // mova z20.s, p0/M, za0v.s[x12]\n"
234 ".inst 0xc0828095 // mova z21.s, p0/M, za1v.s[x12]\n"
235 "add x12, x12, #0x1\n"
236 "cmp x12, x27\n"
237 ".inst 0xa0605714 // st1w { z20.s-z21.s }, pn13.b, [x24]\n"
238 "addvl x24, x24, #2\n"
239 "blt 9b\n"
240 "10:" // End
241 "mov %x[outptr_raw], x24\n"
242 ".inst 0xd503467f // SMSTOP\n"
243 : [outptr_raw] "+&r"(out)
244 471 : [height] "r"(height), [in] "r"(in), [width] "r"(width)
245 : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13",
246 "p14", "p15", "x9", "x10", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0",
247 "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
248 "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
249 471 }
250 150 }
251
252 #endif // Architectural features check.
253