kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #if !defined(__aarch64__) || !(defined(__ARM_FEATURE_SVE2) || defined(__ARM_FEATURE_SME2)) | ||
| 8 | #error This file must be compiled for AArch64, FEAT_SVE2 or FEAT_SME2. | ||
| 9 | #else // Architectural features check. | ||
| 10 | |||
| 11 | #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" | ||
| 12 | |||
| 13 | #include <stddef.h> | ||
| 14 | #include <stdint.h> | ||
| 15 | |||
| 16 | #include "kai/kai_common.h" | ||
| 17 | |||
| 18 | // Compute args | ||
| 19 | static const size_t kai_m_step = 1; // Multiple of vector length | ||
| 20 | static const size_t kai_n_step = 4; // Multiple of vector length | ||
| 21 | // Packing args | ||
| 22 | static const size_t kai_mr = 1; // Multiple of vector length | ||
| 23 | static const size_t kai_nr = 4; // Multiple of vector length | ||
| 24 | static const size_t kai_kr = 4; | ||
| 25 | static const size_t kai_sr = 2; | ||
| 26 | // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) | ||
| 27 | static const size_t kai_num_bytes_qvalue_lhs = 1; | ||
| 28 | static const size_t kai_num_bytes_multiplier_lhs = 2; | ||
| 29 | // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is | ||
| 30 | // asymmetric)) | ||
| 31 | static const size_t kai_recip_num_bytes_qvalue_rhs = 2; | ||
| 32 | static const size_t kai_num_bytes_multiplier_rhs = 2; | ||
| 33 | // DST format args | ||
| 34 | static const size_t kai_num_bytes_dst_value = 4; | ||
| 35 | // Extra args | ||
| 36 | static const size_t kai_bl = 32; | ||
| 37 | |||
| 38 | // Look-up table used for int4->int8 convert | ||
| 39 | static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}; | ||
| 40 | |||
| 41 | 98 | inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) { | |
| 42 | 98 | return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs; | |
| 43 | } | ||
| 44 | |||
| 45 | 98 | inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) { | |
| 46 | − | KAI_ASSUME(bl == kai_bl); | |
| 47 | 98 | size_t num_bytes_per_block_rhs = (bl / kai_recip_num_bytes_qvalue_rhs) + kai_num_bytes_multiplier_rhs; | |
| 48 | 196 | return num_bytes_per_block_rhs; | |
| 49 | 98 | } | |
| 50 | |||
| 51 | 230 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
| 52 | − | KAI_ASSUME(bl == kai_bl); | |
| 53 | − | KAI_ASSUME((k % kai_bl) == 0); | |
| 54 | |||
| 55 | 230 | return kai_roundup(k, bl) / bl; | |
| 56 | } | ||
| 57 | |||
| 58 | 98 | inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) { | |
| 59 | 98 | const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 60 | 196 | return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl); | |
| 61 | 98 | } | |
| 62 | |||
| 63 | 98 | inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) { | |
| 64 | − | KAI_ASSUME(bl == kai_bl); | |
| 65 | − | KAI_ASSUME((k % kai_bl) == 0); | |
| 66 | |||
| 67 | 98 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
| 68 | 98 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl); | |
| 69 | 98 | const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 70 | |||
| 71 | 98 | size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row); | |
| 72 | |||
| 73 | 196 | return rhs_packed_stride; | |
| 74 | 98 | } | |
| 75 | |||
| 76 | 192 | size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
| 77 | 192 | return kai_m_step * kai_get_sme_vector_length_u32(); | |
| 78 | } | ||
| 79 | |||
| 80 | 192 | size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
| 81 | 192 | return kai_n_step * kai_get_sme_vector_length_u32(); | |
| 82 | } | ||
| 83 | |||
| 84 | 260 | size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
| 85 | 260 | return kai_mr * kai_get_sme_vector_length_u32(); | |
| 86 | } | ||
| 87 | |||
| 88 | 260 | size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
| 89 | 260 | return kai_nr * kai_get_sme_vector_length_u32(); | |
| 90 | } | ||
| 91 | |||
| 92 | 96 | size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
| 93 | 96 | return kai_kr; | |
| 94 | } | ||
| 95 | |||
| 96 | 64 | size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) { | |
| 97 | 64 | return kai_sr; | |
| 98 | } | ||
| 99 | |||
| 100 | 64 | size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( | |
| 101 | size_t m_idx, size_t k, size_t bl) { | ||
| 102 | 64 | const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 103 | 64 | const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 104 | − | KAI_ASSUME((m_idx % m_step) == 0); | |
| 105 | |||
| 106 | 128 | return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl); | |
| 107 | 64 | } | |
| 108 | |||
| 109 | 64 | size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( | |
| 110 | size_t n_idx, size_t k, size_t bl) { | ||
| 111 | 64 | const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 112 | 64 | const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 113 | |||
| 114 | − | KAI_ASSUME((n_idx % n_step) == 0); | |
| 115 | |||
| 116 | 128 | return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl); | |
| 117 | 64 | } | |
| 118 | |||
| 119 | 32 | size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( | |
| 120 | size_t m_idx, size_t n_idx, size_t dst_stride) { | ||
| 121 | 32 | const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 122 | 32 | const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 123 | − | KAI_ASSUME((m_idx % m_step) == 0); | |
| 124 | − | KAI_ASSUME((n_idx % n_step) == 0); | |
| 125 | |||
| 126 | 64 | return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; | |
| 127 | 32 | } | |
| 128 | |||
| 129 | 32 | size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(size_t m, size_t n) { | |
| 130 | 32 | return m * n * kai_num_bytes_dst_value; | |
| 131 | } | ||
| 132 | |||
| 133 | 34 | void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( | |
| 134 | size_t m, // | ||
| 135 | size_t n, // | ||
| 136 | size_t k, // | ||
| 137 | size_t bl, // | ||
| 138 | const void* restrict lhs_packed, // | ||
| 139 | const void* restrict rhs_packed, // | ||
| 140 | float* restrict dst, // NOLINT(readability-non-const-parameter) | ||
| 141 | size_t dst_stride_row, // | ||
| 142 | size_t dst_stride_col, // | ||
| 143 | float scalar_min, // | ||
| 144 | float scalar_max) { | ||
| 145 | − | KAI_ASSUME(dst_stride_col == sizeof(float)); | |
| 146 | |||
| 147 | 34 | KAI_UNUSED(scalar_min); | |
| 148 | 34 | KAI_UNUSED(scalar_max); | |
| 149 | |||
| 150 |
1/2✓ Branch 0 taken 34 times.
✗ Branch 1 not taken.
|
34 | if (m == 0) { |
| 151 | ✗ | return; | |
| 152 | } | ||
| 153 | |||
| 154 | typedef struct { | ||
| 155 | size_t lhs_packed_stride; | ||
| 156 | size_t rhs_packed_stride; | ||
| 157 | size_t mr; | ||
| 158 | } KernelArgs; | ||
| 159 | |||
| 160 | 34 | KernelArgs ka; | |
| 161 | |||
| 162 | 34 | const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); | |
| 163 | |||
| 164 | 34 | const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 165 | 34 | const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); | |
| 166 | |||
| 167 | 34 | ka.mr = mr; | |
| 168 | 34 | ka.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl); | |
| 169 | 34 | ka.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); | |
| 170 | |||
| 171 | 68 | const uint16_t* lhs_scales = (const uint16_t*)((const int8_t*)lhs_packed + ka.lhs_packed_stride - | |
| 172 | 34 | (mr * num_blocks) * kai_num_bytes_multiplier_lhs); | |
| 173 | 68 | const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + ka.rhs_packed_stride - | |
| 174 | 34 | (nr * num_blocks) * kai_num_bytes_multiplier_rhs); | |
| 175 | |||
| 176 | 34 | kai_commit_za(); | |
| 177 | |||
| 178 | 68 | __asm__ volatile( | |
| 179 | // Switch to streaming mode with ZA enabling | ||
| 180 | " .inst 0xd503477f // smstart \n" | ||
| 181 | |||
| 182 | // Constants | ||
| 183 | // - SVLs | ||
| 184 | " cntw x14 \n" | ||
| 185 | // - ptrue | ||
| 186 | " ptrue p0.b, all \n" | ||
| 187 | " .inst 0x25a07810 // ptrue pn8.s \n" | ||
| 188 | |||
| 189 | // Predicate for loading fp16 scaling factors | ||
| 190 | " ldr x5, [%x[args_ptr], %[offset_mr]]\n" | ||
| 191 | " lsl x5, x5, #1 \n" | ||
| 192 | " whilelt p4.b, xzr, x5 \n" | ||
| 193 | |||
| 194 | // Initialize ZT0 (Lookup table) | ||
| 195 | " mov x6, %[lut]\n" | ||
| 196 | " .inst 0xe11f80c0 // ldr zt0, [x6] \n" | ||
| 197 | |||
| 198 | // Initialize the RHS packes and scale pointers | ||
| 199 | " mov x16, %[rhs_packed] \n" | ||
| 200 | " mov x17, %[rhs_scales] \n" | ||
| 201 | |||
| 202 | // Iterate over n (x8) | ||
| 203 | // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) | ||
| 204 | " mov x8, #0 \n" | ||
| 205 | " mov x0, %[N] \n" | ||
| 206 | " .inst 0x25a06511 // whilelt pn9.s, x8, x0, VLx4 \n" | ||
| 207 | |||
| 208 | " b.none 9f // .LOOP_N_END%= \n" | ||
| 209 | |||
| 210 | " 1: // .LOOP_N_START%=: \n" | ||
| 211 | |||
| 212 | // Iterate over m (x9) | ||
| 213 | // e.g. for(n_idx = 0; n_idx < n; n_idx+=n_step) | ||
| 214 | " mov x9, %[M] \n" | ||
| 215 | |||
| 216 | // Initialize the LHS packed and scale pointers | ||
| 217 | " mov x22, %[lhs_packed] \n" | ||
| 218 | " mov x23, %[lhs_scales] \n" | ||
| 219 | |||
| 220 | // Initialize the DST pointer | ||
| 221 | " mov x24, %[dst] \n" | ||
| 222 | |||
| 223 | " 2: // .LOOP_M_START%=: \n" | ||
| 224 | |||
| 225 | // Address offset for the left and right quantized values | ||
| 226 | " mov x20, #0 \n" | ||
| 227 | " mov x21, #0 \n" | ||
| 228 | |||
| 229 | // Number of output rows to store -> min(SVLh, loop M index) | ||
| 230 | " cmp x9, x14 \n" | ||
| 231 | " csel x15, x9, x14, lo \n" | ||
| 232 | " lsl x15, x15, #2 \n" | ||
| 233 | |||
| 234 | // Iterate over all K values | ||
| 235 | // e.g. for(k_idx = 0; k_idx < k; k_idx += bl) | ||
| 236 | " mov x10, %[K] \n" | ||
| 237 | |||
| 238 | // Skip processing if K=0 | ||
| 239 | " cmp x10, #0 \n" | ||
| 240 | " b.eq 8f // .LOOP_K_END%= \n" | ||
| 241 | |||
| 242 | " 3: // .LOOP_K_START%=: \n" | ||
| 243 | |||
| 244 | // Zeroing of ZA accumulator | ||
| 245 | " .inst 0xc00800ff // zero {za} \n" | ||
| 246 | |||
| 247 | // Load the fp16 scaling factors for the right matrix block | ||
| 248 | " .inst 0xa0154220 // ld1w {z0.s - z1.s}, pn8/z, [x17], x21, lsl #2] \n" | ||
| 249 | " .inst 0xc161d000 // zip {z0.h - z1.h}, z0.h, z1.h \n" | ||
| 250 | |||
| 251 | // Iterate over all values in the block | ||
| 252 | // k_blk_idx = bl | ||
| 253 | // e.g. while(k_blk_idx > 0) {... k_blk_idx -= 4} | ||
| 254 | " mov x11, #32\n" | ||
| 255 | |||
| 256 | " 4: // .LOOP_BL_START%=: \n" | ||
| 257 | |||
| 258 | // Load right matrix row | ||
| 259 | " .inst 0xa0144202 // ld1w {z2.s - z3.s}, pn8/z, [x16], x20, lsl #2] \n" | ||
| 260 | |||
| 261 | // Load left matrix column | ||
| 262 | " ld1h {z8.h}, p0/z, [x22, x20, lsl #1] \n" | ||
| 263 | " inch x20, all \n" | ||
| 264 | |||
| 265 | // Convert Int4 -> Int8 | ||
| 266 | " .inst 0xc08a4044 // luti4 {z4.b - z5.b}, zt0, z2[0] \n" | ||
| 267 | " .inst 0xc08a4066 // luti4 {z6.b - z7.b}, zt0, z3[0] \n" | ||
| 268 | |||
| 269 | // Outer-products | ||
| 270 | " .inst 0xa0840100 // smopa za0.s, p0/m, p0/m, z8.b, z4.b \n" | ||
| 271 | " .inst 0xa0850101 // smopa za1.s, p0/m, p0/m, z8.b, z5.b \n" | ||
| 272 | " .inst 0xa0860102 // smopa za2.s, p0/m, p0/m, z8.b, z6.b \n" | ||
| 273 | " .inst 0xa0870103 // smopa za3.s, p0/m, p0/m, z8.b, z7.b \n" | ||
| 274 | |||
| 275 | // Decrement the block loop index | ||
| 276 | " subs x11, x11, #4 \n" | ||
| 277 | |||
| 278 | " b.gt 4b // .LOOP_BL_START%= \n" | ||
| 279 | |||
| 280 | // === End of the block loop === | ||
| 281 | |||
| 282 | // Store loop index | ||
| 283 | " mov w12, #0 \n" | ||
| 284 | |||
| 285 | // Copy destination pointer for store loop | ||
| 286 | " mov x25, x24 \n" | ||
| 287 | |||
| 288 | // Load the fp16 scaling factors for the left matrix block | ||
| 289 | " ld1b {z16.b}, p4/z, [x23, x21] \n" | ||
| 290 | " inch x21, all \n" | ||
| 291 | |||
| 292 | // Predicate for the selection of a scaling among the vector | ||
| 293 | " pfalse p3.b \n" | ||
| 294 | |||
| 295 | " 5: // .LOOP_ZA%=: \n" | ||
| 296 | |||
| 297 | // Select and replicate scaling factor for the right block | ||
| 298 | " pnext p3.h, p0, p3.h \n" | ||
| 299 | " clastb z19.h, p3, z19.h, z16.h \n" | ||
| 300 | |||
| 301 | // Get data from za | ||
| 302 | " .inst 0xc006041c // mova {z28.b-z31.b}, za0h.b[w12, 0:3] \n" | ||
| 303 | " add w12, w12, #4 \n" | ||
| 304 | |||
| 305 | // Convert from int32 to fp32 | ||
| 306 | " .inst 0xc132e39c // scvtf {z28.s-z31.s}, {z28.s-z31.s} \n" | ||
| 307 | |||
| 308 | // Multiply left and right scaling factors | ||
| 309 | " movprfx z8, z18 \n" | ||
| 310 | " fmlalb z8.s, z19.h, z0.h \n" | ||
| 311 | " movprfx z9, z18 \n" | ||
| 312 | " fmlalb z9.s, z19.h, z1.h \n" | ||
| 313 | " movprfx z10, z18 \n" | ||
| 314 | " fmlalt z10.s, z19.h, z0.h \n" | ||
| 315 | " movprfx z11, z18 \n" | ||
| 316 | " fmlalt z11.s, z19.h, z1.h \n" | ||
| 317 | |||
| 318 | " cmp x10, %[K] \n" | ||
| 319 | " b.ne 6f // .ACCUMULATE%= \n" | ||
| 320 | |||
| 321 | // Applying combined scaling factors to processed block | ||
| 322 | " fmul z24.s, z8.s, z28.s \n" | ||
| 323 | " fmul z25.s, z9.s, z29.s \n" | ||
| 324 | " fmul z26.s, z10.s, z30.s \n" | ||
| 325 | " fmul z27.s, z11.s, z31.s \n" | ||
| 326 | |||
| 327 | "b 7f // .STORE%= \n" | ||
| 328 | |||
| 329 | " 6: // .ACCUMULATE%=: \n" | ||
| 330 | // Load intermediate result | ||
| 331 | " .inst 0xa040c738 // ld1w {z24.s-z27.s}, pn9/z, [x25] \n" | ||
| 332 | |||
| 333 | // Multiply the intermediate results by LHS_SCALE x RHS_SCALE | ||
| 334 | // and store in the main floating-point accumulator | ||
| 335 | " fmla z24.s, p0/m, z8.s, z28.s \n" | ||
| 336 | " fmla z25.s, p0/m, z9.s, z29.s \n" | ||
| 337 | " fmla z26.s, p0/m, z10.s, z30.s \n" | ||
| 338 | " fmla z27.s, p0/m, z11.s, z31.s \n" | ||
| 339 | |||
| 340 | "7: // .STORE%=: \n" | ||
| 341 | // Store the results into memory | ||
| 342 | " .inst 0xa060c738 // st1w {z24.s-z27.s}, pn9, [x25] \n" | ||
| 343 | " add x25, x25, %[stride] \n" | ||
| 344 | |||
| 345 | " cmp x12, x15 \n" | ||
| 346 | " blt 5b // .LOOP_ZA%= \n" | ||
| 347 | |||
| 348 | // Decrement K loop index by bl | ||
| 349 | " subs x10, x10, #32 \n" | ||
| 350 | |||
| 351 | " b.gt 3b // .LOOP_K_START%= \n" | ||
| 352 | |||
| 353 | " 8: // .LOOP_K_END%=: \n" | ||
| 354 | |||
| 355 | // === End of the K loop === | ||
| 356 | |||
| 357 | " ldr x5, [%x[args_ptr], %[offset_stride_l]] \n" | ||
| 358 | |||
| 359 | // Increment pointer to the quantized values of the right matrix | ||
| 360 | " add x22, x22, x5\n" | ||
| 361 | |||
| 362 | // Increment pointer to the scaling factors of the right matrix | ||
| 363 | " add x23, x23, x5 \n" | ||
| 364 | |||
| 365 | // Update destination pointer | ||
| 366 | " mov x24, x25 \n" | ||
| 367 | |||
| 368 | // Decrement M loop index | ||
| 369 | " decw x9, all \n" | ||
| 370 | |||
| 371 | " cmp x9, #0 \n" | ||
| 372 | " b.gt 2b // .LOOP_M_START%= \n" | ||
| 373 | |||
| 374 | // === End of M loop === | ||
| 375 | |||
| 376 | // Increment output pointer | ||
| 377 | " incb %[dst], all, mul #4 \n" | ||
| 378 | |||
| 379 | " ldr x5, [%x[args_ptr], %[offset_stride_r]]\n" | ||
| 380 | |||
| 381 | " add x16, x16, x5 \n" | ||
| 382 | " add x17, x17, x5 \n" | ||
| 383 | |||
| 384 | // Increment N loop index | ||
| 385 | " incb x8, all \n" | ||
| 386 | |||
| 387 | " .inst 0x25a06511 // whilelt pn9.s, x8, %[N], VLx4 \n" | ||
| 388 | |||
| 389 | " b.first 1b // .LOOP_N_START%= \n" | ||
| 390 | |||
| 391 | " 9: // .LOOP_N_END%=: \n" | ||
| 392 | |||
| 393 | // === End of N loop === | ||
| 394 | |||
| 395 | // Exit streaming mode | ||
| 396 | " .inst 0xd503467f // smstop \n" | ||
| 397 | : [dst] "+r"(dst), [rhs_packed] "+r"(rhs_packed), [rhs_scales] "+r"(rhs_scales) | ||
| 398 | 34 | : [M] "r"(m), [N] "r"(n), [K] "r"(k), [lhs_packed] "r"(lhs_packed), [lhs_scales] "r"(lhs_scales), | |
| 399 | 34 | [stride] "r"(dst_stride_row), [lut] "r"(lut), [args_ptr] "r"(&ka), | |
| 400 | [offset_stride_l] "I"(offsetof(KernelArgs, lhs_packed_stride)), | ||
| 401 | [offset_stride_r] "I"(offsetof(KernelArgs, rhs_packed_stride)), [offset_mr] "I"(offsetof(KernelArgs, mr)) | ||
| 402 | : "p0", "p1", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", "z1", | ||
| 403 | "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", | ||
| 404 | "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x5", "x6", | ||
| 405 | "x8", "x9", "x10", "x11", "x12", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", | ||
| 406 | "memory", "cc"); | ||
| 407 | 34 | } | |
| 408 | |||
| 409 | #endif // Architectural features check. | ||
| 410 |