KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 98.0% 49 / 7 / 57
Functions: 100.0% 14 / 0 / 14
Branches: 50.0% 3 / 14 / 20

kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // Do not flag up inline assembly blocks
8 #pragma GCC diagnostic ignored "-Woverlength-strings"
9
10 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)
11 #error This file must be compiled for AArch64, FEAT_SVE2.
12 #else // Architectural features check
13
14 #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h"
15
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include "kai/kai_common.h"
20
21 static const size_t kai_m_step = 1;
22 static const size_t kai_n_step = 1;
23 static const size_t kai_mr = 1;
24 static const size_t kai_nr = 4; // multiple of vector length
25 static const size_t kai_kr = 4;
26 static const size_t kai_sr = 1;
27 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
28 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
29 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
30 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
31 static const size_t kai_num_bytes_bias_rhs = sizeof(float);
32 static const size_t kai_k_multiple_of = 32;
33
34 2946 inline static size_t kai_k_roundedup(size_t k) {
35 // Round up k to be a multiple of 32.
36 2946 return kai_roundup(k, kai_k_multiple_of);
37 }
38
39 1202 inline static size_t kai_get_lhs_packed_stride(size_t k) {
40 1202 const size_t k_internal = kai_k_roundedup(k);
41
42 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
43
44 3606 return kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() *
45 1202 (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
46 1202 }
47
48 1202 inline static size_t kai_get_rhs_packed_stride(size_t k) {
49 1202 const size_t k_internal = kai_k_roundedup(k);
50
51 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
52
53 3606 return kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() *
54 1202 ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs);
55 1202 }
56
57 780 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
58 780 return kai_m_step;
59 }
60
61 3644 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
62 3644 return kai_nr * kai_get_sme_vector_length_u32();
63 }
64
65 780 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
66 780 return kai_n_step * kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot();
67 }
68
69 1662 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
70 // For gemv mr must be 1 to consecutively read the data
71 1662 return kai_mr;
72 }
73
74 580 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
75 580 return kai_kr;
76 }
77
78 580 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
79 580 return kai_sr;
80 }
81
82 660 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t m_idx, size_t k) {
83 KAI_ASSERT((m_idx % kai_m_step) == 0);
84
85 660 return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k);
86 }
87
88 660 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t n_idx, size_t k) {
89 KAI_ASSERT((n_idx % kai_n_step) == 0);
90 660 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot();
91 1320 return (n_idx / nr) * kai_get_rhs_packed_stride(k);
92 660 }
93
94 540 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(
95 size_t m_idx, size_t n_idx, size_t dst_stride) {
96 KAI_ASSERT((m_idx % kai_m_step) == 0);
97 KAI_ASSERT((n_idx % kai_n_step) == 0);
98
99 540 return (n_idx * sizeof(float)) + (m_idx * dst_stride);
100 }
101
102 540 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) {
103 540 return m * n * sizeof(float);
104 }
105
106 /// Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110).
107 static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0,
108 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0,
109 -5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0};
110
111 // Optimized for GEMV (matrix vector multiplication => m == 1).
112 // Does a matmul for compatibility reasons, but should not be used that way.
113 542 void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(
114 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed,
115 float* dst, // NOLINT(readability-non-const-parameter)
116 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
117 KAI_ASSERT(dst_stride_col == sizeof(float));
118
119
3/6
✓ Branch 0 taken 542 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 542 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 542 times.
542 if (m == 0 || n == 0 || k == 0) {
120 return;
121 }
122
123 // Do function calls and calculations first to not overwrite registers we will use
124 542 uint64_t k_internal = kai_k_roundedup(k);
125 542 uint64_t lhs_stride = kai_get_lhs_packed_stride(k);
126 542 uint64_t rhs_stride = kai_get_rhs_packed_stride(k);
127 542 uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot();
128
129 542 uint64_t rhs_row_bytes = nr * k_internal / 2;
130 542 uint64_t lhs_end_ptr = ((uint64_t)lhs_packed) + (m * lhs_stride);
131
132 542 kai_commit_za();
133
134 /*
135 * x11: zero = 0 // MUST BE x8-x11
136 * x15: n initialized as n
137 * x19: nr initialized as nr
138 * x20: lut_ptr initialized as lut
139 * x21: lhs_packed initialized as lhs_packed
140 * x22: n_idx
141 * x23: k_idx
142 * x24: RHS block ptr
143 * x25: RHS end ptr
144 * x26: rhs_packed
145 * x27: dst_ptr
146 * x28: tmp_1
147 */
148
149 1084 __asm__ volatile(
150
151 // Setup
152 " .inst 0xd503477f // smstart \n"
153 " mov x11, #0 \n"
154 " mov x15, %[n] \n"
155 " mov x19, %[nr] \n"
156 " mov x21, %[lhs_packed] \n"
157 " mov x20, %[lut] \n"
158 " .inst 0xe11f8280 // ldr zt0, [x20] \n"
159 " ptrue p0.b \n"
160 " .inst 0x25207810 // ptrue pn8.b \n"
161 // predicate to load nr words for the RHS sums and scaling factors (should be exactly all true)
162 " .inst 0x25b36571 // whilelt pn9.s, x11, x19, vlx4 \n"
163 " dup z30.s, %w[scalar_min] \n"
164 " dup z31.s, %w[scalar_max] \n"
165
166 // lhs matrix row loop
167 "1: \n"
168 // Reset rhs matrix ptr
169 " mov x26, %[rhs_packed] \n"
170 // Reset dst_ptr to dst of next GEMV result
171 " mov x27, %[dst_ptr] \n"
172 // Reset n index
173 " mov x22, #0 \n"
174 // whilelt pn12.s, x22, %[n], vlx4
175 " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n"
176
177 // rhs matrix row loop (transposed so theoretical columns)
178 "2: \n"
179
180 // Reset rhs block ptr to start of row
181 " mov x24, x26 \n"
182 " add x25, x26, %[rhs_row_bytes] \n"
183 " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n"
184 " addvl x28, x24, #4 \n"
185 " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n"
186 " mov x23, #0 \n"
187 " whilelt p1.b, x23, %[k_internal] \n"
188 // Zero for sdot accumulation in inner loop
189 " .inst 0xc00800ff // zero {za} \n"
190
191 // before k loop
192 "3: \n"
193
194 // Load lhs
195 " ld1rqb { z0.b }, p1/z , [x21, x23] \n"
196
197 // Load w
198 " .inst 0xa0408b10 // ld1b { z16.b - z19.b }, pn10/z, [x24] \n"
199 " .inst 0xa0418f14 // ld1b {z20.b-z23.b}, pn11/z, [x24,#0x4, mul vl]\n"
200
201 // rhs i4 to i8 and sdot
202 // k block + 0
203 " .inst 0xc08a4218 // luti4 { z24.b, z25.b }, zt0, z16[0] \n"
204 " .inst 0xc08a423a // luti4 { z26.b, z27.b }, zt0, z17[0] \n"
205 " .inst 0xc150f320 // sdot za.s[w11,0, vgx4], {z24.b-z27.b}, z0.b[0]\n"
206 // k block + 1
207 " .inst 0xc08a4244 // luti4 { z4.b, z5.b }, zt0, z18[0] \n"
208 " .inst 0xc08a4266 // luti4 { z6.b, z7.b }, zt0, z19[0] \n"
209 " .inst 0xc150f4a0 // sdot za.s[w11,0, vgx4], {z4.b-z7.b}, z0.b[1] \n"
210 // k block + 2
211 " .inst 0xc08a4288 // luti4 { z8.b, z9.b }, zt0, z20[0] \n"
212 " .inst 0xc08a42aa // luti4 { z10.b, z11.b }, zt0, z21[0] \n"
213 " .inst 0xc150f920 // sdot za.s[w11,0, vgx4], {z8.b-z11.b}, z0.b[2] \n"
214 // k block + 3
215 " .inst 0xc08a42cc // luti4 { z12.b, z13.b }, zt0, z22[0] \n"
216 " .inst 0xc08a42ee // luti4 { z14.b, z15.b }, zt0, z23[0] \n"
217 " .inst 0xc150fda0 // sdot za.s[w11,0, vgx4], {z12.b-z15.b}, z0.b[3]\n"
218
219 // End K block loop
220 " addvl x24, x24, #8 \n"
221 " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n"
222 " addvl x28, x24, #4 \n"
223 " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n"
224 " add x23, x23, #16 \n"
225 " whilelt p1.b, x23, %[k_internal] \n"
226 " b.first 3b \n"
227
228 // Finish of accumulators with scaling factors and zero points
229
230 // Load lhs zero point
231 " add x28, x21, %[k_internal] \n"
232 " ld1rw { z2.s }, p0/z , [x28] \n"
233 // Load lhs scaling factor
234 " ld1rw { z3.s }, p0/z , [x28, #4] \n"
235 // Load rhs sums
236 " add x28, x26, %[rhs_row_bytes] \n"
237 " .inst 0xa040c794 // ld1w { z20.s - z23.s }, pn9/z, [x28] \n"
238 // Load rhs scaling factors
239 " .inst 0xa041c798 // ld1w {z24.s-z27.s}, pn9/z, [x28, #0x4, mul vl]\n"
240 // Load biases
241 " .inst 0xa042c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, #0x8, mul vl]\n"
242
243 // Get accumulated value out of ZA
244 " .inst 0xc0066c04 // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] \n"
245
246 // za contains a * w, which needs to be done + z * wsum -> smla
247 // zero point * rhs row sum
248 " mla z4.s, p0/m, z20.s, z2.s \n"
249 " mla z5.s, p0/m, z21.s, z2.s \n"
250 " mla z6.s, p0/m, z22.s, z2.s \n"
251 " mla z7.s, p0/m, z23.s, z2.s \n"
252
253 // Convert to float
254 " .inst 0xc132e084 // scvtf { z4.s - z7.s }, { z4.s - z7.s } \n"
255
256 // lhs scaling factor * rhs scaling factor
257 " fmul z24.s, z24.s, z3.s \n"
258 " fmul z25.s, z25.s, z3.s \n"
259 " fmul z26.s, z26.s, z3.s \n"
260 " fmul z27.s, z27.s, z3.s \n"
261
262 // Bias + combined scaling factor * combined accumulator
263 " fmla z12.s, p0/m, z24.s, z4.s \n"
264 " fmla z13.s, p0/m, z25.s, z5.s \n"
265 " fmla z14.s, p0/m, z26.s, z6.s \n"
266 " fmla z15.s, p0/m, z27.s, z7.s \n"
267
268 // Clamp
269 " .inst 0xc1bfcbcc // fclamp { z12.s - z15.s }, z30.s, z31.s \n"
270
271 // Store
272 " .inst 0xa036d36c // st1w {z12.s-z15.s}, pn12, [x27, x22, lsl #2] \n"
273
274 // End rhs row loop
275 " add x26, x26, %[rhs_stride] \n"
276 // nr == svlb
277 " addvl x22, x22, #1 \n"
278 // whilelt pn12.s, x22, %[n], vlx4
279 " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n"
280 " b.lt 2b \n"
281
282 // End lhs row loop
283 " add %[dst_ptr], %[dst_ptr], %[dst_stride_row] \n"
284 " add x21, x21, %[lhs_stride] \n"
285 " cmp x21, %[lhs_end_ptr] \n"
286 " b.lt 1b \n"
287
288 " .inst 0xd503467f // smstop \n"
289
290 : [dst_ptr] "+r"(dst)
291 542 : [lut] "r"(lut), [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_packed] "r"(lhs_packed),
292 542 [rhs_packed] "r"(rhs_packed), [dst_stride_row] "r"(dst_stride_row), [scalar_min] "r"(scalar_min),
293 542 [scalar_max] "r"(scalar_max), [k_internal] "r"(k_internal), [lhs_stride] "r"(lhs_stride),
294 542 [rhs_stride] "r"(rhs_stride), [nr] "r"(nr), [rhs_row_bytes] "r"(rhs_row_bytes), [lhs_end_ptr] "r"(lhs_end_ptr)
295 : "x11", "x15", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "p0", "p1", "p8", "p9",
296 "p10", "p11", "p12", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13",
297 "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28",
298 "z29", "z30", "z31",
299 #ifdef __ARM_STATE_ZA
300 "za",
301 #endif
302 #ifdef __ARM_STATE_ZT0
303 "zt0",
304 #endif
305 "memory", "cc");
306 542 }
307
308 #endif // Architectural features check.
309