KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 51 / 9 / 60
Functions: 100.0% 14 / 0 / 14
Branches: -% 0 / 18 / 18

kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // Do not flag up inline assembly blocks
8 #pragma GCC diagnostic ignored "-Woverlength-strings"
9
10 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)
11 #error This file must be compiled for AArch64, FEAT_SVE2.
12 #else // Architectural feature check
13
14 #include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h"
15
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include "kai/kai_common.h"
20
21 static const size_t kai_mr = 1; // multiple of vector length
22 static const size_t kai_nr = 4; // multiple of vector length
23 static const size_t kai_kr = 4;
24 static const size_t kai_sr = 1;
25 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
26 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
27 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
28 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
29 static const size_t kai_num_bytes_bias_rhs = sizeof(float);
30 static const size_t kai_k_multiple_of = 32;
31
32 /// Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110).
33 static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0,
34 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0,
35 -5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0};
36
37 2946 inline static size_t kai_k_roundedup(size_t k) {
38 // Round up k to be a multiple of 32.
39 2946 return kai_roundup(k, kai_k_multiple_of);
40 }
41
42 1202 inline static size_t kai_get_lhs_packed_stride(size_t k) {
43 1202 const size_t k_internal = kai_k_roundedup(k);
44
45 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
46
47 3606 return kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa() *
48 1202 (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
49 1202 }
50
51 1202 inline static size_t kai_get_rhs_packed_stride(size_t k) {
52 1202 const size_t k_internal = kai_k_roundedup(k);
53
54 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
55
56 3606 return kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa() *
57 1202 ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs);
58 1202 }
59
60 1980 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
61 1980 return kai_mr * kai_get_sme_vector_length_u32();
62 }
63
64 1980 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
65 1980 return kai_nr * kai_get_sme_vector_length_u32();
66 }
67
68 2864 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
69 2864 return kai_mr * kai_get_sme_vector_length_u32();
70 }
71
72 2864 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
73 2864 return kai_nr * kai_get_sme_vector_length_u32();
74 }
75
76 580 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
77 580 return kai_kr;
78 }
79
80 580 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) {
81 580 return kai_sr;
82 }
83
84 660 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) {
85 KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0);
86
87 660 const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa();
88
89 1320 return (m_idx / mr) * kai_get_lhs_packed_stride(k);
90 660 }
91
92 660 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) {
93 KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0);
94
95 660 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa();
96
97 1320 return (n_idx / nr) * kai_get_rhs_packed_stride(k);
98 660 }
99
100 540 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(
101 size_t m_idx, size_t n_idx, size_t dst_stride) {
102 KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0);
103 KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0);
104
105 540 return (n_idx * sizeof(float) + m_idx * dst_stride);
106 }
107
108 540 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n) {
109 540 return m * n * sizeof(float);
110 }
111
112 542 void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(
113 size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed,
114 float* restrict dst, // NOLINT(readability-non-const-parameter)
115 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
116 KAI_ASSERT(dst_stride_col == sizeof(float));
117 KAI_ASSERT(n > 0);
118 KAI_ASSERT(m > 0);
119
120 // Constants
121 542 uint64_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa();
122 542 uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa();
123 542 uint64_t lhs_stride = kai_get_lhs_packed_stride(k);
124 542 uint64_t rhs_stride = kai_get_rhs_packed_stride(k);
125 542 uint64_t m_blk = (uint64_t)kai_k_roundedup(k) * mr;
126 542 uint64_t dst_inc = mr * dst_stride_row;
127 542 float scalar_bounds[2] = {scalar_min, scalar_max};
128
129 542 kai_commit_za();
130
131 /* ---------------------------------------------------
132 Registers allocations
133 x7: Look up table(lut)
134 x8: RHS base address (rhs)
135 x9: Destination base address (dst)
136 x10: LHS pointer (lhs)
137 x11: RHS pointer (rhs)
138 x12: Remaining M elements
139 x13: Remaining N elements
140 x14: k exit condition (k_cond)
141 ZA tile index (l_idx)
142 x15: LHS scaling factor pointer (lhs_sf_ptr)
143 x16: ZA tile exit condition (l_cnd)
144 x17: Destination pointer (dst)
145 x19: Destination outer address (dst)
146 x20: LHS base address (lhs)
147 --------------------------------------------------- */
148 542 __asm__ volatile(
149 " .inst 0xd503477f //smstart \n"
150 " mov x19, %[dst] \n"
151 " mov x20, %[lhs] \n"
152 " mov x7, %[lut] \n"
153 " .inst 0xe11f80e0 //ldr zt0, [x7] \n"
154 " cntw x7 \n"
155 " ptrue p2.b \n"
156 " ld1rw {z30.s}, p2/Z, [%[scalar_bounds]] \n"
157 " ld1rw {z31.s}, p2/Z, [%[scalar_bounds], #4] \n"
158
159 // M loop head
160 " mov x12, %[m] \n"
161 " .inst 0x25ac17e0 //whilelt p0.s, xzr, x12 \n"
162 "1: \n"
163 " mov x8, %[rhs] \n"
164 " mov x9, x19 \n"
165 " mov x13, %[n] \n"
166 " cmp x7, x12 \n"
167 " csel x16, x7, x12, lt \n"
168 " lsl x16, x16, #2 \n"
169
170 // N loop head
171 " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n"
172 "2: \n"
173 " mov x10, x20 \n"
174 " mov x11, x8 \n"
175 " mov x17, x9 \n"
176 " .inst 0x25ad67f1 //whilelt pn9.s, xzr, x13, vlx4 \n"
177
178 // K loop
179 " .inst 0xc00800ff //zero {za} \n"
180 " add x14, x10, %[m_blk] \n"
181 "3: \n"
182 " .inst 0xa540a144 //ld1w { z4.s }, p0/z, [x10] \n"
183 " .inst 0x042a502a //addvl x10, x10, #1 \n"
184 " .inst 0xa0402160 //ld1h { z0.h-z1.h }, pn8/z, [x11] \n"
185 " .inst 0x042b504b //addvl x11, x11, #2 \n"
186 " .inst 0xc08a4008 //luti4 { z8.b - z9.b }, zt0, z0[0] \n"
187 " .inst 0xc08a402a //luti4 { z10.b - z11.b }, zt0, z1[0] \n"
188 " .inst 0xa0884880 //smopa za0.s, p2/m, p2/m, z4.b, z8.b \n"
189 " .inst 0xa0894881 //smopa za1.s, p2/m, p2/m, z4.b, z9.b \n"
190 " .inst 0xa08a4882 //smopa za2.s, p2/m, p2/m, z4.b, z10.b\n"
191 " .inst 0xa08b4883 //smopa za3.s, p2/m, p2/m, z4.b, z11.b\n"
192 " cmp x10, x14 \n"
193 " b.lt 3b \n"
194
195 // RHS row sum, scale factor & bias
196 " .inst 0xa040c560 //ld1w { z0.s-z3.s }, pn9/z, [x11] \n"
197 " .inst 0xa041c564 //ld1w { z4.s-z7.s }, pn9/z, [x11, #4, mul vl] \n"
198 " .inst 0xa042c568 //ld1w { z8.s-z11.s }, pn9/z, [x11, #8, mul vl]\n"
199 " .inst 0x042b518b //addvl x11, x11, #12 \n"
200 " .inst 0xc132e000 //scvtf { z0.s-z3.s }, { z0.s-z3.s }\n"
201
202 // Store loop
203 " mov x14, #0 \n"
204 " addvl x15, x10, #1 \n"
205 "4: \n"
206 // Load LHS Row-offset & SF
207 " ld1rw {z16.s}, p2/z, [x10] \n"
208 " ld1rw {z17.s}, p2/z, [x15] \n"
209 " add x10, x10, #4 \n"
210 " add x15, x15, #4 \n"
211 " scvtf z16.s, p2/m, z16.s \n"
212
213 // offset x Row-sum
214 " fmul z24.s, z16.s, z0.s \n"
215 " fmul z25.s, z16.s, z1.s \n"
216 " fmul z26.s, z16.s, z2.s \n"
217 " fmul z27.s, z16.s, z3.s \n"
218
219 // Scaling factors
220 " fmul z20.s, z17.s, z4.s \n"
221 " fmul z21.s, z17.s, z5.s \n"
222 " fmul z22.s, z17.s, z6.s \n"
223 " fmul z23.s, z17.s, z7.s \n"
224
225 // Result = offset x Row-sum x SFs
226 " fmul z24.s, z24.s, z20.s \n"
227 " fmul z25.s, z25.s, z21.s \n"
228 " fmul z26.s, z26.s, z22.s \n"
229 " fmul z27.s, z27.s, z23.s \n"
230
231 // Load inner accumulation & convert
232 " .inst 0xc006440c //mova { z12.b-z15.b }, za0h.b[w14, 0:3]\n"
233 " .inst 0xc132e18c //scvtf { z12.s-z15.s }, { z12.s-z15.s } \n"
234
235 // Result += iacc x SF
236 " fmla z24.s, p2/m, z20.s, z12.s \n"
237 " fmla z25.s, p2/m, z21.s, z13.s \n"
238 " fmla z26.s, p2/m, z22.s, z14.s \n"
239 " fmla z27.s, p2/m, z23.s, z15.s \n"
240
241 // Add the bias
242 " fadd z24.s, p2/m, z24.s, z8.s \n"
243 " fadd z25.s, p2/m, z25.s, z9.s \n"
244 " fadd z26.s, p2/m, z26.s, z10.s \n"
245 " fadd z27.s, p2/m, z27.s, z11.s \n"
246
247 // CLAMP and store
248 " .inst 0xc1bfcbd8 //fclamp { z24.s-z27.s }, z30.s, z31.s\n"
249 " .inst 0xa060c638 //st1w { z24.s-z27.s }, pn9, [x17] \n"
250
251 " add x17, x17, %[dst_stride_row] \n"
252 " add x14, x14, #4 \n"
253 " cmp x14, x16 \n"
254 " b.lt 4b \n"
255
256 // N loop tail
257 " add x8, x8, %[rhs_stride] \n"
258 " .inst 0x04295089 // addvl x9, x9, #4 \n"
259 " sub x13, x13, %[nr] \n"
260 " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n"
261 " b.mi 2b \n"
262
263 // M loop tail
264 " add x20, x20, %[lhs_stride] \n"
265 " add x19, x19, %[dst_inc] \n"
266 " sub x12, x12, %[mr] \n"
267 " whilelt p0.s, xzr, x12 \n"
268 " b.mi 1b \n"
269
270 "5: \n"
271 " .inst 0xd503467f //smstop \n"
272 :
273 542 : [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_stride] "r"(lhs_stride), [rhs_stride] "r"(rhs_stride),
274 542 [dst_stride_row] "r"(dst_stride_row), [lut] "r"(lut), [m_blk] "r"(m_blk), [nr] "r"(nr), [mr] "r"(mr),
275 542 [lhs] "r"(lhs_packed), [rhs] "r"(rhs_packed), [dst_inc] "r"(dst_inc), [scalar_bounds] "r"(scalar_bounds),
276 542 [dst] "r"(dst)
277 : "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "p0", "p2", "p8",
278 "p9", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
279 "z16", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z30", "z31",
280 #ifdef __ARM_STATE_ZA
281 "za",
282 #endif
283 #ifdef __ARM_STATE_ZT0
284 "zt0",
285 #endif
286 "cc", "memory");
287 542 }
288
289 #endif // Architectural feature check
290