Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
8 | #error This file must be compiled for AArch64. | ||
9 | #else // Architectural features check. | ||
10 | #include "kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.h" | ||
11 | |||
12 | #include <arm_neon.h> | ||
13 | #include <stdint.h> | ||
14 | #include <string.h> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | |||
18 | static const size_t kai_num_bytes_offset_rhs = sizeof(float); | ||
19 | static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); | ||
20 | static const size_t kai_num_bytes_bias = sizeof(float); | ||
21 | static const size_t kai_bl_multiple_of = 32; | ||
22 | static const size_t kai_nr_multiple_of = 4; | ||
23 | |||
24 | 3924 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
25 | − | KAI_ASSUME((k % 2) == 0); | |
26 | − | KAI_ASSUME((k % bl) == 0); | |
27 | − | KAI_ASSUME((bl % kai_bl_multiple_of) == 0); | |
28 | 3924 | return kai_roundup(k, bl) / bl; | |
29 | } | ||
30 | |||
31 | 5232 | inline static size_t kai_get_num_bytes_per_block(size_t bl) { | |
32 | 5232 | return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs; | |
33 | } | ||
34 | |||
35 | 3924 | inline static size_t kai_get_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) { | |
36 | − | KAI_ASSUME((k % 2) == 0); | |
37 | − | KAI_ASSUME((k % kr) == 0); | |
38 | − | KAI_ASSUME((k % bl) == 0); | |
39 | − | KAI_ASSUME((bl % kr) == 0); | |
40 | − | KAI_ASSUME((bl % kai_bl_multiple_of) == 0); | |
41 | 3924 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
42 | 3924 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl); | |
43 | 7848 | return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias); | |
44 | 3924 | } | |
45 | |||
46 | ✗ | size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) { | |
47 | ✗ | return n_idx * rhs_stride; | |
48 | } | ||
49 | |||
50 | 1308 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon( | |
51 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { | ||
52 | − | KAI_ASSUME((k % 2) == 0); | |
53 | − | KAI_ASSUME((k % kr) == 0); | |
54 | − | KAI_ASSUME((k % bl) == 0); | |
55 | − | KAI_ASSUME((n_idx % nr) == 0); | |
56 | 1308 | KAI_UNUSED(kr); | |
57 | 1308 | return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, bl); | |
58 | } | ||
59 | |||
60 | 1308 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon( | |
61 | size_t n, size_t k, size_t nr, size_t kr, size_t bl) { | ||
62 | − | KAI_ASSUME((k % 2) == 0); | |
63 | − | KAI_ASSUME((k % kr) == 0); | |
64 | − | KAI_ASSUME((k % bl) == 0); | |
65 | 1308 | KAI_UNUSED(kr); | |
66 | 1308 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
67 | 2616 | return num_rows * kai_get_rhs_packed_stride(k, nr, kr, bl); | |
68 | 1308 | } | |
69 | |||
70 | 1308 | void kai_run_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon( | |
71 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, | ||
72 | const void* zero, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, | ||
73 | const struct kai_rhs_pack_nxk_qai4c32p_params* params) { | ||
74 | − | KAI_ASSUME(num_groups == 1); | |
75 | − | KAI_ASSUME((k % 2) == 0); | |
76 | − | KAI_ASSUME((k % kr) == 0); | |
77 | − | KAI_ASSUME((k % bl) == 0); | |
78 | − | KAI_ASSUME((bl % kai_bl_multiple_of) == 0); | |
79 | − | KAI_ASSUME((nr % kai_nr_multiple_of) == 0); | |
80 | − | KAI_ASSUME(extra_bytes == 0); | |
81 | |||
82 | − | KAI_ASSUME(sr == 2); | |
83 | − | KAI_ASSUME(kr / sr == 4); | |
84 | − | KAI_ASSUME(rhs != NULL); | |
85 | − | KAI_ASSUME(zero != NULL); | |
86 | − | KAI_ASSUME(scale != NULL); | |
87 | − | KAI_ASSUME(rhs_packed != NULL); | |
88 | − | KAI_ASSUME(params != NULL); | |
89 | − | KAI_ASSUME(params->rhs_zero_point == 8); | |
90 | − | KAI_ASSUME(params->lhs_zero_point == 1); | |
91 | |||
92 | // Note: The input matrix (rhs) is expected with: | ||
93 | // "k" columns and "n" rows (NxK) | ||
94 | |||
95 | 1308 | const size_t block_length = kr / sr; | |
96 | 1308 | const size_t num_blocks_per_row = k / bl; | |
97 | 1308 | const size_t rhs_stride = k / 2; | |
98 | 1308 | const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, bl); | |
99 | |||
100 | 1308 | const size_t dst_packed_block_size = kai_get_num_bytes_per_block(bl) * nr; | |
101 | 1308 | const size_t dst_block_data_size = bl / 2; | |
102 | 1308 | const size_t dst_num_rows = kai_roundup(n, nr) / nr; | |
103 | 1308 | const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size; | |
104 | 1308 | const size_t k_block_length_in_bytes = (block_length * sizeof(uint8_t)) / 2; | |
105 | |||
106 |
2/2✓ Branch 0 taken 1308 times.
✓ Branch 1 taken 1640 times.
|
2948 | for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { |
107 | 1640 | uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; | |
108 | 1640 | float* dst_row_bias = (float*)(dst_row + dst_bias_offset); | |
109 | 1640 | size_t row_idx = dst_row_idx * nr; | |
110 | 1640 | size_t rows_left = n - row_idx; | |
111 |
2/2✓ Branch 0 taken 2572 times.
✓ Branch 1 taken 1640 times.
|
4212 | for (size_t block_idx = 0; block_idx < num_blocks_per_row; block_idx++) { |
112 | 2572 | uint8_t* block_dst_row = dst_row + block_idx * dst_packed_block_size; | |
113 | 2572 | float* block_dst_zp = (float*)(block_dst_row + nr * dst_block_data_size); | |
114 | 2572 | float* block_dst_scale = block_dst_zp + nr; | |
115 | 2572 | size_t k_idx = block_idx * bl; | |
116 |
2/2✓ Branch 0 taken 6432 times.
✓ Branch 1 taken 2572 times.
|
9004 | for (size_t dst_byte_idx = 0; dst_byte_idx < dst_block_data_size; dst_byte_idx += 8) { |
117 |
2/2✓ Branch 0 taken 102912 times.
✓ Branch 1 taken 6432 times.
|
109344 | for (size_t nr_idx = 0; nr_idx <= nr - 4; nr_idx += 4) { |
118 |
2/2✓ Branch 0 taken 84152 times.
✓ Branch 1 taken 18760 times.
|
102912 | const size_t n0_idx = KAI_MIN(dst_row_idx * nr + nr_idx, n - 1); |
119 |
2/2✓ Branch 0 taken 84072 times.
✓ Branch 1 taken 18840 times.
|
102912 | const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1); |
120 |
2/2✓ Branch 0 taken 83376 times.
✓ Branch 1 taken 19536 times.
|
102912 | const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1); |
121 |
2/2✓ Branch 0 taken 79656 times.
✓ Branch 1 taken 23256 times.
|
102912 | const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1); |
122 | 102912 | const uint8_t* src_addr_byte = rhs + (k_idx / 2) + dst_byte_idx; | |
123 | |||
124 | 102912 | const uint8x8_t vec0_u8 = vld1_u8(src_addr_byte + n0_idx * rhs_stride); | |
125 | 102912 | const uint8x8_t vec1_u8 = vld1_u8(src_addr_byte + n1_idx * rhs_stride); | |
126 | 102912 | const uint8x8_t vec2_u8 = vld1_u8(src_addr_byte + n2_idx * rhs_stride); | |
127 | 102912 | const uint8x8_t vec3_u8 = vld1_u8(src_addr_byte + n3_idx * rhs_stride); | |
128 | |||
129 | 102912 | const uint16x4_t vec0_u16 = vreinterpret_u16_u8(vec0_u8); | |
130 | 102912 | const uint16x4_t vec1_u16 = vreinterpret_u16_u8(vec1_u8); | |
131 | 102912 | const uint16x4_t vec2_u16 = vreinterpret_u16_u8(vec2_u8); | |
132 | 102912 | const uint16x4_t vec3_u16 = vreinterpret_u16_u8(vec3_u8); | |
133 | |||
134 | 102912 | const uint16x4_t vec01_lo_u16 = vzip1_u16(vec0_u16, vec1_u16); | |
135 | 102912 | const uint16x4_t vec01_hi_u16 = vzip2_u16(vec0_u16, vec1_u16); | |
136 | 102912 | const uint16x4_t vec23_lo_u16 = vzip1_u16(vec2_u16, vec3_u16); | |
137 | 102912 | const uint16x4_t vec23_hi_u16 = vzip2_u16(vec2_u16, vec3_u16); | |
138 | |||
139 | 102912 | const uint32x2_t vec01_lo_u32 = vreinterpret_u32_u16(vec01_lo_u16); | |
140 | 102912 | const uint32x2_t vec01_hi_u32 = vreinterpret_u32_u16(vec01_hi_u16); | |
141 | 102912 | const uint32x2_t vec23_lo_u32 = vreinterpret_u32_u16(vec23_lo_u16); | |
142 | 102912 | const uint32x2_t vec23_hi_u32 = vreinterpret_u32_u16(vec23_hi_u16); | |
143 | |||
144 | 102912 | const uint32x2_t vin0_u32 = vzip1_u32(vec01_lo_u32, vec23_lo_u32); | |
145 | 102912 | const uint32x2_t vin1_u32 = vzip2_u32(vec01_lo_u32, vec23_lo_u32); | |
146 | 102912 | const uint32x2_t vin2_u32 = vzip1_u32(vec01_hi_u32, vec23_hi_u32); | |
147 | 102912 | const uint32x2_t vin3_u32 = vzip2_u32(vec01_hi_u32, vec23_hi_u32); | |
148 | |||
149 | 102912 | uint8x8_t vin0_u8 = vreinterpret_u8_u32(vin0_u32); | |
150 | 102912 | uint8x8_t vin1_u8 = vreinterpret_u8_u32(vin1_u32); | |
151 | 102912 | uint8x8_t vin2_u8 = vreinterpret_u8_u32(vin2_u32); | |
152 | 102912 | uint8x8_t vin3_u8 = vreinterpret_u8_u32(vin3_u32); | |
153 | |||
154 | 102912 | const uint8x8_t vin0_s1s = vshr_n_u8(vin0_u8, 4); | |
155 | 102912 | const uint8x8_t vin1_s1s = vshr_n_u8(vin1_u8, 4); | |
156 | 102912 | const uint8x8_t vin2_s1s = vshr_n_u8(vin2_u8, 4); | |
157 | 102912 | const uint8x8_t vin3_s1s = vshr_n_u8(vin3_u8, 4); | |
158 | |||
159 | 102912 | vin0_u8 = vshl_n_u8(vin0_u8, 4); | |
160 | 102912 | vin1_u8 = vshl_n_u8(vin1_u8, 4); | |
161 | 102912 | vin2_u8 = vshl_n_u8(vin2_u8, 4); | |
162 | 102912 | vin3_u8 = vshl_n_u8(vin3_u8, 4); | |
163 | |||
164 | 102912 | vin0_u8 = vorr_u8(vin0_u8, vin0_s1s); | |
165 | 102912 | vin1_u8 = vorr_u8(vin1_u8, vin1_s1s); | |
166 | 102912 | vin2_u8 = vorr_u8(vin2_u8, vin2_s1s); | |
167 | 102912 | vin3_u8 = vorr_u8(vin3_u8, vin3_s1s); | |
168 | |||
169 | 102912 | uint8_t* dst_row_offset = block_dst_row + nr_idx * k_block_length_in_bytes; | |
170 | 102912 | vst1_u8(dst_row_offset, vin0_u8); | |
171 | 102912 | vst1_u8(dst_row_offset + nr * k_block_length_in_bytes, vin1_u8); | |
172 | 102912 | vst1_u8(dst_row_offset + 2 * (nr * k_block_length_in_bytes), vin2_u8); | |
173 | 102912 | vst1_u8(dst_row_offset + 3 * (nr * k_block_length_in_bytes), vin3_u8); | |
174 | 102912 | } | |
175 | 6432 | block_dst_row += nr * sizeof(uint8x8_t); | |
176 | 6432 | } | |
177 | |||
178 | // Adjust the zero points and scales | ||
179 |
2/2✓ Branch 0 taken 1760 times.
✓ Branch 1 taken 812 times.
|
2572 | if (rows_left >= nr) { |
180 | 1760 | memcpy(block_dst_scale, &((const float*)scale)[row_idx], nr * kai_num_bytes_multiplier_rhs); | |
181 | 1760 | memcpy(block_dst_zp, &((const float*)zero)[row_idx], nr * kai_num_bytes_offset_rhs); | |
182 | 1760 | } else { | |
183 | // Fill remaining values | ||
184 | 812 | memcpy(block_dst_scale, &((const float*)scale)[row_idx], rows_left * kai_num_bytes_multiplier_rhs); | |
185 | 812 | memcpy(block_dst_zp, &((const float*)zero)[row_idx], rows_left * kai_num_bytes_offset_rhs); | |
186 | // Set leftover to 0 | ||
187 | 812 | memset(&block_dst_scale[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs); | |
188 | 812 | memset(&block_dst_zp[rows_left], 0, (nr - rows_left) * kai_num_bytes_offset_rhs); | |
189 | } | ||
190 | 2572 | } | |
191 | // Set the bias | ||
192 |
2/2✓ Branch 0 taken 820 times.
✓ Branch 1 taken 820 times.
|
1640 | if (bias == NULL) { |
193 | 820 | memset(dst_row_bias, 0, nr * kai_num_bytes_bias); | |
194 | 820 | } else { | |
195 |
2/2✓ Branch 0 taken 484 times.
✓ Branch 1 taken 336 times.
|
820 | if (rows_left >= nr) { |
196 | 484 | memcpy(dst_row_bias, &((const float*)bias)[row_idx], nr * kai_num_bytes_bias); | |
197 | 484 | } else { | |
198 | // Fill remaining values | ||
199 | 336 | memcpy(dst_row_bias, &((const float*)bias)[row_idx], rows_left * kai_num_bytes_bias); | |
200 | // Set leftover to 0 | ||
201 | 336 | memset(&dst_row_bias[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias); | |
202 | } | ||
203 | } | ||
204 | 1640 | } | |
205 | 1308 | } | |
206 | #endif | ||
207 |