kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) | ||
| 8 | #error This file must be compiled for AArch64, FEAT_SVE2. | ||
| 9 | #else // Architectural features check. | ||
| 10 | |||
| 11 | #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" | ||
| 12 | |||
| 13 | #include <stddef.h> | ||
| 14 | #include <stdint.h> | ||
| 15 | #include <string.h> | ||
| 16 | |||
| 17 | #include "kai/kai_common.h" | ||
| 18 | |||
| 19 | static const size_t kai_nr = 2; | ||
| 20 | static const size_t kai_kr = 2; | ||
| 21 | static const size_t kai_num_bytes_input = 4; | ||
| 22 | static const size_t kai_num_bytes_output = 2; | ||
| 23 | static const size_t kai_num_bytes_bias = 4; | ||
| 24 | |||
| 25 | 1230 | size_t kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(void) { | |
| 26 | 1230 | return kai_nr * kai_get_sme_vector_length_u16() / kai_kr; | |
| 27 | } | ||
| 28 | |||
| 29 | 150 | size_t kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx) { | |
| 30 | − | KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() == 0); | |
| 31 | |||
| 32 | 150 | return n_idx * kai_num_bytes_input; | |
| 33 | } | ||
| 34 | |||
| 35 | ✗ | size_t kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx) { | |
| 36 | ✗ | return n_idx * kai_num_bytes_bias; | |
| 37 | } | ||
| 38 | |||
| 39 | 360 | size_t kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t k) { | |
| 40 | 720 | return kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() * | |
| 41 | 360 | (kai_num_bytes_bias + kai_roundup(k, kai_kr) * kai_num_bytes_output); | |
| 42 | } | ||
| 43 | |||
| 44 | 180 | size_t kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n_idx, size_t k) { | |
| 45 | − | KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme() == 0); | |
| 46 | |||
| 47 | 180 | const size_t block_idx = n_idx / kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(); | |
| 48 | 360 | return block_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(k); | |
| 49 | 180 | } | |
| 50 | |||
| 51 | 180 | size_t kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(size_t n, size_t k) { | |
| 52 | 180 | const size_t n_rounded_up = kai_roundup(n, kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme()); | |
| 53 | 360 | return kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(n_rounded_up, k); | |
| 54 | 180 | } | |
| 55 | |||
| 56 | 180 | void kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme( | |
| 57 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||
| 58 | const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { | ||
| 59 | − | KAI_ASSUME(num_groups == 1); | |
| 60 | − | KAI_ASSUME(nr == kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme()); | |
| 61 | − | KAI_ASSUME(kr == kai_kr); | |
| 62 | − | KAI_ASSUME(sr == 1); | |
| 63 | − | KAI_ASSUME(rhs != NULL); | |
| 64 | − | KAI_ASSUME(bias != NULL); | |
| 65 | − | KAI_ASSUME(scale == NULL); | |
| 66 | − | KAI_ASSUME(rhs_packed != NULL); | |
| 67 | − | KAI_ASSUME(extra_bytes == 0); | |
| 68 | − | KAI_ASSUME(params == NULL); | |
| 69 | |||
| 70 | 180 | size_t height = k; | |
| 71 | 180 | const size_t width = n; | |
| 72 | 180 | const void* in = rhs; | |
| 73 | 180 | void* out = rhs_packed; | |
| 74 | 180 | const size_t in_stride = rhs_stride; | |
| 75 | 180 | const float* pad_row = rhs; | |
| 76 | |||
| 77 | 180 | size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(height); | |
| 78 | |||
| 79 | 180 | kai_commit_za(); | |
| 80 | |||
| 81 | 360 | __asm__ __volatile__( | |
| 82 | ".inst 0xd503477f // SMSTART ZA\n" | ||
| 83 | "mov x22, %x[out]\n" | ||
| 84 | "mov x21, %x[width]\n" | ||
| 85 | "ptrue p2.b\n" | ||
| 86 | "1:" // Bias: Full loop | ||
| 87 | "mov x20, x21\n" | ||
| 88 | "decw x21, ALL, MUL #2\n" | ||
| 89 | "whilelt p1.s, XZR, x20\n" | ||
| 90 | "decw x20\n" | ||
| 91 | "whilelt p0.s, XZR, x20\n" | ||
| 92 | "cmp x21, #0x0\n" | ||
| 93 | "ld1w { z17.s }, p1/Z, [%x[bias]]\n" | ||
| 94 | "ld1w { z16.s }, p0/Z, [%x[bias], #1, MUL VL]\n" | ||
| 95 | "incb %x[bias], ALL, MUL #2\n" | ||
| 96 | "st1w { z17.s }, p2, [x22]\n" | ||
| 97 | "st1w { z16.s }, p2, [x22, #1, MUL VL]\n" | ||
| 98 | "add x22, x22, %x[out_stride]\n" | ||
| 99 | "bgt 1b\n" | ||
| 100 | "cmp %x[height], #0x8\n" | ||
| 101 | "incb %x[out], ALL, MUL #2\n" | ||
| 102 | "blt 5f\n" | ||
| 103 | "2:" // Main row loop: Head | ||
| 104 | "mov x10, %x[in]\n" | ||
| 105 | "mov x9, %x[out]\n" | ||
| 106 | "add x28, x10, %x[in_stride]\n" | ||
| 107 | "sub %x[height], %x[height], #0x8\n" | ||
| 108 | "add x27, x28, %x[in_stride]\n" | ||
| 109 | "mov x26, %x[width]\n" | ||
| 110 | "add x25, x27, %x[in_stride]\n" | ||
| 111 | "add x24, x25, %x[in_stride]\n" | ||
| 112 | "add x23, x24, %x[in_stride]\n" | ||
| 113 | "add x22, x23, %x[in_stride]\n" | ||
| 114 | "add x21, x22, %x[in_stride]\n" | ||
| 115 | "add %x[in], x21, %x[in_stride]\n" | ||
| 116 | "3:" // Main row loop: Column loop | ||
| 117 | "mov x20, x26\n" | ||
| 118 | "decw x26, ALL, MUL #2\n" | ||
| 119 | "whilelt p1.s, XZR, x20\n" | ||
| 120 | "decw x20\n" | ||
| 121 | "whilelt p0.s, XZR, x20\n" | ||
| 122 | "ld1w { z19.s }, p1/Z, [x10]\n" | ||
| 123 | "cmp x26, #0x0\n" | ||
| 124 | "ld1w { z18.s }, p0/Z, [x10, #1, MUL VL]\n" | ||
| 125 | "addvl x10, x10, #2\n" | ||
| 126 | "ld1w { z17.s }, p1/Z, [x27]\n" | ||
| 127 | "ld1w { z16.s }, p0/Z, [x27, #1, MUL VL]\n" | ||
| 128 | ".inst 0x658aaa7b // bfcvt z27.h, p2/M, z19.s\n" | ||
| 129 | "addvl x27, x27, #2\n" | ||
| 130 | "ld1w { z19.s }, p1/Z, [x24]\n" | ||
| 131 | ".inst 0x658aaa5a // bfcvt z26.h, p2/M, z18.s\n" | ||
| 132 | "ld1w { z18.s }, p0/Z, [x24, #1, MUL VL]\n" | ||
| 133 | ".inst 0x658aaa39 // bfcvt z25.h, p2/M, z17.s\n" | ||
| 134 | "addvl x24, x24, #2\n" | ||
| 135 | "ld1w { z17.s }, p1/Z, [x22]\n" | ||
| 136 | ".inst 0x658aaa18 // bfcvt z24.h, p2/M, z16.s\n" | ||
| 137 | "ld1w { z16.s }, p0/Z, [x22, #1, MUL VL]\n" | ||
| 138 | ".inst 0x658aaa77 // bfcvt z23.h, p2/M, z19.s\n" | ||
| 139 | "addvl x22, x22, #2\n" | ||
| 140 | ".inst 0x658aaa56 // bfcvt z22.h, p2/M, z18.s\n" | ||
| 141 | "ld1w { z19.s }, p1/Z, [x28]\n" | ||
| 142 | ".inst 0x658aaa35 // bfcvt z21.h, p2/M, z17.s\n" | ||
| 143 | "ld1w { z18.s }, p0/Z, [x28, #1, MUL VL]\n" | ||
| 144 | "addvl x28, x28, #2\n" | ||
| 145 | ".inst 0x658aaa14 // bfcvt z20.h, p2/M, z16.s\n" | ||
| 146 | "ld1w { z17.s }, p1/Z, [x25]\n" | ||
| 147 | "ld1w { z16.s }, p0/Z, [x25, #1, MUL VL]\n" | ||
| 148 | "addvl x25, x25, #2\n" | ||
| 149 | ".inst 0x648aaa7b // bfcvtnt z27.h, p2/M, z19.s\n" | ||
| 150 | "ld1w { z19.s }, p1/Z, [x23]\n" | ||
| 151 | ".inst 0x648aaa5a // bfcvtnt z26.h, p2/M, z18.s\n" | ||
| 152 | "ld1w { z18.s }, p0/Z, [x23, #1, MUL VL]\n" | ||
| 153 | "addvl x23, x23, #2\n" | ||
| 154 | ".inst 0x648aaa39 // bfcvtnt z25.h, p2/M, z17.s\n" | ||
| 155 | "ld1w { z17.s }, p1/Z, [x21]\n" | ||
| 156 | ".inst 0x648aaa18 // bfcvtnt z24.h, p2/M, z16.s\n" | ||
| 157 | "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n" | ||
| 158 | "addvl x21, x21, #2\n" | ||
| 159 | ".inst 0x648aaa77 // bfcvtnt z23.h, p2/M, z19.s\n" | ||
| 160 | "st1h { z27.h }, p2, [x9]\n" | ||
| 161 | ".inst 0x648aaa56 // bfcvtnt z22.h, p2/M, z18.s\n" | ||
| 162 | "st1h { z26.h }, p2, [x9, #1, MUL VL]\n" | ||
| 163 | ".inst 0x648aaa35 // bfcvtnt z21.h, p2/M, z17.s\n" | ||
| 164 | "st1h { z25.h }, p2, [x9, #2, MUL VL]\n" | ||
| 165 | ".inst 0x648aaa14 // bfcvtnt z20.h, p2/M, z16.s\n" | ||
| 166 | "st1h { z24.h }, p2, [x9, #3, MUL VL]\n" | ||
| 167 | "st1h { z23.h }, p2, [x9, #4, MUL VL]\n" | ||
| 168 | "st1h { z22.h }, p2, [x9, #5, MUL VL]\n" | ||
| 169 | "st1h { z21.h }, p2, [x9, #6, MUL VL]\n" | ||
| 170 | "st1h { z20.h }, p2, [x9, #7, MUL VL]\n" | ||
| 171 | "add x9, x9, %x[out_stride]\n" | ||
| 172 | "bgt 3b\n" | ||
| 173 | "cmp %x[height], #0x8\n" | ||
| 174 | "addvl %x[out], %x[out], #8\n" | ||
| 175 | "bge 2b\n" | ||
| 176 | "cbz %x[height], 9f\n" | ||
| 177 | "5:" // Main loop skip | ||
| 178 | "6:" // Tail row loop: Head | ||
| 179 | "mov x10, %x[in]\n" | ||
| 180 | "cmp %x[height], #0x1\n" | ||
| 181 | "add x28, x10, %x[in_stride]\n" | ||
| 182 | "mov x9, %x[out]\n" | ||
| 183 | "add %x[in], x28, %x[in_stride]\n" | ||
| 184 | "csel x28, x28, %x[pad_row], GT\n" | ||
| 185 | "sub %x[height], %x[height], #0x2\n" | ||
| 186 | "mov x21, %x[width]\n" | ||
| 187 | "7:" // Tail row loop: Column loop | ||
| 188 | "mov x20, x21\n" | ||
| 189 | "decw x21, ALL, MUL #2\n" | ||
| 190 | "whilelt p1.s, XZR, x20\n" | ||
| 191 | "decw x20\n" | ||
| 192 | "whilelt p0.s, XZR, x20\n" | ||
| 193 | "ld1w { z17.s }, p1/Z, [x10]\n" | ||
| 194 | "cmp x21, #0x0\n" | ||
| 195 | "ld1w { z16.s }, p0/Z, [x10, #1, MUL VL]\n" | ||
| 196 | "addvl x10, x10, #2\n" | ||
| 197 | "ld1w { z19.s }, p1/Z, [x28]\n" | ||
| 198 | ".inst 0x658aaa32 // bfcvt z18.h, p2/M, z17.s\n" | ||
| 199 | "ld1w { z17.s }, p0/Z, [x28, #1, MUL VL]\n" | ||
| 200 | "addvl x28, x28, #2\n" | ||
| 201 | ".inst 0x658aaa10 // bfcvt z16.h, p2/M, z16.s\n" | ||
| 202 | ".inst 0x648aaa72 // bfcvtnt z18.h, p2/M, z19.s\n" | ||
| 203 | ".inst 0x648aaa30 // bfcvtnt z16.h, p2/M, z17.s\n" | ||
| 204 | "st1h { z18.h }, p2, [x9]\n" | ||
| 205 | "st1h { z16.h }, p2, [x9, #1, MUL VL]\n" | ||
| 206 | "add x9, x9, %x[out_stride]\n" | ||
| 207 | "bgt 7b\n" | ||
| 208 | "cmp %x[height], #0x1\n" | ||
| 209 | "addvl %x[out], %x[out], #2\n" | ||
| 210 | "bge 6b\n" | ||
| 211 | "9:" // Done | ||
| 212 | ".inst 0xd503467f // SMSTOP\n" | ||
| 213 | : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) | ||
| 214 | 180 | : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row), [width] "r"(width) | |
| 215 | : "cc", "memory", "p0", "p1", "p10", "p11", "p12", "p13", "p14", "p15", "p2", "p3", "p4", "p5", "p6", "p7", | ||
| 216 | "p8", "p9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9", "z0", "z1", "z10", | ||
| 217 | "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z2", "z20", "z21", "z22", "z23", "z24", "z25", | ||
| 218 | "z26", "z27", "z28", "z29", "z3", "z30", "z31", "z4", "z5", "z6", "z7", "z8", "z9"); | ||
| 219 | 180 | } | |
| 220 | |||
| 221 | #endif // Architectural features check. | ||
| 222 |