Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) | ||
8 | #error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. | ||
9 | #else // Architectural features check. | ||
10 | |||
11 | #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" | ||
12 | |||
13 | #include <stddef.h> | ||
14 | #include <stdint.h> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | |||
18 | // Compute args | ||
19 | static const size_t kai_m_step = 1; | ||
20 | static const size_t kai_n_step = 4; // Multiple of vector length | ||
21 | // Packing args | ||
22 | static const size_t kai_mr = 1; | ||
23 | static const size_t kai_nr = 4; // Multiple of vector length | ||
24 | static const size_t kai_kr = 4; | ||
25 | static const size_t kai_sr = 2; | ||
26 | // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) | ||
27 | static const size_t kai_num_bytes_qvalue_lhs = 1; | ||
28 | static const size_t kai_num_bytes_multiplier_lhs = 2; | ||
29 | // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is | ||
30 | // asymmetric)) | ||
31 | static const size_t kai_recip_num_bytes_qvalue_rhs = 2; | ||
32 | static const size_t kai_num_bytes_multiplier_rhs = 2; | ||
33 | // DST format args | ||
34 | static const size_t kai_num_bytes_dst_value = 4; | ||
35 | // Extra args | ||
36 | static const size_t kai_bl = 32; | ||
37 | |||
38 | // Look-up table used for int4->int8 convert | ||
39 | static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}; | ||
40 | |||
41 | 30 | inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { | |
42 | 30 | return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs; | |
43 | } | ||
44 | |||
45 | 30 | inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { | |
46 | − | KAI_ASSUME(bl == kai_bl); | |
47 | 30 | size_t num_bytes_per_block_rhs = (bl / kai_recip_num_bytes_qvalue_rhs) + kai_num_bytes_multiplier_rhs; | |
48 | 60 | return num_bytes_per_block_rhs; | |
49 | 30 | } | |
50 | |||
51 | 64 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
52 | − | KAI_ASSUME(bl == kai_bl); | |
53 | − | KAI_ASSUME((k % kai_bl) == 0); | |
54 | |||
55 | 64 | return kai_roundup(k, bl) / bl; | |
56 | } | ||
57 | |||
58 | 30 | inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { | |
59 | 30 | const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
60 | 60 | return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); | |
61 | 30 | } | |
62 | |||
63 | 30 | inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { | |
64 | − | KAI_ASSUME(bl == kai_bl); | |
65 | − | KAI_ASSUME((k % kai_bl) == 0); | |
66 | |||
67 | 30 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
68 | 30 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); | |
69 | 30 | const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
70 | |||
71 | 30 | size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row); | |
72 | |||
73 | 60 | return rhs_packed_stride; | |
74 | 30 | } | |
75 | |||
76 | 81 | size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) { | |
77 | 81 | return kai_m_step; | |
78 | } | ||
79 | |||
80 | 81 | size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) { | |
81 | 81 | return kai_n_step * kai_get_sme_vector_length_u32(); | |
82 | } | ||
83 | |||
84 | 108 | size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) { | |
85 | 108 | return kai_mr; | |
86 | } | ||
87 | |||
88 | 108 | size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) { | |
89 | 108 | return kai_nr * kai_get_sme_vector_length_u32(); | |
90 | } | ||
91 | |||
92 | 72 | size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) { | |
93 | 72 | return kai_kr; | |
94 | } | ||
95 | |||
96 | 48 | size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(void) { | |
97 | 48 | return kai_sr; | |
98 | } | ||
99 | |||
100 | 26 | size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot( | |
101 | size_t m_idx, size_t k, size_t bl) { | ||
102 | 26 | const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
103 | 26 | const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
104 | |||
105 | − | KAI_ASSUME((m_idx % m_step) == 0); | |
106 | |||
107 | 52 | return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl); | |
108 | 26 | } | |
109 | |||
110 | 26 | size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot( | |
111 | size_t n_idx, size_t k, size_t bl) { | ||
112 | 26 | const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
113 | 26 | const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
114 | |||
115 | − | KAI_ASSUME((n_idx % nr) == 0); | |
116 | |||
117 | 52 | return (n_idx / n_step) * kai_get_rhs_packed_stride(k, bl); | |
118 | 26 | } | |
119 | |||
120 | 3 | size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot( | |
121 | size_t m_idx, size_t n_idx, size_t dst_stride) { | ||
122 | 3 | const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
123 | 3 | const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
124 | − | KAI_ASSUME((m_idx % m_step) == 0); | |
125 | − | KAI_ASSUME((n_idx % n_step) == 0); | |
126 | |||
127 | 6 | return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; | |
128 | 3 | } | |
129 | |||
130 | 3 | size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) { | |
131 | 3 | return m * n * kai_num_bytes_dst_value; | |
132 | } | ||
133 | |||
134 | 4 | void kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot( | |
135 | size_t m, // | ||
136 | size_t n, // | ||
137 | size_t k, // | ||
138 | size_t bl, // | ||
139 | const void* restrict lhs_packed, // | ||
140 | const void* restrict rhs_packed, // | ||
141 | float* restrict dst, // NOLINT(readability-non-const-parameter) | ||
142 | size_t dst_stride_row, // | ||
143 | size_t dst_stride_col, // | ||
144 | float scalar_min, // | ||
145 | float scalar_max) { | ||
146 | − | KAI_ASSUME(dst_stride_col == sizeof(float)); | |
147 | − | KAI_ASSUME(m == 1); | |
148 | |||
149 | 4 | KAI_UNUSED(dst_stride_row); | |
150 | 4 | KAI_UNUSED(scalar_min); | |
151 | 4 | KAI_UNUSED(scalar_max); | |
152 | |||
153 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | if (m == 0) { |
154 | ✗ | return; | |
155 | } | ||
156 | |||
157 | 4 | const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, bl); | |
158 | 4 | const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); | |
159 | 4 | const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); | |
160 | |||
161 | 4 | const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
162 | 4 | const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); | |
163 | |||
164 | 8 | const uint16_t* lhs_scales = (const uint16_t*)((const int8_t*)lhs_packed + lhs_packed_stride - | |
165 | 4 | (mr * num_blocks) * kai_num_bytes_multiplier_lhs); | |
166 | 8 | const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + rhs_packed_stride - | |
167 | 4 | (nr * num_blocks) * kai_num_bytes_multiplier_rhs); | |
168 | |||
169 | 4 | __asm__ volatile( | |
170 | // Switch to streaming mode with ZA enabling | ||
171 | " .inst 0xd503477f // smstart \n" | ||
172 | |||
173 | " ptrue p2.b, all \n" | ||
174 | " .inst 0x25607810 // ptrue pn8.h \n" | ||
175 | |||
176 | " fmov z28.s, #0.0 \n" | ||
177 | |||
178 | // Initialize ZT0 (Lookup table) | ||
179 | " mov x9, %[lut] \n" | ||
180 | " .inst 0xe11f8120 // ldr zt0, [x9] \n" | ||
181 | |||
182 | // Initialize the RHS packed and scale pointers | ||
183 | " mov x0, %[rhs_packed] \n" | ||
184 | " mov x1, %[rhs_scales] \n" | ||
185 | |||
186 | // Initialize the DST pointer | ||
187 | " mov x5, %[dst] \n" | ||
188 | |||
189 | // Iterate over n (x0) | ||
190 | // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) | ||
191 | " mov x4, #0\n" | ||
192 | " mov x17, %[n] \n" | ||
193 | " .inst 0x25b16491 // whilelt pn9.s, x4, x17, VLx4 \n" | ||
194 | |||
195 | " b.none 5f // .LOOP_N_END%= \n" | ||
196 | |||
197 | " 1: // .LOOP_N_START%=: \n" | ||
198 | |||
199 | // Initialize the LHS packed and scale pointers | ||
200 | " mov x2, %[lhs_packed] \n" | ||
201 | " mov x3, %[lhs_scales] \n" | ||
202 | |||
203 | // Initialize the 4xVL-32bit accumulators to zero | ||
204 | " dup z24.s, #0 \n" | ||
205 | " dup z25.s, #0 \n" | ||
206 | " dup z26.s, #0 \n" | ||
207 | " dup z27.s, #0 \n" | ||
208 | |||
209 | // Initialize the vector selector for ZA array | ||
210 | " mov w8, #0 \n" | ||
211 | |||
212 | // Iterate over all K values | ||
213 | // e.g. for(k_idx = 0; k_idx < k; k_idx += bl) | ||
214 | " mov x6, #0 \n" | ||
215 | " whilelt p1.s, x6, %[k] \n" | ||
216 | " b.none 4f // .LOOP_K_END%= \n" | ||
217 | |||
218 | " 2: // .LOOP_K_START%=: \n" | ||
219 | // Zeroing of inner accumulation array | ||
220 | " .inst 0xc00800ff // zero {za} \n" | ||
221 | |||
222 | // Iterate over all values in the block | ||
223 | // k_blk_idx = bl | ||
224 | // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 16} | ||
225 | |||
226 | "mov x13, %[bl] \n" | ||
227 | |||
228 | "3: // .LOOP_BL_START%=: \n" | ||
229 | // Load the LHS (int8) quantized values | ||
230 | // Load contiguous 16 bytes and replicate. | ||
231 | // For GeMV, we do not interleave the LHS M rows. | ||
232 | " ld1rqb { z0.b }, p2/z , [x2] \n" | ||
233 | " add x2, x2, #16 \n" | ||
234 | |||
235 | // -- First half | ||
236 | // Load the RHS (int4) quantized values | ||
237 | " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0] \n" | ||
238 | |||
239 | // Increment the RHS pointer | ||
240 | " addvl x0, x0, #4 \n" | ||
241 | |||
242 | // Convert Int4 -> Int8 | ||
243 | " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n" | ||
244 | " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n" | ||
245 | " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n" | ||
246 | " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n" | ||
247 | |||
248 | // SDOT indexed | ||
249 | " .inst 0xc15090a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[0] \n" | ||
250 | " .inst 0xc1509520 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[1] \n" | ||
251 | |||
252 | // -- Second half | ||
253 | |||
254 | // Load the RHS (int4) quantized values | ||
255 | " .inst 0xa040a00c // ld1h { z12.h - z15.h }, pn8/z, [x0]\n" | ||
256 | |||
257 | // Increment the RHS scale pointer | ||
258 | " addvl x0, x0, #4 \n" | ||
259 | |||
260 | // Convert Int4 -> Int8 | ||
261 | " .inst 0xc08a4184 // luti4 { z4.b, z5.b }, zt0, z12[0] \n" | ||
262 | " .inst 0xc08a41a6 // luti4 { z6.b, z7.b }, zt0, z13[0] \n" | ||
263 | " .inst 0xc08a41c8 // luti4 { z8.b, z9.b }, zt0, z14[0] \n" | ||
264 | " .inst 0xc08a41ea // luti4 { z10.b, z11.b }, zt0, z15[0] \n" | ||
265 | |||
266 | // SDOT indexed | ||
267 | " .inst 0xc15098a0 // sdot za.s[w8, 0, vgx4], {z4.b - z7.b}, z0.b[2] \n" | ||
268 | " .inst 0xc1509d20 // sdot za.s[w8, 0, vgx4], {z8.b - z11.b}, z0.b[3] \n" | ||
269 | |||
270 | // Decrement the block loop index | ||
271 | "subs x13, x13, #16 \n" | ||
272 | |||
273 | "b.gt 3b // .LOOP_BL_START%= \n" | ||
274 | |||
275 | // === End of the block loop === | ||
276 | |||
277 | // Load Z registers with intermediate values from ZA array | ||
278 | " .inst 0xc0060c10 // mova {z16.s - z19.s}, za.s[w8, 0, vgx4] \n" | ||
279 | |||
280 | // Convert from int32 to float32 | ||
281 | " .inst 0xc132e210 // scvtf{z16.s - z19.s}, {z16.s - z19.s} \n" | ||
282 | |||
283 | // Load 1 fp16 LHS scale scalar value and replicate for VL-16-bit | ||
284 | " ld1rh z1.h, p2/z, [x3] \n" | ||
285 | |||
286 | // Increment the LHS scale pointer by 2 (1 x sizeof(fp16)) | ||
287 | " add x3, x3, #2 \n" | ||
288 | |||
289 | // Load 2xVL-16bit (fp16) RHS scales. | ||
290 | // If VL=512bit, we load 64 fp16 values, which is equal to the number of output columns (n_step) processed | ||
291 | " .inst 0xa0402024 // ld1h { z4.h - z5.h }, pn8/z, [x1]\n" | ||
292 | |||
293 | // Increment the RHS scale pointer | ||
294 | " addvl x1, x1, #2 \n" | ||
295 | |||
296 | // Combine all the LHS and RHS scales | ||
297 | " .inst 0xc165d082 //zip { z2.h-z3.h }, z4.h, z5.h\n" | ||
298 | " movprfx z4, z28 \n" | ||
299 | |||
300 | // Multiply two half floating-point vectors and store the result | ||
301 | // to a floating-point 32-bit vector | ||
302 | " fmlalb z4.s, z1.h, z2.h\n" | ||
303 | " movprfx z5, z28 \n" | ||
304 | " fmlalb z5.s, z1.h, z3.h\n" | ||
305 | " movprfx z6, z28 \n" | ||
306 | " fmlalt z6.s, z1.h, z2.h\n" | ||
307 | " movprfx z7, z28 \n" | ||
308 | " fmlalt z7.s, z1.h, z3.h\n" | ||
309 | |||
310 | // Multiply the intermediate results by LHS_SCALE x RHS_SCALE | ||
311 | // and store in the main floating-point accumulator | ||
312 | " fmla z24.s, p2/m, z16.s, z4.s \n" | ||
313 | " fmla z25.s, p2/m, z17.s, z5.s \n" | ||
314 | " fmla z26.s, p2/m, z18.s, z6.s \n" | ||
315 | " fmla z27.s, p2/m, z19.s, z7.s \n" | ||
316 | |||
317 | // Increment the number of K values processed and | ||
318 | // go to the next block | ||
319 | " add x6, x6, %[bl] \n" | ||
320 | " whilelt p1.s, x6, %[k] \n" | ||
321 | " b.first 2b // .LOOP_K_START%= \n" | ||
322 | " 4: //.LOOP_K_END%=: \n" | ||
323 | |||
324 | // Store the results into memory | ||
325 | " .inst 0xa060c4b8 // st1w { z24.s-z27.s }, pn9, [x5] \n" | ||
326 | " incb x4, all \n" | ||
327 | " addvl x5, x5, #4 \n" | ||
328 | |||
329 | // The new rhs_packed pointer is the current rhs_scales pointer | ||
330 | // The new rhs_scales pointer is the current rhs_packed plus the rhs_packed_stride | ||
331 | " mov x7, x0 \n" | ||
332 | |||
333 | // Initialize the rhs_packed pointer | ||
334 | " mov x0, x1 \n" | ||
335 | |||
336 | // Initialize the rhs_scales pointer | ||
337 | " add x1, x7, %[rhs_packed_stride] \n" | ||
338 | |||
339 | " .inst 0x25b16491 // whilelt pn9.s, x4, %[n], VLx4 \n" | ||
340 | |||
341 | " b.first 1b // .LOOP_N_START%= \n" | ||
342 | |||
343 | " 5: // .LOOP_N_END%=: \n" | ||
344 | |||
345 | // Exit streaming mode | ||
346 | " .inst 0xd503467f //smstop \n" | ||
347 | : | ||
348 | 4 | : [lut] "r"(lut), [dst] "r"(dst), [rhs_packed] "r"(rhs_packed), [rhs_scales] "r"(rhs_scales), | |
349 | 4 | [lhs_packed] "r"(lhs_packed), [lhs_scales] "r"(lhs_scales), [rhs_packed_stride] "r"(rhs_packed_stride), | |
350 | 4 | [n] "r"((int64_t)n), [k] "r"(k), [bl] "i"(kai_bl) | |
351 | : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", | ||
352 | "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", | ||
353 | "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x1", | ||
354 | "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x13", "x17", "memory", "cc"); | ||
355 | 4 | } | |
356 | |||
357 | #endif // Architectural features check. | ||
358 |