KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 100.0% 50 9 59
Functions: 100.0% 14 0 14
Branches: -% 0 18 18

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // Do not flag up inline assembly blocks
8 #pragma GCC diagnostic ignored "-Woverlength-strings"
9
10 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)
11 #error This file must be compiled for AArch64, FEAT_SVE2.
12 #else // Architectural feature check
13
14 #include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h"
15
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include "kai/kai_common.h"
20
21 static const size_t kai_mr = 1; // multiple of vector length
22 static const size_t kai_nr = 4; // multiple of vector length
23 static const size_t kai_kr = 4;
24 static const size_t kai_sr = 1;
25 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
26 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
27 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
28 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
29 static const size_t kai_num_bytes_bias_rhs = sizeof(float);
30 static const size_t kai_k_multiple_of = 32;
31
32 /// Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110).
33 static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0,
34 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0,
35 -5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0};
36
37 483 inline static size_t kai_k_roundedup(size_t k) {
38 // Round up k to be a multiple of 32.
39 483 return kai_roundup(k, kai_k_multiple_of);
40 }
41
42 201 inline static size_t kai_get_lhs_packed_stride(size_t k) {
43 201 const size_t k_internal = kai_k_roundedup(k);
44
45 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
46
47 603 return kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa() *
48 201 (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
49 201 }
50
51 201 inline static size_t kai_get_rhs_packed_stride(size_t k) {
52 201 const size_t k_internal = kai_k_roundedup(k);
53
54 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
55
56 603 return kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa() *
57 201 ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs);
58 201 }
59
60 360 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
61 360 return kai_mr * kai_get_sme_vector_length_u32();
62 }
63
64 360 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
65 360 return kai_nr * kai_get_sme_vector_length_u32();
66 }
67
68 522 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
69 522 return kai_mr * kai_get_sme_vector_length_u32();
70 }
71
72 522 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
73 522 return kai_nr * kai_get_sme_vector_length_u32();
74 }
75
76 160 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
77 160 return kai_kr;
78 }
79
80 160 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
81 160 return kai_sr;
82 }
83
84 120 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) {
85 KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0);
86
87 120 const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa();
88
89 240 return (m_idx / mr) * kai_get_lhs_packed_stride(k);
90 120 }
91
92 120 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) {
93 KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0);
94
95 120 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa();
96
97 240 return (n_idx / nr) * kai_get_rhs_packed_stride(k);
98 120 }
99
100 80 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(
101 size_t m_idx, size_t n_idx, size_t dst_stride) {
102 KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0);
103 KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0);
104
105 80 return (n_idx * sizeof(float) + m_idx * dst_stride);
106 }
107
108 80 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n) {
109 80 return m * n * sizeof(float);
110 }
111
112 81 void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(
113 size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed,
114 float* restrict dst, // NOLINT(readability-non-const-parameter)
115 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
116 KAI_ASSERT(dst_stride_col == sizeof(float));
117 KAI_ASSERT(n > 0);
118 KAI_ASSERT(m > 0);
119
120 // Constants
121 81 uint64_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa();
122 81 uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa();
123 81 uint64_t lhs_stride = kai_get_lhs_packed_stride(k);
124 81 uint64_t rhs_stride = kai_get_rhs_packed_stride(k);
125 81 uint64_t m_blk = (uint64_t)kai_k_roundedup(k) * mr;
126 81 uint64_t dst_inc = mr * dst_stride_row;
127 81 float scalar_bounds[2] = {scalar_min, scalar_max};
128
129 /* ---------------------------------------------------
130 Registers allocations
131 x7: Look up table(lut)
132 x8: RHS base address (rhs)
133 x9: Destination base address (dst)
134 x10: LHS pointer (lhs)
135 x11: RHS pointer (rhs)
136 x12: Remaining M elements
137 x13: Remaining N elements
138 x14: k exit condition (k_cond)
139 ZA tile index (l_idx)
140 x15: LHS scaling factor pointer (lhs_sf_ptr)
141 x16: ZA tile exit condition (l_cnd)
142 x17: Destination pointer (dst)
143 x19: Destination outer address (dst)
144 x20: LHS base address (lhs)
145 --------------------------------------------------- */
146 81 __asm__ volatile(
147 " .inst 0xd503477f //smstart \n"
148 " mov x19, %[dst] \n"
149 " mov x20, %[lhs] \n"
150 " mov x7, %[lut] \n"
151 " .inst 0xe11f80e0 //ldr zt0, [x7] \n"
152 " cntw x7 \n"
153 " ptrue p2.b \n"
154 " ld1rw {z30.s}, p2/Z, [%[scalar_bounds]] \n"
155 " ld1rw {z31.s}, p2/Z, [%[scalar_bounds], #4] \n"
156
157 // M loop head
158 " mov x12, %[m] \n"
159 " .inst 0x25ac17e0 //whilelt p0.s, xzr, x12 \n"
160 "1: \n"
161 " mov x8, %[rhs] \n"
162 " mov x9, x19 \n"
163 " mov x13, %[n] \n"
164 " cmp x7, x12 \n"
165 " csel x16, x7, x12, lt \n"
166 " lsl x16, x16, #2 \n"
167
168 // N loop head
169 " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n"
170 "2: \n"
171 " mov x10, x20 \n"
172 " mov x11, x8 \n"
173 " mov x17, x9 \n"
174 " .inst 0x25ad67f1 //whilelt pn9.s, xzr, x13, vlx4 \n"
175
176 // K loop
177 " .inst 0xc00800ff //zero {za} \n"
178 " add x14, x10, %[m_blk] \n"
179 "3: \n"
180 " .inst 0xa540a144 //ld1w { z4.s }, p0/z, [x10] \n"
181 " .inst 0x042a502a //addvl x10, x10, #1 \n"
182 " .inst 0xa0402160 //ld1h { z0.h-z1.h }, pn8/z, [x11] \n"
183 " .inst 0x042b504b //addvl x11, x11, #2 \n"
184 " .inst 0xc08a4008 //luti4 { z8.b - z9.b }, zt0, z0[0] \n"
185 " .inst 0xc08a402a //luti4 { z10.b - z11.b }, zt0, z1[0] \n"
186 " .inst 0xa0884880 //smopa za0.s, p2/m, p2/m, z4.b, z8.b \n"
187 " .inst 0xa0894881 //smopa za1.s, p2/m, p2/m, z4.b, z9.b \n"
188 " .inst 0xa08a4882 //smopa za2.s, p2/m, p2/m, z4.b, z10.b\n"
189 " .inst 0xa08b4883 //smopa za3.s, p2/m, p2/m, z4.b, z11.b\n"
190 " cmp x10, x14 \n"
191 " b.lt 3b \n"
192
193 // RHS row sum, scale factor & bias
194 " .inst 0xa040c560 //ld1w { z0.s-z3.s }, pn9/z, [x11] \n"
195 " .inst 0xa041c564 //ld1w { z4.s-z7.s }, pn9/z, [x11, #4, mul vl] \n"
196 " .inst 0xa042c568 //ld1w { z8.s-z11.s }, pn9/z, [x11, #8, mul vl]\n"
197 " .inst 0x042b518b //addvl x11, x11, #12 \n"
198 " .inst 0xc132e000 //scvtf { z0.s-z3.s }, { z0.s-z3.s }\n"
199
200 // Store loop
201 " mov x14, #0 \n"
202 " addvl x15, x10, #1 \n"
203 "4: \n"
204 // Load LHS Row-offset & SF
205 " ld1rw {z16.s}, p2/z, [x10] \n"
206 " ld1rw {z17.s}, p2/z, [x15] \n"
207 " add x10, x10, #4 \n"
208 " add x15, x15, #4 \n"
209 " scvtf z16.s, p2/m, z16.s \n"
210
211 // offset x Row-sum
212 " fmul z24.s, z16.s, z0.s \n"
213 " fmul z25.s, z16.s, z1.s \n"
214 " fmul z26.s, z16.s, z2.s \n"
215 " fmul z27.s, z16.s, z3.s \n"
216
217 // Scaling factors
218 " fmul z20.s, z17.s, z4.s \n"
219 " fmul z21.s, z17.s, z5.s \n"
220 " fmul z22.s, z17.s, z6.s \n"
221 " fmul z23.s, z17.s, z7.s \n"
222
223 // Result = offset x Row-sum x SFs
224 " fmul z24.s, z24.s, z20.s \n"
225 " fmul z25.s, z25.s, z21.s \n"
226 " fmul z26.s, z26.s, z22.s \n"
227 " fmul z27.s, z27.s, z23.s \n"
228
229 // Load inner accumulation & convert
230 " .inst 0xc006440c //mova { z12.b-z15.b }, za0h.b[w14, 0:3]\n"
231 " .inst 0xc132e18c //scvtf { z12.s-z15.s }, { z12.s-z15.s } \n"
232
233 // Result += iacc x SF
234 " fmla z24.s, p2/m, z20.s, z12.s \n"
235 " fmla z25.s, p2/m, z21.s, z13.s \n"
236 " fmla z26.s, p2/m, z22.s, z14.s \n"
237 " fmla z27.s, p2/m, z23.s, z15.s \n"
238
239 // Add the bias
240 " fadd z24.s, p2/m, z24.s, z8.s \n"
241 " fadd z25.s, p2/m, z25.s, z9.s \n"
242 " fadd z26.s, p2/m, z26.s, z10.s \n"
243 " fadd z27.s, p2/m, z27.s, z11.s \n"
244
245 // CLAMP and store
246 " .inst 0xc1bfcbd8 //fclamp { z24.s-z27.s }, z30.s, z31.s\n"
247 " .inst 0xa060c638 //st1w { z24.s-z27.s }, pn9, [x17] \n"
248
249 " add x17, x17, %[dst_stride_row] \n"
250 " add x14, x14, #4 \n"
251 " cmp x14, x16 \n"
252 " b.lt 4b \n"
253
254 // N loop tail
255 " add x8, x8, %[rhs_stride] \n"
256 " .inst 0x04295089 // addvl x9, x9, #4 \n"
257 " sub x13, x13, %[nr] \n"
258 " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n"
259 " b.mi 2b \n"
260
261 // M loop tail
262 " add x20, x20, %[lhs_stride] \n"
263 " add x19, x19, %[dst_inc] \n"
264 " sub x12, x12, %[mr] \n"
265 " whilelt p0.s, xzr, x12 \n"
266 " b.mi 1b \n"
267
268 "5: \n"
269 " .inst 0xd503467f //smstop \n"
270 :
271 81 : [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_stride] "r"(lhs_stride), [rhs_stride] "r"(rhs_stride),
272 81 [dst_stride_row] "r"(dst_stride_row), [lut] "r"(lut), [m_blk] "r"(m_blk), [nr] "r"(nr), [mr] "r"(mr),
273 81 [lhs] "r"(lhs_packed), [rhs] "r"(rhs_packed), [dst_inc] "r"(dst_inc), [scalar_bounds] "r"(scalar_bounds),
274 81 [dst] "r"(dst)
275 : "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "p0", "p2", "p8",
276 "p9", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
277 "z16", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z30", "z31",
278 #ifdef __ARM_STATE_ZA
279 "za",
280 #endif
281 #ifdef __ARM_STATE_ZT0
282 "zt0",
283 #endif
284 "cc", "memory");
285 81 }
286
287 #endif // Architectural feature check
288