Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | // Do not flag up inline assembly blocks | ||
8 | #pragma GCC diagnostic ignored "-Woverlength-strings" | ||
9 | |||
10 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) | ||
11 | #error This file must be compiled for AArch64, FEAT_SVE2. | ||
12 | #else // Architectural features check | ||
13 | |||
14 | #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h" | ||
15 | |||
16 | #include <stddef.h> | ||
17 | #include <stdint.h> | ||
18 | |||
19 | #include "kai/kai_common.h" | ||
20 | |||
21 | static const size_t kai_m_step = 1; | ||
22 | static const size_t kai_n_step = 1; | ||
23 | static const size_t kai_mr = 1; | ||
24 | static const size_t kai_nr = 4; // multiple of vector length | ||
25 | static const size_t kai_kr = 4; | ||
26 | static const size_t kai_sr = 1; | ||
27 | static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); | ||
28 | static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); | ||
29 | static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); | ||
30 | static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); | ||
31 | static const size_t kai_num_bytes_bias_rhs = sizeof(float); | ||
32 | static const size_t kai_k_multiple_of = 32; | ||
33 | |||
34 | 483 | inline static size_t kai_k_roundedup(size_t k) { | |
35 | // Round up k to be a multiple of 32. | ||
36 | 483 | return kai_roundup(k, kai_k_multiple_of); | |
37 | } | ||
38 | |||
39 | 201 | inline static size_t kai_get_lhs_packed_stride(size_t k) { | |
40 | 201 | const size_t k_internal = kai_k_roundedup(k); | |
41 | |||
42 | − | KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); | |
43 | |||
44 | 603 | return kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() * | |
45 | 201 | (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); | |
46 | 201 | } | |
47 | |||
48 | 201 | inline static size_t kai_get_rhs_packed_stride(size_t k) { | |
49 | 201 | const size_t k_internal = kai_k_roundedup(k); | |
50 | |||
51 | − | KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); | |
52 | |||
53 | 603 | return kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() * | |
54 | 201 | ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs); | |
55 | 201 | } | |
56 | |||
57 | 160 | size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
58 | 160 | return kai_m_step; | |
59 | } | ||
60 | |||
61 | 682 | size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
62 | 682 | return kai_nr * kai_get_sme_vector_length_u32(); | |
63 | } | ||
64 | |||
65 | 160 | size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
66 | 160 | return kai_n_step * kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(); | |
67 | } | ||
68 | |||
69 | 321 | size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
70 | // For gemv mr must be 1 to consecutively read the data | ||
71 | 321 | return kai_mr; | |
72 | } | ||
73 | |||
74 | 160 | size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
75 | 160 | return kai_kr; | |
76 | } | ||
77 | |||
78 | 160 | size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
79 | 160 | return kai_sr; | |
80 | } | ||
81 | |||
82 | 120 | size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t m_idx, size_t k) { | |
83 | − | KAI_ASSERT((m_idx % kai_m_step) == 0); | |
84 | |||
85 | 120 | return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); | |
86 | } | ||
87 | |||
88 | 120 | size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t n_idx, size_t k) { | |
89 | − | KAI_ASSERT((n_idx % kai_n_step) == 0); | |
90 | 120 | const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(); | |
91 | 240 | return (n_idx / nr) * kai_get_rhs_packed_stride(k); | |
92 | 120 | } | |
93 | |||
94 | 80 | size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot( | |
95 | size_t m_idx, size_t n_idx, size_t dst_stride) { | ||
96 | − | KAI_ASSERT((m_idx % kai_m_step) == 0); | |
97 | − | KAI_ASSERT((n_idx % kai_n_step) == 0); | |
98 | |||
99 | 80 | return (n_idx * sizeof(float)) + (m_idx * dst_stride); | |
100 | } | ||
101 | |||
102 | 80 | size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) { | |
103 | 80 | return m * n * sizeof(float); | |
104 | } | ||
105 | |||
106 | /// Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110). | ||
107 | static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, | ||
108 | 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0, | ||
109 | -5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0}; | ||
110 | |||
111 | // Optimized for GEMV (matrix vector multiplication => m == 1). | ||
112 | // Does a matmul for compatibility reasons, but should not be used that way. | ||
113 | 81 | void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot( | |
114 | size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, | ||
115 | float* dst, // NOLINT(readability-non-const-parameter) | ||
116 | size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { | ||
117 | − | KAI_ASSERT(dst_stride_col == sizeof(float)); | |
118 | |||
119 |
3/6✓ Branch 0 taken 81 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 81 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 81 times.
|
81 | if (m == 0 || n == 0 || k == 0) { |
120 | ✗ | return; | |
121 | } | ||
122 | |||
123 | // Do function calls and calculations first to not overwrite registers we will use | ||
124 | 81 | uint64_t k_internal = kai_k_roundedup(k); | |
125 | 81 | uint64_t lhs_stride = kai_get_lhs_packed_stride(k); | |
126 | 81 | uint64_t rhs_stride = kai_get_rhs_packed_stride(k); | |
127 | 81 | uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(); | |
128 | |||
129 | 81 | uint64_t rhs_row_bytes = nr * k_internal / 2; | |
130 | 81 | uint64_t lhs_end_ptr = ((uint64_t)lhs_packed) + (m * lhs_stride); | |
131 | |||
132 | /* | ||
133 | * x11: zero = 0 // MUST BE x8-x11 | ||
134 | * x15: n initialized as n | ||
135 | * x19: nr initialized as nr | ||
136 | * x20: lut_ptr initialized as lut | ||
137 | * x21: lhs_packed initialized as lhs_packed | ||
138 | * x22: n_idx | ||
139 | * x23: k_idx | ||
140 | * x24: RHS block ptr | ||
141 | * x25: RHS end ptr | ||
142 | * x26: rhs_packed | ||
143 | * x27: dst_ptr | ||
144 | * x28: tmp_1 | ||
145 | */ | ||
146 | |||
147 | 162 | __asm__ volatile( | |
148 | |||
149 | // Setup | ||
150 | " .inst 0xd503477f // smstart \n" | ||
151 | " mov x11, #0 \n" | ||
152 | " mov x15, %[n] \n" | ||
153 | " mov x19, %[nr] \n" | ||
154 | " mov x21, %[lhs_packed] \n" | ||
155 | " mov x20, %[lut] \n" | ||
156 | " .inst 0xe11f8280 // ldr zt0, [x20] \n" | ||
157 | " ptrue p0.b \n" | ||
158 | " .inst 0x25207810 // ptrue pn8.b \n" | ||
159 | // predicate to load nr words for the RHS sums and scaling factors (should be exactly all true) | ||
160 | " .inst 0x25b36571 // whilelt pn9.s, x11, x19, vlx4 \n" | ||
161 | " dup z30.s, %w[scalar_min] \n" | ||
162 | " dup z31.s, %w[scalar_max] \n" | ||
163 | |||
164 | // lhs matrix row loop | ||
165 | "1: \n" | ||
166 | // Reset rhs matrix ptr | ||
167 | " mov x26, %[rhs_packed] \n" | ||
168 | // Reset dst_ptr to dst of next GEMV result | ||
169 | " mov x27, %[dst_ptr] \n" | ||
170 | // Reset n index | ||
171 | " mov x22, #0 \n" | ||
172 | // whilelt pn12.s, x22, %[n], vlx4 | ||
173 | " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" | ||
174 | |||
175 | // rhs matrix row loop (transposed so theoretical columns) | ||
176 | "2: \n" | ||
177 | |||
178 | // Reset rhs block ptr to start of row | ||
179 | " mov x24, x26 \n" | ||
180 | " add x25, x26, %[rhs_row_bytes] \n" | ||
181 | " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" | ||
182 | " addvl x28, x24, #4 \n" | ||
183 | " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" | ||
184 | " mov x23, #0 \n" | ||
185 | " whilelt p1.b, x23, %[k_internal] \n" | ||
186 | // Zero for sdot accumulation in inner loop | ||
187 | " .inst 0xc00800ff // zero {za} \n" | ||
188 | |||
189 | // before k loop | ||
190 | "3: \n" | ||
191 | |||
192 | // Load lhs | ||
193 | " ld1rqb { z0.b }, p1/z , [x21, x23] \n" | ||
194 | |||
195 | // Load w | ||
196 | " .inst 0xa0408b10 // ld1b { z16.b - z19.b }, pn10/z, [x24] \n" | ||
197 | " .inst 0xa0418f14 // ld1b {z20.b-z23.b}, pn11/z, [x24,#0x4, mul vl]\n" | ||
198 | |||
199 | // rhs i4 to i8 and sdot | ||
200 | // k block + 0 | ||
201 | " .inst 0xc08a4218 // luti4 { z24.b, z25.b }, zt0, z16[0] \n" | ||
202 | " .inst 0xc08a423a // luti4 { z26.b, z27.b }, zt0, z17[0] \n" | ||
203 | " .inst 0xc150f320 // sdot za.s[w11,0, vgx4], {z24.b-z27.b}, z0.b[0]\n" | ||
204 | // k block + 1 | ||
205 | " .inst 0xc08a4244 // luti4 { z4.b, z5.b }, zt0, z18[0] \n" | ||
206 | " .inst 0xc08a4266 // luti4 { z6.b, z7.b }, zt0, z19[0] \n" | ||
207 | " .inst 0xc150f4a0 // sdot za.s[w11,0, vgx4], {z4.b-z7.b}, z0.b[1] \n" | ||
208 | // k block + 2 | ||
209 | " .inst 0xc08a4288 // luti4 { z8.b, z9.b }, zt0, z20[0] \n" | ||
210 | " .inst 0xc08a42aa // luti4 { z10.b, z11.b }, zt0, z21[0] \n" | ||
211 | " .inst 0xc150f920 // sdot za.s[w11,0, vgx4], {z8.b-z11.b}, z0.b[2] \n" | ||
212 | // k block + 3 | ||
213 | " .inst 0xc08a42cc // luti4 { z12.b, z13.b }, zt0, z22[0] \n" | ||
214 | " .inst 0xc08a42ee // luti4 { z14.b, z15.b }, zt0, z23[0] \n" | ||
215 | " .inst 0xc150fda0 // sdot za.s[w11,0, vgx4], {z12.b-z15.b}, z0.b[3]\n" | ||
216 | |||
217 | // End K block loop | ||
218 | " addvl x24, x24, #8 \n" | ||
219 | " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" | ||
220 | " addvl x28, x24, #4 \n" | ||
221 | " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" | ||
222 | " add x23, x23, #16 \n" | ||
223 | " whilelt p1.b, x23, %[k_internal] \n" | ||
224 | " b.first 3b \n" | ||
225 | |||
226 | // Finish of accumulators with scaling factors and zero points | ||
227 | |||
228 | // Load lhs zero point | ||
229 | " add x28, x21, %[k_internal] \n" | ||
230 | " ld1rw { z2.s }, p0/z , [x28] \n" | ||
231 | // Load lhs scaling factor | ||
232 | " ld1rw { z3.s }, p0/z , [x28, #4] \n" | ||
233 | // Load rhs sums | ||
234 | " add x28, x26, %[rhs_row_bytes] \n" | ||
235 | " .inst 0xa040c794 // ld1w { z20.s - z23.s }, pn9/z, [x28] \n" | ||
236 | // Load rhs scaling factors | ||
237 | " .inst 0xa041c798 // ld1w {z24.s-z27.s}, pn9/z, [x28, #0x4, mul vl]\n" | ||
238 | // Load biases | ||
239 | " .inst 0xa042c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, #0x8, mul vl]\n" | ||
240 | |||
241 | // Get accumulated value out of ZA | ||
242 | " .inst 0xc0066c04 // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] \n" | ||
243 | |||
244 | // za contains a * w, which needs to be done + z * wsum -> smla | ||
245 | // zero point * rhs row sum | ||
246 | " mla z4.s, p0/m, z20.s, z2.s \n" | ||
247 | " mla z5.s, p0/m, z21.s, z2.s \n" | ||
248 | " mla z6.s, p0/m, z22.s, z2.s \n" | ||
249 | " mla z7.s, p0/m, z23.s, z2.s \n" | ||
250 | |||
251 | // Convert to float | ||
252 | " .inst 0xc132e084 // scvtf { z4.s - z7.s }, { z4.s - z7.s } \n" | ||
253 | |||
254 | // lhs scaling factor * rhs scaling factor | ||
255 | " fmul z24.s, z24.s, z3.s \n" | ||
256 | " fmul z25.s, z25.s, z3.s \n" | ||
257 | " fmul z26.s, z26.s, z3.s \n" | ||
258 | " fmul z27.s, z27.s, z3.s \n" | ||
259 | |||
260 | // Bias + combined scaling factor * combined accumulator | ||
261 | " fmla z12.s, p0/m, z24.s, z4.s \n" | ||
262 | " fmla z13.s, p0/m, z25.s, z5.s \n" | ||
263 | " fmla z14.s, p0/m, z26.s, z6.s \n" | ||
264 | " fmla z15.s, p0/m, z27.s, z7.s \n" | ||
265 | |||
266 | // Clamp | ||
267 | " .inst 0xc1bfcbcc // fclamp { z12.s - z15.s }, z30.s, z31.s \n" | ||
268 | |||
269 | // Store | ||
270 | " .inst 0xa036d36c // st1w {z12.s-z15.s}, pn12, [x27, x22, lsl #2] \n" | ||
271 | |||
272 | // End rhs row loop | ||
273 | " add x26, x26, %[rhs_stride] \n" | ||
274 | // nr == svlb | ||
275 | " addvl x22, x22, #1 \n" | ||
276 | // whilelt pn12.s, x22, %[n], vlx4 | ||
277 | " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" | ||
278 | " b.lt 2b \n" | ||
279 | |||
280 | // End lhs row loop | ||
281 | " add %[dst_ptr], %[dst_ptr], %[dst_stride_row] \n" | ||
282 | " add x21, x21, %[lhs_stride] \n" | ||
283 | " cmp x21, %[lhs_end_ptr] \n" | ||
284 | " b.lt 1b \n" | ||
285 | |||
286 | " .inst 0xd503467f // smstop \n" | ||
287 | |||
288 | : [dst_ptr] "+r"(dst) | ||
289 | 81 | : [lut] "r"(lut), [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_packed] "r"(lhs_packed), | |
290 | 81 | [rhs_packed] "r"(rhs_packed), [dst_stride_row] "r"(dst_stride_row), [scalar_min] "r"(scalar_min), | |
291 | 81 | [scalar_max] "r"(scalar_max), [k_internal] "r"(k_internal), [lhs_stride] "r"(lhs_stride), | |
292 | 81 | [rhs_stride] "r"(rhs_stride), [nr] "r"(nr), [rhs_row_bytes] "r"(rhs_row_bytes), [lhs_end_ptr] "r"(lhs_end_ptr) | |
293 | : "x11", "x15", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "p0", "p1", "p8", "p9", | ||
294 | "p10", "p11", "p12", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", | ||
295 | "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", | ||
296 | "z29", "z30", "z31", | ||
297 | #ifdef __ARM_STATE_ZA | ||
298 | "za", | ||
299 | #endif | ||
300 | #ifdef __ARM_STATE_ZT0 | ||
301 | "zt0", | ||
302 | #endif | ||
303 | "memory", "cc"); | ||
304 | 81 | } | |
305 | |||
306 | #endif // Architectural features check. | ||
307 |