Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" | ||
7 | |||
8 | #include <stddef.h> | ||
9 | #include <stdint.h> | ||
10 | #include <string.h> | ||
11 | |||
12 | #include "kai/kai_common.h" | ||
13 | |||
14 | static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); | ||
15 | static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); | ||
16 | static const size_t kai_num_bytes_bias = sizeof(float); | ||
17 | |||
18 | 9352 | inline static size_t kai_k_roundedup(size_t k) { | |
19 | // Round up k to be a multiple of 32. | ||
20 | 9352 | size_t kai_k_multiple_of = 32; | |
21 | 18704 | return kai_roundup(k, kai_k_multiple_of); | |
22 | 9352 | } | |
23 | |||
24 | ✗ | size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t nr) { | |
25 | ✗ | return nr; | |
26 | } | ||
27 | |||
28 | 720 | size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride) { | |
29 | 720 | return n_idx * rhs_stride; | |
30 | } | ||
31 | |||
32 | 6048 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { | |
33 | 6048 | KAI_UNUSED(kr); | |
34 | 6048 | KAI_UNUSED(sr); | |
35 | 6048 | const size_t k_internal = kai_k_roundedup(k); | |
36 | |||
37 | − | KAI_ASSERT((k_internal % 2) == 0); | |
38 | |||
39 | 12096 | return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); | |
40 | 6048 | } | |
41 | |||
42 | 2372 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( | |
43 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { | ||
44 | − | KAI_ASSERT((n_idx % nr) == 0); | |
45 | |||
46 | 2372 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); | |
47 | } | ||
48 | |||
49 | 2024 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { | |
50 | 2024 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
51 | |||
52 | 4048 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); | |
53 | 2024 | } | |
54 | |||
55 | 1652 | void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( | |
56 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, | ||
57 | const float* scale, void* rhs_packed, size_t extra_bytes, | ||
58 | const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { | ||
59 | − | KAI_ASSERT(num_groups == 1); | |
60 | − | KAI_ASSERT(extra_bytes == 0); | |
61 | − | KAI_ASSERT((kr % sr) == 0); | |
62 | − | KAI_ASSERT(rhs != NULL); | |
63 | − | KAI_ASSERT(scale != NULL); | |
64 | − | KAI_ASSERT(rhs_packed != NULL); | |
65 | − | KAI_ASSERT(params != NULL); | |
66 | − | KAI_ASSERT(params->lhs_zero_point == 1); | |
67 | − | KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); | |
68 | |||
69 | 1652 | const uint8_t rhs_zero_point = params->rhs_zero_point; | |
70 | 1652 | const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); | |
71 | 1652 | const size_t k_internal = kai_k_roundedup(k); | |
72 | 1652 | const size_t dst_num_rows = kai_roundup(n, nr) / nr; | |
73 | 1652 | const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k) / 2); | |
74 | 1652 | const size_t block_length_in_bytes = kr / sr; | |
75 | 1652 | const size_t k_interleaved_v = 16U; | |
76 | 1652 | const size_t rhs_stride = kai_roundup(k, 2) / 2; | |
77 | |||
78 |
2/2✓ Branch 0 taken 1652 times.
✓ Branch 1 taken 13230 times.
|
14882 | for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { |
79 | 13230 | uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; | |
80 | |||
81 | 13230 | int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); | |
82 | |||
83 | // Initialize to zero the RHS reduction sums | ||
84 | 13230 | memset(sums, 0, nr * sizeof(int32_t)); | |
85 | |||
86 |
2/2✓ Branch 0 taken 1896960 times.
✓ Branch 1 taken 13230 times.
|
1910190 | for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { |
87 | 1896960 | const size_t block_idx = dst_byte_idx / block_length_in_bytes; | |
88 | 1896960 | const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; | |
89 | 1896960 | const size_t super_block_idx = block_idx / nr; | |
90 | 1896960 | const size_t nr_idx = block_idx % nr; | |
91 | |||
92 | 3793920 | const size_t k_adjustment = | |
93 | 1896960 | ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; | |
94 | 1896960 | const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; | |
95 | 1896960 | const size_t k1_idx = k0_idx + k_interleaved_v; | |
96 | 1896960 | const size_t n0_idx = dst_row_idx * nr + nr_idx; | |
97 | |||
98 | // Clamp the index to avoid out-of-bound reads | ||
99 |
2/2✓ Branch 0 taken 1790336 times.
✓ Branch 1 taken 106624 times.
|
1896960 | const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); |
100 | |||
101 | 1896960 | const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; | |
102 | 1896960 | const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; | |
103 | |||
104 | 1896960 | const size_t shift_right_x0 = (k0_idx % 2) * 4; | |
105 | 1896960 | const size_t shift_right_x1 = (k1_idx % 2) * 4; | |
106 | |||
107 |
2/2✓ Branch 0 taken 476288 times.
✓ Branch 1 taken 1420672 times.
|
1896960 | if (rhs_zero_point == 8) { |
108 | 476288 | uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; | |
109 | 476288 | uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; | |
110 | |||
111 |
2/2✓ Branch 0 taken 32316 times.
✓ Branch 1 taken 443972 times.
|
476288 | if (k0_idx < k) { |
112 | 443972 | byte0 = rhs[src_addr_byte0]; | |
113 | 443972 | } | |
114 | |||
115 |
2/2✓ Branch 0 taken 111516 times.
✓ Branch 1 taken 364772 times.
|
476288 | if (k1_idx < k) { |
116 | 364772 | byte1 = rhs[src_addr_byte1]; | |
117 | 364772 | } | |
118 | |||
119 | // The following operations where we extract the values from the bytes | ||
120 | // can be also written in the following and less efficient manner: | ||
121 | /* | ||
122 | uint8_t src_x0_lo = 0; | ||
123 | uint8_t src_x0_hi = 0; | ||
124 | |||
125 | if ((k0_idx % 2) == 0) { | ||
126 | src_x0_lo = (byte0 & 0x0F); | ||
127 | } else { | ||
128 | src_x0_lo = (byte0 >> 4); | ||
129 | } | ||
130 | |||
131 | if ((k1_idx % 2) == 0) { | ||
132 | src_x0_hi = (byte1 & 0x0F); | ||
133 | } else { | ||
134 | src_x0_hi = (byte1 >> 4); | ||
135 | } | ||
136 | */ | ||
137 | 476288 | const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; | |
138 | 476288 | const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; | |
139 | |||
140 | 476288 | sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; | |
141 | |||
142 | 476288 | const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); | |
143 | |||
144 | 476288 | *dst_row = dst_qs0 ^ 0x88; | |
145 | 476288 | dst_row += sizeof(uint8_t); | |
146 | 476288 | } else { | |
147 | 1420672 | int8_t byte0 = 0; | |
148 | 1420672 | int8_t byte1 = 0; | |
149 | |||
150 | // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
151 |
2/2✓ Branch 0 taken 127564 times.
✓ Branch 1 taken 1293108 times.
|
1420672 | if (k0_idx < k) { |
152 | 1293108 | byte0 = rhs[src_addr_byte0]; | |
153 | 1293108 | } | |
154 | |||
155 |
2/2✓ Branch 0 taken 234076 times.
✓ Branch 1 taken 1186596 times.
|
1420672 | if (k1_idx < k) { |
156 | 1186596 | byte1 = rhs[src_addr_byte1]; | |
157 | 1186596 | } | |
158 | |||
159 | // The logic behind the following operations where we extract the | ||
160 | // values from the bytes is same as unsigned | ||
161 | |||
162 | 1420672 | int8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; | |
163 | 1420672 | int8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; | |
164 | |||
165 | 1420672 | const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); | |
166 | // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
167 | |||
168 | 1420672 | *(int8_t*)dst_row = dst_qs0; | |
169 | 1420672 | dst_row += sizeof(int8_t); | |
170 | |||
171 | 1420672 | src_x0_lo = kai_ext_sign_i8_i4(src_x0_lo); | |
172 | 1420672 | src_x0_hi = kai_ext_sign_i8_i4(src_x0_hi); | |
173 | 1420672 | sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi; | |
174 | 1420672 | } | |
175 | 1896960 | } | |
176 | |||
177 | // Adjust the reduction sums | ||
178 |
2/2✓ Branch 0 taken 67656 times.
✓ Branch 1 taken 13230 times.
|
80886 | for (size_t i = 0; i < nr; ++i) { |
179 | 67656 | sums[i] = sums[i] * 16; | |
180 | 67656 | dst_row += sizeof(int32_t); | |
181 | 67656 | } | |
182 | |||
183 | // Adjust the scales | ||
184 |
2/2✓ Branch 0 taken 67656 times.
✓ Branch 1 taken 13230 times.
|
80886 | for (size_t i = 0; i < nr; ++i) { |
185 | // Clamp the row index to avoid out-of-bound reads | ||
186 |
2/2✓ Branch 0 taken 63712 times.
✓ Branch 1 taken 3944 times.
|
67656 | const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); |
187 | 67656 | *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; | |
188 | 67656 | dst_row += sizeof(float); | |
189 | 67656 | } | |
190 | |||
191 | // Set the bias | ||
192 |
2/2✓ Branch 0 taken 3442 times.
✓ Branch 1 taken 9788 times.
|
13230 | if (bias == NULL) { |
193 | 3442 | memset(dst_row, 0, nr * kai_num_bytes_bias); | |
194 | 3442 | } else { | |
195 |
2/2✓ Branch 0 taken 50232 times.
✓ Branch 1 taken 9788 times.
|
60020 | for (size_t i = 0; i < nr; ++i) { |
196 | // Clamp the row index to avoid out-of-bound reads | ||
197 |
2/2✓ Branch 0 taken 47594 times.
✓ Branch 1 taken 2638 times.
|
50232 | const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); |
198 | 50232 | ((float*)dst_row)[i] = bias[src_row_idx]; | |
199 | 50232 | } | |
200 | } | ||
201 | 13230 | } | |
202 | 1652 | } | |
203 |