kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
| 8 | #error This file must be compiled for AArch64. | ||
| 9 | #else // Architectural features check. | ||
| 10 | #include "kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.h" | ||
| 11 | |||
| 12 | #include <arm_neon.h> | ||
| 13 | #include <stdint.h> | ||
| 14 | #include <string.h> | ||
| 15 | |||
| 16 | #include "kai/kai_common.h" | ||
| 17 | |||
| 18 | static const size_t kai_num_bytes_offset_rhs = sizeof(float); | ||
| 19 | static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); | ||
| 20 | static const size_t kai_num_bytes_bias = sizeof(float); | ||
| 21 | static const size_t kai_bl_multiple_of = 32; | ||
| 22 | static const size_t kai_nr_multiple_of = 4; | ||
| 23 | |||
| 24 | 12852 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
| 25 | − | KAI_ASSUME((k % 2) == 0); | |
| 26 | − | KAI_ASSUME((k % bl) == 0); | |
| 27 | − | KAI_ASSUME((bl % kai_bl_multiple_of) == 0); | |
| 28 | 12852 | return kai_roundup(k, bl) / bl; | |
| 29 | } | ||
| 30 | |||
| 31 | 17136 | inline static size_t kai_get_num_bytes_per_block(size_t bl) { | |
| 32 | 17136 | return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; | |
| 33 | } | ||
| 34 | |||
| 35 | 12852 | inline static size_t kai_get_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) { | |
| 36 | − | KAI_ASSUME((k % 2) == 0); | |
| 37 | − | KAI_ASSUME((k % kr) == 0); | |
| 38 | − | KAI_ASSUME((k % bl) == 0); | |
| 39 | − | KAI_ASSUME((bl % kr) == 0); | |
| 40 | − | KAI_ASSUME((bl % kai_bl_multiple_of) == 0); | |
| 41 | 12852 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
| 42 | 12852 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl); | |
| 43 | 25704 | return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias); | |
| 44 | 12852 | } | |
| 45 | |||
| 46 | ✗ | size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) { | |
| 47 | ✗ | return n_idx * rhs_stride; | |
| 48 | } | ||
| 49 | |||
| 50 | 4284 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon( | |
| 51 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { | ||
| 52 | − | KAI_ASSUME((k % 2) == 0); | |
| 53 | − | KAI_ASSUME((k % kr) == 0); | |
| 54 | − | KAI_ASSUME((k % bl) == 0); | |
| 55 | − | KAI_ASSUME((n_idx % nr) == 0); | |
| 56 | 4284 | KAI_UNUSED(kr); | |
| 57 | 4284 | return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, bl); | |
| 58 | } | ||
| 59 | |||
| 60 | 4284 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon( | |
| 61 | size_t n, size_t k, size_t nr, size_t kr, size_t bl) { | ||
| 62 | − | KAI_ASSUME((k % 2) == 0); | |
| 63 | − | KAI_ASSUME((k % kr) == 0); | |
| 64 | − | KAI_ASSUME((k % bl) == 0); | |
| 65 | 4284 | KAI_UNUSED(kr); | |
| 66 | 4284 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
| 67 | 8568 | return num_rows * kai_get_rhs_packed_stride(k, nr, kr, bl); | |
| 68 | 4284 | } | |
| 69 | |||
| 70 | 4284 | void kai_run_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon( | |
| 71 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, | ||
| 72 | const void* zero, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, | ||
| 73 | const struct kai_rhs_pack_nxk_qai4c32p_params* params) { | ||
| 74 | − | KAI_ASSUME(num_groups == 1); | |
| 75 | − | KAI_ASSUME((k % 2) == 0); | |
| 76 | − | KAI_ASSUME((k % kr) == 0); | |
| 77 | − | KAI_ASSUME((k % bl) == 0); | |
| 78 | − | KAI_ASSUME((bl % kai_bl_multiple_of) == 0); | |
| 79 | − | KAI_ASSUME((nr % kai_nr_multiple_of) == 0); | |
| 80 | − | KAI_ASSUME(extra_bytes == 0); | |
| 81 | |||
| 82 | − | KAI_ASSUME(sr == 2); | |
| 83 | − | KAI_ASSUME(kr / sr == 4); | |
| 84 | − | KAI_ASSUME(rhs != NULL); | |
| 85 | − | KAI_ASSUME(zero != NULL); | |
| 86 | − | KAI_ASSUME(scale != NULL); | |
| 87 | − | KAI_ASSUME(rhs_packed != NULL); | |
| 88 | − | KAI_ASSUME(params != NULL); | |
| 89 | − | KAI_ASSUME(params->rhs_zero_point == 8); | |
| 90 | − | KAI_ASSUME(params->lhs_zero_point == 1); | |
| 91 | |||
| 92 | // Note: The input matrix (rhs) is expected with: | ||
| 93 | // "k" columns and "n" rows (NxK) | ||
| 94 | |||
| 95 | 4284 | const size_t block_length = kr / sr; | |
| 96 | 4284 | const size_t num_blocks_per_row = k / bl; | |
| 97 | 4284 | const size_t rhs_stride = k / 2; | |
| 98 | 4284 | const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, bl); | |
| 99 | |||
| 100 | 4284 | const size_t dst_packed_block_size = kai_get_num_bytes_per_block(bl) * nr; | |
| 101 | 4284 | const size_t dst_block_data_size = bl / 2; | |
| 102 | 4284 | const size_t dst_num_rows = kai_roundup(n, nr) / nr; | |
| 103 | 4284 | const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size; | |
| 104 | 4284 | const size_t k_block_length_in_bytes = (block_length * sizeof(uint8_t)) / 2; | |
| 105 | |||
| 106 |
2/2✓ Branch 0 taken 4284 times.
✓ Branch 1 taken 5376 times.
|
9660 | for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { |
| 107 | 5376 | uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; | |
| 108 | 5376 | float* dst_row_bias = (float*)(dst_row + dst_bias_offset); | |
| 109 | 5376 | size_t row_idx = dst_row_idx * nr; | |
| 110 | 5376 | size_t rows_left = n - row_idx; | |
| 111 |
2/2✓ Branch 0 taken 8484 times.
✓ Branch 1 taken 5376 times.
|
13860 | for (size_t block_idx = 0; block_idx < num_blocks_per_row; block_idx++) { |
| 112 | 8484 | uint8_t* block_dst_row = dst_row + block_idx * dst_packed_block_size; | |
| 113 | 8484 | float* block_dst_zp = (float*)(block_dst_row + nr * dst_block_data_size); | |
| 114 | 8484 | float* block_dst_scale = block_dst_zp + nr; | |
| 115 | 8484 | size_t k_idx = block_idx * bl; | |
| 116 |
2/2✓ Branch 0 taken 21168 times.
✓ Branch 1 taken 8484 times.
|
29652 | for (size_t dst_byte_idx = 0; dst_byte_idx < dst_block_data_size; dst_byte_idx += 8) { |
| 117 |
2/2✓ Branch 0 taken 338688 times.
✓ Branch 1 taken 21168 times.
|
359856 | for (size_t nr_idx = 0; nr_idx <= nr - 4; nr_idx += 4) { |
| 118 |
2/2✓ Branch 0 taken 278040 times.
✓ Branch 1 taken 60648 times.
|
338688 | const size_t n0_idx = KAI_MIN(dst_row_idx * nr + nr_idx, n - 1); |
| 119 |
2/2✓ Branch 0 taken 277704 times.
✓ Branch 1 taken 60984 times.
|
338688 | const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1); |
| 120 |
2/2✓ Branch 0 taken 275520 times.
✓ Branch 1 taken 63168 times.
|
338688 | const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1); |
| 121 |
2/2✓ Branch 0 taken 263256 times.
✓ Branch 1 taken 75432 times.
|
338688 | const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1); |
| 122 | 338688 | const uint8_t* src_addr_byte = rhs + (k_idx / 2) + dst_byte_idx; | |
| 123 | |||
| 124 | 338688 | const uint8x8_t vec0_u8 = vld1_u8(src_addr_byte + n0_idx * rhs_stride); | |
| 125 | 338688 | const uint8x8_t vec1_u8 = vld1_u8(src_addr_byte + n1_idx * rhs_stride); | |
| 126 | 338688 | const uint8x8_t vec2_u8 = vld1_u8(src_addr_byte + n2_idx * rhs_stride); | |
| 127 | 338688 | const uint8x8_t vec3_u8 = vld1_u8(src_addr_byte + n3_idx * rhs_stride); | |
| 128 | |||
| 129 | 338688 | const uint16x4_t vec0_u16 = vreinterpret_u16_u8(vec0_u8); | |
| 130 | 338688 | const uint16x4_t vec1_u16 = vreinterpret_u16_u8(vec1_u8); | |
| 131 | 338688 | const uint16x4_t vec2_u16 = vreinterpret_u16_u8(vec2_u8); | |
| 132 | 338688 | const uint16x4_t vec3_u16 = vreinterpret_u16_u8(vec3_u8); | |
| 133 | |||
| 134 | 338688 | const uint16x4_t vec01_lo_u16 = vzip1_u16(vec0_u16, vec1_u16); | |
| 135 | 338688 | const uint16x4_t vec01_hi_u16 = vzip2_u16(vec0_u16, vec1_u16); | |
| 136 | 338688 | const uint16x4_t vec23_lo_u16 = vzip1_u16(vec2_u16, vec3_u16); | |
| 137 | 338688 | const uint16x4_t vec23_hi_u16 = vzip2_u16(vec2_u16, vec3_u16); | |
| 138 | |||
| 139 | 338688 | const uint32x2_t vec01_lo_u32 = vreinterpret_u32_u16(vec01_lo_u16); | |
| 140 | 338688 | const uint32x2_t vec01_hi_u32 = vreinterpret_u32_u16(vec01_hi_u16); | |
| 141 | 338688 | const uint32x2_t vec23_lo_u32 = vreinterpret_u32_u16(vec23_lo_u16); | |
| 142 | 338688 | const uint32x2_t vec23_hi_u32 = vreinterpret_u32_u16(vec23_hi_u16); | |
| 143 | |||
| 144 | 338688 | const uint32x2_t vin0_u32 = vzip1_u32(vec01_lo_u32, vec23_lo_u32); | |
| 145 | 338688 | const uint32x2_t vin1_u32 = vzip2_u32(vec01_lo_u32, vec23_lo_u32); | |
| 146 | 338688 | const uint32x2_t vin2_u32 = vzip1_u32(vec01_hi_u32, vec23_hi_u32); | |
| 147 | 338688 | const uint32x2_t vin3_u32 = vzip2_u32(vec01_hi_u32, vec23_hi_u32); | |
| 148 | |||
| 149 | 338688 | uint8x8_t vin0_u8 = vreinterpret_u8_u32(vin0_u32); | |
| 150 | 338688 | uint8x8_t vin1_u8 = vreinterpret_u8_u32(vin1_u32); | |
| 151 | 338688 | uint8x8_t vin2_u8 = vreinterpret_u8_u32(vin2_u32); | |
| 152 | 338688 | uint8x8_t vin3_u8 = vreinterpret_u8_u32(vin3_u32); | |
| 153 | |||
| 154 | 338688 | const uint8x8_t vin0_s1s = vshr_n_u8(vin0_u8, 4); | |
| 155 | 338688 | const uint8x8_t vin1_s1s = vshr_n_u8(vin1_u8, 4); | |
| 156 | 338688 | const uint8x8_t vin2_s1s = vshr_n_u8(vin2_u8, 4); | |
| 157 | 338688 | const uint8x8_t vin3_s1s = vshr_n_u8(vin3_u8, 4); | |
| 158 | |||
| 159 | 338688 | vin0_u8 = vshl_n_u8(vin0_u8, 4); | |
| 160 | 338688 | vin1_u8 = vshl_n_u8(vin1_u8, 4); | |
| 161 | 338688 | vin2_u8 = vshl_n_u8(vin2_u8, 4); | |
| 162 | 338688 | vin3_u8 = vshl_n_u8(vin3_u8, 4); | |
| 163 | |||
| 164 | 338688 | vin0_u8 = vorr_u8(vin0_u8, vin0_s1s); | |
| 165 | 338688 | vin1_u8 = vorr_u8(vin1_u8, vin1_s1s); | |
| 166 | 338688 | vin2_u8 = vorr_u8(vin2_u8, vin2_s1s); | |
| 167 | 338688 | vin3_u8 = vorr_u8(vin3_u8, vin3_s1s); | |
| 168 | |||
| 169 | 338688 | uint8_t* dst_row_offset = block_dst_row + nr_idx * k_block_length_in_bytes; | |
| 170 | 338688 | vst1_u8(dst_row_offset, vin0_u8); | |
| 171 | 338688 | vst1_u8(dst_row_offset + nr * k_block_length_in_bytes, vin1_u8); | |
| 172 | 338688 | vst1_u8(dst_row_offset + 2 * (nr * k_block_length_in_bytes), vin2_u8); | |
| 173 | 338688 | vst1_u8(dst_row_offset + 3 * (nr * k_block_length_in_bytes), vin3_u8); | |
| 174 | 338688 | } | |
| 175 | 21168 | block_dst_row += nr * sizeof(uint8x8_t); | |
| 176 | 21168 | } | |
| 177 | |||
| 178 | // Adjust the zero points and scales | ||
| 179 |
2/2✓ Branch 0 taken 542976 times.
✓ Branch 1 taken 8484 times.
|
551460 | for (size_t i = 0; i < nr; ++i) { |
| 180 |
2/2✓ Branch 0 taken 427644 times.
✓ Branch 1 taken 115332 times.
|
542976 | const size_t src_row_idx = KAI_MIN(row_idx + i, n - 1); |
| 181 | 542976 | const size_t src_idx = src_row_idx * num_blocks_per_row + block_idx; | |
| 182 | |||
| 183 | 542976 | block_dst_scale[i] = ((const float*)scale)[src_idx]; | |
| 184 | 542976 | block_dst_zp[i] = ((const float*)zero)[src_idx]; | |
| 185 | 542976 | } | |
| 186 | 8484 | } | |
| 187 | // Set the bias | ||
| 188 |
2/2✓ Branch 0 taken 2688 times.
✓ Branch 1 taken 2688 times.
|
5376 | if (bias == NULL) { |
| 189 | 2688 | memset(dst_row_bias, 0, nr * kai_num_bytes_bias); | |
| 190 | 2688 | } else { | |
| 191 |
2/2✓ Branch 0 taken 1596 times.
✓ Branch 1 taken 1092 times.
|
2688 | if (rows_left >= nr) { |
| 192 | 1596 | memcpy(dst_row_bias, &((const float*)bias)[row_idx], nr * kai_num_bytes_bias); | |
| 193 | 1596 | } else { | |
| 194 | // Fill remaining values | ||
| 195 | 1092 | memcpy(dst_row_bias, &((const float*)bias)[row_idx], rows_left * kai_num_bytes_bias); | |
| 196 | // Set leftover to 0 | ||
| 197 | 1092 | memset(&dst_row_bias[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias); | |
| 198 | } | ||
| 199 | } | ||
| 200 | 5376 | } | |
| 201 | 4284 | } | |
| 202 | #endif | ||
| 203 |