KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 100.0% 20 12 32
Functions: 100.0% 6 0 6
Branches: -% 0 24 24

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10
11 #include "kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_nr = 8;
19 static const size_t kai_kr = 1;
20
21 17 size_t kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(void) {
22 17 return kai_nr;
23 }
24
25 17 size_t kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx) {
26 KAI_ASSUME(n_idx % kai_nr == 0);
27
28 17 return n_idx * sizeof(uint32_t);
29 }
30
31 17 size_t kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx) {
32 17 return n_idx * sizeof(uint32_t);
33 }
34
35 34 size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx, size_t k) {
36 KAI_ASSUME(n_idx % kai_nr == 0);
37
38 34 return n_idx * (sizeof(uint32_t) + k * sizeof(uint32_t));
39 }
40
41 17 size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n, size_t k) {
42 17 return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(kai_roundup(n, kai_nr), k);
43 }
44
45 17 void kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(
46 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
47 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) {
48 KAI_ASSUME(num_groups == 1);
49 KAI_ASSUME(nr == kai_nr);
50 KAI_ASSUME(kr == kai_kr);
51 KAI_ASSUME(sr == 1);
52 KAI_ASSUME(rhs != NULL);
53 KAI_ASSUME(bias != NULL);
54 KAI_ASSUME(scale == NULL);
55 KAI_ASSUME(rhs_packed != NULL);
56 KAI_ASSUME(extra_bytes == 0);
57 KAI_ASSUME(params == NULL);
58
59 17 size_t height = k;
60 17 const size_t width = n;
61 17 const void* in = rhs;
62 17 void* out = rhs_packed;
63 17 const size_t in_stride = rhs_stride;
64 17 size_t out_stride = kai_nr * height * sizeof(uint32_t) + kai_nr * sizeof(uint32_t);
65
66 34 __asm__ __volatile__(
67 "mov x22, %x[width]\n"
68 "mov x21, %x[out]\n"
69 "cmp x22, #0x8\n"
70 "blt 2f\n"
71 "1:" // Bias: Full loop
72 "ldr q17, [%x[bias], #0x0]\n"
73 "ldr q16, [%x[bias], #0x10]\n"
74 "sub x22, x22, #0x8\n"
75 "add %x[bias], %x[bias], #0x20\n"
76 "cmp x22, #0x8\n"
77 "str q17, [x21, #0x0]\n"
78 "str q16, [x21, #0x10]\n"
79 "add x21, x21, %x[out_stride]\n"
80 "bge 1b\n"
81 "cbz x22, 3f\n"
82 "2:" // Bias: Tail loop
83 "ldr w20, [%x[bias], #0x0]\n"
84 "sub x22, x22, #0x1\n"
85 "add %x[bias], %x[bias], #0x4\n"
86 "cmp x22, #0x0\n"
87 "str x20, [x21]\n"
88 "add x21, x21, #0x4\n"
89 "bgt 2b\n"
90 "3:" // Bias: Done
91 "cmp %x[height], #0x4\n"
92 "add %x[out], %x[out], #0x20\n"
93 "blt 12f\n"
94 "4:" // Main row loop: Head
95 "mov x25, %x[in]\n"
96 "mov x24, %x[width]\n"
97 "mov x23, %x[out]\n"
98 "sub %x[height], %x[height], #0x4\n"
99 "add x22, x25, %x[in_stride]\n"
100 "add x21, x22, %x[in_stride]\n"
101 "add x20, x21, %x[in_stride]\n"
102 "cmp x24, #0x8\n"
103 "add %x[in], x20, %x[in_stride]\n"
104 "blt 6f\n"
105 "5:" // Main row loop: Column loop
106 "ldr q23, [x25], #0x10\n"
107 "ldr q22, [x22], #0x10\n"
108 "sub x24, x24, #0x8\n"
109 "ldr q21, [x21], #0x10\n"
110 "ldr q20, [x20], #0x10\n"
111 "cmp x24, #0x8\n"
112 "ldr q19, [x25], #0x10\n"
113 "ldr q18, [x22], #0x10\n"
114 "ldr q17, [x21], #0x10\n"
115 "ldr q16, [x20], #0x10\n"
116 "str q23, [x23, #0x0]\n"
117 "str q19, [x23, #0x10]\n"
118 "str q22, [x23, #0x20]\n"
119 "str q18, [x23, #0x30]\n"
120 "str q21, [x23, #0x40]\n"
121 "str q17, [x23, #0x50]\n"
122 "str q20, [x23, #0x60]\n"
123 "str q16, [x23, #0x70]\n"
124 "add x23, x23, %x[out_stride]\n"
125 "bge 5b\n"
126 "6:" // Main row loop: Column loop skip
127 "cbz x24, 11f\n"
128 "cmp x24, #0x4\n"
129 "movi v16.4s, #0x0\n"
130 "str q16, [x23, #0x0]\n"
131 "str q16, [x23, #0x10]\n"
132 "str q16, [x23, #0x20]\n"
133 "str q16, [x23, #0x30]\n"
134 "str q16, [x23, #0x40]\n"
135 "str q16, [x23, #0x50]\n"
136 "str q16, [x23, #0x60]\n"
137 "str q16, [x23, #0x70]\n"
138 "blt 8f\n"
139 "7:" // Main row loop: width 4 loop: loop
140 "ldr q19, [x25], #0x10\n"
141 "ldr q18, [x22], #0x10\n"
142 "sub x24, x24, #0x4\n"
143 "ldr q17, [x21], #0x10\n"
144 "ldr q16, [x20], #0x10\n"
145 "cmp x24, #0x4\n"
146 "str q19, [x23, #0x0]\n"
147 "str q18, [x23, #0x20]\n"
148 "str q17, [x23, #0x40]\n"
149 "str q16, [x23, #0x60]\n"
150 "add x23, x23, #0x10\n"
151 "bge 7b\n"
152 "8:" // Main row loop: width 4 loop: skip
153 "cmp x24, #0x1\n"
154 "blt 10f\n"
155 "9:" // Main row loop: width 1 loop: loop
156 "ldr s19, [x25], #0x4\n"
157 "ldr s18, [x22], #0x4\n"
158 "sub x24, x24, #0x1\n"
159 "ldr s17, [x21], #0x4\n"
160 "ldr s16, [x20], #0x4\n"
161 "cmp x24, #0x1\n"
162 "str s19, [x23, #0x0]\n"
163 "str s18, [x23, #0x20]\n"
164 "str s17, [x23, #0x40]\n"
165 "str s16, [x23, #0x60]\n"
166 "add x23, x23, #0x4\n"
167 "bge 9b\n"
168 "10:" // Main row loop: width 1 loop: skip
169 "11:" // Main row loop: odd col skip
170 "cmp %x[height], #0x4\n"
171 "add %x[out], %x[out], #0x80\n"
172 "bge 4b\n"
173 "cbz %x[height], 21f\n"
174 "12:" // Main loop skip
175 "13:" // Tail row loop: Head
176 "mov x20, %x[width]\n"
177 "mov x25, %x[in]\n"
178 "mov x23, %x[out]\n"
179 "sub %x[height], %x[height], #0x1\n"
180 "cmp x20, #0x8\n"
181 "add %x[in], x25, %x[in_stride]\n"
182 "blt 15f\n"
183 "14:" // Tail row loop: Column loop
184 "ldr q17, [x25], #0x10\n"
185 "sub x20, x20, #0x8\n"
186 "ldr q16, [x25], #0x10\n"
187 "cmp x20, #0x8\n"
188 "str q17, [x23, #0x0]\n"
189 "str q16, [x23, #0x10]\n"
190 "add x23, x23, %x[out_stride]\n"
191 "bge 14b\n"
192 "15:" // Tail row loop: Column loop skip
193 "cbz x20, 20f\n"
194 "cmp x20, #0x4\n"
195 "movi v16.4s, #0x0\n"
196 "str q16, [x23, #0x0]\n"
197 "str q16, [x23, #0x10]\n"
198 "blt 17f\n"
199 "16:" // Tail row loop: width 4 loop: loop
200 "ldr q16, [x25], #0x10\n"
201 "sub x20, x20, #0x4\n"
202 "cmp x20, #0x4\n"
203 "str q16, [x23, #0x0]\n"
204 "add x23, x23, #0x10\n"
205 "bge 16b\n"
206 "17:" // Tail row loop: width 4 loop: skip
207 "cmp x20, #0x1\n"
208 "blt 19f\n"
209 "18:" // Tail row loop: width 1 loop: loop
210 "ldr s16, [x25], #0x4\n"
211 "sub x20, x20, #0x1\n"
212 "cmp x20, #0x1\n"
213 "str s16, [x23, #0x0]\n"
214 "add x23, x23, #0x4\n"
215 "bge 18b\n"
216 "19:" // Tail row loop: width 1 loop: skip
217 "20:" // Tail row loop: odd col skip
218 "cmp %x[height], #0x1\n"
219 "add %x[out], %x[out], #0x20\n"
220 "bge 13b\n"
221 "21:" // Done
222 : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out)
223 17 : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [width] "r"(width)
224 : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24",
225 "x25");
226 17 }
227
228 #endif // Architectural features check.
229