Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) | ||
8 | #error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. | ||
9 | #else // Architectural features check. | ||
10 | |||
11 | #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" | ||
12 | |||
13 | #include <stddef.h> | ||
14 | #include <stdint.h> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | |||
18 | // Compute args | ||
19 | static const size_t kai_m_step = 1; // Multiple of vector length | ||
20 | static const size_t kai_n_step = 4; // Multiple of vector length | ||
21 | // Packing args | ||
22 | static const size_t kai_mr = 1; // Multiple of vector length | ||
23 | static const size_t kai_nr = 4; // Multiple of vector length | ||
24 | static const size_t kai_kr = 4; | ||
25 | static const size_t kai_sr = 2; | ||
26 | // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) | ||
27 | static const size_t kai_num_bytes_qvalue_lhs = 1; | ||
28 | static const size_t kai_num_bytes_multiplier_lhs = 2; | ||
29 | // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is | ||
30 | // asymmetric)) | ||
31 | static const size_t kai_recip_num_bytes_qvalue_rhs = 2; | ||
32 | static const size_t kai_num_bytes_multiplier_rhs = 2; | ||
33 | // DST format args | ||
34 | static const size_t kai_num_bytes_dst_value = 4; | ||
35 | // Extra args | ||
36 | static const size_t kai_bl = 32; | ||
37 | |||
38 | // Look-up table used for int4->int8 convert | ||
39 | static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}; | ||
40 | |||
41 | 70 | inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { | |
42 | 70 | return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs; | |
43 | } | ||
44 | |||
45 | 70 | inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { | |
46 | − | KAI_ASSUME(bl == kai_bl); | |
47 | 70 | size_t num_bytes_per_block_rhs = (bl / kai_recip_num_bytes_qvalue_rhs) + kai_num_bytes_multiplier_rhs; | |
48 | 140 | return num_bytes_per_block_rhs; | |
49 | 70 | } | |
50 | |||
51 | 164 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
52 | − | KAI_ASSUME(bl == kai_bl); | |
53 | − | KAI_ASSUME((k % kai_bl) == 0); | |
54 | |||
55 | 164 | return kai_roundup(k, bl) / bl; | |
56 | } | ||
57 | |||
58 | 70 | inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { | |
59 | 70 | const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
60 | 140 | return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); | |
61 | 70 | } | |
62 | |||
63 | 70 | inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { | |
64 | − | KAI_ASSUME(bl == kai_bl); | |
65 | − | KAI_ASSUME((k % kai_bl) == 0); | |
66 | |||
67 | 70 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
68 | 70 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); | |
69 | 70 | const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
70 | |||
71 | 70 | size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row); | |
72 | |||
73 | 140 | return rhs_packed_stride; | |
74 | 70 | } | |
75 | |||
76 | 141 | size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
77 | 141 | return kai_m_step * kai_get_sme_vector_length_u32(); | |
78 | } | ||
79 | |||
80 | 141 | size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
81 | 141 | return kai_n_step * kai_get_sme_vector_length_u32(); | |
82 | } | ||
83 | |||
84 | 188 | size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
85 | 188 | return kai_mr * kai_get_sme_vector_length_u32(); | |
86 | } | ||
87 | |||
88 | 188 | size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
89 | 188 | return kai_nr * kai_get_sme_vector_length_u32(); | |
90 | } | ||
91 | |||
92 | 72 | size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
93 | 72 | return kai_kr; | |
94 | } | ||
95 | |||
96 | 48 | size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
97 | 48 | return kai_sr; | |
98 | } | ||
99 | |||
100 | 46 | size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( | |
101 | size_t m_idx, size_t k, size_t bl) { | ||
102 | 46 | const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
103 | 46 | const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
104 | − | KAI_ASSUME((m_idx % m_step) == 0); | |
105 | |||
106 | 92 | return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl); | |
107 | 46 | } | |
108 | |||
109 | 46 | size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( | |
110 | size_t n_idx, size_t k, size_t bl) { | ||
111 | 46 | const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
112 | 46 | const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
113 | |||
114 | − | KAI_ASSUME((n_idx % n_step) == 0); | |
115 | |||
116 | 92 | return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl); | |
117 | 46 | } | |
118 | |||
119 | 23 | size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( | |
120 | size_t m_idx, size_t n_idx, size_t dst_stride) { | ||
121 | 23 | const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
122 | 23 | const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
123 | − | KAI_ASSUME((m_idx % m_step) == 0); | |
124 | − | KAI_ASSUME((n_idx % n_step) == 0); | |
125 | |||
126 | 46 | return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; | |
127 | 23 | } | |
128 | |||
129 | 23 | size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(size_t m, size_t n) { | |
130 | 23 | return m * n * kai_num_bytes_dst_value; | |
131 | } | ||
132 | |||
133 | 24 | void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( | |
134 | size_t m, // | ||
135 | size_t n, // | ||
136 | size_t k, // | ||
137 | size_t bl, // | ||
138 | const void* restrict lhs_packed, // | ||
139 | const void* restrict rhs_packed, // | ||
140 | float* restrict dst, // NOLINT(readability-non-const-parameter) | ||
141 | size_t dst_stride_row, // | ||
142 | size_t dst_stride_col, // | ||
143 | float scalar_min, // | ||
144 | float scalar_max) { | ||
145 | − | KAI_ASSUME(dst_stride_col == sizeof(float)); | |
146 | |||
147 | 24 | KAI_UNUSED(scalar_min); | |
148 | 24 | KAI_UNUSED(scalar_max); | |
149 | |||
150 |
1/2✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
|
24 | if (m == 0) { |
151 | ✗ | return; | |
152 | } | ||
153 | |||
154 | typedef struct { | ||
155 | size_t lhs_packed_stride; | ||
156 | size_t rhs_packed_stride; | ||
157 | size_t mr; | ||
158 | } KernelArgs; | ||
159 | |||
160 | 24 | KernelArgs ka; | |
161 | |||
162 | 24 | const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); | |
163 | |||
164 | 24 | const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
165 | 24 | const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
166 | |||
167 | 24 | ka.mr = mr; | |
168 | 24 | ka.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl); | |
169 | 24 | ka.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); | |
170 | |||
171 | 48 | const uint16_t* lhs_scales = (const uint16_t*)((const int8_t*)lhs_packed + ka.lhs_packed_stride - | |
172 | 24 | (mr * num_blocks) * kai_num_bytes_multiplier_lhs); | |
173 | 48 | const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + ka.rhs_packed_stride - | |
174 | 24 | (nr * num_blocks) * kai_num_bytes_multiplier_rhs); | |
175 | |||
176 | 48 | __asm__ volatile( | |
177 | // Switch to streaming mode with ZA enabling | ||
178 | " .inst 0xd503477f // smstart \n" | ||
179 | |||
180 | // Constants | ||
181 | // - SVLs | ||
182 | " cntw x14 \n" | ||
183 | // - ptrue | ||
184 | " ptrue p0.b, all \n" | ||
185 | " .inst 0x25a07810 // ptrue pn8.s \n" | ||
186 | |||
187 | // Predicate for loading fp16 scaling factors | ||
188 | " ldr x5, [%x[args_ptr], %[offset_mr]]\n" | ||
189 | " lsl x5, x5, #1 \n" | ||
190 | " whilelt p4.b, xzr, x5 \n" | ||
191 | |||
192 | // Initialize ZT0 (Lookup table) | ||
193 | " mov x6, %[lut]\n" | ||
194 | " .inst 0xe11f80c0 // ldr zt0, [x6] \n" | ||
195 | |||
196 | // Initialize the RHS packes and scale pointers | ||
197 | " mov x16, %[rhs_packed] \n" | ||
198 | " mov x17, %[rhs_scales] \n" | ||
199 | |||
200 | // Iterate over n (x8) | ||
201 | // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) | ||
202 | " mov x8, #0 \n" | ||
203 | " mov x0, %[N] \n" | ||
204 | " .inst 0x25a06511 // whilelt pn9.s, x8, x0, VLx4 \n" | ||
205 | |||
206 | " b.none 9f // .LOOP_N_END%= \n" | ||
207 | |||
208 | " 1: // .LOOP_N_START%=: \n" | ||
209 | |||
210 | // Iterate over m (x9) | ||
211 | // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) | ||
212 | " mov x9, %[M] \n" | ||
213 | |||
214 | // Initialize the LHS packed and scale pointers | ||
215 | " mov x22, %[lhs_packed] \n" | ||
216 | " mov x23, %[lhs_scales] \n" | ||
217 | |||
218 | // Initialize the DST pointer | ||
219 | " mov x24, %[dst] \n" | ||
220 | |||
221 | " 2: // .LOOP_M_START%=: \n" | ||
222 | |||
223 | // Address offset for the left and right quantized values | ||
224 | " mov x20, #0 \n" | ||
225 | " mov x21, #0 \n" | ||
226 | |||
227 | // Number of output rows to store -> min(SVLh, loop M index) | ||
228 | " cmp x9, x14 \n" | ||
229 | " csel x15, x9, x14, lo \n" | ||
230 | " lsl x15, x15, #2 \n" | ||
231 | |||
232 | // Iterate over all K values | ||
233 | // e.g. for(k_idx = 0; k_idx < k; k_idx += bl) | ||
234 | " mov x10, %[K] \n" | ||
235 | |||
236 | // Skip processing if K=0 | ||
237 | " cmp x10, #0 \n" | ||
238 | " b.eq 8f // .LOOP_K_END%= \n" | ||
239 | |||
240 | " 3: // .LOOP_K_START%=: \n" | ||
241 | |||
242 | // Zeroing of ZA accumulator | ||
243 | " .inst 0xc00800ff // zero {za} \n" | ||
244 | |||
245 | // Load the fp16 scaling factors for the right matrix block | ||
246 | " .inst 0xa0154220 // ld1w {z0.s - z1.s}, pn8/z, [x17], x21, lsl #2] \n" | ||
247 | " .inst 0xc161d000 // zip {z0.h - z1.h}, z0.h, z1.h \n" | ||
248 | |||
249 | // Iterate over all values in the block | ||
250 | // k_blk_idx = bl | ||
251 | // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 4} | ||
252 | " mov x11, #32\n" | ||
253 | |||
254 | " 4: // .LOOP_BL_START%=: \n" | ||
255 | |||
256 | // Load right matrix row | ||
257 | " .inst 0xa0144202 // ld1w {z2.s - z3.s}, pn8/z, [x16], x20, lsl #2] \n" | ||
258 | |||
259 | // Load left matrix column | ||
260 | " ld1h {z8.h}, p0/z, [x22, x20, lsl #1] \n" | ||
261 | " inch x20, all \n" | ||
262 | |||
263 | // Convert Int4 -> Int8 | ||
264 | " .inst 0xc08a4044 // luti4 {z4.b - z5.b}, zt0, z2[0] \n" | ||
265 | " .inst 0xc08a4066 // luti4 {z6.b - z7.b}, zt0, z3[0] \n" | ||
266 | |||
267 | // Outer-products | ||
268 | " .inst 0xa0840100 // smopa za0.s, p0/m, p0/m, z8.b, z4.b \n" | ||
269 | " .inst 0xa0850101 // smopa za1.s, p0/m, p0/m, z8.b, z5.b \n" | ||
270 | " .inst 0xa0860102 // smopa za2.s, p0/m, p0/m, z8.b, z6.b \n" | ||
271 | " .inst 0xa0870103 // smopa za3.s, p0/m, p0/m, z8.b, z7.b \n" | ||
272 | |||
273 | // Decrement the block loop index | ||
274 | " subs x11, x11, #4 \n" | ||
275 | |||
276 | " b.gt 4b // .LOOP_BL_START%= \n" | ||
277 | |||
278 | // === End of the block loop === | ||
279 | |||
280 | // Store loop index | ||
281 | " mov w12, #0 \n" | ||
282 | |||
283 | // Copy destination pointer for store loop | ||
284 | " mov x25, x24 \n" | ||
285 | |||
286 | // Load the fp16 scaling factors for the left matrix block | ||
287 | " ld1b {z16.b}, p4/z, [x23, x21] \n" | ||
288 | " inch x21, all \n" | ||
289 | |||
290 | // Predicate for the selection of a scaling among the vector | ||
291 | " pfalse p3.b \n" | ||
292 | |||
293 | " 5: // .LOOP_ZA%=: \n" | ||
294 | |||
295 | // Select and replicate scaling factor for the right block | ||
296 | " pnext p3.h, p0, p3.h \n" | ||
297 | " clastb z19.h, p3, z19.h, z16.h \n" | ||
298 | |||
299 | // Get data from za | ||
300 | " .inst 0xc006041c // mova {z28.b-z31.b}, za0h.b[w12, 0:3] \n" | ||
301 | " add w12, w12, #4 \n" | ||
302 | |||
303 | // Convert from int32 to fp32 | ||
304 | " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n" | ||
305 | |||
306 | // Multiply left and right scaling factors | ||
307 | " movprfx z8, z18 \n" | ||
308 | " fmlalb z8.s, z19.h, z0.h \n" | ||
309 | " movprfx z9, z18 \n" | ||
310 | " fmlalb z9.s, z19.h, z1.h \n" | ||
311 | " movprfx z10, z18 \n" | ||
312 | " fmlalt z10.s, z19.h, z0.h \n" | ||
313 | " movprfx z11, z18 \n" | ||
314 | " fmlalt z11.s, z19.h, z1.h \n" | ||
315 | |||
316 | " cmp x10, %[K] \n" | ||
317 | " b.ne 6f // .ACCUMULATE%= \n" | ||
318 | |||
319 | // Applying combined scaling factors to processed block | ||
320 | " fmul z24.s, z8.s, z28.s \n" | ||
321 | " fmul z25.s, z9.s, z29.s \n" | ||
322 | " fmul z26.s, z10.s, z30.s \n" | ||
323 | " fmul z27.s, z11.s, z31.s \n" | ||
324 | |||
325 | "b 7f // .STORE%= \n" | ||
326 | |||
327 | " 6: // .ACCUMULATE%=: \n" | ||
328 | // Load intermediate result | ||
329 | " .inst 0xa040c738 // ld1w {z24.s-z27.s}, pn9/z, [x25] \n" | ||
330 | |||
331 | // Multiply the intermediate results by LHS_SCALE x RHS_SCALE | ||
332 | // and store in the main floating-point accumulator | ||
333 | " fmla z24.s, p0/m, z8.s, z28.s \n" | ||
334 | " fmla z25.s, p0/m, z9.s, z29.s \n" | ||
335 | " fmla z26.s, p0/m, z10.s, z30.s \n" | ||
336 | " fmla z27.s, p0/m, z11.s, z31.s \n" | ||
337 | |||
338 | "7: // .STORE%=: \n" | ||
339 | // Store the results into memory | ||
340 | " .inst 0xa060c738 // st1w {z24.s-z27.s}, pn9, [x25] \n" | ||
341 | " add x25, x25, %[stride] \n" | ||
342 | |||
343 | " cmp x12, x15 \n" | ||
344 | " blt 5b // .LOOP_ZA%= \n" | ||
345 | |||
346 | // Decrement K loop index by bl | ||
347 | " subs x10, x10, #32 \n" | ||
348 | |||
349 | " b.gt 3b // .LOOP_K_START%= \n" | ||
350 | |||
351 | " 8: // .LOOP_K_END%=: \n" | ||
352 | |||
353 | // === End of the K loop === | ||
354 | |||
355 | " ldr x5, [%x[args_ptr], %[offset_stride_l]] \n" | ||
356 | |||
357 | // Increment pointer to the quantized values of the right matrix | ||
358 | " add x22, x22, x5\n" | ||
359 | |||
360 | // Increment pointer to the scaling factors of the right matrix | ||
361 | " add x23, x23, x5 \n" | ||
362 | |||
363 | // Update destination pointer | ||
364 | " mov x24, x25 \n" | ||
365 | |||
366 | // Decrement M loop index | ||
367 | " decw x9, all \n" | ||
368 | |||
369 | " cmp x9, #0 \n" | ||
370 | " b.gt 2b // .LOOP_M_START%= \n" | ||
371 | |||
372 | // === End of M loop === | ||
373 | |||
374 | // Increment output pointer | ||
375 | " incb %[dst], all, mul #4 \n" | ||
376 | |||
377 | " ldr x5, [%x[args_ptr], %[offset_stride_r]]\n" | ||
378 | |||
379 | " add x16, x16, x5 \n" | ||
380 | " add x17, x17, x5 \n" | ||
381 | |||
382 | // Increment N loop index | ||
383 | " incb x8, all \n" | ||
384 | |||
385 | " .inst 0x25a06511 // whilelt pn9.s, x8, %[N], VLx4 \n" | ||
386 | |||
387 | " b.first 1b // .LOOP_N_START%= \n" | ||
388 | |||
389 | " 9: // .LOOP_N_END%=: \n" | ||
390 | |||
391 | // === End of N loop === | ||
392 | |||
393 | // Exit streaming mode | ||
394 | " .inst 0xd503467f // smstop \n" | ||
395 | : [dst] "+r"(dst), [rhs_packed] "+r"(rhs_packed), [rhs_scales] "+r"(rhs_scales) | ||
396 | 24 | : [M] "r"(m), [N] "r"(n), [K] "r"(k), [lhs_packed] "r"(lhs_packed), [lhs_scales] "r"(lhs_scales), | |
397 | 24 | [stride] "r"(dst_stride_row), [lut] "r"(lut), [args_ptr] "r"(&ka), | |
398 | [offset_stride_l] "I"(offsetof(KernelArgs, lhs_packed_stride)), | ||
399 | [offset_stride_r] "I"(offsetof(KernelArgs, rhs_packed_stride)), [offset_mr] "I"(offsetof(KernelArgs, mr)) | ||
400 | : "p0", "p1", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", "z1", | ||
401 | "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", | ||
402 | "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x5", "x6", | ||
403 | "x8", "x9", "x10", "x11", "x12", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", | ||
404 | "memory", "cc"); | ||
405 | 24 | } | |
406 | |||
407 | #endif // Architectural features check. | ||
408 |