Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
7 | #error This file must be compiled for AArch64. | ||
8 | #else // Architectural features check. | ||
9 | |||
10 | #include "kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h" | ||
11 | |||
12 | #include <stdint.h> | ||
13 | #include <string.h> | ||
14 | |||
15 | #include "kai/kai_common.h" | ||
16 | |||
17 | static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); | ||
18 | static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); | ||
19 | static const size_t kai_num_bytes_bias = sizeof(float); | ||
20 | |||
21 | 720 | inline static size_t kai_k_roundedup(size_t k) { | |
22 | // Round up k to be a multiple of 32. | ||
23 | 720 | size_t kai_k_multiple_of = 32; | |
24 | 1440 | return kai_roundup(k, kai_k_multiple_of); | |
25 | 720 | } | |
26 | |||
27 | ✗ | size_t kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t nr) { | |
28 | ✗ | return nr; | |
29 | } | ||
30 | |||
31 | 160 | size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t n_idx, size_t rhs_stride) { | |
32 | 160 | return n_idx * rhs_stride; | |
33 | } | ||
34 | |||
35 | 560 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t k, size_t nr, size_t kr, size_t sr) { | |
36 | 560 | KAI_UNUSED(kr); | |
37 | 560 | KAI_UNUSED(sr); | |
38 | |||
39 | 560 | const size_t k_internal = kai_k_roundedup(k); | |
40 | |||
41 | // multiple of 2 because 2 elements in a byte | ||
42 | − | KAI_ASSERT((k_internal % 2) == 0); | |
43 | |||
44 | 1120 | return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); | |
45 | 560 | } | |
46 | |||
47 | 240 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( | |
48 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { | ||
49 | − | KAI_ASSERT((n_idx % nr) == 0); | |
50 | |||
51 | 240 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr); | |
52 | } | ||
53 | |||
54 | 160 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( | |
55 | size_t n, size_t k, size_t nr, size_t kr, size_t sr) { | ||
56 | 160 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
57 | |||
58 | 320 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr); | |
59 | 160 | } | |
60 | |||
61 | 160 | void kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( | |
62 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, | ||
63 | const float* scale, void* rhs_packed, size_t extra_bytes, | ||
64 | const struct kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon_params* params) { | ||
65 | 160 | const size_t k_internal = kai_k_roundedup(k); | |
66 | |||
67 | − | KAI_ASSERT((k_internal % kr) == 0); | |
68 | − | KAI_ASSERT(num_groups == 1); | |
69 | − | KAI_ASSERT(extra_bytes == 0); | |
70 | − | KAI_ASSERT((kr % sr) == 0); | |
71 | − | KAI_ASSERT(rhs != NULL); | |
72 | − | KAI_ASSERT(scale != NULL); | |
73 | − | KAI_ASSERT(rhs_packed != NULL); | |
74 | − | KAI_ASSERT(params != NULL); | |
75 | − | KAI_ASSERT(params->lhs_zero_point == 1); | |
76 | − | KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); | |
77 | |||
78 | // Note: The input matrix (rhs) is expected with: | ||
79 | // "k" columns and "n" rows (NxK) | ||
80 | |||
81 | 160 | const int32_t rhs_zero_point = params->rhs_zero_point; | |
82 | 160 | const size_t rhs_stride = kai_roundup(k, 2) / 2; | |
83 | 160 | const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr); | |
84 | 160 | const size_t dst_nr_block_size = nr * kr * sizeof(uint8_t) / 2; | |
85 | |||
86 | // Iterate over n src rows in blocks of nr rows | ||
87 |
2/2✓ Branch 0 taken 160 times.
✓ Branch 1 taken 216 times.
|
376 | for (size_t row_idx = 0; row_idx < n; row_idx += nr) { |
88 | 216 | int8_t* const dst_row = (int8_t*)rhs_packed + ((row_idx / nr) * rhs_packed_stride); | |
89 | |||
90 | 216 | int32_t* const sums = (int32_t*)(dst_row + (nr * (k_internal / 2))); | |
91 | 216 | float* const scaling_factors = (float*)((uint8_t*)sums + (nr * kai_num_bytes_sum_rhs)); | |
92 | // Update destination row pointer | ||
93 | 216 | float* const biases = (float*)((uint8_t*)scaling_factors + (nr * kai_num_bytes_multiplier_rhs)); | |
94 | |||
95 | // initialize sums to 0 | ||
96 | 216 | memset(sums, 0, nr * kai_num_bytes_sum_rhs); | |
97 | |||
98 | // Copy the scaling factors and bias | ||
99 | 216 | size_t rows_left = n - row_idx; | |
100 | // Saving scales. | ||
101 |
2/2✓ Branch 0 taken 96 times.
✓ Branch 1 taken 120 times.
|
216 | if (rows_left >= nr) { |
102 | 96 | memcpy(scaling_factors, &scale[row_idx], nr * kai_num_bytes_multiplier_rhs); | |
103 | 96 | } else { | |
104 | // Fill remaining values | ||
105 | 120 | memcpy(scaling_factors, &scale[row_idx], rows_left * kai_num_bytes_multiplier_rhs); | |
106 | // Set leftover to 0 | ||
107 | 120 | memset(&scaling_factors[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs); | |
108 | } | ||
109 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | if (bias == NULL) { |
110 | // Set bias to 0 | ||
111 | ✗ | memset(biases, 0, nr * kai_num_bytes_bias); | |
112 | ✗ | } else { | |
113 |
2/2✓ Branch 0 taken 96 times.
✓ Branch 1 taken 120 times.
|
216 | if (rows_left >= nr) { |
114 | 96 | memcpy(biases, &bias[row_idx], nr * kai_num_bytes_bias); | |
115 | 96 | } else { | |
116 | // Fill remaining values | ||
117 | 120 | memcpy(biases, &bias[row_idx], rows_left * kai_num_bytes_bias); | |
118 | // Set leftover to 0 | ||
119 | 120 | memset(&biases[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias); | |
120 | } | ||
121 | } | ||
122 | // Iterate over rows in the nr row block | ||
123 |
2/2✓ Branch 0 taken 13824 times.
✓ Branch 1 taken 216 times.
|
14040 | for (size_t nr_block_idx = 0; nr_block_idx < nr; ++nr_block_idx) { |
124 | 13824 | const uint8_t* const src_row = rhs + ((row_idx + nr_block_idx) * rhs_stride); | |
125 | // Go to the first kr block for this row in the nr block | ||
126 | 13824 | int8_t* dst_kr_block = dst_row + (nr_block_idx * kr / 2); | |
127 | |||
128 | 13824 | int32_t sum = 0; | |
129 | |||
130 | // Iterate over k src columns in blocks of kr columns | ||
131 |
2/2✓ Branch 0 taken 6912 times.
✓ Branch 1 taken 6912 times.
|
13824 | if (rhs_zero_point == 8) { |
132 |
2/2✓ Branch 0 taken 108544 times.
✓ Branch 1 taken 6912 times.
|
115456 | for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { |
133 | // Iterate over columns in the kr block | ||
134 | // Kr checked to be multiple of 2 (because 2 values per byte) | ||
135 |
2/2✓ Branch 0 taken 217088 times.
✓ Branch 1 taken 108544 times.
|
325632 | for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { |
136 | // We pad dst with 0s if the rounded k or n values have been exceeded | ||
137 |
4/4✓ Branch 0 taken 145088 times.
✓ Branch 1 taken 72000 times.
✓ Branch 2 taken 22368 times.
✓ Branch 3 taken 122720 times.
|
217088 | if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { |
138 | 94368 | dst_kr_block[kr_block_idx / 2] = 0; | |
139 | 94368 | continue; | |
140 | } | ||
141 | |||
142 | // Load the 2 u4 values from source | ||
143 | 122720 | const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; | |
144 | |||
145 | // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
146 | // extract i8 values from the 2 u4 values | ||
147 | 122720 | const int8_t first_value = (dst_byte & 0xF) - rhs_zero_point; | |
148 | 245440 | const int8_t second_value = | |
149 |
2/2✓ Branch 0 taken 1902 times.
✓ Branch 1 taken 120818 times.
|
122720 | col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point; |
150 | |||
151 | // Add the i4 value to the row sum | ||
152 | 122720 | sum += (int32_t)first_value + (int32_t)second_value; | |
153 | |||
154 | // Truncate i8 to i4 and write to dst | ||
155 | 122720 | const uint8_t hi = second_value & 0x0F; | |
156 | 122720 | const uint8_t lo = first_value & 0x0F; | |
157 | 122720 | dst_kr_block[kr_block_idx / 2] = (hi << 4) | lo; | |
158 | // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
159 | 122720 | } | |
160 | |||
161 | // Go to the next kr block for this row in the nr rows | ||
162 | 108544 | dst_kr_block += dst_nr_block_size; | |
163 | 108544 | } | |
164 | 6912 | } else { | |
165 |
2/2✓ Branch 0 taken 108544 times.
✓ Branch 1 taken 6912 times.
|
115456 | for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { |
166 | // Iterate over columns in the kr block | ||
167 | // Kr checked to be multiple of 2 (because 2 values per byte) | ||
168 |
2/2✓ Branch 0 taken 217088 times.
✓ Branch 1 taken 108544 times.
|
325632 | for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { |
169 | // We pad dst with 0s if the rounded k or n values have been | ||
170 | // exceeded | ||
171 |
4/4✓ Branch 0 taken 145088 times.
✓ Branch 1 taken 72000 times.
✓ Branch 2 taken 22368 times.
✓ Branch 3 taken 122720 times.
|
217088 | if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { |
172 | 94368 | dst_kr_block[kr_block_idx / 2] = 0; | |
173 | 94368 | continue; | |
174 | } | ||
175 | |||
176 | // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
177 | // Load the 2 u4 values from source | ||
178 | 122720 | const int8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; | |
179 | |||
180 | // extract i8 values from the 2 u4 values, shift first value | ||
181 | // back and forth to get the sign right. | ||
182 | 122720 | const int8_t first_value = kai_ext_sign_i8_i4(dst_byte & 0xF); | |
183 | 245440 | const int8_t second_value = | |
184 |
2/2✓ Branch 0 taken 1902 times.
✓ Branch 1 taken 120818 times.
|
122720 | col_idx + kr_block_idx + 1 >= k ? 0 : kai_ext_sign_i8_i4((dst_byte >> 4) & 0xF); |
185 | |||
186 | // Add the i4 value to the row sum | ||
187 | 122720 | sum += (int32_t)first_value + (int32_t)second_value; | |
188 | |||
189 | // Truncate i8 to i4 and write to dst | ||
190 | 122720 | const uint8_t hi = second_value & 0x0F; | |
191 | 122720 | const uint8_t lo = first_value & 0x0F; | |
192 | 122720 | dst_kr_block[kr_block_idx / 2] = (hi << 4) | lo; | |
193 | // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
194 | 122720 | } | |
195 | |||
196 | // Go to the next kr block for this row in the nr rows | ||
197 | 108544 | dst_kr_block += dst_nr_block_size; | |
198 | 108544 | } | |
199 | } | ||
200 | |||
201 | // save sum | ||
202 | 13824 | sums[nr_block_idx] = sum; | |
203 | 13824 | } | |
204 | 216 | } | |
205 | 160 | } | |
206 | #endif // Architectural features check. | ||
207 |