kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | // Do not flag up inline assembly blocks | ||
| 8 | #pragma GCC diagnostic ignored "-Woverlength-strings" | ||
| 9 | |||
| 10 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) | ||
| 11 | #error This file must be compiled for AArch64, FEAT_SVE2. | ||
| 12 | #else // Architectural features check. | ||
| 13 | #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" | ||
| 14 | |||
| 15 | #include <stddef.h> | ||
| 16 | #include <stdint.h> | ||
| 17 | |||
| 18 | #include "kai/kai_common.h" | ||
| 19 | |||
| 20 | enum { | ||
| 21 | MR = 2, | ||
| 22 | KR = 2, | ||
| 23 | MAX_M_STEP = MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR, | ||
| 24 | SR = 1, | ||
| 25 | }; | ||
| 26 | |||
| 27 | 750 | static size_t kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme(void) { | |
| 28 | 750 | return MR * kai_get_sme_vector_length_u16() / KR; | |
| 29 | } | ||
| 30 | |||
| 31 | ✗ | size_t kai_get_m_step_lhs_pack_bf16p2vlx2_f32_sme(size_t mr) { | |
| 32 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()); | |
| 33 | ✗ | KAI_UNUSED(mr); | |
| 34 | ✗ | return kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme(); | |
| 35 | } | ||
| 36 | |||
| 37 | 150 | size_t kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme(size_t m_idx, size_t lhs_stride_row) { | |
| 38 | − | KAI_ASSUME(m_idx % kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme() == 0); | |
| 39 | |||
| 40 | 150 | return m_idx * lhs_stride_row; | |
| 41 | } | ||
| 42 | |||
| 43 | ✗ | size_t kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | |
| 44 | − | KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_bf16p2vlx2_f32_sme(mr) == 0); | |
| 45 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()); | |
| 46 | − | KAI_ASSUME(kr == KR); | |
| 47 | − | KAI_ASSUME(sr == SR); | |
| 48 | |||
| 49 | ✗ | KAI_UNUSED(mr); | |
| 50 | ✗ | KAI_UNUSED(kr); | |
| 51 | ✗ | KAI_UNUSED(sr); | |
| 52 | |||
| 53 | ✗ | return m_idx * kai_roundup(k, KR) * sizeof(uint16_t); | |
| 54 | } | ||
| 55 | |||
| 56 | 150 | size_t kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
| 57 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()); | |
| 58 | − | KAI_ASSUME(kr == KR); | |
| 59 | − | KAI_ASSUME(sr == SR); | |
| 60 | |||
| 61 | 150 | KAI_UNUSED(mr); | |
| 62 | 150 | KAI_UNUSED(kr); | |
| 63 | 150 | KAI_UNUSED(sr); | |
| 64 | 150 | return kai_roundup(m, kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()) * kai_roundup(k, KR) * sizeof(uint16_t); | |
| 65 | } | ||
| 66 | |||
| 67 | 150 | void kai_run_lhs_pack_bf16p2vlx2_f32_sme( | |
| 68 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride_row, | ||
| 69 | void* lhs_packed) { | ||
| 70 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme()); | |
| 71 | − | KAI_ASSUME(kr == KR); | |
| 72 | − | KAI_ASSUME(sr == SR); | |
| 73 | − | KAI_ASSUME(m_idx_start == 0); | |
| 74 | − | KAI_ASSUME(lhs != NULL); | |
| 75 | − | KAI_ASSUME(lhs_packed != NULL); | |
| 76 | |||
| 77 | 150 | const size_t m_step = kai_get_mr_lhs_pack_bf16p2vlx2_f32_sme(); | |
| 78 | 150 | const size_t width = k; | |
| 79 | |||
| 80 | − | KAI_ASSERT(m_step <= MAX_M_STEP); | |
| 81 | 150 | const uint8_t* in[MAX_M_STEP]; | |
| 82 | |||
| 83 | 150 | uint8_t* out_base = lhs_packed; | |
| 84 | 150 | const uint8_t* lhs_ptr = lhs; | |
| 85 | |||
| 86 | 150 | kai_commit_za(); | |
| 87 | |||
| 88 |
2/2✓ Branch 0 taken 150 times.
✓ Branch 1 taken 471 times.
|
621 | for (size_t i_m = 0; i_m < m; i_m += m_step) { |
| 89 |
2/2✓ Branch 0 taken 120 times.
✓ Branch 1 taken 351 times.
|
471 | const size_t height = KAI_MIN(m - i_m, m_step); |
| 90 | 471 | void* out = out_base; | |
| 91 | 471 | out_base += m_step * kai_roundup(k, KR) * sizeof(uint16_t); | |
| 92 | |||
| 93 |
2/2✓ Branch 0 taken 12258 times.
✓ Branch 1 taken 471 times.
|
12729 | for (size_t y = 0; y < height; y++) { |
| 94 | 12258 | in[y] = lhs_ptr + (i_m + y) * lhs_stride_row; | |
| 95 | 12258 | } | |
| 96 | |||
| 97 | 942 | __asm__ __volatile__( | |
| 98 | ".inst 0xd503477f // SMSTART ZA\n" | ||
| 99 | "sub x10, %x[width], #0x1\n" | ||
| 100 | "mov x9, #0x0\n" | ||
| 101 | "cntw x22, ALL, MUL #2\n" | ||
| 102 | "cntw x28\n" | ||
| 103 | "cntw x21, ALL, MUL #2\n" | ||
| 104 | "sub x20, x22, #0x1\n" | ||
| 105 | ".inst 0x25207815 // ptrue pn13.b\n" | ||
| 106 | "whilelt p12.s, XZR, %x[height]\n" | ||
| 107 | "whilelt p11.s, x28, %x[height]\n" | ||
| 108 | "add x10, x10, x21\n" | ||
| 109 | "ands x27, %x[width], x20\n" | ||
| 110 | "udiv x10, x10, x21\n" | ||
| 111 | "csel x27, x27, x22, NE\n" | ||
| 112 | "and x26, x10, #0x1\n" | ||
| 113 | "sub x10, x10, #0x1\n" | ||
| 114 | "add x27, x27, #0x1\n" | ||
| 115 | "mov x20, %x[width]\n" | ||
| 116 | "mov x25, %x[in]\n" | ||
| 117 | "ptrue p0.b\n" | ||
| 118 | "mov x24, %x[outptr_raw]\n" | ||
| 119 | "mov x23, #0x0\n" | ||
| 120 | "lsr x10, x10, #0x1\n" | ||
| 121 | "lsr x27, x27, #0x1\n" | ||
| 122 | "mov x12, #0x0\n" | ||
| 123 | ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n" | ||
| 124 | "add x22, x25, x28, LSL #3\n" | ||
| 125 | "1:" // Width loop: Preamble: Loop | ||
| 126 | "ldr x21, [x25], #0x8\n" | ||
| 127 | ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n" | ||
| 128 | ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n" | ||
| 129 | "ldr x20, [x22], #0x8\n" | ||
| 130 | ".inst 0xa01746b4 // ld1w { z20.s-z21.s }, pn9.s/Z, [x21, x23, LSL #2]\n" | ||
| 131 | ".inst 0xa017428c // ld1w { z12.s-z13.s }, pn8.s/Z, [x20, x23, LSL #2]\n" | ||
| 132 | ".inst 0xc160e294 // bfcvt z20.h, { z20.s-z21.s }\n" | ||
| 133 | ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n" | ||
| 134 | ".inst 0xc0800280 // mova za0h.s[x12], p0/M, z20.s\n" | ||
| 135 | ".inst 0xc0800184 // mova za1h.s[x12], p0/M, z12.s\n" | ||
| 136 | "add x12, x12, #0x1\n" | ||
| 137 | "cmp x12, x28\n" | ||
| 138 | "blt 1b\n" | ||
| 139 | "incw x23, ALL, MUL #2\n" | ||
| 140 | "incw x9, ALL, MUL #2\n" | ||
| 141 | "cbz x10, 5f\n" | ||
| 142 | "2:" // Width loop | ||
| 143 | "mov x20, %x[width]\n" | ||
| 144 | "mov x25, %x[in]\n" | ||
| 145 | "mov x12, #0x0\n" | ||
| 146 | ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n" | ||
| 147 | "add x22, x25, x28, LSL #3\n" | ||
| 148 | "3:" // Width loop: Odd: Loop | ||
| 149 | "ldr x21, [x25], #0x8\n" | ||
| 150 | ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n" | ||
| 151 | ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n" | ||
| 152 | ".inst 0xc0828007 // mova z7.s, p0/M, za0v.s[x12]\n" | ||
| 153 | "ldr x20, [x22], #0x8\n" | ||
| 154 | ".inst 0xc082808f // mova z15.s, p0/M, za1v.s[x12]\n" | ||
| 155 | ".inst 0xa01746b6 // ld1w { z22.s-z23.s }, pn9.s/Z, [x21, x23, LSL #2]\n" | ||
| 156 | ".inst 0xa017429a // ld1w { z26.s-z27.s }, pn8.s/Z, [x20, x23, LSL #2]\n" | ||
| 157 | ".inst 0xa1605707 // st1w { z7.s, z15.s }, pn13.b, [x24]\n" | ||
| 158 | "addvl x24, x24, #2\n" | ||
| 159 | ".inst 0xc160e2d6 // bfcvt z22.h, { z22.s-z23.s }\n" | ||
| 160 | ".inst 0xc160e35a // bfcvt z26.h, { z26.s-z27.s }\n" | ||
| 161 | ".inst 0xc08002c8 // mova za2h.s[x12], p0/M, z22.s\n" | ||
| 162 | ".inst 0xc080034c // mova za3h.s[x12], p0/M, z26.s\n" | ||
| 163 | "add x12, x12, #0x1\n" | ||
| 164 | "cmp x12, x28\n" | ||
| 165 | "blt 3b\n" | ||
| 166 | "incw x9, ALL, MUL #2\n" | ||
| 167 | "mov x20, %x[width]\n" | ||
| 168 | "mov x25, %x[in]\n" | ||
| 169 | "incw x23, ALL, MUL #2\n" | ||
| 170 | "mov x12, #0x0\n" | ||
| 171 | ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n" | ||
| 172 | "add x22, x25, x28, LSL #3\n" | ||
| 173 | "4:" // Width loop: Even: Loop | ||
| 174 | "ldr x21, [x25], #0x8\n" | ||
| 175 | ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n" | ||
| 176 | ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n" | ||
| 177 | ".inst 0xc0828108 // mova z8.s, p0/M, za2v.s[x12]\n" | ||
| 178 | "ldr x20, [x22], #0x8\n" | ||
| 179 | ".inst 0xc0828189 // mova z9.s, p0/M, za3v.s[x12]\n" | ||
| 180 | ".inst 0xa01746ae // ld1w { z14.s-z15.s }, pn9.s/Z, [x21, x23, LSL #2]\n" | ||
| 181 | ".inst 0xa017428c // ld1w { z12.s-z13.s }, pn8.s/Z, [x20, x23, LSL #2]\n" | ||
| 182 | ".inst 0xa0605708 // st1w { z8.s-z9.s }, pn13.b, [x24]\n" | ||
| 183 | "addvl x24, x24, #2\n" | ||
| 184 | ".inst 0xc160e1ce // bfcvt z14.h, { z14.s-z15.s }\n" | ||
| 185 | ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n" | ||
| 186 | ".inst 0xc08001c0 // mova za0h.s[x12], p0/M, z14.s\n" | ||
| 187 | ".inst 0xc0800184 // mova za1h.s[x12], p0/M, z12.s\n" | ||
| 188 | "add x12, x12, #0x1\n" | ||
| 189 | "cmp x12, x28\n" | ||
| 190 | "blt 4b\n" | ||
| 191 | "subs x10, x10, #0x1\n" | ||
| 192 | "incw x23, ALL, MUL #2\n" | ||
| 193 | "incw x9, ALL, MUL #2\n" | ||
| 194 | "bgt 2b\n" | ||
| 195 | "5:" // Width loop: Tails | ||
| 196 | "cbnz x26, 8f\n" | ||
| 197 | "mov x20, %x[width]\n" | ||
| 198 | "mov x25, %x[in]\n" | ||
| 199 | "mov x12, #0x0\n" | ||
| 200 | ".inst 0x25b44532 // whilelt pn10.s, x9, x20, VLx2\n" | ||
| 201 | "add x22, x25, x28, LSL #3\n" | ||
| 202 | "6:" // Width loop: Tails: Even: Odd: Loop | ||
| 203 | "ldr x21, [x25], #0x8\n" | ||
| 204 | ".inst 0x25306989 // psel p9.s, p10.s/Z, p12.s[w12]\n" | ||
| 205 | ".inst 0x25306968 // psel p8.s, p10.s/Z, p11.s[w12]\n" | ||
| 206 | ".inst 0xc0828003 // mova z3.s, p0/M, za0v.s[x12]\n" | ||
| 207 | "ldr x20, [x22], #0x8\n" | ||
| 208 | ".inst 0xc082808b // mova z11.s, p0/M, za1v.s[x12]\n" | ||
| 209 | ".inst 0xa01746ac // ld1w { z12.s-z13.s }, pn9.s/Z, [x21, x23, LSL #2]\n" | ||
| 210 | ".inst 0xa017428e // ld1w { z14.s-z15.s }, pn8.s/Z, [x20, x23, LSL #2]\n" | ||
| 211 | ".inst 0xa1605703 // st1w { z3.s, z11.s }, pn13.b, [x24]\n" | ||
| 212 | "addvl x24, x24, #2\n" | ||
| 213 | ".inst 0xc160e18c // bfcvt z12.h, { z12.s-z13.s }\n" | ||
| 214 | ".inst 0xc160e1ce // bfcvt z14.h, { z14.s-z15.s }\n" | ||
| 215 | ".inst 0xc0800188 // mova za2h.s[x12], p0/M, z12.s\n" | ||
| 216 | ".inst 0xc08001cc // mova za3h.s[x12], p0/M, z14.s\n" | ||
| 217 | "add x12, x12, #0x1\n" | ||
| 218 | "cmp x12, x28\n" | ||
| 219 | "blt 6b\n" | ||
| 220 | "mov x12, #0x0\n" | ||
| 221 | "7:" // Width loop: Tails: Even: Even: Loop | ||
| 222 | ".inst 0xc082810e // mova z14.s, p0/M, za2v.s[x12]\n" | ||
| 223 | ".inst 0xc082818f // mova z15.s, p0/M, za3v.s[x12]\n" | ||
| 224 | "add x12, x12, #0x1\n" | ||
| 225 | "cmp x12, x27\n" | ||
| 226 | ".inst 0xa060570e // st1w { z14.s-z15.s }, pn13.b, [x24]\n" | ||
| 227 | "addvl x24, x24, #2\n" | ||
| 228 | "blt 7b\n" | ||
| 229 | "b 10f\n" | ||
| 230 | "8:" // Width loop: Tails: Odd | ||
| 231 | "mov x12, #0x0\n" | ||
| 232 | "9:" // Width loop: Tails: Odd: Loop | ||
| 233 | ".inst 0xc0828014 // mova z20.s, p0/M, za0v.s[x12]\n" | ||
| 234 | ".inst 0xc0828095 // mova z21.s, p0/M, za1v.s[x12]\n" | ||
| 235 | "add x12, x12, #0x1\n" | ||
| 236 | "cmp x12, x27\n" | ||
| 237 | ".inst 0xa0605714 // st1w { z20.s-z21.s }, pn13.b, [x24]\n" | ||
| 238 | "addvl x24, x24, #2\n" | ||
| 239 | "blt 9b\n" | ||
| 240 | "10:" // End | ||
| 241 | "mov %x[outptr_raw], x24\n" | ||
| 242 | ".inst 0xd503467f // SMSTOP\n" | ||
| 243 | : [outptr_raw] "+&r"(out) | ||
| 244 | 471 | : [height] "r"(height), [in] "r"(in), [width] "r"(width) | |
| 245 | : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", | ||
| 246 | "p14", "p15", "x9", "x10", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", | ||
| 247 | "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", | ||
| 248 | "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); | ||
| 249 | 471 | } | |
| 250 | 150 | } | |
| 251 | |||
| 252 | #endif // Architectural features check. | ||
| 253 |