KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.6% 41 7 49
Functions: 100.0% 14 0 14
Branches: 50.0% 1 14 16

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__ARM_FEATURE_MATMUL_INT8)
7 #error "I8mm extension required to compile this micro-kernel"
8 #else
9 #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h"
10
11 #include <arm_neon.h>
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16
17 static const size_t kai_m_step = 4;
18 static const size_t kai_n_step = 4;
19 static const size_t kai_mr = 4;
20 static const size_t kai_nr = 4;
21 static const size_t kai_kr = 16;
22 static const size_t kai_sr = 2;
23 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
24 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
25 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
26 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
27 static const size_t kai_num_bytes_bias = sizeof(float);
28
29 641 inline static size_t kai_k_roundedup(size_t k) {
30 // Since we pack a float and int32 value at the end of the row,
31 // we must make sure that k is a multiple of 4 for alignment
32 641 size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4);
33 1282 return kai_roundup(k, kr_sr_roundedup4);
34 641 }
35
36 240 inline static size_t kai_lhs_packed_stride(size_t k) {
37 240 const size_t k_internal = kai_k_roundedup(k);
38
39 KAI_ASSERT((k_internal % 2) == 0);
40
41 480 return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
42 240 }
43
44 240 inline static size_t kai_rhs_packed_stride(size_t k) {
45 240 const size_t k_internal = kai_k_roundedup(k);
46
47 KAI_ASSERT((k_internal % 2) == 0);
48
49 480 return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
50 240 }
51
52 320 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
53 320 return kai_m_step;
54 }
55
56 320 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
57 320 return kai_n_step;
58 }
59
60 240 size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
61 240 return kai_mr;
62 }
63
64 240 size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
65 240 return kai_nr;
66 }
67
68 320 size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
69 320 return kai_kr;
70 }
71
72 320 size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
73 320 return kai_sr;
74 }
75
76 240 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k) {
77 KAI_ASSERT((m_idx % kai_m_step) == 0);
78
79 240 return (m_idx / kai_mr) * kai_lhs_packed_stride(k);
80 }
81
82 240 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k) {
83 KAI_ASSERT((n_idx % kai_n_step) == 0);
84
85 240 return (n_idx / kai_nr) * kai_rhs_packed_stride(k);
86 }
87
88 160 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(
89 size_t m_idx, size_t n_idx, size_t dst_stride) {
90 KAI_ASSERT((m_idx % kai_m_step) == 0);
91 KAI_ASSERT((n_idx % kai_n_step) == 0);
92
93 160 return (n_idx * sizeof(float)) + m_idx * dst_stride;
94 }
95
96 160 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n) {
97 160 return m * n * sizeof(float);
98 }
99
100 161 void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(
101 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed,
102 float* dst, // NOLINT(readability-non-const-parameter)
103 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
104 KAI_ASSERT(dst_stride_col == sizeof(float));
105
106
1/2
✓ Branch 0 taken 161 times.
✗ Branch 1 not taken.
161 if (m == 0) {
107 return;
108 }
109
110 161 const size_t k_internal = kai_k_roundedup(k);
111
112 161 size_t num_blocks = k_internal / 32;
113
114 161 float clamp_vals[2] = {scalar_min, scalar_max};
115
116 322 __asm__ __volatile__(
117 "mov x28, #0x80\n"
118 "mov x20, #0x20\n"
119 "movi v4.16b, #0xf0\n"
120 "mov x27, %x[m]\n"
121 "madd x28, %x[num_blocks], x28, x20\n"
122 "cbz x27, 9f\n"
123 "1:" // Row loop
124 "mov x26, %x[rhs_packed]\n"
125 "mov x25, %x[n]\n"
126 "add x24, %x[dst], %x[dst_stride_row], LSL #2\n"
127 "2:" // Column loop
128 "mov x21, %x[lhs_packed]\n"
129 "movi v3.4s, #0x0\n"
130 "movi v2.4s, #0x0\n"
131 "mov x20, %x[num_blocks]\n"
132 "movi v1.4s, #0x0\n"
133 "movi v0.4s, #0x0\n"
134 "3:" // Sub block loop
135 "ldr q31, [x26, #0x0]\n"
136 "ldr q30, [x26, #0x10]\n"
137 "subs x20, x20, #0x1\n"
138 "ldr q29, [x21, #0x0]\n"
139 "ldr q28, [x21, #0x10]\n"
140 "ldr q27, [x26, #0x20]\n"
141 "ldr q26, [x26, #0x30]\n"
142 "add x26, x26, #0x40\n"
143 "ldr q25, [x21, #0x20]\n"
144 "ldr q24, [x21, #0x30]\n"
145 "shl v23.16b, v31.16b, #0x4\n"
146 "shl v22.16b, v30.16b, #0x4\n"
147 "ldr q21, [x21, #0x40]\n"
148 "ldr q20, [x21, #0x50]\n"
149 "and v31.16b, v31.16b, v4.16b\n"
150 "and v30.16b, v30.16b, v4.16b\n"
151 "ldr q19, [x21, #0x60]\n"
152 "ldr q18, [x21, #0x70]\n"
153 "shl v17.16b, v27.16b, #0x4\n"
154 "shl v16.16b, v26.16b, #0x4\n"
155 ".inst 0x4e97a7a3 // smmla v3.4s, v29.16b, v23.16b\n"
156 ".inst 0x4e96a7a2 // smmla v2.4s, v29.16b, v22.16b\n"
157 "and v27.16b, v27.16b, v4.16b\n"
158 "add x21, x21, #0x80\n"
159 ".inst 0x4e97a781 // smmla v1.4s, v28.16b, v23.16b\n"
160 ".inst 0x4e96a780 // smmla v0.4s, v28.16b, v22.16b\n"
161 "and v26.16b, v26.16b, v4.16b\n"
162 ".inst 0x4e91a723 // smmla v3.4s, v25.16b, v17.16b\n"
163 ".inst 0x4e90a722 // smmla v2.4s, v25.16b, v16.16b\n"
164 ".inst 0x4e91a701 // smmla v1.4s, v24.16b, v17.16b\n"
165 ".inst 0x4e90a700 // smmla v0.4s, v24.16b, v16.16b\n"
166 ".inst 0x4e9fa6a3 // smmla v3.4s, v21.16b, v31.16b\n"
167 ".inst 0x4e9ea6a2 // smmla v2.4s, v21.16b, v30.16b\n"
168 ".inst 0x4e9fa681 // smmla v1.4s, v20.16b, v31.16b\n"
169 ".inst 0x4e9ea680 // smmla v0.4s, v20.16b, v30.16b\n"
170 ".inst 0x4e9ba663 // smmla v3.4s, v19.16b, v27.16b\n"
171 ".inst 0x4e9aa662 // smmla v2.4s, v19.16b, v26.16b\n"
172 ".inst 0x4e9ba641 // smmla v1.4s, v18.16b, v27.16b\n"
173 ".inst 0x4e9aa640 // smmla v0.4s, v18.16b, v26.16b\n"
174 "bgt 3b\n"
175 "ldr q18, [x26, #0x0]\n"
176 "ld1 { v17.4s }, [x21]\n"
177 "uzp1 v24.2d, v3.2d, v2.2d\n"
178 "uzp2 v23.2d, v3.2d, v2.2d\n"
179 "ldr q22, [x26, #0x10]\n"
180 "uzp1 v21.2d, v1.2d, v0.2d\n"
181 "uzp2 v20.2d, v1.2d, v0.2d\n"
182 "add x21, x21, #0x10\n"
183 "ldr q16, [x21, #0x0]\n"
184 "add x26, x26, #0x20\n"
185 "mla v24.4s, v18.4s, v17.s[0]\n"
186 "mla v23.4s, v18.4s, v17.s[1]\n"
187 "mla v21.4s, v18.4s, v17.s[2]\n"
188 "mla v20.4s, v18.4s, v17.s[3]\n"
189 "fmul v19.4s, v22.4s, v16.s[0]\n"
190 "fmul v18.4s, v22.4s, v16.s[1]\n"
191 "fmul v17.4s, v22.4s, v16.s[2]\n"
192 "fmul v16.4s, v22.4s, v16.s[3]\n"
193 "scvtf v24.4s, v24.4s\n"
194 "scvtf v23.4s, v23.4s\n"
195 "scvtf v21.4s, v21.4s\n"
196 "scvtf v20.4s, v20.4s\n"
197 "fmul v3.4s, v24.4s, v19.4s\n"
198 "fmul v2.4s, v23.4s, v18.4s\n"
199 "fmul v1.4s, v21.4s, v17.4s\n"
200 "fmul v0.4s, v20.4s, v16.4s\n"
201 "ldr q18, [x26, #0x0]\n"
202 "ld1r { v17.4s }, [%x[clamp_vals]]\n"
203 "add x20, %x[clamp_vals], #0x4\n"
204 "cmp x25, #0x4\n"
205 "ld1r { v16.4s }, [x20]\n"
206 "add x26, x26, #0x10\n"
207 "fadd v3.4s, v3.4s, v18.4s\n"
208 "fadd v2.4s, v2.4s, v18.4s\n"
209 "fadd v1.4s, v1.4s, v18.4s\n"
210 "fadd v0.4s, v0.4s, v18.4s\n"
211 "fmax v3.4s, v3.4s, v17.4s\n"
212 "fmax v2.4s, v2.4s, v17.4s\n"
213 "fmax v1.4s, v1.4s, v17.4s\n"
214 "fmax v0.4s, v0.4s, v17.4s\n"
215 "fmin v3.4s, v3.4s, v16.4s\n"
216 "fmin v2.4s, v2.4s, v16.4s\n"
217 "fmin v1.4s, v1.4s, v16.4s\n"
218 "fmin v0.4s, v0.4s, v16.4s\n"
219 "blt 5f\n"
220 "mov x20, %x[dst]\n"
221 "cmp x27, #0x1\n"
222 "str q3, [x20, #0x0]\n"
223 "add x20, x20, %x[dst_stride_row]\n"
224 "ble 8f\n"
225 "cmp x27, #0x2\n"
226 "str q2, [x20, #0x0]\n"
227 "add x20, x20, %x[dst_stride_row]\n"
228 "ble 8f\n"
229 "cmp x27, #0x3\n"
230 "str q1, [x20, #0x0]\n"
231 "add x20, x20, %x[dst_stride_row]\n"
232 "ble 8f\n"
233 "str q0, [x20, #0x0]\n"
234 "b 8f\n"
235 "5:" // Partial output
236 "mov x23, %x[dst]\n"
237 "cmp x27, #0x1\n"
238 "add x22, x23, %x[dst_stride_row]\n"
239 "csel x22, x22, x23, GT\n"
240 "cmp x27, #0x2\n"
241 "add x21, x23, %x[dst_stride_row], LSL #1\n"
242 "csel x21, x21, x22, GT\n"
243 "cmp x27, #0x3\n"
244 "add x20, x21, %x[dst_stride_row]\n"
245 "csel x20, x20, x21, GT\n"
246 "tbz x25, #1, 6f\n"
247 "st1 { v0.d }[0], [x20], #0x8\n"
248 "st1 { v1.d }[0], [x21], #0x8\n"
249 "st1 { v2.d }[0], [x22], #0x8\n"
250 "st1 { v3.d }[0], [x23], #0x8\n"
251 "tbz x25, #0, 7f\n"
252 "st1 { v0.s }[2], [x20]\n"
253 "st1 { v1.s }[2], [x21]\n"
254 "st1 { v2.s }[2], [x22]\n"
255 "st1 { v3.s }[2], [x23]\n"
256 "b 7f\n"
257 "6:" // Output block 0: partial_1_0
258 "st1 { v0.s }[0], [x20]\n"
259 "st1 { v1.s }[0], [x21]\n"
260 "st1 { v2.s }[0], [x22]\n"
261 "st1 { v3.s }[0], [x23]\n"
262 "7:" // Output block 0: Done
263 "8:" // Output stage exit
264 "subs x25, x25, #0x4\n"
265 "add %x[dst], %x[dst], #0x10\n"
266 "bgt 2b\n"
267 "subs x27, x27, #0x4\n"
268 "add %x[lhs_packed], %x[lhs_packed], x28\n"
269 "mov %x[dst], x24\n"
270 "bgt 1b\n"
271 "9:" // Row loop skip
272 : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed)
273 161 : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n),
274 161 [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed)
275 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
276 "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27",
277 "x28");
278 161 }
279 #endif // Architectural feature check
280