KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 98.6% 68 / 10 / 79
Functions: 100.0% 16 / 0 / 16
Branches: 50.0% 1 / 20 / 22

kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2))
8 #error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2.
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 // Compute args
19 static const size_t kai_m_step = 1; // Multiple of vector length
20 static const size_t kai_n_step = 4; // Multiple of vector length
21 // Packing args
22 static const size_t kai_mr = 1; // Multiple of vector length
23 static const size_t kai_nr = 4; // Multiple of vector length
24 static const size_t kai_kr = 4;
25 static const size_t kai_sr = 2;
26 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
27 static const size_t kai_num_bytes_qvalue_lhs = 1;
28 static const size_t kai_num_bytes_multiplier_lhs = 2;
29 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
30 // asymmetric))
31 static const size_t kai_recip_num_bytes_qvalue_rhs = 2;
32 static const size_t kai_num_bytes_multiplier_rhs = 2;
33 // DST format args
34 static const size_t kai_num_bytes_dst_value = 4;
35 // Extra args
36 static const size_t kai_bl = 32;
37
38 // Look-up table used for int4->int8 convert
39 static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
40
41 98 inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) {
42 98 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs;
43 }
44
45 98 inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) {
46 KAI_ASSUME(bl == kai_bl);
47 98 size_t num_bytes_per_block_rhs = (bl / kai_recip_num_bytes_qvalue_rhs) + kai_num_bytes_multiplier_rhs;
48 196 return num_bytes_per_block_rhs;
49 98 }
50
51 230 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
52 KAI_ASSUME(bl == kai_bl);
53 KAI_ASSUME((k % kai_bl) == 0);
54
55 230 return kai_roundup(k, bl) / bl;
56 }
57
58 98 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) {
59 98 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
60 196 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl);
61 98 }
62
63 98 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) {
64 KAI_ASSUME(bl == kai_bl);
65 KAI_ASSUME((k % kai_bl) == 0);
66
67 98 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
68 98 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
69 98 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
70
71 98 size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row);
72
73 196 return rhs_packed_stride;
74 98 }
75
76 192 size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
77 192 return kai_m_step * kai_get_sme_vector_length_u32();
78 }
79
80 192 size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
81 192 return kai_n_step * kai_get_sme_vector_length_u32();
82 }
83
84 260 size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
85 260 return kai_mr * kai_get_sme_vector_length_u32();
86 }
87
88 260 size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
89 260 return kai_nr * kai_get_sme_vector_length_u32();
90 }
91
92 96 size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
93 96 return kai_kr;
94 }
95
96 64 size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
97 64 return kai_sr;
98 }
99
100 64 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
101 size_t m_idx, size_t k, size_t bl) {
102 64 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
103 64 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
104 KAI_ASSUME((m_idx % m_step) == 0);
105
106 128 return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl);
107 64 }
108
109 64 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
110 size_t n_idx, size_t k, size_t bl) {
111 64 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
112 64 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
113
114 KAI_ASSUME((n_idx % n_step) == 0);
115
116 128 return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl);
117 64 }
118
119 32 size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
120 size_t m_idx, size_t n_idx, size_t dst_stride) {
121 32 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
122 32 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
123 KAI_ASSUME((m_idx % m_step) == 0);
124 KAI_ASSUME((n_idx % n_step) == 0);
125
126 64 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
127 32 }
128
129 32 size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(size_t m, size_t n) {
130 32 return m * n * kai_num_bytes_dst_value;
131 }
132
133 34 void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
134 size_t m, //
135 size_t n, //
136 size_t k, //
137 size_t bl, //
138 const void* restrict lhs_packed, //
139 const void* restrict rhs_packed, //
140 float* restrict dst, // NOLINT(readability-non-const-parameter)
141 size_t dst_stride_row, //
142 size_t dst_stride_col, //
143 float scalar_min, //
144 float scalar_max) {
145 KAI_ASSUME(dst_stride_col == sizeof(float));
146
147 34 KAI_UNUSED(scalar_min);
148 34 KAI_UNUSED(scalar_max);
149
150
1/2
✓ Branch 0 taken 34 times.
✗ Branch 1 not taken.
34 if (m == 0) {
151 return;
152 }
153
154 typedef struct {
155 size_t lhs_packed_stride;
156 size_t rhs_packed_stride;
157 size_t mr;
158 } KernelArgs;
159
160 34 KernelArgs ka;
161
162 34 const size_t num_blocks = kai_get_num_blocks_per_row(k, bl);
163
164 34 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
165 34 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
166
167 34 ka.mr = mr;
168 34 ka.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl);
169 34 ka.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl);
170
171 68 const uint16_t* lhs_scales = (const uint16_t*)((const int8_t*)lhs_packed + ka.lhs_packed_stride -
172 34 (mr * num_blocks) * kai_num_bytes_multiplier_lhs);
173 68 const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + ka.rhs_packed_stride -
174 34 (nr * num_blocks) * kai_num_bytes_multiplier_rhs);
175
176 34 kai_commit_za();
177
178 68 __asm__ volatile(
179 // Switch to streaming mode with ZA enabling
180 " .inst 0xd503477f // smstart \n"
181
182 // Constants
183 // - SVLs
184 " cntw x14 \n"
185 // - ptrue
186 " ptrue p0.b, all \n"
187 " .inst 0x25a07810 // ptrue pn8.s \n"
188
189 // Predicate for loading fp16 scaling factors
190 " ldr x5, [%x[args_ptr], %[offset_mr]]\n"
191 " lsl x5, x5, #1 \n"
192 " whilelt p4.b, xzr, x5 \n"
193
194 // Initialize ZT0 (Lookup table)
195 " mov x6, %[lut]\n"
196 " .inst 0xe11f80c0 // ldr zt0, [x6] \n"
197
198 // Initialize the RHS packes and scale pointers
199 " mov x16, %[rhs_packed] \n"
200 " mov x17, %[rhs_scales] \n"
201
202 // Iterate over n (x8)
203 // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step)
204 " mov x8, #0 \n"
205 " mov x0, %[N] \n"
206 " .inst 0x25a06511 // whilelt pn9.s, x8, x0, VLx4 \n"
207
208 " b.none 9f // .LOOP_N_END%= \n"
209
210 " 1: // .LOOP_N_START%=: \n"
211
212 // Iterate over m (x9)
213 // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step)
214 " mov x9, %[M] \n"
215
216 // Initialize the LHS packed and scale pointers
217 " mov x22, %[lhs_packed] \n"
218 " mov x23, %[lhs_scales] \n"
219
220 // Initialize the DST pointer
221 " mov x24, %[dst] \n"
222
223 " 2: // .LOOP_M_START%=: \n"
224
225 // Address offset for the left and right quantized values
226 " mov x20, #0 \n"
227 " mov x21, #0 \n"
228
229 // Number of output rows to store -> min(SVLh, loop M index)
230 " cmp x9, x14 \n"
231 " csel x15, x9, x14, lo \n"
232 " lsl x15, x15, #2 \n"
233
234 // Iterate over all K values
235 // e.g. for(k_idx = 0; k_idx < k; k_idx += bl)
236 " mov x10, %[K] \n"
237
238 // Skip processing if K=0
239 " cmp x10, #0 \n"
240 " b.eq 8f // .LOOP_K_END%= \n"
241
242 " 3: // .LOOP_K_START%=: \n"
243
244 // Zeroing of ZA accumulator
245 " .inst 0xc00800ff // zero {za} \n"
246
247 // Load the fp16 scaling factors for the right matrix block
248 " .inst 0xa0154220 // ld1w {z0.s - z1.s}, pn8/z, [x17], x21, lsl #2] \n"
249 " .inst 0xc161d000 // zip {z0.h - z1.h}, z0.h, z1.h \n"
250
251 // Iterate over all values in the block
252 // k_blk_idx = bl
253 // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 4}
254 " mov x11, #32\n"
255
256 " 4: // .LOOP_BL_START%=: \n"
257
258 // Load right matrix row
259 " .inst 0xa0144202 // ld1w {z2.s - z3.s}, pn8/z, [x16], x20, lsl #2] \n"
260
261 // Load left matrix column
262 " ld1h {z8.h}, p0/z, [x22, x20, lsl #1] \n"
263 " inch x20, all \n"
264
265 // Convert Int4 -> Int8
266 " .inst 0xc08a4044 // luti4 {z4.b - z5.b}, zt0, z2[0] \n"
267 " .inst 0xc08a4066 // luti4 {z6.b - z7.b}, zt0, z3[0] \n"
268
269 // Outer-products
270 " .inst 0xa0840100 // smopa za0.s, p0/m, p0/m, z8.b, z4.b \n"
271 " .inst 0xa0850101 // smopa za1.s, p0/m, p0/m, z8.b, z5.b \n"
272 " .inst 0xa0860102 // smopa za2.s, p0/m, p0/m, z8.b, z6.b \n"
273 " .inst 0xa0870103 // smopa za3.s, p0/m, p0/m, z8.b, z7.b \n"
274
275 // Decrement the block loop index
276 " subs x11, x11, #4 \n"
277
278 " b.gt 4b // .LOOP_BL_START%= \n"
279
280 // === End of the block loop ===
281
282 // Store loop index
283 " mov w12, #0 \n"
284
285 // Copy destination pointer for store loop
286 " mov x25, x24 \n"
287
288 // Load the fp16 scaling factors for the left matrix block
289 " ld1b {z16.b}, p4/z, [x23, x21] \n"
290 " inch x21, all \n"
291
292 // Predicate for the selection of a scaling among the vector
293 " pfalse p3.b \n"
294
295 " 5: // .LOOP_ZA%=: \n"
296
297 // Select and replicate scaling factor for the right block
298 " pnext p3.h, p0, p3.h \n"
299 " clastb z19.h, p3, z19.h, z16.h \n"
300
301 // Get data from za
302 " .inst 0xc006041c // mova {z28.b-z31.b}, za0h.b[w12, 0:3] \n"
303 " add w12, w12, #4 \n"
304
305 // Convert from int32 to fp32
306 " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n"
307
308 // Multiply left and right scaling factors
309 " movprfx z8, z18 \n"
310 " fmlalb z8.s, z19.h, z0.h \n"
311 " movprfx z9, z18 \n"
312 " fmlalb z9.s, z19.h, z1.h \n"
313 " movprfx z10, z18 \n"
314 " fmlalt z10.s, z19.h, z0.h \n"
315 " movprfx z11, z18 \n"
316 " fmlalt z11.s, z19.h, z1.h \n"
317
318 " cmp x10, %[K] \n"
319 " b.ne 6f // .ACCUMULATE%= \n"
320
321 // Applying combined scaling factors to processed block
322 " fmul z24.s, z8.s, z28.s \n"
323 " fmul z25.s, z9.s, z29.s \n"
324 " fmul z26.s, z10.s, z30.s \n"
325 " fmul z27.s, z11.s, z31.s \n"
326
327 "b 7f // .STORE%= \n"
328
329 " 6: // .ACCUMULATE%=: \n"
330 // Load intermediate result
331 " .inst 0xa040c738 // ld1w {z24.s-z27.s}, pn9/z, [x25] \n"
332
333 // Multiply the intermediate results by LHS_SCALE x RHS_SCALE
334 // and store in the main floating-point accumulator
335 " fmla z24.s, p0/m, z8.s, z28.s \n"
336 " fmla z25.s, p0/m, z9.s, z29.s \n"
337 " fmla z26.s, p0/m, z10.s, z30.s \n"
338 " fmla z27.s, p0/m, z11.s, z31.s \n"
339
340 "7: // .STORE%=: \n"
341 // Store the results into memory
342 " .inst 0xa060c738 // st1w {z24.s-z27.s}, pn9, [x25] \n"
343 " add x25, x25, %[stride] \n"
344
345 " cmp x12, x15 \n"
346 " blt 5b // .LOOP_ZA%= \n"
347
348 // Decrement K loop index by bl
349 " subs x10, x10, #32 \n"
350
351 " b.gt 3b // .LOOP_K_START%= \n"
352
353 " 8: // .LOOP_K_END%=: \n"
354
355 // === End of the K loop ===
356
357 " ldr x5, [%x[args_ptr], %[offset_stride_l]] \n"
358
359 // Increment pointer to the quantized values of the right matrix
360 " add x22, x22, x5\n"
361
362 // Increment pointer to the scaling factors of the right matrix
363 " add x23, x23, x5 \n"
364
365 // Update destination pointer
366 " mov x24, x25 \n"
367
368 // Decrement M loop index
369 " decw x9, all \n"
370
371 " cmp x9, #0 \n"
372 " b.gt 2b // .LOOP_M_START%= \n"
373
374 // === End of M loop ===
375
376 // Increment output pointer
377 " incb %[dst], all, mul #4 \n"
378
379 " ldr x5, [%x[args_ptr], %[offset_stride_r]]\n"
380
381 " add x16, x16, x5 \n"
382 " add x17, x17, x5 \n"
383
384 // Increment N loop index
385 " incb x8, all \n"
386
387 " .inst 0x25a06511 // whilelt pn9.s, x8, %[N], VLx4 \n"
388
389 " b.first 1b // .LOOP_N_START%= \n"
390
391 " 9: // .LOOP_N_END%=: \n"
392
393 // === End of N loop ===
394
395 // Exit streaming mode
396 " .inst 0xd503467f // smstop \n"
397 : [dst] "+r"(dst), [rhs_packed] "+r"(rhs_packed), [rhs_scales] "+r"(rhs_scales)
398 34 : [M] "r"(m), [N] "r"(n), [K] "r"(k), [lhs_packed] "r"(lhs_packed), [lhs_scales] "r"(lhs_scales),
399 34 [stride] "r"(dst_stride_row), [lut] "r"(lut), [args_ptr] "r"(&ka),
400 [offset_stride_l] "I"(offsetof(KernelArgs, lhs_packed_stride)),
401 [offset_stride_r] "I"(offsetof(KernelArgs, rhs_packed_stride)), [offset_mr] "I"(offsetof(KernelArgs, mr))
402 : "p0", "p1", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", "z1",
403 "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18",
404 "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x5", "x6",
405 "x8", "x9", "x10", "x11", "x12", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25",
406 "memory", "cc");
407 34 }
408
409 #endif // Architectural features check.
410