kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
| 7 | #error This file must be compiled for AArch64. | ||
| 8 | #else // Architectural features check. | ||
| 9 | |||
| 10 | #include "kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h" | ||
| 11 | |||
| 12 | #include <stdint.h> | ||
| 13 | #include <string.h> | ||
| 14 | |||
| 15 | #include "kai/kai_common.h" | ||
| 16 | |||
| 17 | static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); | ||
| 18 | static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); | ||
| 19 | static const size_t kai_num_bytes_bias = sizeof(float); | ||
| 20 | |||
| 21 | 4560 | inline static size_t kai_k_roundedup(size_t k) { | |
| 22 | // Round up k to be a multiple of 32. | ||
| 23 | 4560 | size_t kai_k_multiple_of = 32; | |
| 24 | 9120 | return kai_roundup(k, kai_k_multiple_of); | |
| 25 | 4560 | } | |
| 26 | |||
| 27 | 600 | size_t kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t nr) { | |
| 28 | 600 | return nr; | |
| 29 | } | ||
| 30 | |||
| 31 | 1080 | size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t n_idx, size_t rhs_stride) { | |
| 32 | 1080 | return n_idx * rhs_stride; | |
| 33 | } | ||
| 34 | |||
| 35 | 3480 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t k, size_t nr, size_t kr, size_t sr) { | |
| 36 | 3480 | KAI_UNUSED(kr); | |
| 37 | 3480 | KAI_UNUSED(sr); | |
| 38 | |||
| 39 | 3480 | const size_t k_internal = kai_k_roundedup(k); | |
| 40 | |||
| 41 | // multiple of 2 because 2 elements in a byte | ||
| 42 | − | KAI_ASSERT((k_internal % 2) == 0); | |
| 43 | |||
| 44 | 6960 | return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); | |
| 45 | 3480 | } | |
| 46 | |||
| 47 | 1320 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( | |
| 48 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { | ||
| 49 | − | KAI_ASSERT((n_idx % nr) == 0); | |
| 50 | |||
| 51 | 1320 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr); | |
| 52 | } | ||
| 53 | |||
| 54 | 1080 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( | |
| 55 | size_t n, size_t k, size_t nr, size_t kr, size_t sr) { | ||
| 56 | 1080 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
| 57 | |||
| 58 | 2160 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr); | |
| 59 | 1080 | } | |
| 60 | |||
| 61 | 1080 | void kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( | |
| 62 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, | ||
| 63 | const float* scale, void* rhs_packed, size_t extra_bytes, | ||
| 64 | const struct kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon_params* params) { | ||
| 65 | 1080 | const size_t k_internal = kai_k_roundedup(k); | |
| 66 | |||
| 67 | − | KAI_ASSERT((k_internal % kr) == 0); | |
| 68 | − | KAI_ASSERT(num_groups == 1); | |
| 69 | − | KAI_ASSERT(extra_bytes == 0); | |
| 70 | − | KAI_ASSERT((kr % sr) == 0); | |
| 71 | − | KAI_ASSERT(rhs != NULL); | |
| 72 | − | KAI_ASSERT(scale != NULL); | |
| 73 | − | KAI_ASSERT(rhs_packed != NULL); | |
| 74 | − | KAI_ASSERT(params != NULL); | |
| 75 | − | KAI_ASSERT(params->lhs_zero_point == 1); | |
| 76 | − | KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); | |
| 77 | |||
| 78 | // Note: The input matrix (rhs) is expected with: | ||
| 79 | // "k" columns and "n" rows (NxK) | ||
| 80 | |||
| 81 | 1080 | const int32_t rhs_zero_point = params->rhs_zero_point; | |
| 82 | 1080 | const size_t rhs_stride = kai_roundup(k, 2) / 2; | |
| 83 | 1080 | const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr); | |
| 84 | 1080 | const size_t dst_nr_block_size = nr * kr * sizeof(uint8_t) / 2; | |
| 85 | |||
| 86 | // Iterate over n src rows in blocks of nr rows | ||
| 87 |
2/2✓ Branch 0 taken 1080 times.
✓ Branch 1 taken 1474 times.
|
2554 | for (size_t row_idx = 0; row_idx < n; row_idx += nr) { |
| 88 | 1474 | int8_t* const dst_row = (int8_t*)rhs_packed + ((row_idx / nr) * rhs_packed_stride); | |
| 89 | |||
| 90 | 1474 | int32_t* const sums = (int32_t*)(dst_row + (nr * (k_internal / 2))); | |
| 91 | 1474 | float* const scaling_factors = (float*)((uint8_t*)sums + (nr * kai_num_bytes_sum_rhs)); | |
| 92 | // Update destination row pointer | ||
| 93 | 1474 | float* const biases = (float*)((uint8_t*)scaling_factors + (nr * kai_num_bytes_multiplier_rhs)); | |
| 94 | |||
| 95 | // initialize sums to 0 | ||
| 96 | 1474 | memset(sums, 0, nr * kai_num_bytes_sum_rhs); | |
| 97 | |||
| 98 | // Copy the scaling factors and bias | ||
| 99 | 1474 | size_t rows_left = n - row_idx; | |
| 100 | // Saving scales. | ||
| 101 |
2/2✓ Branch 0 taken 645 times.
✓ Branch 1 taken 829 times.
|
1474 | if (rows_left >= nr) { |
| 102 | 645 | memcpy(scaling_factors, &scale[row_idx], nr * kai_num_bytes_multiplier_rhs); | |
| 103 | 645 | } else { | |
| 104 | // Fill remaining values | ||
| 105 | 829 | memcpy(scaling_factors, &scale[row_idx], rows_left * kai_num_bytes_multiplier_rhs); | |
| 106 | // Set leftover to 0 | ||
| 107 | 829 | memset(&scaling_factors[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs); | |
| 108 | } | ||
| 109 |
2/2✓ Branch 0 taken 1245 times.
✓ Branch 1 taken 229 times.
|
1474 | if (bias == NULL) { |
| 110 | // Set bias to 0 | ||
| 111 | 229 | memset(biases, 0, nr * kai_num_bytes_bias); | |
| 112 | 229 | } else { | |
| 113 |
2/2✓ Branch 0 taken 541 times.
✓ Branch 1 taken 704 times.
|
1245 | if (rows_left >= nr) { |
| 114 | 541 | memcpy(biases, &bias[row_idx], nr * kai_num_bytes_bias); | |
| 115 | 541 | } else { | |
| 116 | // Fill remaining values | ||
| 117 | 704 | memcpy(biases, &bias[row_idx], rows_left * kai_num_bytes_bias); | |
| 118 | // Set leftover to 0 | ||
| 119 | 704 | memset(&biases[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias); | |
| 120 | } | ||
| 121 | } | ||
| 122 | // Iterate over rows in the nr row block | ||
| 123 |
2/2✓ Branch 0 taken 94336 times.
✓ Branch 1 taken 1474 times.
|
95810 | for (size_t nr_block_idx = 0; nr_block_idx < nr; ++nr_block_idx) { |
| 124 | 94336 | const uint8_t* const src_row = rhs + ((row_idx + nr_block_idx) * rhs_stride); | |
| 125 | // Go to the first kr block for this row in the nr block | ||
| 126 | 94336 | int8_t* dst_kr_block = dst_row + (nr_block_idx * kr / 2); | |
| 127 | |||
| 128 | 94336 | int32_t sum = 0; | |
| 129 | |||
| 130 | // Iterate over k src columns in blocks of kr columns | ||
| 131 |
2/2✓ Branch 0 taken 73600 times.
✓ Branch 1 taken 20736 times.
|
94336 | if (rhs_zero_point == 8) { |
| 132 |
2/2✓ Branch 0 taken 1595904 times.
✓ Branch 1 taken 73600 times.
|
1669504 | for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { |
| 133 | // Iterate over columns in the kr block | ||
| 134 | // Kr checked to be multiple of 2 (because 2 values per byte) | ||
| 135 |
2/2✓ Branch 0 taken 3191808 times.
✓ Branch 1 taken 1595904 times.
|
4787712 | for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { |
| 136 | // We pad dst with 0s if the rounded k or n values have been exceeded | ||
| 137 |
4/4✓ Branch 0 taken 2143984 times.
✓ Branch 1 taken 1047824 times.
✓ Branch 2 taken 342311 times.
✓ Branch 3 taken 1801673 times.
|
3191808 | if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { |
| 138 | 1390135 | dst_kr_block[kr_block_idx / 2] = 0; | |
| 139 | 1390135 | continue; | |
| 140 | } | ||
| 141 | |||
| 142 | // Load the 2 u4 values from source | ||
| 143 | 1801673 | const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; | |
| 144 | |||
| 145 | // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
| 146 | // extract i8 values from the 2 u4 values | ||
| 147 | 1801673 | const int8_t first_value = (dst_byte & 0xF) - rhs_zero_point; | |
| 148 | 3603346 | const int8_t second_value = | |
| 149 |
2/2✓ Branch 0 taken 23674 times.
✓ Branch 1 taken 1777999 times.
|
1801673 | col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point; |
| 150 | |||
| 151 | // Add the i4 value to the row sum | ||
| 152 | 1801673 | sum += (int32_t)first_value + (int32_t)second_value; | |
| 153 | |||
| 154 | // Truncate i8 to i4 and write to dst | ||
| 155 | 1801673 | const uint8_t hi = second_value & 0x0F; | |
| 156 | 1801673 | const uint8_t lo = first_value & 0x0F; | |
| 157 | 1801673 | dst_kr_block[kr_block_idx / 2] = (hi << 4) | lo; | |
| 158 | // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
| 159 | 1801673 | } | |
| 160 | |||
| 161 | // Go to the next kr block for this row in the nr rows | ||
| 162 | 1595904 | dst_kr_block += dst_nr_block_size; | |
| 163 | 1595904 | } | |
| 164 | 73600 | } else { | |
| 165 |
2/2✓ Branch 0 taken 325632 times.
✓ Branch 1 taken 20736 times.
|
346368 | for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { |
| 166 | // Iterate over columns in the kr block | ||
| 167 | // Kr checked to be multiple of 2 (because 2 values per byte) | ||
| 168 |
2/2✓ Branch 0 taken 651264 times.
✓ Branch 1 taken 325632 times.
|
976896 | for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { |
| 169 | // We pad dst with 0s if the rounded k or n values have been | ||
| 170 | // exceeded | ||
| 171 |
4/4✓ Branch 0 taken 435264 times.
✓ Branch 1 taken 216000 times.
✓ Branch 2 taken 67104 times.
✓ Branch 3 taken 368160 times.
|
651264 | if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { |
| 172 | 283104 | dst_kr_block[kr_block_idx / 2] = 0; | |
| 173 | 283104 | continue; | |
| 174 | } | ||
| 175 | |||
| 176 | // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
| 177 | // Load the 2 u4 values from source | ||
| 178 | 368160 | const int8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; | |
| 179 | |||
| 180 | // extract i8 values from the 2 u4 values, shift first value | ||
| 181 | // back and forth to get the sign right. | ||
| 182 | 368160 | const int8_t first_value = kai_ext_sign_i8_i4(dst_byte & 0xF); | |
| 183 | 736320 | const int8_t second_value = | |
| 184 |
2/2✓ Branch 0 taken 5706 times.
✓ Branch 1 taken 362454 times.
|
368160 | col_idx + kr_block_idx + 1 >= k ? 0 : kai_ext_sign_i8_i4((dst_byte >> 4) & 0xF); |
| 185 | |||
| 186 | // Add the i4 value to the row sum | ||
| 187 | 368160 | sum += (int32_t)first_value + (int32_t)second_value; | |
| 188 | |||
| 189 | // Truncate i8 to i4 and write to dst | ||
| 190 | 368160 | const uint8_t hi = second_value & 0x0F; | |
| 191 | 368160 | const uint8_t lo = first_value & 0x0F; | |
| 192 | 368160 | dst_kr_block[kr_block_idx / 2] = (hi << 4) | lo; | |
| 193 | // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
| 194 | 368160 | } | |
| 195 | |||
| 196 | // Go to the next kr block for this row in the nr rows | ||
| 197 | 325632 | dst_kr_block += dst_nr_block_size; | |
| 198 | 325632 | } | |
| 199 | } | ||
| 200 | |||
| 201 | // save sum | ||
| 202 | 94336 | sums[nr_block_idx] = sum; | |
| 203 | 94336 | } | |
| 204 | 1474 | } | |
| 205 | 1080 | } | |
| 206 | #endif // Architectural features check. | ||
| 207 |