kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | // Do not flag up inline assembly blocks | ||
| 8 | #pragma GCC diagnostic ignored "-Woverlength-strings" | ||
| 9 | |||
| 10 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) | ||
| 11 | #error This file must be compiled for AArch64, FEAT_SVE2. | ||
| 12 | #else // Architectural features check | ||
| 13 | |||
| 14 | #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h" | ||
| 15 | |||
| 16 | #include <stddef.h> | ||
| 17 | #include <stdint.h> | ||
| 18 | |||
| 19 | #include "kai/kai_common.h" | ||
| 20 | |||
| 21 | static const size_t kai_m_step = 1; | ||
| 22 | static const size_t kai_n_step = 1; | ||
| 23 | static const size_t kai_mr = 1; | ||
| 24 | static const size_t kai_nr = 4; // multiple of vector length | ||
| 25 | static const size_t kai_kr = 4; | ||
| 26 | static const size_t kai_sr = 1; | ||
| 27 | static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); | ||
| 28 | static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); | ||
| 29 | static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); | ||
| 30 | static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); | ||
| 31 | static const size_t kai_num_bytes_bias_rhs = sizeof(float); | ||
| 32 | static const size_t kai_k_multiple_of = 32; | ||
| 33 | |||
| 34 | 2946 | inline static size_t kai_k_roundedup(size_t k) { | |
| 35 | // Round up k to be a multiple of 32. | ||
| 36 | 2946 | return kai_roundup(k, kai_k_multiple_of); | |
| 37 | } | ||
| 38 | |||
| 39 | 1202 | inline static size_t kai_get_lhs_packed_stride(size_t k) { | |
| 40 | 1202 | const size_t k_internal = kai_k_roundedup(k); | |
| 41 | |||
| 42 | − | KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); | |
| 43 | |||
| 44 | 3606 | return kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() * | |
| 45 | 1202 | (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); | |
| 46 | 1202 | } | |
| 47 | |||
| 48 | 1202 | inline static size_t kai_get_rhs_packed_stride(size_t k) { | |
| 49 | 1202 | const size_t k_internal = kai_k_roundedup(k); | |
| 50 | |||
| 51 | − | KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); | |
| 52 | |||
| 53 | 3606 | return kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() * | |
| 54 | 1202 | ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs); | |
| 55 | 1202 | } | |
| 56 | |||
| 57 | 780 | size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
| 58 | 780 | return kai_m_step; | |
| 59 | } | ||
| 60 | |||
| 61 | 3644 | size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
| 62 | 3644 | return kai_nr * kai_get_sme_vector_length_u32(); | |
| 63 | } | ||
| 64 | |||
| 65 | 780 | size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
| 66 | 780 | return kai_n_step * kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(); | |
| 67 | } | ||
| 68 | |||
| 69 | 1662 | size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
| 70 | // For gemv mr must be 1 to consecutively read the data | ||
| 71 | 1662 | return kai_mr; | |
| 72 | } | ||
| 73 | |||
| 74 | 580 | size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
| 75 | 580 | return kai_kr; | |
| 76 | } | ||
| 77 | |||
| 78 | 580 | size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(void) { | |
| 79 | 580 | return kai_sr; | |
| 80 | } | ||
| 81 | |||
| 82 | 660 | size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t m_idx, size_t k) { | |
| 83 | − | KAI_ASSERT((m_idx % kai_m_step) == 0); | |
| 84 | |||
| 85 | 660 | return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); | |
| 86 | } | ||
| 87 | |||
| 88 | 660 | size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t n_idx, size_t k) { | |
| 89 | − | KAI_ASSERT((n_idx % kai_n_step) == 0); | |
| 90 | 660 | const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(); | |
| 91 | 1320 | return (n_idx / nr) * kai_get_rhs_packed_stride(k); | |
| 92 | 660 | } | |
| 93 | |||
| 94 | 540 | size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot( | |
| 95 | size_t m_idx, size_t n_idx, size_t dst_stride) { | ||
| 96 | − | KAI_ASSERT((m_idx % kai_m_step) == 0); | |
| 97 | − | KAI_ASSERT((n_idx % kai_n_step) == 0); | |
| 98 | |||
| 99 | 540 | return (n_idx * sizeof(float)) + (m_idx * dst_stride); | |
| 100 | } | ||
| 101 | |||
| 102 | 540 | size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) { | |
| 103 | 540 | return m * n * sizeof(float); | |
| 104 | } | ||
| 105 | |||
| 106 | /// Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110). | ||
| 107 | static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, | ||
| 108 | 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0, | ||
| 109 | -5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0}; | ||
| 110 | |||
| 111 | // Optimized for GEMV (matrix vector multiplication => m == 1). | ||
| 112 | // Does a matmul for compatibility reasons, but should not be used that way. | ||
| 113 | 542 | void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot( | |
| 114 | size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, | ||
| 115 | float* dst, // NOLINT(readability-non-const-parameter) | ||
| 116 | size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { | ||
| 117 | − | KAI_ASSERT(dst_stride_col == sizeof(float)); | |
| 118 | |||
| 119 |
3/6✓ Branch 0 taken 542 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 542 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 542 times.
|
542 | if (m == 0 || n == 0 || k == 0) { |
| 120 | ✗ | return; | |
| 121 | } | ||
| 122 | |||
| 123 | // Do function calls and calculations first to not overwrite registers we will use | ||
| 124 | 542 | uint64_t k_internal = kai_k_roundedup(k); | |
| 125 | 542 | uint64_t lhs_stride = kai_get_lhs_packed_stride(k); | |
| 126 | 542 | uint64_t rhs_stride = kai_get_rhs_packed_stride(k); | |
| 127 | 542 | uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot(); | |
| 128 | |||
| 129 | 542 | uint64_t rhs_row_bytes = nr * k_internal / 2; | |
| 130 | 542 | uint64_t lhs_end_ptr = ((uint64_t)lhs_packed) + (m * lhs_stride); | |
| 131 | |||
| 132 | 542 | kai_commit_za(); | |
| 133 | |||
| 134 | /* | ||
| 135 | * x11: zero = 0 // MUST BE x8-x11 | ||
| 136 | * x15: n initialized as n | ||
| 137 | * x19: nr initialized as nr | ||
| 138 | * x20: lut_ptr initialized as lut | ||
| 139 | * x21: lhs_packed initialized as lhs_packed | ||
| 140 | * x22: n_idx | ||
| 141 | * x23: k_idx | ||
| 142 | * x24: RHS block ptr | ||
| 143 | * x25: RHS end ptr | ||
| 144 | * x26: rhs_packed | ||
| 145 | * x27: dst_ptr | ||
| 146 | * x28: tmp_1 | ||
| 147 | */ | ||
| 148 | |||
| 149 | 1084 | __asm__ volatile( | |
| 150 | |||
| 151 | // Setup | ||
| 152 | " .inst 0xd503477f // smstart \n" | ||
| 153 | " mov x11, #0 \n" | ||
| 154 | " mov x15, %[n] \n" | ||
| 155 | " mov x19, %[nr] \n" | ||
| 156 | " mov x21, %[lhs_packed] \n" | ||
| 157 | " mov x20, %[lut] \n" | ||
| 158 | " .inst 0xe11f8280 // ldr zt0, [x20] \n" | ||
| 159 | " ptrue p0.b \n" | ||
| 160 | " .inst 0x25207810 // ptrue pn8.b \n" | ||
| 161 | // predicate to load nr words for the RHS sums and scaling factors (should be exactly all true) | ||
| 162 | " .inst 0x25b36571 // whilelt pn9.s, x11, x19, vlx4 \n" | ||
| 163 | " dup z30.s, %w[scalar_min] \n" | ||
| 164 | " dup z31.s, %w[scalar_max] \n" | ||
| 165 | |||
| 166 | // lhs matrix row loop | ||
| 167 | "1: \n" | ||
| 168 | // Reset rhs matrix ptr | ||
| 169 | " mov x26, %[rhs_packed] \n" | ||
| 170 | // Reset dst_ptr to dst of next GEMV result | ||
| 171 | " mov x27, %[dst_ptr] \n" | ||
| 172 | // Reset n index | ||
| 173 | " mov x22, #0 \n" | ||
| 174 | // whilelt pn12.s, x22, %[n], vlx4 | ||
| 175 | " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" | ||
| 176 | |||
| 177 | // rhs matrix row loop (transposed so theoretical columns) | ||
| 178 | "2: \n" | ||
| 179 | |||
| 180 | // Reset rhs block ptr to start of row | ||
| 181 | " mov x24, x26 \n" | ||
| 182 | " add x25, x26, %[rhs_row_bytes] \n" | ||
| 183 | " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" | ||
| 184 | " addvl x28, x24, #4 \n" | ||
| 185 | " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" | ||
| 186 | " mov x23, #0 \n" | ||
| 187 | " whilelt p1.b, x23, %[k_internal] \n" | ||
| 188 | // Zero for sdot accumulation in inner loop | ||
| 189 | " .inst 0xc00800ff // zero {za} \n" | ||
| 190 | |||
| 191 | // before k loop | ||
| 192 | "3: \n" | ||
| 193 | |||
| 194 | // Load lhs | ||
| 195 | " ld1rqb { z0.b }, p1/z , [x21, x23] \n" | ||
| 196 | |||
| 197 | // Load w | ||
| 198 | " .inst 0xa0408b10 // ld1b { z16.b - z19.b }, pn10/z, [x24] \n" | ||
| 199 | " .inst 0xa0418f14 // ld1b {z20.b-z23.b}, pn11/z, [x24,#0x4, mul vl]\n" | ||
| 200 | |||
| 201 | // rhs i4 to i8 and sdot | ||
| 202 | // k block + 0 | ||
| 203 | " .inst 0xc08a4218 // luti4 { z24.b, z25.b }, zt0, z16[0] \n" | ||
| 204 | " .inst 0xc08a423a // luti4 { z26.b, z27.b }, zt0, z17[0] \n" | ||
| 205 | " .inst 0xc150f320 // sdot za.s[w11,0, vgx4], {z24.b-z27.b}, z0.b[0]\n" | ||
| 206 | // k block + 1 | ||
| 207 | " .inst 0xc08a4244 // luti4 { z4.b, z5.b }, zt0, z18[0] \n" | ||
| 208 | " .inst 0xc08a4266 // luti4 { z6.b, z7.b }, zt0, z19[0] \n" | ||
| 209 | " .inst 0xc150f4a0 // sdot za.s[w11,0, vgx4], {z4.b-z7.b}, z0.b[1] \n" | ||
| 210 | // k block + 2 | ||
| 211 | " .inst 0xc08a4288 // luti4 { z8.b, z9.b }, zt0, z20[0] \n" | ||
| 212 | " .inst 0xc08a42aa // luti4 { z10.b, z11.b }, zt0, z21[0] \n" | ||
| 213 | " .inst 0xc150f920 // sdot za.s[w11,0, vgx4], {z8.b-z11.b}, z0.b[2] \n" | ||
| 214 | // k block + 3 | ||
| 215 | " .inst 0xc08a42cc // luti4 { z12.b, z13.b }, zt0, z22[0] \n" | ||
| 216 | " .inst 0xc08a42ee // luti4 { z14.b, z15.b }, zt0, z23[0] \n" | ||
| 217 | " .inst 0xc150fda0 // sdot za.s[w11,0, vgx4], {z12.b-z15.b}, z0.b[3]\n" | ||
| 218 | |||
| 219 | // End K block loop | ||
| 220 | " addvl x24, x24, #8 \n" | ||
| 221 | " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" | ||
| 222 | " addvl x28, x24, #4 \n" | ||
| 223 | " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" | ||
| 224 | " add x23, x23, #16 \n" | ||
| 225 | " whilelt p1.b, x23, %[k_internal] \n" | ||
| 226 | " b.first 3b \n" | ||
| 227 | |||
| 228 | // Finish of accumulators with scaling factors and zero points | ||
| 229 | |||
| 230 | // Load lhs zero point | ||
| 231 | " add x28, x21, %[k_internal] \n" | ||
| 232 | " ld1rw { z2.s }, p0/z , [x28] \n" | ||
| 233 | // Load lhs scaling factor | ||
| 234 | " ld1rw { z3.s }, p0/z , [x28, #4] \n" | ||
| 235 | // Load rhs sums | ||
| 236 | " add x28, x26, %[rhs_row_bytes] \n" | ||
| 237 | " .inst 0xa040c794 // ld1w { z20.s - z23.s }, pn9/z, [x28] \n" | ||
| 238 | // Load rhs scaling factors | ||
| 239 | " .inst 0xa041c798 // ld1w {z24.s-z27.s}, pn9/z, [x28, #0x4, mul vl]\n" | ||
| 240 | // Load biases | ||
| 241 | " .inst 0xa042c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, #0x8, mul vl]\n" | ||
| 242 | |||
| 243 | // Get accumulated value out of ZA | ||
| 244 | " .inst 0xc0066c04 // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] \n" | ||
| 245 | |||
| 246 | // za contains a * w, which needs to be done + z * wsum -> smla | ||
| 247 | // zero point * rhs row sum | ||
| 248 | " mla z4.s, p0/m, z20.s, z2.s \n" | ||
| 249 | " mla z5.s, p0/m, z21.s, z2.s \n" | ||
| 250 | " mla z6.s, p0/m, z22.s, z2.s \n" | ||
| 251 | " mla z7.s, p0/m, z23.s, z2.s \n" | ||
| 252 | |||
| 253 | // Convert to float | ||
| 254 | " .inst 0xc132e084 // scvtf { z4.s - z7.s }, { z4.s - z7.s } \n" | ||
| 255 | |||
| 256 | // lhs scaling factor * rhs scaling factor | ||
| 257 | " fmul z24.s, z24.s, z3.s \n" | ||
| 258 | " fmul z25.s, z25.s, z3.s \n" | ||
| 259 | " fmul z26.s, z26.s, z3.s \n" | ||
| 260 | " fmul z27.s, z27.s, z3.s \n" | ||
| 261 | |||
| 262 | // Bias + combined scaling factor * combined accumulator | ||
| 263 | " fmla z12.s, p0/m, z24.s, z4.s \n" | ||
| 264 | " fmla z13.s, p0/m, z25.s, z5.s \n" | ||
| 265 | " fmla z14.s, p0/m, z26.s, z6.s \n" | ||
| 266 | " fmla z15.s, p0/m, z27.s, z7.s \n" | ||
| 267 | |||
| 268 | // Clamp | ||
| 269 | " .inst 0xc1bfcbcc // fclamp { z12.s - z15.s }, z30.s, z31.s \n" | ||
| 270 | |||
| 271 | // Store | ||
| 272 | " .inst 0xa036d36c // st1w {z12.s-z15.s}, pn12, [x27, x22, lsl #2] \n" | ||
| 273 | |||
| 274 | // End rhs row loop | ||
| 275 | " add x26, x26, %[rhs_stride] \n" | ||
| 276 | // nr == svlb | ||
| 277 | " addvl x22, x22, #1 \n" | ||
| 278 | // whilelt pn12.s, x22, %[n], vlx4 | ||
| 279 | " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" | ||
| 280 | " b.lt 2b \n" | ||
| 281 | |||
| 282 | // End lhs row loop | ||
| 283 | " add %[dst_ptr], %[dst_ptr], %[dst_stride_row] \n" | ||
| 284 | " add x21, x21, %[lhs_stride] \n" | ||
| 285 | " cmp x21, %[lhs_end_ptr] \n" | ||
| 286 | " b.lt 1b \n" | ||
| 287 | |||
| 288 | " .inst 0xd503467f // smstop \n" | ||
| 289 | |||
| 290 | : [dst_ptr] "+r"(dst) | ||
| 291 | 542 | : [lut] "r"(lut), [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_packed] "r"(lhs_packed), | |
| 292 | 542 | [rhs_packed] "r"(rhs_packed), [dst_stride_row] "r"(dst_stride_row), [scalar_min] "r"(scalar_min), | |
| 293 | 542 | [scalar_max] "r"(scalar_max), [k_internal] "r"(k_internal), [lhs_stride] "r"(lhs_stride), | |
| 294 | 542 | [rhs_stride] "r"(rhs_stride), [nr] "r"(nr), [rhs_row_bytes] "r"(rhs_row_bytes), [lhs_end_ptr] "r"(lhs_end_ptr) | |
| 295 | : "x11", "x15", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "p0", "p1", "p8", "p9", | ||
| 296 | "p10", "p11", "p12", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", | ||
| 297 | "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", | ||
| 298 | "z29", "z30", "z31", | ||
| 299 | #ifdef __ARM_STATE_ZA | ||
| 300 | "za", | ||
| 301 | #endif | ||
| 302 | #ifdef __ARM_STATE_ZT0 | ||
| 303 | "zt0", | ||
| 304 | #endif | ||
| 305 | "memory", "cc"); | ||
| 306 | 542 | } | |
| 307 | |||
| 308 | #endif // Architectural features check. | ||
| 309 |