Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" | ||
7 | |||
8 | #include <stddef.h> | ||
9 | #include <stdint.h> | ||
10 | #include <string.h> | ||
11 | |||
12 | #include "kai/kai_common.h" | ||
13 | |||
14 | static const size_t kai_num_bytes_sum_rhs = sizeof(float); | ||
15 | static const size_t kai_num_bytes_bias = sizeof(float); | ||
16 | static const size_t kai_nr_multiple_of = 4; | ||
17 | static const size_t kai_bl_multiple_of = 32; | ||
18 | |||
19 | 2808 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
20 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
21 | 2808 | return kai_roundup(k, bl) / bl; | |
22 | } | ||
23 | |||
24 | 2106 | inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { | |
25 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
26 | 2106 | return (bl / 2) + num_bytes_multiplier_rhs; | |
27 | } | ||
28 | |||
29 | 702 | inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks( | |
30 | size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { | ||
31 | − | KAI_ASSERT((bl % kr) == 0); | |
32 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
33 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
34 | |||
35 | 702 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
36 | 702 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); | |
37 | |||
38 | 1404 | return (nr * num_bytes_per_block * num_blocks_per_row); | |
39 | 702 | } | |
40 | |||
41 | ✗ | size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(size_t nr) { | |
42 | ✗ | return nr; | |
43 | } | ||
44 | |||
45 | 702 | size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( | |
46 | size_t n_idx, // | ||
47 | size_t rhs_stride) { | ||
48 | 702 | return n_idx * rhs_stride; | |
49 | } | ||
50 | |||
51 | 1404 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( | |
52 | size_t k, // | ||
53 | size_t nr, // | ||
54 | size_t kr, // | ||
55 | size_t sr, // | ||
56 | size_t bl, // | ||
57 | enum kai_datatype scale_dt) { | ||
58 | − | KAI_ASSERT((k % bl) == 0); | |
59 | − | KAI_ASSERT((bl % kr) == 0); | |
60 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
61 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
62 | − | KAI_ASSERT(scale_dt == kai_dt_bf16); | |
63 | |||
64 | 1404 | KAI_UNUSED(kr); | |
65 | 1404 | KAI_UNUSED(sr); | |
66 | |||
67 | 1404 | const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); | |
68 | 1404 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
69 | 1404 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); | |
70 | |||
71 | 2808 | return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); | |
72 | 1404 | } | |
73 | |||
74 | 702 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( | |
75 | size_t n_idx, // | ||
76 | size_t k, // | ||
77 | size_t nr, // | ||
78 | size_t kr, // | ||
79 | size_t sr, // | ||
80 | size_t bl, // | ||
81 | enum kai_datatype scale_dt) { | ||
82 | − | KAI_ASSERT((n_idx % nr) == 0); | |
83 | − | KAI_ASSERT((k % bl) == 0); | |
84 | − | KAI_ASSERT((bl % kr) == 0); | |
85 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
86 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
87 | − | KAI_ASSERT(scale_dt == kai_dt_bf16); | |
88 | |||
89 | 702 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); | |
90 | } | ||
91 | |||
92 | 702 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( | |
93 | size_t n, // | ||
94 | size_t k, // | ||
95 | size_t nr, // | ||
96 | size_t kr, // | ||
97 | size_t sr, // | ||
98 | size_t bl, // | ||
99 | enum kai_datatype scale_dt) { | ||
100 | − | KAI_ASSERT((k % bl) == 0); | |
101 | − | KAI_ASSERT((bl % kr) == 0); | |
102 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
103 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
104 | − | KAI_ASSERT(scale_dt == kai_dt_bf16); | |
105 | |||
106 | 702 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
107 | |||
108 | 1404 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); | |
109 | 702 | } | |
110 | |||
111 | 702 | void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( | |
112 | size_t num_groups, // | ||
113 | size_t n, // | ||
114 | size_t k, // | ||
115 | size_t nr, // | ||
116 | size_t kr, // | ||
117 | size_t sr, // | ||
118 | size_t bl, // | ||
119 | const uint8_t* rhs, // | ||
120 | size_t rhs_stride, // | ||
121 | const float* bias, // | ||
122 | const void* scale, // | ||
123 | size_t scale_stride, // | ||
124 | void* rhs_packed, // | ||
125 | size_t extra_bytes, // | ||
126 | const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) { | ||
127 | − | KAI_ASSERT(num_groups == 1); | |
128 | − | KAI_ASSERT(extra_bytes == 0); | |
129 | − | KAI_ASSERT(rhs != NULL); | |
130 | − | KAI_ASSERT(scale != NULL); | |
131 | − | KAI_ASSERT(rhs_packed != NULL); | |
132 | − | KAI_ASSERT(params != NULL); | |
133 | − | KAI_ASSERT(params->rhs_zero_point == 8); | |
134 | − | KAI_ASSERT(params->lhs_zero_point == 1); | |
135 | |||
136 | − | KAI_ASSERT((k % bl) == 0); | |
137 | − | KAI_ASSERT((bl % kr) == 0); | |
138 | − | KAI_ASSERT((kr % sr) == 0); | |
139 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
140 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
141 | − | KAI_ASSERT(params->scale_dt == kai_dt_bf16); | |
142 | |||
143 | // Note: The input matrix (rhs) is expected with: | ||
144 | // "k" columns and "n" rows (NxK) | ||
145 | 702 | const enum kai_datatype scale_dt = params->scale_dt; | |
146 | 702 | const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); | |
147 | 1404 | const size_t rhs_packed_offset_end_of_all_blocks = | |
148 | 702 | kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); | |
149 | 702 | const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
150 | 702 | const size_t num_bytes_per_block_k = bl / 2; | |
151 | 702 | const size_t dst_num_rows = kai_roundup(n, nr); | |
152 | 702 | const size_t block_length_in_bytes = kr / sr; | |
153 | |||
154 | 702 | uint8_t* dst_row = (uint8_t*)rhs_packed; | |
155 | |||
156 |
2/2✓ Branch 0 taken 702 times.
✓ Branch 1 taken 4364 times.
|
5066 | for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) { |
157 | 4364 | float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); | |
158 | |||
159 | // Initialize the RHS reduction sums to zero | ||
160 | 4364 | memset(sums, 0, nr * kai_num_bytes_sum_rhs); | |
161 | |||
162 | // Iterate over the quantized blocks | ||
163 |
2/2✓ Branch 0 taken 19299 times.
✓ Branch 1 taken 4364 times.
|
23663 | for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { |
164 | // Store the scales after packing all K values in the block | ||
165 | 19299 | uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; | |
166 | 19299 | const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * num_bytes_multiplier_rhs; | |
167 | |||
168 |
2/2✓ Branch 0 taken 100884 times.
✓ Branch 1 taken 19299 times.
|
120183 | for (size_t i = 0; i < nr; ++i) { |
169 |
2/2✓ Branch 0 taken 95712 times.
✓ Branch 1 taken 5172 times.
|
100884 | const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1); |
170 | 100884 | const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride; | |
171 | 100884 | void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; | |
172 | |||
173 | 100884 | memcpy( | |
174 | dst_scales_ptr, // | ||
175 | src_scales_ptr, // | ||
176 | num_bytes_multiplier_rhs); // | ||
177 | 100884 | } | |
178 | |||
179 | 19299 | size_t k0_idx_i = dst_qblock_idx * bl; | |
180 | |||
181 |
2/2✓ Branch 0 taken 25732 times.
✓ Branch 1 taken 19299 times.
|
45031 | for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { |
182 |
2/2✓ Branch 0 taken 66920 times.
✓ Branch 1 taken 25732 times.
|
92652 | for (size_t segment_idx = 0; segment_idx < 16 / block_length_in_bytes; ++segment_idx) { |
183 |
2/2✓ Branch 0 taken 351904 times.
✓ Branch 1 taken 66920 times.
|
418824 | for (size_t nr_idx = 0; nr_idx < nr; ++nr_idx) { |
184 | 351904 | const size_t n0_idx = dst_row_idx + nr_idx; | |
185 | |||
186 | // Two int4 values are stored in one byte. | ||
187 | // The lower order part of the byte (low) holds the first nibble (K-index + 0). | ||
188 | // The higher order of the byte holds the second nibble (K-index + 16). | ||
189 | 351904 | size_t k0_idx = k0_idx_i; | |
190 | 351904 | size_t k1_idx = k0_idx_i + 16; | |
191 | |||
192 | // Clamp the index to avoid out-of-bound reads | ||
193 |
2/2✓ Branch 0 taken 333824 times.
✓ Branch 1 taken 18080 times.
|
351904 | const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); |
194 | 351904 | float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); | |
195 | |||
196 | 351904 | int32_t partial_sum = 0; | |
197 | |||
198 | 351904 | size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; | |
199 | |||
200 |
2/2✓ Branch 0 taken 1076096 times.
✓ Branch 1 taken 351904 times.
|
1428000 | for (size_t block_byte_idx = 0; block_byte_idx < block_length_in_bytes; block_byte_idx += 2) { |
201 | // Initialize the byte with the zero-point (8) | ||
202 | // e.g. uint8_t byte0 = 8 | 8 << 4 | ||
203 | 1076096 | uint8_t byte0 = 136; | |
204 | 1076096 | uint8_t byte1 = 136; | |
205 | 1076096 | uint8_t byte2 = 136; | |
206 | 1076096 | uint8_t byte3 = 136; | |
207 | |||
208 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1076096 times.
|
1076096 | if (k0_idx < k) { |
209 | 1076096 | byte0 = rhs[src_addr_byte0]; | |
210 | 1076096 | } | |
211 | |||
212 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1076096 times.
|
1076096 | if (k1_idx < k) { |
213 | 1076096 | byte1 = rhs[src_addr_byte0 + 8]; | |
214 | 1076096 | } | |
215 | |||
216 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1076096 times.
|
1076096 | if (k0_idx + 1 < k) { |
217 | 1076096 | byte2 = byte0; | |
218 | 1076096 | } | |
219 | |||
220 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1076096 times.
|
1076096 | if (k1_idx + 1 < k) { |
221 | 1076096 | byte3 = byte1; | |
222 | 1076096 | } | |
223 | |||
224 | 1076096 | k0_idx += 2; | |
225 | 1076096 | k1_idx += 2; | |
226 | |||
227 | 1076096 | const uint8_t src_x0_lo = byte0 & 0x0F; | |
228 | 1076096 | const uint8_t src_x0_hi = byte1 & 0x0F; | |
229 | 1076096 | const uint8_t src_x1_lo = (byte2 >> 4) & 0x0F; | |
230 | 1076096 | const uint8_t src_x1_hi = (byte3 >> 4) & 0x0F; | |
231 | |||
232 | 1076096 | partial_sum += (int32_t)src_x0_lo; | |
233 | 1076096 | partial_sum += (int32_t)src_x0_hi; | |
234 | 1076096 | partial_sum += (int32_t)src_x1_lo; | |
235 | 1076096 | partial_sum += (int32_t)src_x1_hi; | |
236 | 1076096 | partial_sum -= 32; // 4 * zero_point (8) | |
237 | |||
238 | 2152192 | const uint16_t dst_q = | |
239 | 1076096 | ((src_x0_lo)) | ((src_x0_hi) << 4) | ((src_x1_lo) << 8) | ((src_x1_hi) << 12); | |
240 | |||
241 | 1076096 | *((uint16_t*)dst_row) = dst_q ^ 0x8888; | |
242 | |||
243 | 1076096 | dst_row += 2; | |
244 | 1076096 | src_addr_byte0 += 1; | |
245 | 1076096 | } | |
246 | // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
247 | 351904 | sums[nr_idx] += (float)partial_sum * d; | |
248 | // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
249 | 351904 | } | |
250 | |||
251 | 66920 | k0_idx_i += block_length_in_bytes; | |
252 | 66920 | } | |
253 | 25732 | k0_idx_i += 16; | |
254 | 25732 | } | |
255 | // Move the pointer after scales | ||
256 | 19299 | dst_row += num_bytes_multiplier_rhs * nr; | |
257 | 19299 | } | |
258 | |||
259 | // Move the pointer after the row sum | ||
260 | 4364 | dst_row += kai_num_bytes_sum_rhs * nr; | |
261 | |||
262 | // Set the bias | ||
263 |
1/2✓ Branch 0 taken 4364 times.
✗ Branch 1 not taken.
|
4364 | if (bias == NULL) { |
264 | ✗ | memset(dst_row, 0, nr * kai_num_bytes_bias); | |
265 | ✗ | } else { | |
266 |
2/2✓ Branch 0 taken 22928 times.
✓ Branch 1 taken 4364 times.
|
27292 | for (size_t i = 0; i < nr; ++i) { |
267 | // Clamp the row index to avoid out-of-bound reads | ||
268 |
2/2✓ Branch 0 taken 20780 times.
✓ Branch 1 taken 2148 times.
|
22928 | const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1); |
269 | 22928 | ((float*)dst_row)[i] = bias[src_row_idx]; | |
270 | 22928 | } | |
271 | } | ||
272 | // Move the pointer after the row sum | ||
273 | 4364 | dst_row += kai_num_bytes_bias * nr; | |
274 | 4364 | } | |
275 | 702 | } | |
276 |