KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.6% 41 7 49
Functions: 100.0% 14 0 14
Branches: 50.0% 1 14 16

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // Do not flag up inline assembly blocks
8 #pragma GCC diagnostic ignored "-Woverlength-strings"
9
10 #if !defined(__ARM_FEATURE_MATMUL_INT8)
11 #error "I8mm extension required to compile this micro-kernel"
12 #else
13 #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h"
14
15 #include <arm_neon.h>
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include "kai/kai_common.h"
20
21 static const size_t kai_m_step = 4;
22 static const size_t kai_n_step = 8;
23 static const size_t kai_mr = 4;
24 static const size_t kai_nr = 8;
25 static const size_t kai_kr = 16;
26 static const size_t kai_sr = 2;
27 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
28 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
29 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
30 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
31 static const size_t kai_num_bytes_bias = sizeof(float);
32
33 641 inline static size_t kai_k_roundedup(size_t k) {
34 // Since we pack a float and int32 value at the end of the row,
35 // we must make sure that k is a multiple of 4 for alignment
36 641 size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4);
37 1282 return kai_roundup(k, kr_sr_roundedup4);
38 641 }
39
40 240 inline static size_t kai_lhs_packed_stride(size_t k) {
41 240 const size_t k_internal = kai_k_roundedup(k);
42
43 KAI_ASSERT((k_internal % 2) == 0);
44
45 480 return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
46 240 }
47
48 240 inline static size_t kai_rhs_packed_stride(size_t k) {
49 240 const size_t k_internal = kai_k_roundedup(k);
50
51 KAI_ASSERT((k_internal % 2) == 0);
52
53 480 return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
54 240 }
55
56 320 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) {
57 320 return kai_m_step;
58 }
59
60 320 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) {
61 320 return kai_n_step;
62 }
63
64 240 size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) {
65 240 return kai_mr;
66 }
67
68 240 size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) {
69 240 return kai_nr;
70 }
71
72 320 size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) {
73 320 return kai_kr;
74 }
75
76 320 size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) {
77 320 return kai_sr;
78 }
79
80 240 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) {
81 KAI_ASSERT((m_idx % kai_m_step) == 0);
82
83 240 return (m_idx / kai_mr) * kai_lhs_packed_stride(k);
84 }
85
86 240 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k) {
87 KAI_ASSERT((n_idx % kai_n_step) == 0);
88
89 240 return (n_idx / kai_nr) * kai_rhs_packed_stride(k);
90 }
91
92 160 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(
93 size_t m_idx, size_t n_idx, size_t dst_stride) {
94 KAI_ASSERT((m_idx % kai_m_step) == 0);
95 KAI_ASSERT((n_idx % kai_n_step) == 0);
96
97 160 return (n_idx * sizeof(float)) + m_idx * dst_stride;
98 }
99
100 160 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n) {
101 160 return m * n * sizeof(float);
102 }
103
104 161 void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(
105 size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed,
106 float* restrict dst, // NOLINT(readability-non-const-parameter)
107 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
108 KAI_ASSERT(dst_stride_col == sizeof(float));
109
110
1/2
✓ Branch 0 taken 161 times.
✗ Branch 1 not taken.
161 if (m == 0) {
111 return;
112 }
113
114 161 const size_t k_internal = kai_k_roundedup(k);
115
116 161 size_t num_blocks = k_internal / 32;
117
118 161 float clamp_vals[2] = {scalar_min, scalar_max};
119
120 322 __asm__ __volatile__(
121 "mov x28, #0x80\n"
122 "mov x20, #0x20\n"
123 "movi v12.16b, #0xf0\n"
124 "mov x27, %x[m]\n"
125 "madd x28, %x[num_blocks], x28, x20\n"
126 "cbz x27, 11f\n"
127 "1:" // Row loop
128 "mov x26, %x[rhs_packed]\n"
129 "mov x25, %x[n]\n"
130 "add x24, %x[dst], %x[dst_stride_row], LSL #2\n"
131 "2:" // Column loop
132 "mov x21, %x[lhs_packed]\n"
133 "movi v11.4s, #0x0\n"
134 "movi v10.4s, #0x0\n"
135 "mov x20, %x[num_blocks]\n"
136 "movi v9.4s, #0x0\n"
137 "movi v8.4s, #0x0\n"
138 "movi v7.4s, #0x0\n"
139 "movi v6.4s, #0x0\n"
140 "movi v5.4s, #0x0\n"
141 "movi v4.4s, #0x0\n"
142 "3:" // Sub block loop
143 "ldr q3, [x26, #0x0]\n"
144 "ldr q2, [x26, #0x10]\n"
145 "subs x20, x20, #0x1\n"
146 "ldr q1, [x26, #0x20]\n"
147 "ldr q0, [x26, #0x30]\n"
148 "ldr q31, [x21, #0x0]\n"
149 "ldr q30, [x21, #0x10]\n"
150 "ldr q29, [x26, #0x40]\n"
151 "ldr q28, [x26, #0x50]\n"
152 "shl v19.16b, v3.16b, #0x4\n"
153 "shl v18.16b, v2.16b, #0x4\n"
154 "ldr q27, [x26, #0x60]\n"
155 "ldr q26, [x26, #0x70]\n"
156 "shl v17.16b, v1.16b, #0x4\n"
157 "shl v16.16b, v0.16b, #0x4\n"
158 "ldr q25, [x21, #0x20]\n"
159 "ldr q24, [x21, #0x30]\n"
160 "and v3.16b, v3.16b, v12.16b\n"
161 "and v2.16b, v2.16b, v12.16b\n"
162 "ldr q23, [x21, #0x40]\n"
163 "ldr q22, [x21, #0x50]\n"
164 ".inst 0x4e93a7eb // smmla v11.4s, v31.16b, v19.16b\n"
165 ".inst 0x4e92a7e9 // smmla v9.4s, v31.16b, v18.16b\n"
166 "ldr q21, [x21, #0x60]\n"
167 "ldr q20, [x21, #0x70]\n"
168 ".inst 0x4e91a7ea // smmla v10.4s, v31.16b, v17.16b\n"
169 ".inst 0x4e90a7e8 // smmla v8.4s, v31.16b, v16.16b\n"
170 ".inst 0x4e93a7c7 // smmla v7.4s, v30.16b, v19.16b\n"
171 ".inst 0x4e92a7c5 // smmla v5.4s, v30.16b, v18.16b\n"
172 "shl v19.16b, v29.16b, #0x4\n"
173 "add x26, x26, #0x80\n"
174 ".inst 0x4e91a7c6 // smmla v6.4s, v30.16b, v17.16b\n"
175 ".inst 0x4e90a7c4 // smmla v4.4s, v30.16b, v16.16b\n"
176 "shl v18.16b, v28.16b, #0x4\n"
177 "add x21, x21, #0x80\n"
178 "shl v17.16b, v27.16b, #0x4\n"
179 "shl v16.16b, v26.16b, #0x4\n"
180 ".inst 0x4e93a72b // smmla v11.4s, v25.16b, v19.16b\n"
181 "and v1.16b, v1.16b, v12.16b\n"
182 "and v0.16b, v0.16b, v12.16b\n"
183 ".inst 0x4e92a729 // smmla v9.4s, v25.16b, v18.16b\n"
184 ".inst 0x4e93a707 // smmla v7.4s, v24.16b, v19.16b\n"
185 ".inst 0x4e92a705 // smmla v5.4s, v24.16b, v18.16b\n"
186 "and v29.16b, v29.16b, v12.16b\n"
187 ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n"
188 ".inst 0x4e90a728 // smmla v8.4s, v25.16b, v16.16b\n"
189 "and v28.16b, v28.16b, v12.16b\n"
190 ".inst 0x4e91a706 // smmla v6.4s, v24.16b, v17.16b\n"
191 ".inst 0x4e90a704 // smmla v4.4s, v24.16b, v16.16b\n"
192 "and v27.16b, v27.16b, v12.16b\n"
193 ".inst 0x4e83a6eb // smmla v11.4s, v23.16b, v3.16b\n"
194 ".inst 0x4e82a6e9 // smmla v9.4s, v23.16b, v2.16b\n"
195 "and v26.16b, v26.16b, v12.16b\n"
196 ".inst 0x4e83a6c7 // smmla v7.4s, v22.16b, v3.16b\n"
197 ".inst 0x4e82a6c5 // smmla v5.4s, v22.16b, v2.16b\n"
198 ".inst 0x4e81a6ea // smmla v10.4s, v23.16b, v1.16b\n"
199 ".inst 0x4e80a6e8 // smmla v8.4s, v23.16b, v0.16b\n"
200 ".inst 0x4e81a6c6 // smmla v6.4s, v22.16b, v1.16b\n"
201 ".inst 0x4e80a6c4 // smmla v4.4s, v22.16b, v0.16b\n"
202 ".inst 0x4e9da6ab // smmla v11.4s, v21.16b, v29.16b\n"
203 ".inst 0x4e9ca6a9 // smmla v9.4s, v21.16b, v28.16b\n"
204 ".inst 0x4e9da687 // smmla v7.4s, v20.16b, v29.16b\n"
205 ".inst 0x4e9ca685 // smmla v5.4s, v20.16b, v28.16b\n"
206 ".inst 0x4e9ba6aa // smmla v10.4s, v21.16b, v27.16b\n"
207 ".inst 0x4e9aa6a8 // smmla v8.4s, v21.16b, v26.16b\n"
208 ".inst 0x4e9ba686 // smmla v6.4s, v20.16b, v27.16b\n"
209 ".inst 0x4e9aa684 // smmla v4.4s, v20.16b, v26.16b\n"
210 "bgt 3b\n"
211 "ldr q20, [x26, #0x0]\n"
212 "ldr q19, [x26, #0x10]\n"
213 "uzp1 v0.2d, v11.2d, v9.2d\n"
214 "uzp2 v31.2d, v11.2d, v9.2d\n"
215 "ld1 { v18.4s }, [x21]\n"
216 "ldr q17, [x26, #0x20]\n"
217 "uzp1 v30.2d, v10.2d, v8.2d\n"
218 "uzp2 v29.2d, v10.2d, v8.2d\n"
219 "ldr q28, [x26, #0x30]\n"
220 "uzp1 v27.2d, v7.2d, v5.2d\n"
221 "uzp2 v26.2d, v7.2d, v5.2d\n"
222 "add x21, x21, #0x10\n"
223 "ldr q16, [x21, #0x0]\n"
224 "uzp1 v25.2d, v6.2d, v4.2d\n"
225 "uzp2 v24.2d, v6.2d, v4.2d\n"
226 "add x26, x26, #0x40\n"
227 "mla v0.4s, v20.4s, v18.s[0]\n"
228 "mla v30.4s, v19.4s, v18.s[0]\n"
229 "mla v31.4s, v20.4s, v18.s[1]\n"
230 "mla v29.4s, v19.4s, v18.s[1]\n"
231 "mla v27.4s, v20.4s, v18.s[2]\n"
232 "mla v25.4s, v19.4s, v18.s[2]\n"
233 "fmul v23.4s, v17.4s, v16.s[0]\n"
234 "mla v26.4s, v20.4s, v18.s[3]\n"
235 "mla v24.4s, v19.4s, v18.s[3]\n"
236 "fmul v22.4s, v28.4s, v16.s[0]\n"
237 "scvtf v0.4s, v0.4s\n"
238 "scvtf v30.4s, v30.4s\n"
239 "fmul v21.4s, v17.4s, v16.s[1]\n"
240 "scvtf v31.4s, v31.4s\n"
241 "fmul v20.4s, v28.4s, v16.s[1]\n"
242 "scvtf v29.4s, v29.4s\n"
243 "fmul v19.4s, v17.4s, v16.s[2]\n"
244 "scvtf v27.4s, v27.4s\n"
245 "fmul v18.4s, v28.4s, v16.s[2]\n"
246 "scvtf v25.4s, v25.4s\n"
247 "fmul v17.4s, v17.4s, v16.s[3]\n"
248 "scvtf v26.4s, v26.4s\n"
249 "fmul v16.4s, v28.4s, v16.s[3]\n"
250 "scvtf v24.4s, v24.4s\n"
251 "fmul v11.4s, v0.4s, v23.4s\n"
252 "fmul v10.4s, v30.4s, v22.4s\n"
253 "fmul v9.4s, v31.4s, v21.4s\n"
254 "fmul v8.4s, v29.4s, v20.4s\n"
255 "fmul v7.4s, v27.4s, v19.4s\n"
256 "fmul v6.4s, v25.4s, v18.4s\n"
257 "fmul v5.4s, v26.4s, v17.4s\n"
258 "fmul v4.4s, v24.4s, v16.4s\n"
259 "ldr q19, [x26, #0x0]\n"
260 "ldr q18, [x26, #0x10]\n"
261 "add x20, %x[clamp_vals], #0x4\n"
262 "cmp x25, #0x8\n"
263 "ld1r { v17.4s }, [%x[clamp_vals]]\n"
264 "ld1r { v16.4s }, [x20]\n"
265 "add x26, x26, #0x20\n"
266 "fadd v11.4s, v11.4s, v19.4s\n"
267 "fadd v10.4s, v10.4s, v18.4s\n"
268 "fadd v9.4s, v9.4s, v19.4s\n"
269 "fadd v8.4s, v8.4s, v18.4s\n"
270 "fadd v7.4s, v7.4s, v19.4s\n"
271 "fadd v6.4s, v6.4s, v18.4s\n"
272 "fadd v5.4s, v5.4s, v19.4s\n"
273 "fadd v4.4s, v4.4s, v18.4s\n"
274 "fmax v11.4s, v11.4s, v17.4s\n"
275 "fmax v10.4s, v10.4s, v17.4s\n"
276 "fmax v9.4s, v9.4s, v17.4s\n"
277 "fmax v8.4s, v8.4s, v17.4s\n"
278 "fmax v7.4s, v7.4s, v17.4s\n"
279 "fmax v6.4s, v6.4s, v17.4s\n"
280 "fmax v5.4s, v5.4s, v17.4s\n"
281 "fmax v4.4s, v4.4s, v17.4s\n"
282 "fmin v11.4s, v11.4s, v16.4s\n"
283 "fmin v10.4s, v10.4s, v16.4s\n"
284 "fmin v9.4s, v9.4s, v16.4s\n"
285 "fmin v8.4s, v8.4s, v16.4s\n"
286 "fmin v7.4s, v7.4s, v16.4s\n"
287 "fmin v6.4s, v6.4s, v16.4s\n"
288 "fmin v5.4s, v5.4s, v16.4s\n"
289 "fmin v4.4s, v4.4s, v16.4s\n"
290 "blt 5f\n"
291 "mov x20, %x[dst]\n"
292 "cmp x27, #0x1\n"
293 "str q11, [x20, #0x0]\n"
294 "str q10, [x20, #0x10]\n"
295 "add x20, x20, %x[dst_stride_row]\n"
296 "ble 10f\n"
297 "cmp x27, #0x2\n"
298 "str q9, [x20, #0x0]\n"
299 "str q8, [x20, #0x10]\n"
300 "add x20, x20, %x[dst_stride_row]\n"
301 "ble 10f\n"
302 "cmp x27, #0x3\n"
303 "str q7, [x20, #0x0]\n"
304 "str q6, [x20, #0x10]\n"
305 "add x20, x20, %x[dst_stride_row]\n"
306 "ble 10f\n"
307 "str q5, [x20, #0x0]\n"
308 "str q4, [x20, #0x10]\n"
309 "b 10f\n"
310 "5:" // Partial output
311 "mov x23, %x[dst]\n"
312 "cmp x27, #0x1\n"
313 "add x22, x23, %x[dst_stride_row]\n"
314 "csel x22, x22, x23, GT\n"
315 "cmp x27, #0x2\n"
316 "add x21, x23, %x[dst_stride_row], LSL #1\n"
317 "csel x21, x21, x22, GT\n"
318 "cmp x27, #0x3\n"
319 "add x20, x21, %x[dst_stride_row]\n"
320 "csel x20, x20, x21, GT\n"
321 "tbz x25, #2, 7f\n"
322 "st1 { v5.4s }, [x20], #0x10\n"
323 "st1 { v7.4s }, [x21], #0x10\n"
324 "st1 { v9.4s }, [x22], #0x10\n"
325 "st1 { v11.4s }, [x23], #0x10\n"
326 "tbz x25, #1, 6f\n"
327 "st1 { v4.d }[0], [x20], #0x8\n"
328 "st1 { v6.d }[0], [x21], #0x8\n"
329 "st1 { v8.d }[0], [x22], #0x8\n"
330 "st1 { v10.d }[0], [x23], #0x8\n"
331 "tbz x25, #0, 9f\n"
332 "st1 { v4.s }[2], [x20]\n"
333 "st1 { v6.s }[2], [x21]\n"
334 "st1 { v8.s }[2], [x22]\n"
335 "st1 { v10.s }[2], [x23]\n"
336 "b 9f\n"
337 "6:" // Output block 0: partial_1_4
338 "tbz x25, #0, 9f\n"
339 "st1 { v4.s }[0], [x20]\n"
340 "st1 { v6.s }[0], [x21]\n"
341 "st1 { v8.s }[0], [x22]\n"
342 "st1 { v10.s }[0], [x23]\n"
343 "b 9f\n"
344 "7:" // Output block 0: partial_2_0
345 "tbz x25, #1, 8f\n"
346 "st1 { v5.d }[0], [x20], #0x8\n"
347 "st1 { v7.d }[0], [x21], #0x8\n"
348 "st1 { v9.d }[0], [x22], #0x8\n"
349 "st1 { v11.d }[0], [x23], #0x8\n"
350 "tbz x25, #0, 9f\n"
351 "st1 { v5.s }[2], [x20]\n"
352 "st1 { v7.s }[2], [x21]\n"
353 "st1 { v9.s }[2], [x22]\n"
354 "st1 { v11.s }[2], [x23]\n"
355 "b 9f\n"
356 "8:" // Output block 0: partial_1_0
357 "st1 { v5.s }[0], [x20]\n"
358 "st1 { v7.s }[0], [x21]\n"
359 "st1 { v9.s }[0], [x22]\n"
360 "st1 { v11.s }[0], [x23]\n"
361 "9:" // Output block 0: Done
362 "10:" // Output stage exit
363 "subs x25, x25, #0x8\n"
364 "add %x[dst], %x[dst], #0x20\n"
365 "bgt 2b\n"
366 "subs x27, x27, #0x4\n"
367 "add %x[lhs_packed], %x[lhs_packed], x28\n"
368 "mov %x[dst], x24\n"
369 "bgt 1b\n"
370 "11:" // Row loop skip
371 : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed)
372 161 : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n),
373 161 [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed)
374 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17",
375 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20",
376 "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28");
377 161 }
378 #endif // Architectural feature check
379