KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 93.1% 27 / 12 / 41
Functions: 85.7% 6 / 0 / 7
Branches: -% 0 / 24 / 24

kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15 #include <string.h>
16
17 #include "kai/kai_common.h"
18
19 static const size_t kai_nr = 2;
20 static const size_t kai_kr = 2;
21 static const size_t kai_num_bytes_input = 4;
22 static const size_t kai_num_bytes_output = 2;
23 static const size_t kai_num_bytes_bias = 4;
24
25 1230 size_t kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(void) {
26 1230 return kai_nr * kai_get_sme_vector_length_u16() / kai_kr;
27 }
28
29 150 size_t kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx) {
30 KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() == 0);
31
32 150 return n_idx * kai_num_bytes_input;
33 }
34
35 size_t kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx) {
36 return n_idx * kai_num_bytes_bias;
37 }
38
39 360 size_t kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t k) {
40 720 return kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() *
41 360 (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output);
42 }
43
44 180 size_t kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx, size_t k) {
45 KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() == 0);
46
47 180 const size_t block_idx = n_idx / kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme();
48 360 return block_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(k);
49 180 }
50
51 180 size_t kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n, size_t k) {
52 180 const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme());
53 360 return kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(n_rounded_up, k);
54 180 }
55
56 180 void kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(
57 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
58 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) {
59 KAI_ASSUME(num_groups == 1);
60 KAI_ASSUME(nr == kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme());
61 KAI_ASSUME(kr == kai_kr);
62 KAI_ASSUME(sr == 1);
63 KAI_ASSUME(rhs != NULL);
64 KAI_ASSUME(bias != NULL);
65 KAI_ASSUME(scale == NULL);
66 KAI_ASSUME(rhs_packed != NULL);
67 KAI_ASSUME(extra_bytes == 0);
68 KAI_ASSUME(params == NULL);
69
70 180 size_t height = k;
71 180 const size_t width = n;
72 180 const void* in = rhs;
73 180 void* out = rhs_packed;
74 180 const size_t in_stride = rhs_stride;
75 180 const float* pad_row = rhs;
76
77 180 size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(height);
78
79 180 kai_commit_za();
80
81 360 __asm__ __volatile__(
82 ".inst 0xd503477f // SMSTART ZA\n"
83 "mov x22, %x[out]\n"
84 "mov x21, %x[width]\n"
85 "ptrue p2.b\n"
86 "1:" // Bias: Full loop
87 "mov x20, x21\n"
88 "decw x21, ALL, MUL #2\n"
89 "whilelt p1.s, XZR, x20\n"
90 "decw x20\n"
91 "whilelt p0.s, XZR, x20\n"
92 "cmp x21, #0x0\n"
93 "ld1w { z17.s }, p1/Z, [%x[bias]]\n"
94 "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n"
95 "incb %x[bias], ALL, MUL #2\n"
96 "st1w { z17.s }, p2, [x22]\n"
97 "st1w { z16.s }, p2, [x22, #1, MUL VL]\n"
98 "add x22, x22, %x[out_stride]\n"
99 "bgt 1b\n"
100 "cmp %x[height], #0x8\n"
101 "incb %x[out], ALL, MUL #2\n"
102 "blt 5f\n"
103 "2:" // Main row loop: Head
104 "mov x10, %x[in]\n"
105 "mov x9, %x[out]\n"
106 "add x28, x10, %x[in_stride]\n"
107 "sub %x[height], %x[height], #0x8\n"
108 "add x27, x28, %x[in_stride]\n"
109 "mov x26, %x[width]\n"
110 "add x25, x27, %x[in_stride]\n"
111 "add x24, x25, %x[in_stride]\n"
112 "add x23, x24, %x[in_stride]\n"
113 "add x22, x23, %x[in_stride]\n"
114 "add x21, x22, %x[in_stride]\n"
115 "add %x[in], x21, %x[in_stride]\n"
116 "3:" // Main row loop: Column loop
117 "mov x20, x26\n"
118 "decw x26, ALL, MUL #2\n"
119 "whilelt p1.s, XZR, x20\n"
120 "decw x20\n"
121 "whilelt p0.s, XZR, x20\n"
122 "ld1w { z19.s }, p1/Z, [x10]\n"
123 "cmp x26, #0x0\n"
124 "ld1w { z18.s }, p0/Z, [x10, #1, MUL VL]\n"
125 "addvl x10, x10, #2\n"
126 "ld1w { z17.s }, p1/Z, [x27]\n"
127 "ld1w { z16.s }, p0/Z, [x27, #1, MUL VL]\n"
128 ".inst 0x658aaa7b // bfcvt z27.h, p2/M, z19.s\n"
129 "addvl x27, x27, #2\n"
130 "ld1w { z19.s }, p1/Z, [x24]\n"
131 ".inst 0x658aaa5a // bfcvt z26.h, p2/M, z18.s\n"
132 "ld1w { z18.s }, p0/Z, [x24, #1, MUL VL]\n"
133 ".inst 0x658aaa39 // bfcvt z25.h, p2/M, z17.s\n"
134 "addvl x24, x24, #2\n"
135 "ld1w { z17.s }, p1/Z, [x22]\n"
136 ".inst 0x658aaa18 // bfcvt z24.h, p2/M, z16.s\n"
137 "ld1w { z16.s }, p0/Z, [x22, #1, MUL VL]\n"
138 ".inst 0x658aaa77 // bfcvt z23.h, p2/M, z19.s\n"
139 "addvl x22, x22, #2\n"
140 ".inst 0x658aaa56 // bfcvt z22.h, p2/M, z18.s\n"
141 "ld1w { z19.s }, p1/Z, [x28]\n"
142 ".inst 0x658aaa35 // bfcvt z21.h, p2/M, z17.s\n"
143 "ld1w { z18.s }, p0/Z, [x28, #1, MUL VL]\n"
144 "addvl x28, x28, #2\n"
145 ".inst 0x658aaa14 // bfcvt z20.h, p2/M, z16.s\n"
146 "ld1w { z17.s }, p1/Z, [x25]\n"
147 "ld1w { z16.s }, p0/Z, [x25, #1, MUL VL]\n"
148 "addvl x25, x25, #2\n"
149 ".inst 0x648aaa7b // bfcvtnt z27.h, p2/M, z19.s\n"
150 "ld1w { z19.s }, p1/Z, [x23]\n"
151 ".inst 0x648aaa5a // bfcvtnt z26.h, p2/M, z18.s\n"
152 "ld1w { z18.s }, p0/Z, [x23, #1, MUL VL]\n"
153 "addvl x23, x23, #2\n"
154 ".inst 0x648aaa39 // bfcvtnt z25.h, p2/M, z17.s\n"
155 "ld1w { z17.s }, p1/Z, [x21]\n"
156 ".inst 0x648aaa18 // bfcvtnt z24.h, p2/M, z16.s\n"
157 "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n"
158 "addvl x21, x21, #2\n"
159 ".inst 0x648aaa77 // bfcvtnt z23.h, p2/M, z19.s\n"
160 "st1h { z27.h }, p2, [x9]\n"
161 ".inst 0x648aaa56 // bfcvtnt z22.h, p2/M, z18.s\n"
162 "st1h { z26.h }, p2, [x9, #1, MUL VL]\n"
163 ".inst 0x648aaa35 // bfcvtnt z21.h, p2/M, z17.s\n"
164 "st1h { z25.h }, p2, [x9, #2, MUL VL]\n"
165 ".inst 0x648aaa14 // bfcvtnt z20.h, p2/M, z16.s\n"
166 "st1h { z24.h }, p2, [x9, #3, MUL VL]\n"
167 "st1h { z23.h }, p2, [x9, #4, MUL VL]\n"
168 "st1h { z22.h }, p2, [x9, #5, MUL VL]\n"
169 "st1h { z21.h }, p2, [x9, #6, MUL VL]\n"
170 "st1h { z20.h }, p2, [x9, #7, MUL VL]\n"
171 "add x9, x9, %x[out_stride]\n"
172 "bgt 3b\n"
173 "cmp %x[height], #0x8\n"
174 "addvl %x[out], %x[out], #8\n"
175 "bge 2b\n"
176 "cbz %x[height], 9f\n"
177 "5:" // Main loop skip
178 "6:" // Tail row loop: Head
179 "mov x10, %x[in]\n"
180 "cmp %x[height], #0x1\n"
181 "add x28, x10, %x[in_stride]\n"
182 "mov x9, %x[out]\n"
183 "add %x[in], x28, %x[in_stride]\n"
184 "csel x28, x28, %x[pad_row], GT\n"
185 "sub %x[height], %x[height], #0x2\n"
186 "mov x21, %x[width]\n"
187 "7:" // Tail row loop: Column loop
188 "mov x20, x21\n"
189 "decw x21, ALL, MUL #2\n"
190 "whilelt p1.s, XZR, x20\n"
191 "decw x20\n"
192 "whilelt p0.s, XZR, x20\n"
193 "ld1w { z17.s }, p1/Z, [x10]\n"
194 "cmp x21, #0x0\n"
195 "ld1w { z16.s }, p0/Z, [x10, #1, MUL VL]\n"
196 "addvl x10, x10, #2\n"
197 "ld1w { z19.s }, p1/Z, [x28]\n"
198 ".inst 0x658aaa32 // bfcvt z18.h, p2/M, z17.s\n"
199 "ld1w { z17.s }, p0/Z, [x28, #1, MUL VL]\n"
200 "addvl x28, x28, #2\n"
201 ".inst 0x658aaa10 // bfcvt z16.h, p2/M, z16.s\n"
202 ".inst 0x648aaa72 // bfcvtnt z18.h, p2/M, z19.s\n"
203 ".inst 0x648aaa30 // bfcvtnt z16.h, p2/M, z17.s\n"
204 "st1h { z18.h }, p2, [x9]\n"
205 "st1h { z16.h }, p2, [x9, #1, MUL VL]\n"
206 "add x9, x9, %x[out_stride]\n"
207 "bgt 7b\n"
208 "cmp %x[height], #0x1\n"
209 "addvl %x[out], %x[out], #2\n"
210 "bge 6b\n"
211 "9:" // Done
212 ".inst 0xd503467f // SMSTOP\n"
213 : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out)
214 180 : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width)
215 : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7",
216 "p8", "p9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", "z10",
217 "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25",
218 "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9");
219 180 }
220
221 #endif // Architectural features check.
222