kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | // Do not flag up inline assembly blocks | ||
| 8 | #pragma GCC diagnostic ignored "-Woverlength-strings" | ||
| 9 | |||
| 10 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) | ||
| 11 | #error This file must be compiled for AArch64, FEAT_SVE2. | ||
| 12 | #else // Architectural feature check | ||
| 13 | |||
| 14 | #include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h" | ||
| 15 | |||
| 16 | #include <stddef.h> | ||
| 17 | #include <stdint.h> | ||
| 18 | |||
| 19 | #include "kai/kai_common.h" | ||
| 20 | |||
| 21 | static const size_t kai_mr = 1; // multiple of vector length | ||
| 22 | static const size_t kai_nr = 4; // multiple of vector length | ||
| 23 | static const size_t kai_kr = 4; | ||
| 24 | static const size_t kai_sr = 1; | ||
| 25 | static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); | ||
| 26 | static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); | ||
| 27 | static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); | ||
| 28 | static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); | ||
| 29 | static const size_t kai_num_bytes_bias_rhs = sizeof(float); | ||
| 30 | static const size_t kai_k_multiple_of = 32; | ||
| 31 | |||
| 32 | /// Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110). | ||
| 33 | static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, | ||
| 34 | 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0, | ||
| 35 | -5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0}; | ||
| 36 | |||
| 37 | 2946 | inline static size_t kai_k_roundedup(size_t k) { | |
| 38 | // Round up k to be a multiple of 32. | ||
| 39 | 2946 | return kai_roundup(k, kai_k_multiple_of); | |
| 40 | } | ||
| 41 | |||
| 42 | 1202 | inline static size_t kai_get_lhs_packed_stride(size_t k) { | |
| 43 | 1202 | const size_t k_internal = kai_k_roundedup(k); | |
| 44 | |||
| 45 | − | KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); | |
| 46 | |||
| 47 | 3606 | return kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa() * | |
| 48 | 1202 | (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); | |
| 49 | 1202 | } | |
| 50 | |||
| 51 | 1202 | inline static size_t kai_get_rhs_packed_stride(size_t k) { | |
| 52 | 1202 | const size_t k_internal = kai_k_roundedup(k); | |
| 53 | |||
| 54 | − | KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); | |
| 55 | |||
| 56 | 3606 | return kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa() * | |
| 57 | 1202 | ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs); | |
| 58 | 1202 | } | |
| 59 | |||
| 60 | 1980 | size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { | |
| 61 | 1980 | return kai_mr * kai_get_sme_vector_length_u32(); | |
| 62 | } | ||
| 63 | |||
| 64 | 1980 | size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { | |
| 65 | 1980 | return kai_nr * kai_get_sme_vector_length_u32(); | |
| 66 | } | ||
| 67 | |||
| 68 | 2864 | size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { | |
| 69 | 2864 | return kai_mr * kai_get_sme_vector_length_u32(); | |
| 70 | } | ||
| 71 | |||
| 72 | 2864 | size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { | |
| 73 | 2864 | return kai_nr * kai_get_sme_vector_length_u32(); | |
| 74 | } | ||
| 75 | |||
| 76 | 580 | size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { | |
| 77 | 580 | return kai_kr; | |
| 78 | } | ||
| 79 | |||
| 80 | 580 | size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { | |
| 81 | 580 | return kai_sr; | |
| 82 | } | ||
| 83 | |||
| 84 | 660 | size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) { | |
| 85 | − | KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); | |
| 86 | |||
| 87 | 660 | const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); | |
| 88 | |||
| 89 | 1320 | return (m_idx / mr) * kai_get_lhs_packed_stride(k); | |
| 90 | 660 | } | |
| 91 | |||
| 92 | 660 | size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) { | |
| 93 | − | KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); | |
| 94 | |||
| 95 | 660 | const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); | |
| 96 | |||
| 97 | 1320 | return (n_idx / nr) * kai_get_rhs_packed_stride(k); | |
| 98 | 660 | } | |
| 99 | |||
| 100 | 540 | size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( | |
| 101 | size_t m_idx, size_t n_idx, size_t dst_stride) { | ||
| 102 | − | KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); | |
| 103 | − | KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); | |
| 104 | |||
| 105 | 540 | return (n_idx * sizeof(float) + m_idx * dst_stride); | |
| 106 | } | ||
| 107 | |||
| 108 | 540 | size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n) { | |
| 109 | 540 | return m * n * sizeof(float); | |
| 110 | } | ||
| 111 | |||
| 112 | 542 | void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( | |
| 113 | size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, | ||
| 114 | float* restrict dst, // NOLINT(readability-non-const-parameter) | ||
| 115 | size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { | ||
| 116 | − | KAI_ASSERT(dst_stride_col == sizeof(float)); | |
| 117 | − | KAI_ASSERT(n > 0); | |
| 118 | − | KAI_ASSERT(m > 0); | |
| 119 | |||
| 120 | // Constants | ||
| 121 | 542 | uint64_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); | |
| 122 | 542 | uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); | |
| 123 | 542 | uint64_t lhs_stride = kai_get_lhs_packed_stride(k); | |
| 124 | 542 | uint64_t rhs_stride = kai_get_rhs_packed_stride(k); | |
| 125 | 542 | uint64_t m_blk = (uint64_t)kai_k_roundedup(k) * mr; | |
| 126 | 542 | uint64_t dst_inc = mr * dst_stride_row; | |
| 127 | 542 | float scalar_bounds[2] = {scalar_min, scalar_max}; | |
| 128 | |||
| 129 | 542 | kai_commit_za(); | |
| 130 | |||
| 131 | /* --------------------------------------------------- | ||
| 132 | Registers allocations | ||
| 133 | x7: Look up table(lut) | ||
| 134 | x8: RHS base address (rhs) | ||
| 135 | x9: Destination base address (dst) | ||
| 136 | x10: LHS pointer (lhs) | ||
| 137 | x11: RHS pointer (rhs) | ||
| 138 | x12: Remaining M elements | ||
| 139 | x13: Remaining N elements | ||
| 140 | x14: k exit condition (k_cond) | ||
| 141 | ZA tile index (l_idx) | ||
| 142 | x15: LHS scaling factor pointer (lhs_sf_ptr) | ||
| 143 | x16: ZA tile exit condition (l_cnd) | ||
| 144 | x17: Destination pointer (dst) | ||
| 145 | x19: Destination outer address (dst) | ||
| 146 | x20: LHS base address (lhs) | ||
| 147 | --------------------------------------------------- */ | ||
| 148 | 542 | __asm__ volatile( | |
| 149 | " .inst 0xd503477f //smstart \n" | ||
| 150 | " mov x19, %[dst] \n" | ||
| 151 | " mov x20, %[lhs] \n" | ||
| 152 | " mov x7, %[lut] \n" | ||
| 153 | " .inst 0xe11f80e0 //ldr zt0, [x7] \n" | ||
| 154 | " cntw x7 \n" | ||
| 155 | " ptrue p2.b \n" | ||
| 156 | " ld1rw {z30.s}, p2/Z, [%[scalar_bounds]] \n" | ||
| 157 | " ld1rw {z31.s}, p2/Z, [%[scalar_bounds], #4] \n" | ||
| 158 | |||
| 159 | // M loop head | ||
| 160 | " mov x12, %[m] \n" | ||
| 161 | " .inst 0x25ac17e0 //whilelt p0.s, xzr, x12 \n" | ||
| 162 | "1: \n" | ||
| 163 | " mov x8, %[rhs] \n" | ||
| 164 | " mov x9, x19 \n" | ||
| 165 | " mov x13, %[n] \n" | ||
| 166 | " cmp x7, x12 \n" | ||
| 167 | " csel x16, x7, x12, lt \n" | ||
| 168 | " lsl x16, x16, #2 \n" | ||
| 169 | |||
| 170 | // N loop head | ||
| 171 | " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n" | ||
| 172 | "2: \n" | ||
| 173 | " mov x10, x20 \n" | ||
| 174 | " mov x11, x8 \n" | ||
| 175 | " mov x17, x9 \n" | ||
| 176 | " .inst 0x25ad67f1 //whilelt pn9.s, xzr, x13, vlx4 \n" | ||
| 177 | |||
| 178 | // K loop | ||
| 179 | " .inst 0xc00800ff //zero {za} \n" | ||
| 180 | " add x14, x10, %[m_blk] \n" | ||
| 181 | "3: \n" | ||
| 182 | " .inst 0xa540a144 //ld1w { z4.s }, p0/z, [x10] \n" | ||
| 183 | " .inst 0x042a502a //addvl x10, x10, #1 \n" | ||
| 184 | " .inst 0xa0402160 //ld1h { z0.h-z1.h }, pn8/z, [x11] \n" | ||
| 185 | " .inst 0x042b504b //addvl x11, x11, #2 \n" | ||
| 186 | " .inst 0xc08a4008 //luti4 { z8.b - z9.b }, zt0, z0[0] \n" | ||
| 187 | " .inst 0xc08a402a //luti4 { z10.b - z11.b }, zt0, z1[0] \n" | ||
| 188 | " .inst 0xa0884880 //smopa za0.s, p2/m, p2/m, z4.b, z8.b \n" | ||
| 189 | " .inst 0xa0894881 //smopa za1.s, p2/m, p2/m, z4.b, z9.b \n" | ||
| 190 | " .inst 0xa08a4882 //smopa za2.s, p2/m, p2/m, z4.b, z10.b\n" | ||
| 191 | " .inst 0xa08b4883 //smopa za3.s, p2/m, p2/m, z4.b, z11.b\n" | ||
| 192 | " cmp x10, x14 \n" | ||
| 193 | " b.lt 3b \n" | ||
| 194 | |||
| 195 | // RHS row sum, scale factor & bias | ||
| 196 | " .inst 0xa040c560 //ld1w { z0.s-z3.s }, pn9/z, [x11] \n" | ||
| 197 | " .inst 0xa041c564 //ld1w { z4.s-z7.s }, pn9/z, [x11, #4, mul vl] \n" | ||
| 198 | " .inst 0xa042c568 //ld1w { z8.s-z11.s }, pn9/z, [x11, #8, mul vl]\n" | ||
| 199 | " .inst 0x042b518b //addvl x11, x11, #12 \n" | ||
| 200 | " .inst 0xc132e000 //scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" | ||
| 201 | |||
| 202 | // Store loop | ||
| 203 | " mov x14, #0 \n" | ||
| 204 | " addvl x15, x10, #1 \n" | ||
| 205 | "4: \n" | ||
| 206 | // Load LHS Row-offset & SF | ||
| 207 | " ld1rw {z16.s}, p2/z, [x10] \n" | ||
| 208 | " ld1rw {z17.s}, p2/z, [x15] \n" | ||
| 209 | " add x10, x10, #4 \n" | ||
| 210 | " add x15, x15, #4 \n" | ||
| 211 | " scvtf z16.s, p2/m, z16.s \n" | ||
| 212 | |||
| 213 | // offset x Row-sum | ||
| 214 | " fmul z24.s, z16.s, z0.s \n" | ||
| 215 | " fmul z25.s, z16.s, z1.s \n" | ||
| 216 | " fmul z26.s, z16.s, z2.s \n" | ||
| 217 | " fmul z27.s, z16.s, z3.s \n" | ||
| 218 | |||
| 219 | // Scaling factors | ||
| 220 | " fmul z20.s, z17.s, z4.s \n" | ||
| 221 | " fmul z21.s, z17.s, z5.s \n" | ||
| 222 | " fmul z22.s, z17.s, z6.s \n" | ||
| 223 | " fmul z23.s, z17.s, z7.s \n" | ||
| 224 | |||
| 225 | // Result = offset x Row-sum x SFs | ||
| 226 | " fmul z24.s, z24.s, z20.s \n" | ||
| 227 | " fmul z25.s, z25.s, z21.s \n" | ||
| 228 | " fmul z26.s, z26.s, z22.s \n" | ||
| 229 | " fmul z27.s, z27.s, z23.s \n" | ||
| 230 | |||
| 231 | // Load inner accumulation & convert | ||
| 232 | " .inst 0xc006440c //mova { z12.b-z15.b }, za0h.b[w14, 0:3]\n" | ||
| 233 | " .inst 0xc132e18c //scvtf { z12.s-z15.s }, { z12.s-z15.s } \n" | ||
| 234 | |||
| 235 | // Result += iacc x SF | ||
| 236 | " fmla z24.s, p2/m, z20.s, z12.s \n" | ||
| 237 | " fmla z25.s, p2/m, z21.s, z13.s \n" | ||
| 238 | " fmla z26.s, p2/m, z22.s, z14.s \n" | ||
| 239 | " fmla z27.s, p2/m, z23.s, z15.s \n" | ||
| 240 | |||
| 241 | // Add the bias | ||
| 242 | " fadd z24.s, p2/m, z24.s, z8.s \n" | ||
| 243 | " fadd z25.s, p2/m, z25.s, z9.s \n" | ||
| 244 | " fadd z26.s, p2/m, z26.s, z10.s \n" | ||
| 245 | " fadd z27.s, p2/m, z27.s, z11.s \n" | ||
| 246 | |||
| 247 | // CLAMP and store | ||
| 248 | " .inst 0xc1bfcbd8 //fclamp { z24.s-z27.s }, z30.s, z31.s\n" | ||
| 249 | " .inst 0xa060c638 //st1w { z24.s-z27.s }, pn9, [x17] \n" | ||
| 250 | |||
| 251 | " add x17, x17, %[dst_stride_row] \n" | ||
| 252 | " add x14, x14, #4 \n" | ||
| 253 | " cmp x14, x16 \n" | ||
| 254 | " b.lt 4b \n" | ||
| 255 | |||
| 256 | // N loop tail | ||
| 257 | " add x8, x8, %[rhs_stride] \n" | ||
| 258 | " .inst 0x04295089 // addvl x9, x9, #4 \n" | ||
| 259 | " sub x13, x13, %[nr] \n" | ||
| 260 | " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n" | ||
| 261 | " b.mi 2b \n" | ||
| 262 | |||
| 263 | // M loop tail | ||
| 264 | " add x20, x20, %[lhs_stride] \n" | ||
| 265 | " add x19, x19, %[dst_inc] \n" | ||
| 266 | " sub x12, x12, %[mr] \n" | ||
| 267 | " whilelt p0.s, xzr, x12 \n" | ||
| 268 | " b.mi 1b \n" | ||
| 269 | |||
| 270 | "5: \n" | ||
| 271 | " .inst 0xd503467f //smstop \n" | ||
| 272 | : | ||
| 273 | 542 | : [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_stride] "r"(lhs_stride), [rhs_stride] "r"(rhs_stride), | |
| 274 | 542 | [dst_stride_row] "r"(dst_stride_row), [lut] "r"(lut), [m_blk] "r"(m_blk), [nr] "r"(nr), [mr] "r"(mr), | |
| 275 | 542 | [lhs] "r"(lhs_packed), [rhs] "r"(rhs_packed), [dst_inc] "r"(dst_inc), [scalar_bounds] "r"(scalar_bounds), | |
| 276 | 542 | [dst] "r"(dst) | |
| 277 | : "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "p0", "p2", "p8", | ||
| 278 | "p9", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", | ||
| 279 | "z16", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z30", "z31", | ||
| 280 | #ifdef __ARM_STATE_ZA | ||
| 281 | "za", | ||
| 282 | #endif | ||
| 283 | #ifdef __ARM_STATE_ZT0 | ||
| 284 | "zt0", | ||
| 285 | #endif | ||
| 286 | "cc", "memory"); | ||
| 287 | 542 | } | |
| 288 | |||
| 289 | #endif // Architectural feature check | ||
| 290 |