Line |
Branch |
Exec |
Source |
1 |
|
|
// |
2 |
|
|
// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> |
3 |
|
|
// |
4 |
|
|
// SPDX-License-Identifier: Apache-2.0 |
5 |
|
|
// |
6 |
|
|
|
7 |
|
|
// Do not flag up inline assembly blocks |
8 |
|
|
#pragma GCC diagnostic ignored "-Woverlength-strings" |
9 |
|
|
|
10 |
|
|
#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) |
11 |
|
|
#error This file must be compiled for AArch64, FEAT_SVE2. |
12 |
|
|
#else // Architectural feature check |
13 |
|
|
|
14 |
|
|
#include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h" |
15 |
|
|
|
16 |
|
|
#include <stddef.h> |
17 |
|
|
#include <stdint.h> |
18 |
|
|
|
19 |
|
|
#include "kai/kai_common.h" |
20 |
|
|
|
21 |
|
|
static const size_t kai_mr = 1; // multiple of vector length |
22 |
|
|
static const size_t kai_nr = 4; // multiple of vector length |
23 |
|
|
static const size_t kai_kr = 4; |
24 |
|
|
static const size_t kai_sr = 1; |
25 |
|
|
static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); |
26 |
|
|
static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); |
27 |
|
|
static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); |
28 |
|
|
static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); |
29 |
|
|
static const size_t kai_num_bytes_bias_rhs = sizeof(float); |
30 |
|
|
static const size_t kai_k_multiple_of = 32; |
31 |
|
|
|
32 |
|
|
/// Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110). |
33 |
|
|
static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, |
34 |
|
|
0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0, |
35 |
|
|
-5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0}; |
36 |
|
|
|
37 |
|
483 |
inline static size_t kai_k_roundedup(size_t k) { |
38 |
|
|
// Round up k to be a multiple of 32. |
39 |
|
483 |
return kai_roundup(k, kai_k_multiple_of); |
40 |
|
|
} |
41 |
|
|
|
42 |
|
201 |
inline static size_t kai_get_lhs_packed_stride(size_t k) { |
43 |
|
201 |
const size_t k_internal = kai_k_roundedup(k); |
44 |
|
|
|
45 |
|
− |
KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); |
46 |
|
|
|
47 |
|
603 |
return kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa() * |
48 |
|
201 |
(k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); |
49 |
|
201 |
} |
50 |
|
|
|
51 |
|
201 |
inline static size_t kai_get_rhs_packed_stride(size_t k) { |
52 |
|
201 |
const size_t k_internal = kai_k_roundedup(k); |
53 |
|
|
|
54 |
|
− |
KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); |
55 |
|
|
|
56 |
|
603 |
return kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa() * |
57 |
|
201 |
((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs); |
58 |
|
201 |
} |
59 |
|
|
|
60 |
|
360 |
size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { |
61 |
|
360 |
return kai_mr * kai_get_sme_vector_length_u32(); |
62 |
|
|
} |
63 |
|
|
|
64 |
|
360 |
size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { |
65 |
|
360 |
return kai_nr * kai_get_sme_vector_length_u32(); |
66 |
|
|
} |
67 |
|
|
|
68 |
|
522 |
size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { |
69 |
|
522 |
return kai_mr * kai_get_sme_vector_length_u32(); |
70 |
|
|
} |
71 |
|
|
|
72 |
|
522 |
size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { |
73 |
|
522 |
return kai_nr * kai_get_sme_vector_length_u32(); |
74 |
|
|
} |
75 |
|
|
|
76 |
|
160 |
size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { |
77 |
|
160 |
return kai_kr; |
78 |
|
|
} |
79 |
|
|
|
80 |
|
160 |
size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { |
81 |
|
160 |
return kai_sr; |
82 |
|
|
} |
83 |
|
|
|
84 |
|
120 |
size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) { |
85 |
|
− |
KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); |
86 |
|
|
|
87 |
|
120 |
const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); |
88 |
|
|
|
89 |
|
240 |
return (m_idx / mr) * kai_get_lhs_packed_stride(k); |
90 |
|
120 |
} |
91 |
|
|
|
92 |
|
120 |
size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) { |
93 |
|
− |
KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); |
94 |
|
|
|
95 |
|
120 |
const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); |
96 |
|
|
|
97 |
|
240 |
return (n_idx / nr) * kai_get_rhs_packed_stride(k); |
98 |
|
120 |
} |
99 |
|
|
|
100 |
|
80 |
size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( |
101 |
|
|
size_t m_idx, size_t n_idx, size_t dst_stride) { |
102 |
|
− |
KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); |
103 |
|
− |
KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); |
104 |
|
|
|
105 |
|
80 |
return (n_idx * sizeof(float) + m_idx * dst_stride); |
106 |
|
|
} |
107 |
|
|
|
108 |
|
80 |
size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n) { |
109 |
|
80 |
return m * n * sizeof(float); |
110 |
|
|
} |
111 |
|
|
|
112 |
|
81 |
void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( |
113 |
|
|
size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, |
114 |
|
|
float* restrict dst, // NOLINT(readability-non-const-parameter) |
115 |
|
|
size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { |
116 |
|
− |
KAI_ASSERT(dst_stride_col == sizeof(float)); |
117 |
|
− |
KAI_ASSERT(n > 0); |
118 |
|
− |
KAI_ASSERT(m > 0); |
119 |
|
|
|
120 |
|
|
// Constants |
121 |
|
81 |
uint64_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); |
122 |
|
81 |
uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); |
123 |
|
81 |
uint64_t lhs_stride = kai_get_lhs_packed_stride(k); |
124 |
|
81 |
uint64_t rhs_stride = kai_get_rhs_packed_stride(k); |
125 |
|
81 |
uint64_t m_blk = (uint64_t)kai_k_roundedup(k) * mr; |
126 |
|
81 |
uint64_t dst_inc = mr * dst_stride_row; |
127 |
|
81 |
float scalar_bounds[2] = {scalar_min, scalar_max}; |
128 |
|
|
|
129 |
|
|
/* --------------------------------------------------- |
130 |
|
|
Registers allocations |
131 |
|
|
x7: Look up table(lut) |
132 |
|
|
x8: RHS base address (rhs) |
133 |
|
|
x9: Destination base address (dst) |
134 |
|
|
x10: LHS pointer (lhs) |
135 |
|
|
x11: RHS pointer (rhs) |
136 |
|
|
x12: Remaining M elements |
137 |
|
|
x13: Remaining N elements |
138 |
|
|
x14: k exit condition (k_cond) |
139 |
|
|
ZA tile index (l_idx) |
140 |
|
|
x15: LHS scaling factor pointer (lhs_sf_ptr) |
141 |
|
|
x16: ZA tile exit condition (l_cnd) |
142 |
|
|
x17: Destination pointer (dst) |
143 |
|
|
x19: Destination outer address (dst) |
144 |
|
|
x20: LHS base address (lhs) |
145 |
|
|
--------------------------------------------------- */ |
146 |
|
81 |
__asm__ volatile( |
147 |
|
|
" .inst 0xd503477f //smstart \n" |
148 |
|
|
" mov x19, %[dst] \n" |
149 |
|
|
" mov x20, %[lhs] \n" |
150 |
|
|
" mov x7, %[lut] \n" |
151 |
|
|
" .inst 0xe11f80e0 //ldr zt0, [x7] \n" |
152 |
|
|
" cntw x7 \n" |
153 |
|
|
" ptrue p2.b \n" |
154 |
|
|
" ld1rw {z30.s}, p2/Z, [%[scalar_bounds]] \n" |
155 |
|
|
" ld1rw {z31.s}, p2/Z, [%[scalar_bounds], #4] \n" |
156 |
|
|
|
157 |
|
|
// M loop head |
158 |
|
|
" mov x12, %[m] \n" |
159 |
|
|
" .inst 0x25ac17e0 //whilelt p0.s, xzr, x12 \n" |
160 |
|
|
"1: \n" |
161 |
|
|
" mov x8, %[rhs] \n" |
162 |
|
|
" mov x9, x19 \n" |
163 |
|
|
" mov x13, %[n] \n" |
164 |
|
|
" cmp x7, x12 \n" |
165 |
|
|
" csel x16, x7, x12, lt \n" |
166 |
|
|
" lsl x16, x16, #2 \n" |
167 |
|
|
|
168 |
|
|
// N loop head |
169 |
|
|
" .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n" |
170 |
|
|
"2: \n" |
171 |
|
|
" mov x10, x20 \n" |
172 |
|
|
" mov x11, x8 \n" |
173 |
|
|
" mov x17, x9 \n" |
174 |
|
|
" .inst 0x25ad67f1 //whilelt pn9.s, xzr, x13, vlx4 \n" |
175 |
|
|
|
176 |
|
|
// K loop |
177 |
|
|
" .inst 0xc00800ff //zero {za} \n" |
178 |
|
|
" add x14, x10, %[m_blk] \n" |
179 |
|
|
"3: \n" |
180 |
|
|
" .inst 0xa540a144 //ld1w { z4.s }, p0/z, [x10] \n" |
181 |
|
|
" .inst 0x042a502a //addvl x10, x10, #1 \n" |
182 |
|
|
" .inst 0xa0402160 //ld1h { z0.h-z1.h }, pn8/z, [x11] \n" |
183 |
|
|
" .inst 0x042b504b //addvl x11, x11, #2 \n" |
184 |
|
|
" .inst 0xc08a4008 //luti4 { z8.b - z9.b }, zt0, z0[0] \n" |
185 |
|
|
" .inst 0xc08a402a //luti4 { z10.b - z11.b }, zt0, z1[0] \n" |
186 |
|
|
" .inst 0xa0884880 //smopa za0.s, p2/m, p2/m, z4.b, z8.b \n" |
187 |
|
|
" .inst 0xa0894881 //smopa za1.s, p2/m, p2/m, z4.b, z9.b \n" |
188 |
|
|
" .inst 0xa08a4882 //smopa za2.s, p2/m, p2/m, z4.b, z10.b\n" |
189 |
|
|
" .inst 0xa08b4883 //smopa za3.s, p2/m, p2/m, z4.b, z11.b\n" |
190 |
|
|
" cmp x10, x14 \n" |
191 |
|
|
" b.lt 3b \n" |
192 |
|
|
|
193 |
|
|
// RHS row sum, scale factor & bias |
194 |
|
|
" .inst 0xa040c560 //ld1w { z0.s-z3.s }, pn9/z, [x11] \n" |
195 |
|
|
" .inst 0xa041c564 //ld1w { z4.s-z7.s }, pn9/z, [x11, #4, mul vl] \n" |
196 |
|
|
" .inst 0xa042c568 //ld1w { z8.s-z11.s }, pn9/z, [x11, #8, mul vl]\n" |
197 |
|
|
" .inst 0x042b518b //addvl x11, x11, #12 \n" |
198 |
|
|
" .inst 0xc132e000 //scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" |
199 |
|
|
|
200 |
|
|
// Store loop |
201 |
|
|
" mov x14, #0 \n" |
202 |
|
|
" addvl x15, x10, #1 \n" |
203 |
|
|
"4: \n" |
204 |
|
|
// Load LHS Row-offset & SF |
205 |
|
|
" ld1rw {z16.s}, p2/z, [x10] \n" |
206 |
|
|
" ld1rw {z17.s}, p2/z, [x15] \n" |
207 |
|
|
" add x10, x10, #4 \n" |
208 |
|
|
" add x15, x15, #4 \n" |
209 |
|
|
" scvtf z16.s, p2/m, z16.s \n" |
210 |
|
|
|
211 |
|
|
// offset x Row-sum |
212 |
|
|
" fmul z24.s, z16.s, z0.s \n" |
213 |
|
|
" fmul z25.s, z16.s, z1.s \n" |
214 |
|
|
" fmul z26.s, z16.s, z2.s \n" |
215 |
|
|
" fmul z27.s, z16.s, z3.s \n" |
216 |
|
|
|
217 |
|
|
// Scaling factors |
218 |
|
|
" fmul z20.s, z17.s, z4.s \n" |
219 |
|
|
" fmul z21.s, z17.s, z5.s \n" |
220 |
|
|
" fmul z22.s, z17.s, z6.s \n" |
221 |
|
|
" fmul z23.s, z17.s, z7.s \n" |
222 |
|
|
|
223 |
|
|
// Result = offset x Row-sum x SFs |
224 |
|
|
" fmul z24.s, z24.s, z20.s \n" |
225 |
|
|
" fmul z25.s, z25.s, z21.s \n" |
226 |
|
|
" fmul z26.s, z26.s, z22.s \n" |
227 |
|
|
" fmul z27.s, z27.s, z23.s \n" |
228 |
|
|
|
229 |
|
|
// Load inner accumulation & convert |
230 |
|
|
" .inst 0xc006440c //mova { z12.b-z15.b }, za0h.b[w14, 0:3]\n" |
231 |
|
|
" .inst 0xc132e18c //scvtf { z12.s-z15.s }, { z12.s-z15.s } \n" |
232 |
|
|
|
233 |
|
|
// Result += iacc x SF |
234 |
|
|
" fmla z24.s, p2/m, z20.s, z12.s \n" |
235 |
|
|
" fmla z25.s, p2/m, z21.s, z13.s \n" |
236 |
|
|
" fmla z26.s, p2/m, z22.s, z14.s \n" |
237 |
|
|
" fmla z27.s, p2/m, z23.s, z15.s \n" |
238 |
|
|
|
239 |
|
|
// Add the bias |
240 |
|
|
" fadd z24.s, p2/m, z24.s, z8.s \n" |
241 |
|
|
" fadd z25.s, p2/m, z25.s, z9.s \n" |
242 |
|
|
" fadd z26.s, p2/m, z26.s, z10.s \n" |
243 |
|
|
" fadd z27.s, p2/m, z27.s, z11.s \n" |
244 |
|
|
|
245 |
|
|
// CLAMP and store |
246 |
|
|
" .inst 0xc1bfcbd8 //fclamp { z24.s-z27.s }, z30.s, z31.s\n" |
247 |
|
|
" .inst 0xa060c638 //st1w { z24.s-z27.s }, pn9, [x17] \n" |
248 |
|
|
|
249 |
|
|
" add x17, x17, %[dst_stride_row] \n" |
250 |
|
|
" add x14, x14, #4 \n" |
251 |
|
|
" cmp x14, x16 \n" |
252 |
|
|
" b.lt 4b \n" |
253 |
|
|
|
254 |
|
|
// N loop tail |
255 |
|
|
" add x8, x8, %[rhs_stride] \n" |
256 |
|
|
" .inst 0x04295089 // addvl x9, x9, #4 \n" |
257 |
|
|
" sub x13, x13, %[nr] \n" |
258 |
|
|
" .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n" |
259 |
|
|
" b.mi 2b \n" |
260 |
|
|
|
261 |
|
|
// M loop tail |
262 |
|
|
" add x20, x20, %[lhs_stride] \n" |
263 |
|
|
" add x19, x19, %[dst_inc] \n" |
264 |
|
|
" sub x12, x12, %[mr] \n" |
265 |
|
|
" whilelt p0.s, xzr, x12 \n" |
266 |
|
|
" b.mi 1b \n" |
267 |
|
|
|
268 |
|
|
"5: \n" |
269 |
|
|
" .inst 0xd503467f //smstop \n" |
270 |
|
|
: |
271 |
|
81 |
: [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_stride] "r"(lhs_stride), [rhs_stride] "r"(rhs_stride), |
272 |
|
81 |
[dst_stride_row] "r"(dst_stride_row), [lut] "r"(lut), [m_blk] "r"(m_blk), [nr] "r"(nr), [mr] "r"(mr), |
273 |
|
81 |
[lhs] "r"(lhs_packed), [rhs] "r"(rhs_packed), [dst_inc] "r"(dst_inc), [scalar_bounds] "r"(scalar_bounds), |
274 |
|
81 |
[dst] "r"(dst) |
275 |
|
|
: "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "p0", "p2", "p8", |
276 |
|
|
"p9", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", |
277 |
|
|
"z16", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z30", "z31", |
278 |
|
|
#ifdef __ARM_STATE_ZA |
279 |
|
|
"za", |
280 |
|
|
#endif |
281 |
|
|
#ifdef __ARM_STATE_ZT0 |
282 |
|
|
"zt0", |
283 |
|
|
#endif |
284 |
|
|
"cc", "memory"); |
285 |
|
81 |
} |
286 |
|
|
|
287 |
|
|
#endif // Architectural feature check |
288 |
|
|
|