KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.5% 67 10 78
Functions: 100.0% 16 0 16
Branches: 50.0% 1 20 22

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2))
8 #error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2.
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 // Compute args
19 static const size_t kai_m_step = 1; // Multiple of vector length
20 static const size_t kai_n_step = 4; // Multiple of vector length
21 // Packing args
22 static const size_t kai_mr = 1; // Multiple of vector length
23 static const size_t kai_nr = 4; // Multiple of vector length
24 static const size_t kai_kr = 4;
25 static const size_t kai_sr = 2;
26 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
27 static const size_t kai_num_bytes_qvalue_lhs = 1;
28 static const size_t kai_num_bytes_multiplier_lhs = 2;
29 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
30 // asymmetric))
31 static const size_t kai_recip_num_bytes_qvalue_rhs = 2;
32 static const size_t kai_num_bytes_multiplier_rhs = 2;
33 // DST format args
34 static const size_t kai_num_bytes_dst_value = 4;
35 // Extra args
36 static const size_t kai_bl = 32;
37
38 // Look-up table used for int4->int8 convert
39 static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
40
41 70 inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) {
42 70 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs;
43 }
44
45 70 inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) {
46 KAI_ASSUME(bl == kai_bl);
47 70 size_t num_bytes_per_block_rhs = (bl / kai_recip_num_bytes_qvalue_rhs) + kai_num_bytes_multiplier_rhs;
48 140 return num_bytes_per_block_rhs;
49 70 }
50
51 164 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
52 KAI_ASSUME(bl == kai_bl);
53 KAI_ASSUME((k % kai_bl) == 0);
54
55 164 return kai_roundup(k, bl) / bl;
56 }
57
58 70 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) {
59 70 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
60 140 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl);
61 70 }
62
63 70 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) {
64 KAI_ASSUME(bl == kai_bl);
65 KAI_ASSUME((k % kai_bl) == 0);
66
67 70 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
68 70 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
69 70 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
70
71 70 size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row);
72
73 140 return rhs_packed_stride;
74 70 }
75
76 141 size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
77 141 return kai_m_step * kai_get_sme_vector_length_u32();
78 }
79
80 141 size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
81 141 return kai_n_step * kai_get_sme_vector_length_u32();
82 }
83
84 188 size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
85 188 return kai_mr * kai_get_sme_vector_length_u32();
86 }
87
88 188 size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
89 188 return kai_nr * kai_get_sme_vector_length_u32();
90 }
91
92 72 size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
93 72 return kai_kr;
94 }
95
96 48 size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
97 48 return kai_sr;
98 }
99
100 46 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
101 size_t m_idx, size_t k, size_t bl) {
102 46 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
103 46 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
104 KAI_ASSUME((m_idx % m_step) == 0);
105
106 92 return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl);
107 46 }
108
109 46 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
110 size_t n_idx, size_t k, size_t bl) {
111 46 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
112 46 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
113
114 KAI_ASSUME((n_idx % n_step) == 0);
115
116 92 return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl);
117 46 }
118
119 23 size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
120 size_t m_idx, size_t n_idx, size_t dst_stride) {
121 23 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
122 23 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
123 KAI_ASSUME((m_idx % m_step) == 0);
124 KAI_ASSUME((n_idx % n_step) == 0);
125
126 46 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
127 23 }
128
129 23 size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(size_t m, size_t n) {
130 23 return m * n * kai_num_bytes_dst_value;
131 }
132
133 24 void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
134 size_t m, //
135 size_t n, //
136 size_t k, //
137 size_t bl, //
138 const void* restrict lhs_packed, //
139 const void* restrict rhs_packed, //
140 float* restrict dst, // NOLINT(readability-non-const-parameter)
141 size_t dst_stride_row, //
142 size_t dst_stride_col, //
143 float scalar_min, //
144 float scalar_max) {
145 KAI_ASSUME(dst_stride_col == sizeof(float));
146
147 24 KAI_UNUSED(scalar_min);
148 24 KAI_UNUSED(scalar_max);
149
150
1/2
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
24 if (m == 0) {
151 return;
152 }
153
154 typedef struct {
155 size_t lhs_packed_stride;
156 size_t rhs_packed_stride;
157 size_t mr;
158 } KernelArgs;
159
160 24 KernelArgs ka;
161
162 24 const size_t num_blocks = kai_get_num_blocks_per_row(k, bl);
163
164 24 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
165 24 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
166
167 24 ka.mr = mr;
168 24 ka.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl);
169 24 ka.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl);
170
171 48 const uint16_t* lhs_scales = (const uint16_t*)((const int8_t*)lhs_packed + ka.lhs_packed_stride -
172 24 (mr * num_blocks) * kai_num_bytes_multiplier_lhs);
173 48 const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + ka.rhs_packed_stride -
174 24 (nr * num_blocks) * kai_num_bytes_multiplier_rhs);
175
176 48 __asm__ volatile(
177 // Switch to streaming mode with ZA enabling
178 " .inst 0xd503477f // smstart \n"
179
180 // Constants
181 // - SVLs
182 " cntw x14 \n"
183 // - ptrue
184 " ptrue p0.b, all \n"
185 " .inst 0x25a07810 // ptrue pn8.s \n"
186
187 // Predicate for loading fp16 scaling factors
188 " ldr x5, [%x[args_ptr], %[offset_mr]]\n"
189 " lsl x5, x5, #1 \n"
190 " whilelt p4.b, xzr, x5 \n"
191
192 // Initialize ZT0 (Lookup table)
193 " mov x6, %[lut]\n"
194 " .inst 0xe11f80c0 // ldr zt0, [x6] \n"
195
196 // Initialize the RHS packes and scale pointers
197 " mov x16, %[rhs_packed] \n"
198 " mov x17, %[rhs_scales] \n"
199
200 // Iterate over n (x8)
201 // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step)
202 " mov x8, #0 \n"
203 " mov x0, %[N] \n"
204 " .inst 0x25a06511 // whilelt pn9.s, x8, x0, VLx4 \n"
205
206 " b.none 9f // .LOOP_N_END%= \n"
207
208 " 1: // .LOOP_N_START%=: \n"
209
210 // Iterate over m (x9)
211 // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step)
212 " mov x9, %[M] \n"
213
214 // Initialize the LHS packed and scale pointers
215 " mov x22, %[lhs_packed] \n"
216 " mov x23, %[lhs_scales] \n"
217
218 // Initialize the DST pointer
219 " mov x24, %[dst] \n"
220
221 " 2: // .LOOP_M_START%=: \n"
222
223 // Address offset for the left and right quantized values
224 " mov x20, #0 \n"
225 " mov x21, #0 \n"
226
227 // Number of output rows to store -> min(SVLh, loop M index)
228 " cmp x9, x14 \n"
229 " csel x15, x9, x14, lo \n"
230 " lsl x15, x15, #2 \n"
231
232 // Iterate over all K values
233 // e.g. for(k_idx = 0; k_idx < k; k_idx += bl)
234 " mov x10, %[K] \n"
235
236 // Skip processing if K=0
237 " cmp x10, #0 \n"
238 " b.eq 8f // .LOOP_K_END%= \n"
239
240 " 3: // .LOOP_K_START%=: \n"
241
242 // Zeroing of ZA accumulator
243 " .inst 0xc00800ff // zero {za} \n"
244
245 // Load the fp16 scaling factors for the right matrix block
246 " .inst 0xa0154220 // ld1w {z0.s - z1.s}, pn8/z, [x17], x21, lsl #2] \n"
247 " .inst 0xc161d000 // zip {z0.h - z1.h}, z0.h, z1.h \n"
248
249 // Iterate over all values in the block
250 // k_blk_idx = bl
251 // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 4}
252 " mov x11, #32\n"
253
254 " 4: // .LOOP_BL_START%=: \n"
255
256 // Load right matrix row
257 " .inst 0xa0144202 // ld1w {z2.s - z3.s}, pn8/z, [x16], x20, lsl #2] \n"
258
259 // Load left matrix column
260 " ld1h {z8.h}, p0/z, [x22, x20, lsl #1] \n"
261 " inch x20, all \n"
262
263 // Convert Int4 -> Int8
264 " .inst 0xc08a4044 // luti4 {z4.b - z5.b}, zt0, z2[0] \n"
265 " .inst 0xc08a4066 // luti4 {z6.b - z7.b}, zt0, z3[0] \n"
266
267 // Outer-products
268 " .inst 0xa0840100 // smopa za0.s, p0/m, p0/m, z8.b, z4.b \n"
269 " .inst 0xa0850101 // smopa za1.s, p0/m, p0/m, z8.b, z5.b \n"
270 " .inst 0xa0860102 // smopa za2.s, p0/m, p0/m, z8.b, z6.b \n"
271 " .inst 0xa0870103 // smopa za3.s, p0/m, p0/m, z8.b, z7.b \n"
272
273 // Decrement the block loop index
274 " subs x11, x11, #4 \n"
275
276 " b.gt 4b // .LOOP_BL_START%= \n"
277
278 // === End of the block loop ===
279
280 // Store loop index
281 " mov w12, #0 \n"
282
283 // Copy destination pointer for store loop
284 " mov x25, x24 \n"
285
286 // Load the fp16 scaling factors for the left matrix block
287 " ld1b {z16.b}, p4/z, [x23, x21] \n"
288 " inch x21, all \n"
289
290 // Predicate for the selection of a scaling among the vector
291 " pfalse p3.b \n"
292
293 " 5: // .LOOP_ZA%=: \n"
294
295 // Select and replicate scaling factor for the right block
296 " pnext p3.h, p0, p3.h \n"
297 " clastb z19.h, p3, z19.h, z16.h \n"
298
299 // Get data from za
300 " .inst 0xc006041c // mova {z28.b-z31.b}, za0h.b[w12, 0:3] \n"
301 " add w12, w12, #4 \n"
302
303 // Convert from int32 to fp32
304 " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n"
305
306 // Multiply left and right scaling factors
307 " movprfx z8, z18 \n"
308 " fmlalb z8.s, z19.h, z0.h \n"
309 " movprfx z9, z18 \n"
310 " fmlalb z9.s, z19.h, z1.h \n"
311 " movprfx z10, z18 \n"
312 " fmlalt z10.s, z19.h, z0.h \n"
313 " movprfx z11, z18 \n"
314 " fmlalt z11.s, z19.h, z1.h \n"
315
316 " cmp x10, %[K] \n"
317 " b.ne 6f // .ACCUMULATE%= \n"
318
319 // Applying combined scaling factors to processed block
320 " fmul z24.s, z8.s, z28.s \n"
321 " fmul z25.s, z9.s, z29.s \n"
322 " fmul z26.s, z10.s, z30.s \n"
323 " fmul z27.s, z11.s, z31.s \n"
324
325 "b 7f // .STORE%= \n"
326
327 " 6: // .ACCUMULATE%=: \n"
328 // Load intermediate result
329 " .inst 0xa040c738 // ld1w {z24.s-z27.s}, pn9/z, [x25] \n"
330
331 // Multiply the intermediate results by LHS_SCALE x RHS_SCALE
332 // and store in the main floating-point accumulator
333 " fmla z24.s, p0/m, z8.s, z28.s \n"
334 " fmla z25.s, p0/m, z9.s, z29.s \n"
335 " fmla z26.s, p0/m, z10.s, z30.s \n"
336 " fmla z27.s, p0/m, z11.s, z31.s \n"
337
338 "7: // .STORE%=: \n"
339 // Store the results into memory
340 " .inst 0xa060c738 // st1w {z24.s-z27.s}, pn9, [x25] \n"
341 " add x25, x25, %[stride] \n"
342
343 " cmp x12, x15 \n"
344 " blt 5b // .LOOP_ZA%= \n"
345
346 // Decrement K loop index by bl
347 " subs x10, x10, #32 \n"
348
349 " b.gt 3b // .LOOP_K_START%= \n"
350
351 " 8: // .LOOP_K_END%=: \n"
352
353 // === End of the K loop ===
354
355 " ldr x5, [%x[args_ptr], %[offset_stride_l]] \n"
356
357 // Increment pointer to the quantized values of the right matrix
358 " add x22, x22, x5\n"
359
360 // Increment pointer to the scaling factors of the right matrix
361 " add x23, x23, x5 \n"
362
363 // Update destination pointer
364 " mov x24, x25 \n"
365
366 // Decrement M loop index
367 " decw x9, all \n"
368
369 " cmp x9, #0 \n"
370 " b.gt 2b // .LOOP_M_START%= \n"
371
372 // === End of M loop ===
373
374 // Increment output pointer
375 " incb %[dst], all, mul #4 \n"
376
377 " ldr x5, [%x[args_ptr], %[offset_stride_r]]\n"
378
379 " add x16, x16, x5 \n"
380 " add x17, x17, x5 \n"
381
382 // Increment N loop index
383 " incb x8, all \n"
384
385 " .inst 0x25a06511 // whilelt pn9.s, x8, %[N], VLx4 \n"
386
387 " b.first 1b // .LOOP_N_START%= \n"
388
389 " 9: // .LOOP_N_END%=: \n"
390
391 // === End of N loop ===
392
393 // Exit streaming mode
394 " .inst 0xd503467f // smstop \n"
395 : [dst] "+r"(dst), [rhs_packed] "+r"(rhs_packed), [rhs_scales] "+r"(rhs_scales)
396 24 : [M] "r"(m), [N] "r"(n), [K] "r"(k), [lhs_packed] "r"(lhs_packed), [lhs_scales] "r"(lhs_scales),
397 24 [stride] "r"(dst_stride_row), [lut] "r"(lut), [args_ptr] "r"(&ka),
398 [offset_stride_l] "I"(offsetof(KernelArgs, lhs_packed_stride)),
399 [offset_stride_r] "I"(offsetof(KernelArgs, rhs_packed_stride)), [offset_mr] "I"(offsetof(KernelArgs, mr))
400 : "p0", "p1", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", "z1",
401 "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18",
402 "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x5", "x6",
403 "x8", "x9", "x10", "x11", "x12", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25",
404 "memory", "cc");
405 24 }
406
407 #endif // Architectural features check.
408