Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #include "kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" | ||
7 | |||
8 | #include <stddef.h> | ||
9 | #include <stdint.h> | ||
10 | #include <string.h> | ||
11 | |||
12 | #include "kai/kai_common.h" | ||
13 | |||
14 | static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); | ||
15 | static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); | ||
16 | static const size_t kai_num_bytes_bias = sizeof(float); | ||
17 | |||
18 | 5088 | inline static size_t kai_k_roundedup(size_t k) { | |
19 | // Round up k to be a multiple of 32. | ||
20 | 5088 | size_t kai_k_multiple_of = 32; | |
21 | 10176 | return kai_roundup(k, kai_k_multiple_of); | |
22 | 5088 | } | |
23 | |||
24 | ✗ | size_t kai_get_n_step_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t nr) { | |
25 | ✗ | return nr; | |
26 | } | ||
27 | |||
28 | ✗ | size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride) { | |
29 | ✗ | KAI_UNUSED(rhs_stride); | |
30 | − | KAI_ASSERT((n_idx % 2) == 0); | |
31 | ✗ | return (n_idx / 2) * sizeof(int8_t); | |
32 | } | ||
33 | |||
34 | 2904 | size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { | |
35 | 2904 | KAI_UNUSED(kr); | |
36 | 2904 | KAI_UNUSED(sr); | |
37 | |||
38 | 2904 | const size_t k_internal = kai_k_roundedup(k); | |
39 | |||
40 | − | KAI_ASSERT((k_internal % 2) == 0); | |
41 | |||
42 | 5808 | return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); | |
43 | 2904 | } | |
44 | |||
45 | 1092 | size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0( | |
46 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { | ||
47 | − | KAI_ASSERT((n_idx % nr) == 0); | |
48 | |||
49 | 1092 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); | |
50 | } | ||
51 | |||
52 | 720 | size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { | |
53 | 720 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
54 | |||
55 | 1440 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); | |
56 | 720 | } | |
57 | |||
58 | 1092 | void kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0( | |
59 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, | ||
60 | const float* scale, void* rhs_packed, size_t extra_bytes, | ||
61 | const struct kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params* params) { | ||
62 | − | KAI_ASSERT(num_groups == 1); | |
63 | − | KAI_ASSERT(extra_bytes == 0); | |
64 | − | KAI_ASSERT((kr % sr) == 0); | |
65 | − | KAI_ASSERT(rhs != NULL); | |
66 | − | KAI_ASSERT(scale != NULL); | |
67 | − | KAI_ASSERT(rhs_packed != NULL); | |
68 | − | KAI_ASSERT(params != NULL); | |
69 | − | KAI_ASSERT(params->lhs_zero_point == 1); | |
70 | − | KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); | |
71 | |||
72 | 1092 | const uint8_t rhs_zero_point = params->rhs_zero_point; | |
73 | 1092 | const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); | |
74 | 1092 | const size_t k_internal = kai_k_roundedup(k); | |
75 | 1092 | const size_t dst_num_rows = kai_roundup(n, nr) / nr; | |
76 | 1092 | const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k) / 2); | |
77 | 1092 | const size_t block_length_in_bytes = kr / sr; | |
78 | 1092 | const size_t k_interleaved_v = 16U; | |
79 | 1092 | const size_t rhs_stride = kai_roundup(n, 2) / 2; | |
80 | |||
81 |
2/2✓ Branch 0 taken 1092 times.
✓ Branch 1 taken 14252 times.
|
15344 | for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { |
82 | 14252 | uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; | |
83 | |||
84 | 14252 | int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); | |
85 | |||
86 | // Initialize to zero the RHS reduction sums | ||
87 | 14252 | memset(sums, 0, nr * sizeof(int32_t)); | |
88 | |||
89 |
2/2✓ Branch 0 taken 2272256 times.
✓ Branch 1 taken 14252 times.
|
2286508 | for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { |
90 | 2272256 | const size_t block_idx = dst_byte_idx / block_length_in_bytes; | |
91 | 2272256 | const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; | |
92 | 2272256 | const size_t super_block_idx = block_idx / nr; | |
93 | 2272256 | const size_t nr_idx = block_idx % nr; | |
94 | |||
95 | 4544512 | const size_t k_adjustment = | |
96 | 2272256 | ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; | |
97 | 2272256 | const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; | |
98 | 2272256 | const size_t k1_idx = k0_idx + k_interleaved_v; | |
99 | 2272256 | const size_t n0_idx = dst_row_idx * nr + nr_idx; | |
100 | |||
101 | // Clamp the index to avoid out-of-bound reads | ||
102 |
2/2✓ Branch 0 taken 2182720 times.
✓ Branch 1 taken 89536 times.
|
2272256 | const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); |
103 | |||
104 | 2272256 | const size_t src_addr_byte0 = (n0_valid_idx / 2) + k0_idx * rhs_stride; | |
105 | 2272256 | const size_t src_addr_byte1 = (n0_valid_idx / 2) + k1_idx * rhs_stride; | |
106 | |||
107 | 2272256 | const size_t shift_right_x0 = (n0_idx % 2) * 4; | |
108 | |||
109 |
2/2✓ Branch 0 taken 920576 times.
✓ Branch 1 taken 1351680 times.
|
2272256 | if (rhs_zero_point == 8) { |
110 | 920576 | uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; | |
111 | 920576 | uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; | |
112 | |||
113 |
2/2✓ Branch 0 taken 58080 times.
✓ Branch 1 taken 862496 times.
|
920576 | if (k0_idx < k) { |
114 | 862496 | byte0 = rhs[src_addr_byte0]; | |
115 | 862496 | } | |
116 | |||
117 |
2/2✓ Branch 0 taken 212432 times.
✓ Branch 1 taken 708144 times.
|
920576 | if (k1_idx < k) { |
118 | 708144 | byte1 = rhs[src_addr_byte1]; | |
119 | 708144 | } | |
120 | |||
121 | // The following operations where we extract the values from the bytes | ||
122 | // can be also written in the following and less efficient manner: | ||
123 | /* | ||
124 | uint8_t src_x0_lo = 0; | ||
125 | uint8_t src_x0_hi = 0; | ||
126 | |||
127 | if ((n0_idx % 2) == 0) { | ||
128 | src_x0_lo = (byte0 & 0x0F); | ||
129 | } else { | ||
130 | src_x0_lo = (byte0 >> 4); | ||
131 | } | ||
132 | |||
133 | if ((n0_idx % 2) == 0) { | ||
134 | src_x0_hi = (byte1 & 0x0F); | ||
135 | } else { | ||
136 | src_x0_hi = (byte1 >> 4); | ||
137 | } | ||
138 | */ | ||
139 | |||
140 | 920576 | const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; | |
141 | 920576 | const uint8_t src_x0_hi = (byte1 >> shift_right_x0) & 0x0F; | |
142 | |||
143 | 920576 | sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; | |
144 | |||
145 | 920576 | const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); | |
146 | |||
147 | 920576 | *dst_row = dst_qs0 ^ 0x88; | |
148 | 920576 | dst_row += sizeof(uint8_t); | |
149 | 920576 | } else { | |
150 | 1351680 | int8_t byte0 = 0; | |
151 | 1351680 | int8_t byte1 = 0; | |
152 | |||
153 | // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
154 |
2/2✓ Branch 0 taken 98848 times.
✓ Branch 1 taken 1252832 times.
|
1351680 | if (k0_idx < k) { |
155 | 1252832 | byte0 = rhs[src_addr_byte0]; | |
156 | 1252832 | } | |
157 | |||
158 |
2/2✓ Branch 0 taken 261936 times.
✓ Branch 1 taken 1089744 times.
|
1351680 | if (k1_idx < k) { |
159 | 1089744 | byte1 = rhs[src_addr_byte1]; | |
160 | 1089744 | } | |
161 | |||
162 | // The logic behind the following operations where we extract the | ||
163 | // values from the bytes is same as unsigned | ||
164 | |||
165 | 1351680 | int8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; | |
166 | 1351680 | int8_t src_x0_hi = (byte1 >> shift_right_x0) & 0x0F; | |
167 | |||
168 | 1351680 | const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); | |
169 | // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
170 | |||
171 | 1351680 | *(int8_t*)dst_row = dst_qs0; | |
172 | 1351680 | dst_row += sizeof(int8_t); | |
173 | |||
174 | 1351680 | src_x0_lo = kai_ext_sign_i8_i4(src_x0_lo); | |
175 | 1351680 | src_x0_hi = kai_ext_sign_i8_i4(src_x0_hi); | |
176 | 1351680 | sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi; | |
177 | 1351680 | } | |
178 | 2272256 | } | |
179 | |||
180 | // Adjust the reduction sums | ||
181 |
2/2✓ Branch 0 taken 78656 times.
✓ Branch 1 taken 14252 times.
|
92908 | for (size_t i = 0; i < nr; ++i) { |
182 | 78656 | sums[i] = sums[i] * 16; | |
183 | 78656 | dst_row += sizeof(int32_t); | |
184 | 78656 | } | |
185 | |||
186 | // Adjust the scales | ||
187 |
2/2✓ Branch 0 taken 78656 times.
✓ Branch 1 taken 14252 times.
|
92908 | for (size_t i = 0; i < nr; ++i) { |
188 | // Clamp the row index to avoid out-of-bound reads | ||
189 |
2/2✓ Branch 0 taken 75584 times.
✓ Branch 1 taken 3072 times.
|
78656 | const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); |
190 | 78656 | *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; | |
191 | 78656 | dst_row += sizeof(float); | |
192 | 78656 | } | |
193 | |||
194 | // Set the bias | ||
195 |
2/2✓ Branch 0 taken 914 times.
✓ Branch 1 taken 13338 times.
|
14252 | if (bias == NULL) { |
196 | 914 | memset(dst_row, 0, nr * kai_num_bytes_bias); | |
197 | 914 | } else { | |
198 |
2/2✓ Branch 0 taken 71344 times.
✓ Branch 1 taken 13338 times.
|
84682 | for (size_t i = 0; i < nr; ++i) { |
199 | // Clamp the row index to avoid out-of-bound reads | ||
200 |
2/2✓ Branch 0 taken 69040 times.
✓ Branch 1 taken 2304 times.
|
71344 | const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); |
201 | 71344 | ((float*)dst_row)[i] = bias[src_row_idx]; | |
202 | 71344 | } | |
203 | } | ||
204 | 14252 | } | |
205 | 1092 | } | |
206 |