KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.5% 67 11 79
Functions: 100.0% 16 0 16
Branches: 50.0% 1 22 24

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2))
8 #error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2.
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 // Compute args
19 static const size_t kai_m_step = 1;
20 static const size_t kai_n_step = 4; // Multiple of vector length
21 // Packing args
22 static const size_t kai_mr = 1;
23 static const size_t kai_nr = 4; // Multiple of vector length
24 static const size_t kai_kr = 4;
25 static const size_t kai_sr = 2;
26 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
27 static const size_t kai_num_bytes_qvalue_lhs = 1;
28 static const size_t kai_num_bytes_multiplier_lhs = 2;
29 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
30 // asymmetric))
31 static const size_t kai_recip_num_bytes_qvalue_rhs = 2;
32 static const size_t kai_num_bytes_multiplier_rhs = 2;
33 // DST format args
34 static const size_t kai_num_bytes_dst_value = 4;
35 // Extra args
36 static const size_t kai_bl = 32;
37
38 // Look-up table used for int4->int8 convert
39 static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
40
41 30 inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) {
42 30 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs;
43 }
44
45 30 inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) {
46 KAI_ASSUME(bl == kai_bl);
47 30 size_t num_bytes_per_block_rhs = (bl / kai_recip_num_bytes_qvalue_rhs) + kai_num_bytes_multiplier_rhs;
48 60 return num_bytes_per_block_rhs;
49 30 }
50
51 64 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
52 KAI_ASSUME(bl == kai_bl);
53 KAI_ASSUME((k % kai_bl) == 0);
54
55 64 return kai_roundup(k, bl) / bl;
56 }
57
58 30 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) {
59 30 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
60 60 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl);
61 30 }
62
63 30 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) {
64 KAI_ASSUME(bl == kai_bl);
65 KAI_ASSUME((k % kai_bl) == 0);
66
67 30 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
68 30 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
69 30 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
70
71 30 size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row);
72
73 60 return rhs_packed_stride;
74 30 }
75
76 81 size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
77 81 return kai_m_step;
78 }
79
80 81 size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
81 81 return kai_n_step * kai_get_sme_vector_length_u32();
82 }
83
84 108 size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
85 108 return kai_mr;
86 }
87
88 108 size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
89 108 return kai_nr * kai_get_sme_vector_length_u32();
90 }
91
92 72 size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
93 72 return kai_kr;
94 }
95
96 48 size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
97 48 return kai_sr;
98 }
99
100 26 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(
101 size_t m_idx, size_t k, size_t bl) {
102 26 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
103 26 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
104
105 KAI_ASSUME((m_idx % m_step) == 0);
106
107 52 return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl);
108 26 }
109
110 26 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(
111 size_t n_idx, size_t k, size_t bl) {
112 26 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
113 26 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
114
115 KAI_ASSUME((n_idx % nr) == 0);
116
117 52 return (n_idx / n_step) * kai_get_rhs_packed_stride(k, bl);
118 26 }
119
120 3 size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(
121 size_t m_idx, size_t n_idx, size_t dst_stride) {
122 3 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
123 3 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
124 KAI_ASSUME((m_idx % m_step) == 0);
125 KAI_ASSUME((n_idx % n_step) == 0);
126
127 6 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
128 3 }
129
130 3 size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) {
131 3 return m * n * kai_num_bytes_dst_value;
132 }
133
134 4 void kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(
135 size_t m, //
136 size_t n, //
137 size_t k, //
138 size_t bl, //
139 const void* restrict lhs_packed, //
140 const void* restrict rhs_packed, //
141 float* restrict dst, // NOLINT(readability-non-const-parameter)
142 size_t dst_stride_row, //
143 size_t dst_stride_col, //
144 float scalar_min, //
145 float scalar_max) {
146 KAI_ASSUME(dst_stride_col == sizeof(float));
147 KAI_ASSUME(m == 1);
148
149 4 KAI_UNUSED(dst_stride_row);
150 4 KAI_UNUSED(scalar_min);
151 4 KAI_UNUSED(scalar_max);
152
153
1/2
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
4 if (m == 0) {
154 return;
155 }
156
157 4 const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, bl);
158 4 const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, bl);
159 4 const size_t num_blocks = kai_get_num_blocks_per_row(k, bl);
160
161 4 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
162 4 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
163
164 8 const uint16_t* lhs_scales = (const uint16_t*)((const int8_t*)lhs_packed + lhs_packed_stride -
165 4 (mr * num_blocks) * kai_num_bytes_multiplier_lhs);
166 8 const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + rhs_packed_stride -
167 4 (nr * num_blocks) * kai_num_bytes_multiplier_rhs);
168
169 4 __asm__ volatile(
170 // Switch to streaming mode with ZA enabling
171 " .inst 0xd503477f // smstart \n"
172
173 " ptrue p2.b, all \n"
174 " .inst 0x25607810 // ptrue pn8.h \n"
175
176 " fmov z28.s, #0.0 \n"
177
178 // Initialize ZT0 (Lookup table)
179 " mov x9, %[lut] \n"
180 " .inst 0xe11f8120 // ldr zt0, [x9] \n"
181
182 // Initialize the RHS packed and scale pointers
183 " mov x0, %[rhs_packed] \n"
184 " mov x1, %[rhs_scales] \n"
185
186 // Initialize the DST pointer
187 " mov x5, %[dst] \n"
188
189 // Iterate over n (x0)
190 // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step)
191 " mov x4, #0\n"
192 " mov x17, %[n] \n"
193 " .inst 0x25b16491 // whilelt pn9.s, x4, x17, VLx4 \n"
194
195 " b.none 5f // .LOOP_N_END%= \n"
196
197 " 1: // .LOOP_N_START%=: \n"
198
199 // Initialize the LHS packed and scale pointers
200 " mov x2, %[lhs_packed] \n"
201 " mov x3, %[lhs_scales] \n"
202
203 // Initialize the 4xVL-32bit accumulators to zero
204 " dup z24.s, #0 \n"
205 " dup z25.s, #0 \n"
206 " dup z26.s, #0 \n"
207 " dup z27.s, #0 \n"
208
209 // Initialize the vector selector for ZA array
210 " mov w8, #0 \n"
211
212 // Iterate over all K values
213 // e.g. for(k_idx = 0; k_idx < k; k_idx += bl)
214 " mov x6, #0 \n"
215 " whilelt p1.s, x6, %[k] \n"
216 " b.none 4f // .LOOP_K_END%= \n"
217
218 " 2: // .LOOP_K_START%=: \n"
219 // Zeroing of inner accumulation array
220 " .inst 0xc00800ff // zero {za} \n"
221
222 // Iterate over all values in the block
223 // k_blk_idx = bl
224 // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 16}
225
226 "mov x13, %[bl] \n"
227
228 "3: // .LOOP_BL_START%=: \n"
229 // Load the LHS (int8) quantized values
230 // Load contiguous 16 bytes and replicate.
231 // For GeMV, we do not interleave the LHS M rows.
232 " ld1rqb { z0.b }, p2/z , [x2] \n"
233 " add x2, x2, #16 \n"
234
235 // -- First half
236 // Load the RHS (int4) quantized values
237 " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0] \n"
238
239 // Increment the RHS pointer
240 " addvl x0, x0, #4 \n"
241
242 // Convert Int4 -> Int8
243 " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n"
244 " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n"
245 " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n"
246 " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n"
247
248 // SDOT indexed
249 " .inst 0xc15090a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[0] \n"
250 " .inst 0xc1509520 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[1] \n"
251
252 // -- Second half
253
254 // Load the RHS (int4) quantized values
255 " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0]\n"
256
257 // Increment the RHS scale pointer
258 " addvl x0, x0, #4 \n"
259
260 // Convert Int4 -> Int8
261 " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n"
262 " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n"
263 " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n"
264 " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n"
265
266 // SDOT indexed
267 " .inst 0xc15098a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[2] \n"
268 " .inst 0xc1509d20 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[3] \n"
269
270 // Decrement the block loop index
271 "subs x13, x13, #16 \n"
272
273 "b.gt 3b // .LOOP_BL_START%= \n"
274
275 // === End of the block loop ===
276
277 // Load Z registers with intermediate values from ZA array
278 " .inst 0xc0060c10 // mova {z16.s - z19.s}, za.s[w8, 0, vgx4] \n"
279
280 // Convert from int32 to float32
281 " .inst 0xc132e210 // scvtf{z16.s - z19.s}, {z16.s - z19.s} \n"
282
283 // Load 1 fp16 LHS scale scalar value and replicate for VL-16-bit
284 " ld1rh z1.h, p2/z, [x3] \n"
285
286 // Increment the LHS scale pointer by 2 (1 x sizeof(fp16))
287 " add x3, x3, #2 \n"
288
289 // Load 2xVL-16bit (fp16) RHS scales.
290 // If VL=512bit, we load 64 fp16 values, which is equal to the number of output columns (n_step) processed
291 " .inst 0xa0402024 // ld1h { z4.h - z5.h }, pn8/z, [x1]\n"
292
293 // Increment the RHS scale pointer
294 " addvl x1, x1, #2 \n"
295
296 // Combine all the LHS and RHS scales
297 " .inst 0xc165d082 //zip { z2.h-z3.h }, z4.h, z5.h\n"
298 " movprfx z4, z28 \n"
299
300 // Multiply two half floating-point vectors and store the result
301 // to a floating-point 32-bit vector
302 " fmlalb z4.s, z1.h, z2.h\n"
303 " movprfx z5, z28 \n"
304 " fmlalb z5.s, z1.h, z3.h\n"
305 " movprfx z6, z28 \n"
306 " fmlalt z6.s, z1.h, z2.h\n"
307 " movprfx z7, z28 \n"
308 " fmlalt z7.s, z1.h, z3.h\n"
309
310 // Multiply the intermediate results by LHS_SCALE x RHS_SCALE
311 // and store in the main floating-point accumulator
312 " fmla z24.s, p2/m, z16.s, z4.s \n"
313 " fmla z25.s, p2/m, z17.s, z5.s \n"
314 " fmla z26.s, p2/m, z18.s, z6.s \n"
315 " fmla z27.s, p2/m, z19.s, z7.s \n"
316
317 // Increment the number of K values processed and
318 // go to the next block
319 " add x6, x6, %[bl] \n"
320 " whilelt p1.s, x6, %[k] \n"
321 " b.first 2b // .LOOP_K_START%= \n"
322 " 4: //.LOOP_K_END%=: \n"
323
324 // Store the results into memory
325 " .inst 0xa060c4b8 // st1w { z24.s-z27.s }, pn9, [x5] \n"
326 " incb x4, all \n"
327 " addvl x5, x5, #4 \n"
328
329 // The new rhs_packed pointer is the current rhs_scales pointer
330 // The new rhs_scales pointer is the current rhs_packed plus the rhs_packed_stride
331 " mov x7, x0 \n"
332
333 // Initialize the rhs_packed pointer
334 " mov x0, x1 \n"
335
336 // Initialize the rhs_scales pointer
337 " add x1, x7, %[rhs_packed_stride] \n"
338
339 " .inst 0x25b16491 // whilelt pn9.s, x4, %[n], VLx4 \n"
340
341 " b.first 1b // .LOOP_N_START%= \n"
342
343 " 5: // .LOOP_N_END%=: \n"
344
345 // Exit streaming mode
346 " .inst 0xd503467f //smstop \n"
347 :
348 4 : [lut] "r"(lut), [dst] "r"(dst), [rhs_packed] "r"(rhs_packed), [rhs_scales] "r"(rhs_scales),
349 4 [lhs_packed] "r"(lhs_packed), [lhs_scales] "r"(lhs_scales), [rhs_packed_stride] "r"(rhs_packed_stride),
350 4 [n] "r"((int64_t)n), [k] "r"(k), [bl] "i"(kai_bl)
351 : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0",
352 "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17",
353 "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x1",
354 "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x13", "x17", "memory", "cc");
355 4 }
356
357 #endif // Architectural features check.
358