kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #if !defined(__aarch64__) | ||
| 8 | #error This file must be compiled for AArch64. | ||
| 9 | #else // Architectural features check. | ||
| 10 | |||
| 11 | #include "kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" | ||
| 12 | |||
| 13 | #include <stddef.h> | ||
| 14 | #include <stdint.h> | ||
| 15 | |||
| 16 | #include "kai/kai_common.h" | ||
| 17 | |||
| 18 | static const size_t kai_nr = 8; | ||
| 19 | static const size_t kai_kr = 1; | ||
| 20 | |||
| 21 | 108 | size_t kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(void) { | |
| 22 | 108 | return kai_nr; | |
| 23 | } | ||
| 24 | |||
| 25 | 108 | size_t kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx) { | |
| 26 | − | KAI_ASSUME(n_idx % kai_nr == 0); | |
| 27 | |||
| 28 | 108 | return n_idx * sizeof(uint32_t); | |
| 29 | } | ||
| 30 | |||
| 31 | 108 | size_t kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx) { | |
| 32 | 108 | return n_idx * sizeof(uint32_t); | |
| 33 | } | ||
| 34 | |||
| 35 | 216 | size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n_idx, size_t k) { | |
| 36 | − | KAI_ASSUME(n_idx % kai_nr == 0); | |
| 37 | |||
| 38 | 216 | return n_idx * (sizeof(uint32_t) + k * sizeof(uint32_t)); | |
| 39 | } | ||
| 40 | |||
| 41 | 108 | size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(size_t n, size_t k) { | |
| 42 | 108 | return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon(kai_roundup(n, kai_nr), k); | |
| 43 | } | ||
| 44 | |||
| 45 | 108 | void kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon( | |
| 46 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||
| 47 | const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { | ||
| 48 | − | KAI_ASSUME(num_groups == 1); | |
| 49 | − | KAI_ASSUME(nr == kai_nr); | |
| 50 | − | KAI_ASSUME(kr == kai_kr); | |
| 51 | − | KAI_ASSUME(sr == 1); | |
| 52 | − | KAI_ASSUME(rhs != NULL); | |
| 53 | − | KAI_ASSUME(bias != NULL); | |
| 54 | − | KAI_ASSUME(scale == NULL); | |
| 55 | − | KAI_ASSUME(rhs_packed != NULL); | |
| 56 | − | KAI_ASSUME(extra_bytes == 0); | |
| 57 | − | KAI_ASSUME(params == NULL); | |
| 58 | |||
| 59 | 108 | size_t height = k; | |
| 60 | 108 | const size_t width = n; | |
| 61 | 108 | const void* in = rhs; | |
| 62 | 108 | void* out = rhs_packed; | |
| 63 | 108 | const size_t in_stride = rhs_stride; | |
| 64 | 108 | size_t out_stride = kai_nr * height * sizeof(uint32_t) + kai_nr * sizeof(uint32_t); | |
| 65 | |||
| 66 | 216 | __asm__ __volatile__( | |
| 67 | "mov x22, %x[width]\n" | ||
| 68 | "mov x21, %x[out]\n" | ||
| 69 | "cmp x22, #0x8\n" | ||
| 70 | "blt 2f\n" | ||
| 71 | "1:" // Bias: Full loop | ||
| 72 | "ldr q17, [%x[bias], #0x0]\n" | ||
| 73 | "ldr q16, [%x[bias], #0x10]\n" | ||
| 74 | "sub x22, x22, #0x8\n" | ||
| 75 | "add %x[bias], %x[bias], #0x20\n" | ||
| 76 | "cmp x22, #0x8\n" | ||
| 77 | "str q17, [x21, #0x0]\n" | ||
| 78 | "str q16, [x21, #0x10]\n" | ||
| 79 | "add x21, x21, %x[out_stride]\n" | ||
| 80 | "bge 1b\n" | ||
| 81 | "cbz x22, 3f\n" | ||
| 82 | "2:" // Bias: Tail loop | ||
| 83 | "ldr w20, [%x[bias], #0x0]\n" | ||
| 84 | "sub x22, x22, #0x1\n" | ||
| 85 | "add %x[bias], %x[bias], #0x4\n" | ||
| 86 | "cmp x22, #0x0\n" | ||
| 87 | "str x20, [x21]\n" | ||
| 88 | "add x21, x21, #0x4\n" | ||
| 89 | "bgt 2b\n" | ||
| 90 | "3:" // Bias: Done | ||
| 91 | "cmp %x[height], #0x4\n" | ||
| 92 | "add %x[out], %x[out], #0x20\n" | ||
| 93 | "blt 12f\n" | ||
| 94 | "4:" // Main row loop: Head | ||
| 95 | "mov x25, %x[in]\n" | ||
| 96 | "mov x24, %x[width]\n" | ||
| 97 | "mov x23, %x[out]\n" | ||
| 98 | "sub %x[height], %x[height], #0x4\n" | ||
| 99 | "add x22, x25, %x[in_stride]\n" | ||
| 100 | "add x21, x22, %x[in_stride]\n" | ||
| 101 | "add x20, x21, %x[in_stride]\n" | ||
| 102 | "cmp x24, #0x8\n" | ||
| 103 | "add %x[in], x20, %x[in_stride]\n" | ||
| 104 | "blt 6f\n" | ||
| 105 | "5:" // Main row loop: Column loop | ||
| 106 | "ldr q23, [x25], #0x10\n" | ||
| 107 | "ldr q22, [x22], #0x10\n" | ||
| 108 | "sub x24, x24, #0x8\n" | ||
| 109 | "ldr q21, [x21], #0x10\n" | ||
| 110 | "ldr q20, [x20], #0x10\n" | ||
| 111 | "cmp x24, #0x8\n" | ||
| 112 | "ldr q19, [x25], #0x10\n" | ||
| 113 | "ldr q18, [x22], #0x10\n" | ||
| 114 | "ldr q17, [x21], #0x10\n" | ||
| 115 | "ldr q16, [x20], #0x10\n" | ||
| 116 | "str q23, [x23, #0x0]\n" | ||
| 117 | "str q19, [x23, #0x10]\n" | ||
| 118 | "str q22, [x23, #0x20]\n" | ||
| 119 | "str q18, [x23, #0x30]\n" | ||
| 120 | "str q21, [x23, #0x40]\n" | ||
| 121 | "str q17, [x23, #0x50]\n" | ||
| 122 | "str q20, [x23, #0x60]\n" | ||
| 123 | "str q16, [x23, #0x70]\n" | ||
| 124 | "add x23, x23, %x[out_stride]\n" | ||
| 125 | "bge 5b\n" | ||
| 126 | "6:" // Main row loop: Column loop skip | ||
| 127 | "cbz x24, 11f\n" | ||
| 128 | "cmp x24, #0x4\n" | ||
| 129 | "movi v16.4s, #0x0\n" | ||
| 130 | "str q16, [x23, #0x0]\n" | ||
| 131 | "str q16, [x23, #0x10]\n" | ||
| 132 | "str q16, [x23, #0x20]\n" | ||
| 133 | "str q16, [x23, #0x30]\n" | ||
| 134 | "str q16, [x23, #0x40]\n" | ||
| 135 | "str q16, [x23, #0x50]\n" | ||
| 136 | "str q16, [x23, #0x60]\n" | ||
| 137 | "str q16, [x23, #0x70]\n" | ||
| 138 | "blt 8f\n" | ||
| 139 | "7:" // Main row loop: width 4 loop: loop | ||
| 140 | "ldr q19, [x25], #0x10\n" | ||
| 141 | "ldr q18, [x22], #0x10\n" | ||
| 142 | "sub x24, x24, #0x4\n" | ||
| 143 | "ldr q17, [x21], #0x10\n" | ||
| 144 | "ldr q16, [x20], #0x10\n" | ||
| 145 | "cmp x24, #0x4\n" | ||
| 146 | "str q19, [x23, #0x0]\n" | ||
| 147 | "str q18, [x23, #0x20]\n" | ||
| 148 | "str q17, [x23, #0x40]\n" | ||
| 149 | "str q16, [x23, #0x60]\n" | ||
| 150 | "add x23, x23, #0x10\n" | ||
| 151 | "bge 7b\n" | ||
| 152 | "8:" // Main row loop: width 4 loop: skip | ||
| 153 | "cmp x24, #0x1\n" | ||
| 154 | "blt 10f\n" | ||
| 155 | "9:" // Main row loop: width 1 loop: loop | ||
| 156 | "ldr s19, [x25], #0x4\n" | ||
| 157 | "ldr s18, [x22], #0x4\n" | ||
| 158 | "sub x24, x24, #0x1\n" | ||
| 159 | "ldr s17, [x21], #0x4\n" | ||
| 160 | "ldr s16, [x20], #0x4\n" | ||
| 161 | "cmp x24, #0x1\n" | ||
| 162 | "str s19, [x23, #0x0]\n" | ||
| 163 | "str s18, [x23, #0x20]\n" | ||
| 164 | "str s17, [x23, #0x40]\n" | ||
| 165 | "str s16, [x23, #0x60]\n" | ||
| 166 | "add x23, x23, #0x4\n" | ||
| 167 | "bge 9b\n" | ||
| 168 | "10:" // Main row loop: width 1 loop: skip | ||
| 169 | "11:" // Main row loop: odd col skip | ||
| 170 | "cmp %x[height], #0x4\n" | ||
| 171 | "add %x[out], %x[out], #0x80\n" | ||
| 172 | "bge 4b\n" | ||
| 173 | "cbz %x[height], 21f\n" | ||
| 174 | "12:" // Main loop skip | ||
| 175 | "13:" // Tail row loop: Head | ||
| 176 | "mov x20, %x[width]\n" | ||
| 177 | "mov x25, %x[in]\n" | ||
| 178 | "mov x23, %x[out]\n" | ||
| 179 | "sub %x[height], %x[height], #0x1\n" | ||
| 180 | "cmp x20, #0x8\n" | ||
| 181 | "add %x[in], x25, %x[in_stride]\n" | ||
| 182 | "blt 15f\n" | ||
| 183 | "14:" // Tail row loop: Column loop | ||
| 184 | "ldr q17, [x25], #0x10\n" | ||
| 185 | "sub x20, x20, #0x8\n" | ||
| 186 | "ldr q16, [x25], #0x10\n" | ||
| 187 | "cmp x20, #0x8\n" | ||
| 188 | "str q17, [x23, #0x0]\n" | ||
| 189 | "str q16, [x23, #0x10]\n" | ||
| 190 | "add x23, x23, %x[out_stride]\n" | ||
| 191 | "bge 14b\n" | ||
| 192 | "15:" // Tail row loop: Column loop skip | ||
| 193 | "cbz x20, 20f\n" | ||
| 194 | "cmp x20, #0x4\n" | ||
| 195 | "movi v16.4s, #0x0\n" | ||
| 196 | "str q16, [x23, #0x0]\n" | ||
| 197 | "str q16, [x23, #0x10]\n" | ||
| 198 | "blt 17f\n" | ||
| 199 | "16:" // Tail row loop: width 4 loop: loop | ||
| 200 | "ldr q16, [x25], #0x10\n" | ||
| 201 | "sub x20, x20, #0x4\n" | ||
| 202 | "cmp x20, #0x4\n" | ||
| 203 | "str q16, [x23, #0x0]\n" | ||
| 204 | "add x23, x23, #0x10\n" | ||
| 205 | "bge 16b\n" | ||
| 206 | "17:" // Tail row loop: width 4 loop: skip | ||
| 207 | "cmp x20, #0x1\n" | ||
| 208 | "blt 19f\n" | ||
| 209 | "18:" // Tail row loop: width 1 loop: loop | ||
| 210 | "ldr s16, [x25], #0x4\n" | ||
| 211 | "sub x20, x20, #0x1\n" | ||
| 212 | "cmp x20, #0x1\n" | ||
| 213 | "str s16, [x23, #0x0]\n" | ||
| 214 | "add x23, x23, #0x4\n" | ||
| 215 | "bge 18b\n" | ||
| 216 | "19:" // Tail row loop: width 1 loop: skip | ||
| 217 | "20:" // Tail row loop: odd col skip | ||
| 218 | "cmp %x[height], #0x1\n" | ||
| 219 | "add %x[out], %x[out], #0x20\n" | ||
| 220 | "bge 13b\n" | ||
| 221 | "21:" // Done | ||
| 222 | : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) | ||
| 223 | 108 | : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [width] "r"(width) | |
| 224 | : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", | ||
| 225 | "x25"); | ||
| 226 | 108 | } | |
| 227 | |||
| 228 | #endif // Architectural features check. | ||
| 229 |