KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.0% 48 7 56
Functions: 100.0% 14 0 14
Branches: 50.0% 3 14 20

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // Do not flag up inline assembly blocks
8 #pragma GCC diagnostic ignored "-Woverlength-strings"
9
10 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)
11 #error This file must be compiled for AArch64, FEAT_SVE2.
12 #else // Architectural features check
13
14 #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h"
15
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include "kai/kai_common.h"
20
21 static const size_t kai_m_step = 1;
22 static const size_t kai_n_step = 1;
23 static const size_t kai_mr = 1;
24 static const size_t kai_nr = 4; // multiple of vector length
25 static const size_t kai_kr = 4;
26 static const size_t kai_sr = 1;
27 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
28 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
29 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
30 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
31 static const size_t kai_num_bytes_bias_rhs = sizeof(float);
32 static const size_t kai_k_multiple_of = 32;
33
34 483 inline static size_t kai_k_roundedup(size_t k) {
35 // Round up k to be a multiple of 32.
36 483 return kai_roundup(k, kai_k_multiple_of);
37 }
38
39 201 inline static size_t kai_get_lhs_packed_stride(size_t k) {
40 201 const size_t k_internal = kai_k_roundedup(k);
41
42 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
43
44 603 return kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() *
45 201 (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
46 201 }
47
48 201 inline static size_t kai_get_rhs_packed_stride(size_t k) {
49 201 const size_t k_internal = kai_k_roundedup(k);
50
51 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
52
53 603 return kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() *
54 201 ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs);
55 201 }
56
57 160 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
58 160 return kai_m_step;
59 }
60
61 682 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
62 682 return kai_nr * kai_get_sme_vector_length_u32();
63 }
64
65 160 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
66 160 return kai_n_step * kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot();
67 }
68
69 321 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
70 // For gemv mr must be 1 to consecutively read the data
71 321 return kai_mr;
72 }
73
74 160 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
75 160 return kai_kr;
76 }
77
78 160 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) {
79 160 return kai_sr;
80 }
81
82 120 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t m_idx, size_t k) {
83 KAI_ASSERT((m_idx % kai_m_step) == 0);
84
85 120 return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k);
86 }
87
88 120 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t n_idx, size_t k) {
89 KAI_ASSERT((n_idx % kai_n_step) == 0);
90 120 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot();
91 240 return (n_idx / nr) * kai_get_rhs_packed_stride(k);
92 120 }
93
94 80 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(
95 size_t m_idx, size_t n_idx, size_t dst_stride) {
96 KAI_ASSERT((m_idx % kai_m_step) == 0);
97 KAI_ASSERT((n_idx % kai_n_step) == 0);
98
99 80 return (n_idx * sizeof(float)) + (m_idx * dst_stride);
100 }
101
102 80 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) {
103 80 return m * n * sizeof(float);
104 }
105
106 /// Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110).
107 static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0,
108 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0,
109 -5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0};
110
111 // Optimized for GEMV (matrix vector multiplication => m == 1).
112 // Does a matmul for compatibility reasons, but should not be used that way.
113 81 void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(
114 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed,
115 float* dst, // NOLINT(readability-non-const-parameter)
116 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
117 KAI_ASSERT(dst_stride_col == sizeof(float));
118
119
3/6
✓ Branch 0 taken 81 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 81 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 81 times.
81 if (m == 0 || n == 0 || k == 0) {
120 return;
121 }
122
123 // Do function calls and calculations first to not overwrite registers we will use
124 81 uint64_t k_internal = kai_k_roundedup(k);
125 81 uint64_t lhs_stride = kai_get_lhs_packed_stride(k);
126 81 uint64_t rhs_stride = kai_get_rhs_packed_stride(k);
127 81 uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot();
128
129 81 uint64_t rhs_row_bytes = nr * k_internal / 2;
130 81 uint64_t lhs_end_ptr = ((uint64_t)lhs_packed) + (m * lhs_stride);
131
132 /*
133 * x11: zero = 0 // MUST BE x8-x11
134 * x15: n initialized as n
135 * x19: nr initialized as nr
136 * x20: lut_ptr initialized as lut
137 * x21: lhs_packed initialized as lhs_packed
138 * x22: n_idx
139 * x23: k_idx
140 * x24: RHS block ptr
141 * x25: RHS end ptr
142 * x26: rhs_packed
143 * x27: dst_ptr
144 * x28: tmp_1
145 */
146
147 162 __asm__ volatile(
148
149 // Setup
150 " .inst 0xd503477f // smstart \n"
151 " mov x11, #0 \n"
152 " mov x15, %[n] \n"
153 " mov x19, %[nr] \n"
154 " mov x21, %[lhs_packed] \n"
155 " mov x20, %[lut] \n"
156 " .inst 0xe11f8280 // ldr zt0, [x20] \n"
157 " ptrue p0.b \n"
158 " .inst 0x25207810 // ptrue pn8.b \n"
159 // predicate to load nr words for the RHS sums and scaling factors (should be exactly all true)
160 " .inst 0x25b36571 // whilelt pn9.s, x11, x19, vlx4 \n"
161 " dup z30.s, %w[scalar_min] \n"
162 " dup z31.s, %w[scalar_max] \n"
163
164 // lhs matrix row loop
165 "1: \n"
166 // Reset rhs matrix ptr
167 " mov x26, %[rhs_packed] \n"
168 // Reset dst_ptr to dst of next GEMV result
169 " mov x27, %[dst_ptr] \n"
170 // Reset n index
171 " mov x22, #0 \n"
172 // whilelt pn12.s, x22, %[n], vlx4
173 " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n"
174
175 // rhs matrix row loop (transposed so theoretical columns)
176 "2: \n"
177
178 // Reset rhs block ptr to start of row
179 " mov x24, x26 \n"
180 " add x25, x26, %[rhs_row_bytes] \n"
181 " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n"
182 " addvl x28, x24, #4 \n"
183 " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n"
184 " mov x23, #0 \n"
185 " whilelt p1.b, x23, %[k_internal] \n"
186 // Zero for sdot accumulation in inner loop
187 " .inst 0xc00800ff // zero {za} \n"
188
189 // before k loop
190 "3: \n"
191
192 // Load lhs
193 " ld1rqb { z0.b }, p1/z , [x21, x23] \n"
194
195 // Load w
196 " .inst 0xa0408b10 // ld1b { z16.b - z19.b }, pn10/z, [x24] \n"
197 " .inst 0xa0418f14 // ld1b {z20.b-z23.b}, pn11/z, [x24,#0x4, mul vl]\n"
198
199 // rhs i4 to i8 and sdot
200 // k block + 0
201 " .inst 0xc08a4218 // luti4 { z24.b, z25.b }, zt0, z16[0] \n"
202 " .inst 0xc08a423a // luti4 { z26.b, z27.b }, zt0, z17[0] \n"
203 " .inst 0xc150f320 // sdot za.s[w11,0, vgx4], {z24.b-z27.b}, z0.b[0]\n"
204 // k block + 1
205 " .inst 0xc08a4244 // luti4 { z4.b, z5.b }, zt0, z18[0] \n"
206 " .inst 0xc08a4266 // luti4 { z6.b, z7.b }, zt0, z19[0] \n"
207 " .inst 0xc150f4a0 // sdot za.s[w11,0, vgx4], {z4.b-z7.b}, z0.b[1] \n"
208 // k block + 2
209 " .inst 0xc08a4288 // luti4 { z8.b, z9.b }, zt0, z20[0] \n"
210 " .inst 0xc08a42aa // luti4 { z10.b, z11.b }, zt0, z21[0] \n"
211 " .inst 0xc150f920 // sdot za.s[w11,0, vgx4], {z8.b-z11.b}, z0.b[2] \n"
212 // k block + 3
213 " .inst 0xc08a42cc // luti4 { z12.b, z13.b }, zt0, z22[0] \n"
214 " .inst 0xc08a42ee // luti4 { z14.b, z15.b }, zt0, z23[0] \n"
215 " .inst 0xc150fda0 // sdot za.s[w11,0, vgx4], {z12.b-z15.b}, z0.b[3]\n"
216
217 // End K block loop
218 " addvl x24, x24, #8 \n"
219 " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n"
220 " addvl x28, x24, #4 \n"
221 " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n"
222 " add x23, x23, #16 \n"
223 " whilelt p1.b, x23, %[k_internal] \n"
224 " b.first 3b \n"
225
226 // Finish of accumulators with scaling factors and zero points
227
228 // Load lhs zero point
229 " add x28, x21, %[k_internal] \n"
230 " ld1rw { z2.s }, p0/z , [x28] \n"
231 // Load lhs scaling factor
232 " ld1rw { z3.s }, p0/z , [x28, #4] \n"
233 // Load rhs sums
234 " add x28, x26, %[rhs_row_bytes] \n"
235 " .inst 0xa040c794 // ld1w { z20.s - z23.s }, pn9/z, [x28] \n"
236 // Load rhs scaling factors
237 " .inst 0xa041c798 // ld1w {z24.s-z27.s}, pn9/z, [x28, #0x4, mul vl]\n"
238 // Load biases
239 " .inst 0xa042c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, #0x8, mul vl]\n"
240
241 // Get accumulated value out of ZA
242 " .inst 0xc0066c04 // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] \n"
243
244 // za contains a * w, which needs to be done + z * wsum -> smla
245 // zero point * rhs row sum
246 " mla z4.s, p0/m, z20.s, z2.s \n"
247 " mla z5.s, p0/m, z21.s, z2.s \n"
248 " mla z6.s, p0/m, z22.s, z2.s \n"
249 " mla z7.s, p0/m, z23.s, z2.s \n"
250
251 // Convert to float
252 " .inst 0xc132e084 // scvtf { z4.s - z7.s }, { z4.s - z7.s } \n"
253
254 // lhs scaling factor * rhs scaling factor
255 " fmul z24.s, z24.s, z3.s \n"
256 " fmul z25.s, z25.s, z3.s \n"
257 " fmul z26.s, z26.s, z3.s \n"
258 " fmul z27.s, z27.s, z3.s \n"
259
260 // Bias + combined scaling factor * combined accumulator
261 " fmla z12.s, p0/m, z24.s, z4.s \n"
262 " fmla z13.s, p0/m, z25.s, z5.s \n"
263 " fmla z14.s, p0/m, z26.s, z6.s \n"
264 " fmla z15.s, p0/m, z27.s, z7.s \n"
265
266 // Clamp
267 " .inst 0xc1bfcbcc // fclamp { z12.s - z15.s }, z30.s, z31.s \n"
268
269 // Store
270 " .inst 0xa036d36c // st1w {z12.s-z15.s}, pn12, [x27, x22, lsl #2] \n"
271
272 // End rhs row loop
273 " add x26, x26, %[rhs_stride] \n"
274 // nr == svlb
275 " addvl x22, x22, #1 \n"
276 // whilelt pn12.s, x22, %[n], vlx4
277 " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n"
278 " b.lt 2b \n"
279
280 // End lhs row loop
281 " add %[dst_ptr], %[dst_ptr], %[dst_stride_row] \n"
282 " add x21, x21, %[lhs_stride] \n"
283 " cmp x21, %[lhs_end_ptr] \n"
284 " b.lt 1b \n"
285
286 " .inst 0xd503467f // smstop \n"
287
288 : [dst_ptr] "+r"(dst)
289 81 : [lut] "r"(lut), [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_packed] "r"(lhs_packed),
290 81 [rhs_packed] "r"(rhs_packed), [dst_stride_row] "r"(dst_stride_row), [scalar_min] "r"(scalar_min),
291 81 [scalar_max] "r"(scalar_max), [k_internal] "r"(k_internal), [lhs_stride] "r"(lhs_stride),
292 81 [rhs_stride] "r"(rhs_stride), [nr] "r"(nr), [rhs_row_bytes] "r"(rhs_row_bytes), [lhs_end_ptr] "r"(lhs_end_ptr)
293 : "x11", "x15", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "p0", "p1", "p8", "p9",
294 "p10", "p11", "p12", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13",
295 "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28",
296 "z29", "z30", "z31",
297 #ifdef __ARM_STATE_ZA
298 "za",
299 #endif
300 #ifdef __ARM_STATE_ZT0
301 "zt0",
302 #endif
303 "memory", "cc");
304 81 }
305
306 #endif // Architectural features check.
307