KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.6% 41 / 7 / 49
Functions: 100.0% 14 / 0 / 14
Branches: 50.0% 1 / 14 / 16

kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__ARM_FEATURE_MATMUL_INT8)
7 #error "I8mm extension required to compile this micro-kernel"
8 #else
9 #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h"
10
11 #include <arm_neon.h>
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16
17 static const size_t kai_m_step = 4;
18 static const size_t kai_n_step = 4;
19 static const size_t kai_mr = 4;
20 static const size_t kai_nr = 4;
21 static const size_t kai_kr = 16;
22 static const size_t kai_sr = 2;
23 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
24 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
25 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
26 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
27 static const size_t kai_num_bytes_bias = sizeof(float);
28
29 3846 inline static size_t kai_k_roundedup(size_t k) {
30 // Since we pack a float and int32 value at the end of the row,
31 // we must make sure that k is a multiple of 4 for alignment
32 3846 size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4);
33 7692 return kai_roundup(k, kr_sr_roundedup4);
34 3846 }
35
36 1440 inline static size_t kai_lhs_packed_stride(size_t k) {
37 1440 const size_t k_internal = kai_k_roundedup(k);
38
39 KAI_ASSERT((k_internal % 2) == 0);
40
41 2880 return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
42 1440 }
43
44 1440 inline static size_t kai_rhs_packed_stride(size_t k) {
45 1440 const size_t k_internal = kai_k_roundedup(k);
46
47 KAI_ASSERT((k_internal % 2) == 0);
48
49 2880 return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
50 1440 }
51
52 1920 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
53 1920 return kai_m_step;
54 }
55
56 1920 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
57 1920 return kai_n_step;
58 }
59
60 1440 size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
61 1440 return kai_mr;
62 }
63
64 1440 size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
65 1440 return kai_nr;
66 }
67
68 1920 size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
69 1920 return kai_kr;
70 }
71
72 1920 size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) {
73 1920 return kai_sr;
74 }
75
76 1440 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k) {
77 KAI_ASSERT((m_idx % kai_m_step) == 0);
78
79 1440 return (m_idx / kai_mr) * kai_lhs_packed_stride(k);
80 }
81
82 1440 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k) {
83 KAI_ASSERT((n_idx % kai_n_step) == 0);
84
85 1440 return (n_idx / kai_nr) * kai_rhs_packed_stride(k);
86 }
87
88 960 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(
89 size_t m_idx, size_t n_idx, size_t dst_stride) {
90 KAI_ASSERT((m_idx % kai_m_step) == 0);
91 KAI_ASSERT((n_idx % kai_n_step) == 0);
92
93 960 return (n_idx * sizeof(float)) + m_idx * dst_stride;
94 }
95
96 960 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n) {
97 960 return m * n * sizeof(float);
98 }
99
100 966 void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(
101 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed,
102 float* dst, // NOLINT(readability-non-const-parameter)
103 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
104 KAI_ASSERT(dst_stride_col == sizeof(float));
105
106
1/2
✓ Branch 0 taken 966 times.
✗ Branch 1 not taken.
966 if (m == 0) {
107 return;
108 }
109
110 966 const size_t k_internal = kai_k_roundedup(k);
111
112 966 size_t num_blocks = k_internal / 32;
113
114 966 float clamp_vals[2] = {scalar_min, scalar_max};
115
116 1932 __asm__ __volatile__(
117 "mov x28, #0x80\n"
118 "mov x20, #0x20\n"
119 "movi v4.16b, #0xf0\n"
120 "mov x27, %x[m]\n"
121 "madd x28, %x[num_blocks], x28, x20\n"
122 "cbz x27, 9f\n"
123 "1:" // Row loop
124 "mov x26, %x[rhs_packed]\n"
125 "mov x25, %x[n]\n"
126 "add x24, %x[dst], %x[dst_stride_row], LSL #2\n"
127 "2:" // Column loop
128 "mov x21, %x[lhs_packed]\n"
129 "movi v3.4s, #0x0\n"
130 "movi v2.4s, #0x0\n"
131 "mov x20, %x[num_blocks]\n"
132 "movi v1.4s, #0x0\n"
133 "movi v0.4s, #0x0\n"
134 "3:" // Sub block loop
135 "ldr q31, [x26, #0x0]\n"
136 "ldr q30, [x26, #0x10]\n"
137 "subs x20, x20, #0x1\n"
138 "ldr q29, [x21, #0x0]\n"
139 "ldr q28, [x21, #0x10]\n"
140 "ldr q27, [x26, #0x20]\n"
141 "ldr q26, [x26, #0x30]\n"
142 "add x26, x26, #0x40\n"
143 "ldr q25, [x21, #0x20]\n"
144 "ldr q24, [x21, #0x30]\n"
145 "shl v23.16b, v31.16b, #0x4\n"
146 "shl v22.16b, v30.16b, #0x4\n"
147 "ldr q21, [x21, #0x40]\n"
148 "ldr q20, [x21, #0x50]\n"
149 "and v31.16b, v31.16b, v4.16b\n"
150 "and v30.16b, v30.16b, v4.16b\n"
151 "ldr q19, [x21, #0x60]\n"
152 "ldr q18, [x21, #0x70]\n"
153 "shl v17.16b, v27.16b, #0x4\n"
154 "shl v16.16b, v26.16b, #0x4\n"
155 ".inst 0x4e97a7a3 // smmla v3.4s, v29.16b, v23.16b\n"
156 ".inst 0x4e96a7a2 // smmla v2.4s, v29.16b, v22.16b\n"
157 "and v27.16b, v27.16b, v4.16b\n"
158 "add x21, x21, #0x80\n"
159 ".inst 0x4e97a781 // smmla v1.4s, v28.16b, v23.16b\n"
160 ".inst 0x4e96a780 // smmla v0.4s, v28.16b, v22.16b\n"
161 "and v26.16b, v26.16b, v4.16b\n"
162 ".inst 0x4e91a723 // smmla v3.4s, v25.16b, v17.16b\n"
163 ".inst 0x4e90a722 // smmla v2.4s, v25.16b, v16.16b\n"
164 ".inst 0x4e91a701 // smmla v1.4s, v24.16b, v17.16b\n"
165 ".inst 0x4e90a700 // smmla v0.4s, v24.16b, v16.16b\n"
166 ".inst 0x4e9fa6a3 // smmla v3.4s, v21.16b, v31.16b\n"
167 ".inst 0x4e9ea6a2 // smmla v2.4s, v21.16b, v30.16b\n"
168 ".inst 0x4e9fa681 // smmla v1.4s, v20.16b, v31.16b\n"
169 ".inst 0x4e9ea680 // smmla v0.4s, v20.16b, v30.16b\n"
170 ".inst 0x4e9ba663 // smmla v3.4s, v19.16b, v27.16b\n"
171 ".inst 0x4e9aa662 // smmla v2.4s, v19.16b, v26.16b\n"
172 ".inst 0x4e9ba641 // smmla v1.4s, v18.16b, v27.16b\n"
173 ".inst 0x4e9aa640 // smmla v0.4s, v18.16b, v26.16b\n"
174 "bgt 3b\n"
175 "ldr q18, [x26, #0x0]\n"
176 "ld1 { v17.4s }, [x21]\n"
177 "uzp1 v24.2d, v3.2d, v2.2d\n"
178 "uzp2 v23.2d, v3.2d, v2.2d\n"
179 "ldr q22, [x26, #0x10]\n"
180 "uzp1 v21.2d, v1.2d, v0.2d\n"
181 "uzp2 v20.2d, v1.2d, v0.2d\n"
182 "add x21, x21, #0x10\n"
183 "ldr q16, [x21, #0x0]\n"
184 "add x26, x26, #0x20\n"
185 "mla v24.4s, v18.4s, v17.s[0]\n"
186 "mla v23.4s, v18.4s, v17.s[1]\n"
187 "mla v21.4s, v18.4s, v17.s[2]\n"
188 "mla v20.4s, v18.4s, v17.s[3]\n"
189 "fmul v19.4s, v22.4s, v16.s[0]\n"
190 "fmul v18.4s, v22.4s, v16.s[1]\n"
191 "fmul v17.4s, v22.4s, v16.s[2]\n"
192 "fmul v16.4s, v22.4s, v16.s[3]\n"
193 "scvtf v24.4s, v24.4s\n"
194 "scvtf v23.4s, v23.4s\n"
195 "scvtf v21.4s, v21.4s\n"
196 "scvtf v20.4s, v20.4s\n"
197 "fmul v3.4s, v24.4s, v19.4s\n"
198 "fmul v2.4s, v23.4s, v18.4s\n"
199 "fmul v1.4s, v21.4s, v17.4s\n"
200 "fmul v0.4s, v20.4s, v16.4s\n"
201 "ldr q18, [x26, #0x0]\n"
202 "ld1r { v17.4s }, [%x[clamp_vals]]\n"
203 "add x20, %x[clamp_vals], #0x4\n"
204 "cmp x25, #0x4\n"
205 "ld1r { v16.4s }, [x20]\n"
206 "add x26, x26, #0x10\n"
207 "fadd v3.4s, v3.4s, v18.4s\n"
208 "fadd v2.4s, v2.4s, v18.4s\n"
209 "fadd v1.4s, v1.4s, v18.4s\n"
210 "fadd v0.4s, v0.4s, v18.4s\n"
211 "fmax v3.4s, v3.4s, v17.4s\n"
212 "fmax v2.4s, v2.4s, v17.4s\n"
213 "fmax v1.4s, v1.4s, v17.4s\n"
214 "fmax v0.4s, v0.4s, v17.4s\n"
215 "fmin v3.4s, v3.4s, v16.4s\n"
216 "fmin v2.4s, v2.4s, v16.4s\n"
217 "fmin v1.4s, v1.4s, v16.4s\n"
218 "fmin v0.4s, v0.4s, v16.4s\n"
219 "blt 5f\n"
220 "mov x20, %x[dst]\n"
221 "cmp x27, #0x1\n"
222 "str q3, [x20, #0x0]\n"
223 "add x20, x20, %x[dst_stride_row]\n"
224 "ble 8f\n"
225 "cmp x27, #0x2\n"
226 "str q2, [x20, #0x0]\n"
227 "add x20, x20, %x[dst_stride_row]\n"
228 "ble 8f\n"
229 "cmp x27, #0x3\n"
230 "str q1, [x20, #0x0]\n"
231 "add x20, x20, %x[dst_stride_row]\n"
232 "ble 8f\n"
233 "str q0, [x20, #0x0]\n"
234 "b 8f\n"
235 "5:" // Partial output
236 "mov x23, %x[dst]\n"
237 "cmp x27, #0x1\n"
238 "add x22, x23, %x[dst_stride_row]\n"
239 "csel x22, x22, x23, GT\n"
240 "cmp x27, #0x2\n"
241 "add x21, x23, %x[dst_stride_row], LSL #1\n"
242 "csel x21, x21, x22, GT\n"
243 "cmp x27, #0x3\n"
244 "add x20, x21, %x[dst_stride_row]\n"
245 "csel x20, x20, x21, GT\n"
246 "tbz x25, #1, 6f\n"
247 "st1 { v0.d }[0], [x20], #0x8\n"
248 "st1 { v1.d }[0], [x21], #0x8\n"
249 "st1 { v2.d }[0], [x22], #0x8\n"
250 "st1 { v3.d }[0], [x23], #0x8\n"
251 "tbz x25, #0, 7f\n"
252 "st1 { v0.s }[2], [x20]\n"
253 "st1 { v1.s }[2], [x21]\n"
254 "st1 { v2.s }[2], [x22]\n"
255 "st1 { v3.s }[2], [x23]\n"
256 "b 7f\n"
257 "6:" // Output block 0: partial_1_0
258 "st1 { v0.s }[0], [x20]\n"
259 "st1 { v1.s }[0], [x21]\n"
260 "st1 { v2.s }[0], [x22]\n"
261 "st1 { v3.s }[0], [x23]\n"
262 "7:" // Output block 0: Done
263 "8:" // Output stage exit
264 "subs x25, x25, #0x4\n"
265 "add %x[dst], %x[dst], #0x10\n"
266 "bgt 2b\n"
267 "subs x27, x27, #0x4\n"
268 "add %x[lhs_packed], %x[lhs_packed], x28\n"
269 "mov %x[dst], x24\n"
270 "bgt 1b\n"
271 "9:" // Row loop skip
272 : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed)
273 966 : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n),
274 966 [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed)
275 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
276 "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27",
277 "x28");
278 966 }
279 #endif // Architectural feature check
280