KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 75.0% 24 15 47
Functions: 66.7% 4 0 6
Branches: 100.0% 6 30 36

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // Do not flag up inline assembly blocks
8 #pragma GCC diagnostic ignored "-Woverlength-strings"
9
10 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)
11 #error This file must be compiled for AArch64, FEAT_SVE2.
12 #else // Architectural features check.
13
14 #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
15
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include "kai/kai_common.h"
20
21 static const size_t kai_mr = 2;
22 static const size_t kai_kr = 2;
23 static const size_t kai_sr = 1;
24
25 184 static size_t kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme(void) {
26 184 return kai_mr * kai_get_sme_vector_length_u16() / kai_kr;
27 }
28
29 size_t kai_get_m_step_lhs_pack_bf16p2vlx2_f32_sme(size_t mr) {
30 KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme());
31 KAI_UNUSED(mr);
32
33 return kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme();
34 }
35
36 46 size_t kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme(size_t m_idx, size_t lhs_stride) {
37 KAI_ASSUME(m_idx % kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme() == 0);
38
39 46 return m_idx * lhs_stride;
40 }
41
42 size_t kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
43 KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_bf16p2vlx2_f32_sme(mr) == 0);
44 KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme());
45 KAI_ASSUME(kr == kai_kr);
46 KAI_ASSUME(sr == kai_sr);
47
48 KAI_UNUSED(mr);
49 KAI_UNUSED(kr);
50 KAI_UNUSED(sr);
51
52 return m_idx * kai_roundup(k, kr) * sizeof(uint16_t);
53 }
54
55 46 size_t kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
56 KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme());
57 KAI_ASSUME(kr == kai_kr);
58 KAI_ASSUME(sr == kai_sr);
59
60 46 KAI_UNUSED(mr);
61 46 KAI_UNUSED(kr);
62 46 KAI_UNUSED(sr);
63
64 46 return kai_roundup(m, kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()) * kai_roundup(k, kai_kr) * sizeof(uint16_t);
65 }
66
67 46 void kai_run_lhs_pack_bf16p2vlx2_f32_sme(
68 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
69 void* lhs_packed) {
70 KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme());
71 KAI_ASSUME(kr == kai_kr);
72 KAI_ASSUME(sr == kai_sr);
73 KAI_ASSUME(lhs != NULL);
74 KAI_ASSUME(lhs_packed != NULL);
75
76 KAI_ASSUME(m_idx_start == 0);
77
78 46 const size_t block_height = mr;
79 46 const size_t width = k;
80 46 const size_t row_offset = 0;
81
82 46 const void* in[block_height];
83
84
2/2
✓ Branch 0 taken 46 times.
✓ Branch 1 taken 145 times.
191 for (size_t block_y = 0; block_y < m; block_y += block_height) {
85
2/2
✓ Branch 0 taken 37 times.
✓ Branch 1 taken 108 times.
145 const size_t height = KAI_MIN(m - block_y, block_height);
86 145 void* out = (char*)lhs_packed + (block_y * kai_roundup(k, kai_kr) * sizeof(uint16_t));
87
88
2/2
✓ Branch 0 taken 3794 times.
✓ Branch 1 taken 145 times.
3939 for (size_t y = 0; y < height; y++) {
89 3794 in[y] = (const void*)((const char*)lhs + (block_y + y) * lhs_stride);
90 3794 }
91
92 290 __asm__ __volatile__(
93 ".inst 0xd503477f // SMSTART ZA\n"
94 "sub x10, %x[width], #0x1\n"
95 "mov x9, #0x0\n"
96 "cntw x22, ALL, MUL #2\n"
97 "cntw x28\n"
98 "cntw x21, ALL, MUL #2\n"
99 "sub x20, x22, #0x1\n"
100 ".inst 0x25207815 // ptrue pn13.b\n"
101 "whilelt p12.s, XZR, %x[height]\n"
102 "whilelt p11.s, x28, %x[height]\n"
103 "add x10, x10, x21\n"
104 "ands x27, %x[width], x20\n"
105 "udiv x10, x10, x21\n"
106 "csel x27, x27, x22, NE\n"
107 "and x26, x10, #0x1\n"
108 "sub x10, x10, #0x1\n"
109 "add x27, x27, #0x1\n"
110 "mov x20, %x[width]\n"
111 "mov x25, %x[in]\n"
112 "ptrue p0.b\n"
113 "mov x24, %x[outptr_raw]\n"
114 "mov x23, %x[row_offset]\n"
115 "lsr x10, x10, #0x1\n"
116 "lsr x27, x27, #0x1\n"
117 "mov x12, #0x0\n"
118 ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n"
119 "add x22, x25, x28, LSL #3\n"
120 "1:" // Width loop: Preamble: Loop
121 "ldr x21, [x25], #0x8\n"
122 ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n"
123 ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n"
124 "ldr x20, [x22], #0x8\n"
125 ".inst 0xa01746b4 // ld1w { z20.s-z21.s }, pn9.s/Z, [x21, x23, LSL #2]\n"
126 ".inst 0xa017428c // ld1w { z12.s-z13.s }, pn8.s/Z, [x20, x23, LSL #2]\n"
127 ".inst 0xc160e294 // bfcvt z20.h, { z20.s-z21.s }\n"
128 ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n"
129 ".inst 0xc0800280 // mova za0h.s[x12], p0/M, z20.s\n"
130 ".inst 0xc0800184 // mova za1h.s[x12], p0/M, z12.s\n"
131 "add x12, x12, #0x1\n"
132 "cmp x12, x28\n"
133 "blt 1b\n"
134 "incw x23, ALL, MUL #2\n"
135 "incw x9, ALL, MUL #2\n"
136 "cbz x10, 5f\n"
137 "2:" // Width loop
138 "mov x20, %x[width]\n"
139 "mov x25, %x[in]\n"
140 "mov x12, #0x0\n"
141 ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n"
142 "add x22, x25, x28, LSL #3\n"
143 "3:" // Width loop: Odd: Loop
144 "ldr x21, [x25], #0x8\n"
145 ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n"
146 ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n"
147 ".inst 0xc0828007 // mova z7.s, p0/M, za0v.s[x12]\n"
148 "ldr x20, [x22], #0x8\n"
149 ".inst 0xc082808f // mova z15.s, p0/M, za1v.s[x12]\n"
150 ".inst 0xa01746b6 // ld1w { z22.s-z23.s }, pn9.s/Z, [x21, x23, LSL #2]\n"
151 ".inst 0xa017429a // ld1w { z26.s-z27.s }, pn8.s/Z, [x20, x23, LSL #2]\n"
152 ".inst 0xa1605707 // st1w { z7.s, z15.s }, pn13.b, [x24]\n"
153 "addvl x24, x24, #2\n"
154 ".inst 0xc160e2d6 // bfcvt z22.h, { z22.s-z23.s }\n"
155 ".inst 0xc160e35a // bfcvt z26.h, { z26.s-z27.s }\n"
156 ".inst 0xc08002c8 // mova za2h.s[x12], p0/M, z22.s\n"
157 ".inst 0xc080034c // mova za3h.s[x12], p0/M, z26.s\n"
158 "add x12, x12, #0x1\n"
159 "cmp x12, x28\n"
160 "blt 3b\n"
161 "incw x9, ALL, MUL #2\n"
162 "mov x20, %x[width]\n"
163 "mov x25, %x[in]\n"
164 "incw x23, ALL, MUL #2\n"
165 "mov x12, #0x0\n"
166 ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n"
167 "add x22, x25, x28, LSL #3\n"
168 "4:" // Width loop: Even: Loop
169 "ldr x21, [x25], #0x8\n"
170 ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n"
171 ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n"
172 ".inst 0xc0828108 // mova z8.s, p0/M, za2v.s[x12]\n"
173 "ldr x20, [x22], #0x8\n"
174 ".inst 0xc0828189 // mova z9.s, p0/M, za3v.s[x12]\n"
175 ".inst 0xa01746ae // ld1w { z14.s-z15.s }, pn9.s/Z, [x21, x23, LSL #2]\n"
176 ".inst 0xa017428c // ld1w { z12.s-z13.s }, pn8.s/Z, [x20, x23, LSL #2]\n"
177 ".inst 0xa0605708 // st1w { z8.s-z9.s }, pn13.b, [x24]\n"
178 "addvl x24, x24, #2\n"
179 ".inst 0xc160e1ce // bfcvt z14.h, { z14.s-z15.s }\n"
180 ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n"
181 ".inst 0xc08001c0 // mova za0h.s[x12], p0/M, z14.s\n"
182 ".inst 0xc0800184 // mova za1h.s[x12], p0/M, z12.s\n"
183 "add x12, x12, #0x1\n"
184 "cmp x12, x28\n"
185 "blt 4b\n"
186 "subs x10, x10, #0x1\n"
187 "incw x23, ALL, MUL #2\n"
188 "incw x9, ALL, MUL #2\n"
189 "bgt 2b\n"
190 "5:" // Width loop: Tails
191 "cbnz x26, 8f\n"
192 "mov x20, %x[width]\n"
193 "mov x25, %x[in]\n"
194 "mov x12, #0x0\n"
195 ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n"
196 "add x22, x25, x28, LSL #3\n"
197 "6:" // Width loop: Tails: Even: Odd: Loop
198 "ldr x21, [x25], #0x8\n"
199 ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n"
200 ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n"
201 ".inst 0xc0828003 // mova z3.s, p0/M, za0v.s[x12]\n"
202 "ldr x20, [x22], #0x8\n"
203 ".inst 0xc082808b // mova z11.s, p0/M, za1v.s[x12]\n"
204 ".inst 0xa01746ac // ld1w { z12.s-z13.s }, pn9.s/Z, [x21, x23, LSL #2]\n"
205 ".inst 0xa017428e // ld1w { z14.s-z15.s }, pn8.s/Z, [x20, x23, LSL #2]\n"
206 ".inst 0xa1605703 // st1w { z3.s, z11.s }, pn13.b, [x24]\n"
207 "addvl x24, x24, #2\n"
208 ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n"
209 ".inst 0xc160e1ce // bfcvt z14.h, { z14.s-z15.s }\n"
210 ".inst 0xc0800188 // mova za2h.s[x12], p0/M, z12.s\n"
211 ".inst 0xc08001cc // mova za3h.s[x12], p0/M, z14.s\n"
212 "add x12, x12, #0x1\n"
213 "cmp x12, x28\n"
214 "blt 6b\n"
215 "mov x12, #0x0\n"
216 "7:" // Width loop: Tails: Even: Even: Loop
217 ".inst 0xc082810e // mova z14.s, p0/M, za2v.s[x12]\n"
218 ".inst 0xc082818f // mova z15.s, p0/M, za3v.s[x12]\n"
219 "add x12, x12, #0x1\n"
220 "cmp x12, x27\n"
221 ".inst 0xa060570e // st1w { z14.s-z15.s }, pn13.b, [x24]\n"
222 "addvl x24, x24, #2\n"
223 "blt 7b\n"
224 "b 10f\n"
225 "8:" // Width loop: Tails: Odd
226 "mov x12, #0x0\n"
227 "9:" // Width loop: Tails: Odd: Loop
228 ".inst 0xc0828014 // mova z20.s, p0/M, za0v.s[x12]\n"
229 ".inst 0xc0828095 // mova z21.s, p0/M, za1v.s[x12]\n"
230 "add x12, x12, #0x1\n"
231 "cmp x12, x27\n"
232 ".inst 0xa0605714 // st1w { z20.s-z21.s }, pn13.b, [x24]\n"
233 "addvl x24, x24, #2\n"
234 "blt 9b\n"
235 "10:" // End
236 "mov %x[outptr_raw], x24\n"
237 ".inst 0xd503467f // SMSTOP\n"
238 : [outptr_raw] "+&r"(out)
239 145 : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset), [width] "r"(width)
240 : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7",
241 "p8", "p9", "x10", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1",
242 "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23",
243 "z24", "z25", "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9");
244 145 }
245 46 }
246
247 #endif // Architectural features check.
248