KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 92.9% 26 12 40
Functions: 85.7% 6 0 7
Branches: -% 0 24 24

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15 #include <string.h>
16
17 #include "kai/kai_common.h"
18
19 static const size_t kai_nr = 2;
20 static const size_t kai_kr = 2;
21 static const size_t kai_num_bytes_input = 4;
22 static const size_t kai_num_bytes_output = 2;
23 static const size_t kai_num_bytes_bias = 4;
24
25 382 size_t kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(void) {
26 382 return kai_nr * kai_get_sme_vector_length_u16() / kai_kr;
27 }
28
29 46 size_t kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx) {
30 KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() == 0);
31
32 46 return n_idx * kai_num_bytes_input;
33 }
34
35 size_t kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx) {
36 return n_idx * kai_num_bytes_bias;
37 }
38
39 112 size_t kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t k) {
40 224 return kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() *
41 112 (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output);
42 }
43
44 56 size_t kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx, size_t k) {
45 KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() == 0);
46
47 56 const size_t block_idx = n_idx / kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme();
48 112 return block_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(k);
49 56 }
50
51 56 size_t kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n, size_t k) {
52 56 const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme());
53 112 return kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(n_rounded_up, k);
54 56 }
55
56 56 void kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(
57 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
58 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) {
59 KAI_ASSUME(num_groups == 1);
60 KAI_ASSUME(nr == kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme());
61 KAI_ASSUME(kr == kai_kr);
62 KAI_ASSUME(sr == 1);
63 KAI_ASSUME(rhs != NULL);
64 KAI_ASSUME(bias != NULL);
65 KAI_ASSUME(scale == NULL);
66 KAI_ASSUME(rhs_packed != NULL);
67 KAI_ASSUME(extra_bytes == 0);
68 KAI_ASSUME(params == NULL);
69
70 56 size_t height = k;
71 56 const size_t width = n;
72 56 const void* in = rhs;
73 56 void* out = rhs_packed;
74 56 const size_t in_stride = rhs_stride;
75 56 const float* pad_row = rhs;
76
77 56 size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(height);
78
79 112 __asm__ __volatile__(
80 ".inst 0xd503477f // SMSTART ZA\n"
81 "mov x22, %x[out]\n"
82 "mov x21, %x[width]\n"
83 "ptrue p2.b\n"
84 "1:" // Bias: Full loop
85 "mov x20, x21\n"
86 "decw x21, ALL, MUL #2\n"
87 "whilelt p1.s, XZR, x20\n"
88 "decw x20\n"
89 "whilelt p0.s, XZR, x20\n"
90 "cmp x21, #0x0\n"
91 "ld1w { z17.s }, p1/Z, [%x[bias]]\n"
92 "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n"
93 "incb %x[bias], ALL, MUL #2\n"
94 "st1w { z17.s }, p2, [x22]\n"
95 "st1w { z16.s }, p2, [x22, #1, MUL VL]\n"
96 "add x22, x22, %x[out_stride]\n"
97 "bgt 1b\n"
98 "cmp %x[height], #0x8\n"
99 "incb %x[out], ALL, MUL #2\n"
100 "blt 5f\n"
101 "2:" // Main row loop: Head
102 "mov x10, %x[in]\n"
103 "mov x9, %x[out]\n"
104 "add x28, x10, %x[in_stride]\n"
105 "sub %x[height], %x[height], #0x8\n"
106 "add x27, x28, %x[in_stride]\n"
107 "mov x26, %x[width]\n"
108 "add x25, x27, %x[in_stride]\n"
109 "add x24, x25, %x[in_stride]\n"
110 "add x23, x24, %x[in_stride]\n"
111 "add x22, x23, %x[in_stride]\n"
112 "add x21, x22, %x[in_stride]\n"
113 "add %x[in], x21, %x[in_stride]\n"
114 "3:" // Main row loop: Column loop
115 "mov x20, x26\n"
116 "decw x26, ALL, MUL #2\n"
117 "whilelt p1.s, XZR, x20\n"
118 "decw x20\n"
119 "whilelt p0.s, XZR, x20\n"
120 "ld1w { z19.s }, p1/Z, [x10]\n"
121 "cmp x26, #0x0\n"
122 "ld1w { z18.s }, p0/Z, [x10, #1, MUL VL]\n"
123 "addvl x10, x10, #2\n"
124 "ld1w { z17.s }, p1/Z, [x27]\n"
125 "ld1w { z16.s }, p0/Z, [x27, #1, MUL VL]\n"
126 ".inst 0x658aaa7b // bfcvt z27.h, p2/M, z19.s\n"
127 "addvl x27, x27, #2\n"
128 "ld1w { z19.s }, p1/Z, [x24]\n"
129 ".inst 0x658aaa5a // bfcvt z26.h, p2/M, z18.s\n"
130 "ld1w { z18.s }, p0/Z, [x24, #1, MUL VL]\n"
131 ".inst 0x658aaa39 // bfcvt z25.h, p2/M, z17.s\n"
132 "addvl x24, x24, #2\n"
133 "ld1w { z17.s }, p1/Z, [x22]\n"
134 ".inst 0x658aaa18 // bfcvt z24.h, p2/M, z16.s\n"
135 "ld1w { z16.s }, p0/Z, [x22, #1, MUL VL]\n"
136 ".inst 0x658aaa77 // bfcvt z23.h, p2/M, z19.s\n"
137 "addvl x22, x22, #2\n"
138 ".inst 0x658aaa56 // bfcvt z22.h, p2/M, z18.s\n"
139 "ld1w { z19.s }, p1/Z, [x28]\n"
140 ".inst 0x658aaa35 // bfcvt z21.h, p2/M, z17.s\n"
141 "ld1w { z18.s }, p0/Z, [x28, #1, MUL VL]\n"
142 "addvl x28, x28, #2\n"
143 ".inst 0x658aaa14 // bfcvt z20.h, p2/M, z16.s\n"
144 "ld1w { z17.s }, p1/Z, [x25]\n"
145 "ld1w { z16.s }, p0/Z, [x25, #1, MUL VL]\n"
146 "addvl x25, x25, #2\n"
147 ".inst 0x648aaa7b // bfcvtnt z27.h, p2/M, z19.s\n"
148 "ld1w { z19.s }, p1/Z, [x23]\n"
149 ".inst 0x648aaa5a // bfcvtnt z26.h, p2/M, z18.s\n"
150 "ld1w { z18.s }, p0/Z, [x23, #1, MUL VL]\n"
151 "addvl x23, x23, #2\n"
152 ".inst 0x648aaa39 // bfcvtnt z25.h, p2/M, z17.s\n"
153 "ld1w { z17.s }, p1/Z, [x21]\n"
154 ".inst 0x648aaa18 // bfcvtnt z24.h, p2/M, z16.s\n"
155 "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n"
156 "addvl x21, x21, #2\n"
157 ".inst 0x648aaa77 // bfcvtnt z23.h, p2/M, z19.s\n"
158 "st1h { z27.h }, p2, [x9]\n"
159 ".inst 0x648aaa56 // bfcvtnt z22.h, p2/M, z18.s\n"
160 "st1h { z26.h }, p2, [x9, #1, MUL VL]\n"
161 ".inst 0x648aaa35 // bfcvtnt z21.h, p2/M, z17.s\n"
162 "st1h { z25.h }, p2, [x9, #2, MUL VL]\n"
163 ".inst 0x648aaa14 // bfcvtnt z20.h, p2/M, z16.s\n"
164 "st1h { z24.h }, p2, [x9, #3, MUL VL]\n"
165 "st1h { z23.h }, p2, [x9, #4, MUL VL]\n"
166 "st1h { z22.h }, p2, [x9, #5, MUL VL]\n"
167 "st1h { z21.h }, p2, [x9, #6, MUL VL]\n"
168 "st1h { z20.h }, p2, [x9, #7, MUL VL]\n"
169 "add x9, x9, %x[out_stride]\n"
170 "bgt 3b\n"
171 "cmp %x[height], #0x8\n"
172 "addvl %x[out], %x[out], #8\n"
173 "bge 2b\n"
174 "cbz %x[height], 9f\n"
175 "5:" // Main loop skip
176 "6:" // Tail row loop: Head
177 "mov x10, %x[in]\n"
178 "cmp %x[height], #0x1\n"
179 "add x28, x10, %x[in_stride]\n"
180 "mov x9, %x[out]\n"
181 "add %x[in], x28, %x[in_stride]\n"
182 "csel x28, x28, %x[pad_row], GT\n"
183 "sub %x[height], %x[height], #0x2\n"
184 "mov x21, %x[width]\n"
185 "7:" // Tail row loop: Column loop
186 "mov x20, x21\n"
187 "decw x21, ALL, MUL #2\n"
188 "whilelt p1.s, XZR, x20\n"
189 "decw x20\n"
190 "whilelt p0.s, XZR, x20\n"
191 "ld1w { z17.s }, p1/Z, [x10]\n"
192 "cmp x21, #0x0\n"
193 "ld1w { z16.s }, p0/Z, [x10, #1, MUL VL]\n"
194 "addvl x10, x10, #2\n"
195 "ld1w { z19.s }, p1/Z, [x28]\n"
196 ".inst 0x658aaa32 // bfcvt z18.h, p2/M, z17.s\n"
197 "ld1w { z17.s }, p0/Z, [x28, #1, MUL VL]\n"
198 "addvl x28, x28, #2\n"
199 ".inst 0x658aaa10 // bfcvt z16.h, p2/M, z16.s\n"
200 ".inst 0x648aaa72 // bfcvtnt z18.h, p2/M, z19.s\n"
201 ".inst 0x648aaa30 // bfcvtnt z16.h, p2/M, z17.s\n"
202 "st1h { z18.h }, p2, [x9]\n"
203 "st1h { z16.h }, p2, [x9, #1, MUL VL]\n"
204 "add x9, x9, %x[out_stride]\n"
205 "bgt 7b\n"
206 "cmp %x[height], #0x1\n"
207 "addvl %x[out], %x[out], #2\n"
208 "bge 6b\n"
209 "9:" // Done
210 ".inst 0xd503467f // SMSTOP\n"
211 : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out)
212 56 : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width)
213 : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7",
214 "p8", "p9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", "z10",
215 "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25",
216 "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9");
217 56 }
218
219 #endif // Architectural features check.
220