KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 98.6% 68 / 11 / 80
Functions: 100.0% 16 / 0 / 16
Branches: 50.0% 1 / 22 / 24

kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2))
8 #error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2.
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 // Compute args
19 static const size_t kai_m_step = 1;
20 static const size_t kai_n_step = 4; // Multiple of vector length
21 // Packing args
22 static const size_t kai_mr = 1;
23 static const size_t kai_nr = 4; // Multiple of vector length
24 static const size_t kai_kr = 4;
25 static const size_t kai_sr = 2;
26 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
27 static const size_t kai_num_bytes_qvalue_lhs = 1;
28 static const size_t kai_num_bytes_multiplier_lhs = 2;
29 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
30 // asymmetric))
31 static const size_t kai_recip_num_bytes_qvalue_rhs = 2;
32 static const size_t kai_num_bytes_multiplier_rhs = 2;
33 // DST format args
34 static const size_t kai_num_bytes_dst_value = 4;
35 // Extra args
36 static const size_t kai_bl = 32;
37
38 // Look-up table used for int4->int8 convert
39 static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
40
41 58 inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) {
42 58 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs;
43 }
44
45 58 inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) {
46 KAI_ASSUME(bl == kai_bl);
47 58 size_t num_bytes_per_block_rhs = (bl / kai_recip_num_bytes_qvalue_rhs) + kai_num_bytes_multiplier_rhs;
48 116 return num_bytes_per_block_rhs;
49 58 }
50
51 130 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
52 KAI_ASSUME(bl == kai_bl);
53 KAI_ASSUME((k % kai_bl) == 0);
54
55 130 return kai_roundup(k, bl) / bl;
56 }
57
58 58 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) {
59 58 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
60 116 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl);
61 58 }
62
63 58 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) {
64 KAI_ASSUME(bl == kai_bl);
65 KAI_ASSUME((k % kai_bl) == 0);
66
67 58 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
68 58 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
69 58 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
70
71 58 size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row);
72
73 116 return rhs_packed_stride;
74 58 }
75
76 132 size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
77 132 return kai_m_step;
78 }
79
80 132 size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
81 132 return kai_n_step * kai_get_sme_vector_length_u32();
82 }
83
84 180 size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
85 180 return kai_mr;
86 }
87
88 180 size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
89 180 return kai_nr * kai_get_sme_vector_length_u32();
90 }
91
92 96 size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
93 96 return kai_kr;
94 }
95
96 64 size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) {
97 64 return kai_sr;
98 }
99
100 44 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(
101 size_t m_idx, size_t k, size_t bl) {
102 44 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
103 44 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
104
105 KAI_ASSUME((m_idx % m_step) == 0);
106
107 88 return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl);
108 44 }
109
110 44 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(
111 size_t n_idx, size_t k, size_t bl) {
112 44 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
113 44 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
114
115 KAI_ASSUME((n_idx % nr) == 0);
116
117 88 return (n_idx / n_step) * kai_get_rhs_packed_stride(k, bl);
118 44 }
119
120 12 size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(
121 size_t m_idx, size_t n_idx, size_t dst_stride) {
122 12 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
123 12 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
124 KAI_ASSUME((m_idx % m_step) == 0);
125 KAI_ASSUME((n_idx % n_step) == 0);
126
127 24 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
128 12 }
129
130 12 size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) {
131 12 return m * n * kai_num_bytes_dst_value;
132 }
133
134 14 void kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(
135 size_t m, //
136 size_t n, //
137 size_t k, //
138 size_t bl, //
139 const void* restrict lhs_packed, //
140 const void* restrict rhs_packed, //
141 float* restrict dst, // NOLINT(readability-non-const-parameter)
142 size_t dst_stride_row, //
143 size_t dst_stride_col, //
144 float scalar_min, //
145 float scalar_max) {
146 KAI_ASSUME(dst_stride_col == sizeof(float));
147 KAI_ASSUME(m == 1);
148
149 14 KAI_UNUSED(dst_stride_row);
150 14 KAI_UNUSED(scalar_min);
151 14 KAI_UNUSED(scalar_max);
152
153
1/2
✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
14 if (m == 0) {
154 return;
155 }
156
157 14 const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, bl);
158 14 const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, bl);
159 14 const size_t num_blocks = kai_get_num_blocks_per_row(k, bl);
160
161 14 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
162 14 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot();
163
164 28 const uint16_t* lhs_scales = (const uint16_t*)((const int8_t*)lhs_packed + lhs_packed_stride -
165 14 (mr * num_blocks) * kai_num_bytes_multiplier_lhs);
166 28 const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + rhs_packed_stride -
167 14 (nr * num_blocks) * kai_num_bytes_multiplier_rhs);
168
169 14 kai_commit_za();
170
171 14 __asm__ volatile(
172 // Switch to streaming mode with ZA enabling
173 " .inst 0xd503477f // smstart \n"
174
175 " ptrue p2.b, all \n"
176 " .inst 0x25607810 // ptrue pn8.h \n"
177
178 " fmov z28.s, #0.0 \n"
179
180 // Initialize ZT0 (Lookup table)
181 " mov x9, %[lut] \n"
182 " .inst 0xe11f8120 // ldr zt0, [x9] \n"
183
184 // Initialize the RHS packed and scale pointers
185 " mov x0, %[rhs_packed] \n"
186 " mov x1, %[rhs_scales] \n"
187
188 // Initialize the DST pointer
189 " mov x5, %[dst] \n"
190
191 // Iterate over n (x0)
192 // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step)
193 " mov x4, #0\n"
194 " mov x17, %[n] \n"
195 " .inst 0x25b16491 // whilelt pn9.s, x4, x17, VLx4 \n"
196
197 " b.none 5f // .LOOP_N_END%= \n"
198
199 " 1: // .LOOP_N_START%=: \n"
200
201 // Initialize the LHS packed and scale pointers
202 " mov x2, %[lhs_packed] \n"
203 " mov x3, %[lhs_scales] \n"
204
205 // Initialize the 4xVL-32bit accumulators to zero
206 " dup z24.s, #0 \n"
207 " dup z25.s, #0 \n"
208 " dup z26.s, #0 \n"
209 " dup z27.s, #0 \n"
210
211 // Initialize the vector selector for ZA array
212 " mov w8, #0 \n"
213
214 // Iterate over all K values
215 // e.g. for(k_idx = 0; k_idx < k; k_idx += bl)
216 " mov x6, #0 \n"
217 " whilelt p1.s, x6, %[k] \n"
218 " b.none 4f // .LOOP_K_END%= \n"
219
220 " 2: // .LOOP_K_START%=: \n"
221 // Zeroing of inner accumulation array
222 " .inst 0xc00800ff // zero {za} \n"
223
224 // Iterate over all values in the block
225 // k_blk_idx = bl
226 // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 16}
227
228 "mov x13, %[bl] \n"
229
230 "3: // .LOOP_BL_START%=: \n"
231 // Load the LHS (int8) quantized values
232 // Load contiguous 16 bytes and replicate.
233 // For GeMV, we do not interleave the LHS M rows.
234 " ld1rqb { z0.b }, p2/z , [x2] \n"
235 " add x2, x2, #16 \n"
236
237 // -- First half
238 // Load the RHS (int4) quantized values
239 " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0] \n"
240
241 // Increment the RHS pointer
242 " addvl x0, x0, #4 \n"
243
244 // Convert Int4 -> Int8
245 " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n"
246 " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n"
247 " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n"
248 " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n"
249
250 // SDOT indexed
251 " .inst 0xc15090a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[0] \n"
252 " .inst 0xc1509520 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[1] \n"
253
254 // -- Second half
255
256 // Load the RHS (int4) quantized values
257 " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0]\n"
258
259 // Increment the RHS scale pointer
260 " addvl x0, x0, #4 \n"
261
262 // Convert Int4 -> Int8
263 " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n"
264 " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n"
265 " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n"
266 " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n"
267
268 // SDOT indexed
269 " .inst 0xc15098a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[2] \n"
270 " .inst 0xc1509d20 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[3] \n"
271
272 // Decrement the block loop index
273 "subs x13, x13, #16 \n"
274
275 "b.gt 3b // .LOOP_BL_START%= \n"
276
277 // === End of the block loop ===
278
279 // Load Z registers with intermediate values from ZA array
280 " .inst 0xc0060c10 // mova {z16.s - z19.s}, za.s[w8, 0, vgx4] \n"
281
282 // Convert from int32 to float32
283 " .inst 0xc132e210 // scvtf{z16.s - z19.s}, {z16.s - z19.s} \n"
284
285 // Load 1 fp16 LHS scale scalar value and replicate for VL-16-bit
286 " ld1rh z1.h, p2/z, [x3] \n"
287
288 // Increment the LHS scale pointer by 2 (1 x sizeof(fp16))
289 " add x3, x3, #2 \n"
290
291 // Load 2xVL-16bit (fp16) RHS scales.
292 // If VL=512bit, we load 64 fp16 values, which is equal to the number of output columns (n_step) processed
293 " .inst 0xa0402024 // ld1h { z4.h - z5.h }, pn8/z, [x1]\n"
294
295 // Increment the RHS scale pointer
296 " addvl x1, x1, #2 \n"
297
298 // Combine all the LHS and RHS scales
299 " .inst 0xc165d082 //zip { z2.h-z3.h }, z4.h, z5.h\n"
300 " movprfx z4, z28 \n"
301
302 // Multiply two half floating-point vectors and store the result
303 // to a floating-point 32-bit vector
304 " fmlalb z4.s, z1.h, z2.h\n"
305 " movprfx z5, z28 \n"
306 " fmlalb z5.s, z1.h, z3.h\n"
307 " movprfx z6, z28 \n"
308 " fmlalt z6.s, z1.h, z2.h\n"
309 " movprfx z7, z28 \n"
310 " fmlalt z7.s, z1.h, z3.h\n"
311
312 // Multiply the intermediate results by LHS_SCALE x RHS_SCALE
313 // and store in the main floating-point accumulator
314 " fmla z24.s, p2/m, z16.s, z4.s \n"
315 " fmla z25.s, p2/m, z17.s, z5.s \n"
316 " fmla z26.s, p2/m, z18.s, z6.s \n"
317 " fmla z27.s, p2/m, z19.s, z7.s \n"
318
319 // Increment the number of K values processed and
320 // go to the next block
321 " add x6, x6, %[bl] \n"
322 " whilelt p1.s, x6, %[k] \n"
323 " b.first 2b // .LOOP_K_START%= \n"
324 " 4: //.LOOP_K_END%=: \n"
325
326 // Store the results into memory
327 " .inst 0xa060c4b8 // st1w { z24.s-z27.s }, pn9, [x5] \n"
328 " incb x4, all \n"
329 " addvl x5, x5, #4 \n"
330
331 // The new rhs_packed pointer is the current rhs_scales pointer
332 // The new rhs_scales pointer is the current rhs_packed plus the rhs_packed_stride
333 " mov x7, x0 \n"
334
335 // Initialize the rhs_packed pointer
336 " mov x0, x1 \n"
337
338 // Initialize the rhs_scales pointer
339 " add x1, x7, %[rhs_packed_stride] \n"
340
341 " .inst 0x25b16491 // whilelt pn9.s, x4, %[n], VLx4 \n"
342
343 " b.first 1b // .LOOP_N_START%= \n"
344
345 " 5: // .LOOP_N_END%=: \n"
346
347 // Exit streaming mode
348 " .inst 0xd503467f //smstop \n"
349 :
350 14 : [lut] "r"(lut), [dst] "r"(dst), [rhs_packed] "r"(rhs_packed), [rhs_scales] "r"(rhs_scales),
351 14 [lhs_packed] "r"(lhs_packed), [lhs_scales] "r"(lhs_scales), [rhs_packed_stride] "r"(rhs_packed_stride),
352 14 [n] "r"((int64_t)n), [k] "r"(k), [bl] "i"(kai_bl)
353 : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0",
354 "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17",
355 "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x1",
356 "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x13", "x17", "memory", "cc");
357 14 }
358
359 #endif // Architectural features check.
360