Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #include "kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" | ||
7 | |||
8 | #include <stddef.h> | ||
9 | #include <stdint.h> | ||
10 | #include <string.h> | ||
11 | |||
12 | #include "kai/kai_common.h" | ||
13 | |||
14 | static const size_t kai_num_bytes_sum_rhs = sizeof(float); | ||
15 | static const size_t kai_num_bytes_bias = sizeof(float); | ||
16 | static const size_t kai_nr_multiple_of = 4; | ||
17 | static const size_t kai_bl_multiple_of = 32; | ||
18 | |||
19 | 3510 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
20 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
21 | 3510 | return kai_roundup(k, bl) / bl; | |
22 | } | ||
23 | |||
24 | 3510 | inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { | |
25 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
26 | 3510 | return (bl / 2) + num_bytes_multiplier_rhs; | |
27 | } | ||
28 | |||
29 | 702 | inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks( | |
30 | size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { | ||
31 | − | KAI_ASSERT((bl % kr) == 0); | |
32 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
33 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
34 | |||
35 | 702 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
36 | 702 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); | |
37 | |||
38 | 1404 | return (nr * num_bytes_per_block * num_blocks_per_row); | |
39 | 702 | } | |
40 | |||
41 | ✗ | size_t kai_get_n_step_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(size_t nr) { | |
42 | ✗ | return nr; | |
43 | } | ||
44 | |||
45 | 702 | size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( | |
46 | size_t n_idx, // | ||
47 | size_t rhs_stride) { | ||
48 | 702 | KAI_UNUSED(rhs_stride); | |
49 | − | KAI_ASSERT((n_idx % 2) == 0); | |
50 | 702 | return (n_idx / 2) * sizeof(int8_t); | |
51 | } | ||
52 | |||
53 | 2106 | size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( | |
54 | size_t k, // | ||
55 | size_t nr, // | ||
56 | size_t kr, // | ||
57 | size_t sr, // | ||
58 | size_t bl, // | ||
59 | enum kai_datatype scale_dt) { | ||
60 | − | KAI_ASSERT((k % bl) == 0); | |
61 | − | KAI_ASSERT((bl % kr) == 0); | |
62 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
63 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
64 | − | KAI_ASSERT(scale_dt == kai_dt_bf16); | |
65 | |||
66 | 2106 | KAI_UNUSED(kr); | |
67 | 2106 | KAI_UNUSED(sr); | |
68 | |||
69 | 2106 | const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); | |
70 | 2106 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
71 | 2106 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); | |
72 | |||
73 | 4212 | return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); | |
74 | 2106 | } | |
75 | |||
76 | 702 | size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( | |
77 | size_t n_idx, // | ||
78 | size_t k, // | ||
79 | size_t nr, // | ||
80 | size_t kr, // | ||
81 | size_t sr, // | ||
82 | size_t bl, // | ||
83 | enum kai_datatype scale_dt) { | ||
84 | − | KAI_ASSERT((n_idx % nr) == 0); | |
85 | − | KAI_ASSERT((k % bl) == 0); | |
86 | − | KAI_ASSERT((bl % kr) == 0); | |
87 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
88 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
89 | − | KAI_ASSERT(scale_dt == kai_dt_bf16); | |
90 | |||
91 | 702 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); | |
92 | } | ||
93 | |||
94 | 702 | size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( | |
95 | size_t n, // | ||
96 | size_t k, // | ||
97 | size_t nr, // | ||
98 | size_t kr, // | ||
99 | size_t sr, // | ||
100 | size_t bl, // | ||
101 | enum kai_datatype scale_dt) { | ||
102 | − | KAI_ASSERT((k % bl) == 0); | |
103 | − | KAI_ASSERT((bl % kr) == 0); | |
104 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
105 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
106 | − | KAI_ASSERT(scale_dt == kai_dt_bf16); | |
107 | |||
108 | 702 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
109 | |||
110 | 1404 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); | |
111 | 702 | } | |
112 | |||
113 | 702 | void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( | |
114 | size_t num_groups, // | ||
115 | size_t n, // | ||
116 | size_t k, // | ||
117 | size_t nr, // | ||
118 | size_t kr, // | ||
119 | size_t sr, // | ||
120 | size_t bl, // | ||
121 | const uint8_t* rhs, // | ||
122 | size_t rhs_stride, // | ||
123 | const float* bias, // | ||
124 | const void* scale, // | ||
125 | size_t scale_stride, // | ||
126 | void* rhs_packed, // | ||
127 | size_t extra_bytes, // | ||
128 | const struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params* params) { | ||
129 | − | KAI_ASSERT(num_groups == 1); | |
130 | − | KAI_ASSERT(extra_bytes == 0); | |
131 | − | KAI_ASSERT(rhs != NULL); | |
132 | − | KAI_ASSERT(scale != NULL); | |
133 | − | KAI_ASSERT(rhs_packed != NULL); | |
134 | − | KAI_ASSERT(params != NULL); | |
135 | − | KAI_ASSERT(params->rhs_zero_point == 8); | |
136 | − | KAI_ASSERT(params->lhs_zero_point == 1); | |
137 | |||
138 | − | KAI_ASSERT((k % bl) == 0); | |
139 | − | KAI_ASSERT((bl % kr) == 0); | |
140 | − | KAI_ASSERT((kr % sr) == 0); | |
141 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
142 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
143 | − | KAI_ASSERT(params->scale_dt == kai_dt_bf16); | |
144 | |||
145 | // Note: The input matrix (rhs) is expected with: | ||
146 | // "n" columns and "k" rows (kxn) | ||
147 | |||
148 | 702 | const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(params->scale_dt); | |
149 | 1404 | const size_t rhs_packed_stride = | |
150 | 702 | kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); | |
151 | 1404 | const size_t rhs_packed_offset_end_of_all_blocks = | |
152 | 702 | kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); | |
153 | 702 | const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
154 | 702 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); | |
155 | 702 | const size_t num_bytes_per_block_k = bl / 2; | |
156 | 702 | const size_t dst_num_rows = kai_roundup(n, nr) / nr; | |
157 | 702 | const size_t k_interleaved_v = 16U; | |
158 | 702 | const size_t block_length_in_bytes = kr / sr; | |
159 | |||
160 | 702 | const int32_t rhs_zero_point = params->rhs_zero_point; | |
161 | |||
162 |
2/2✓ Branch 0 taken 702 times.
✓ Branch 1 taken 4364 times.
|
5066 | for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { |
163 | // Before packing, it keeps the pointer to the first quantized block | ||
164 | 4364 | uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; | |
165 | |||
166 | 4364 | float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); | |
167 | |||
168 | // Initialize the RHS reduction sums to zero | ||
169 | 4364 | memset(sums, 0, nr * kai_num_bytes_sum_rhs); | |
170 | |||
171 | // Iterate over the quantized blocks | ||
172 |
2/2✓ Branch 0 taken 19299 times.
✓ Branch 1 taken 4364 times.
|
23663 | for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { |
173 | // Store the scales after packing all K values | ||
174 | 19299 | uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; | |
175 | 19299 | const uint8_t* scale_ptr = scale; | |
176 | |||
177 |
2/2✓ Branch 0 taken 100884 times.
✓ Branch 1 taken 19299 times.
|
120183 | for (size_t i = 0; i < nr; ++i) { |
178 |
2/2✓ Branch 0 taken 95712 times.
✓ Branch 1 taken 5172 times.
|
100884 | const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); |
179 | |||
180 | 100884 | void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; | |
181 | 201768 | const void* src_scales_ptr = scale_ptr + dst_qblock_idx * num_bytes_multiplier_rhs + // | |
182 | 100884 | (src_row_idx * scale_stride); // | |
183 | |||
184 | 100884 | memcpy( | |
185 | dst_scales_ptr, // | ||
186 | src_scales_ptr, // | ||
187 | num_bytes_multiplier_rhs); // | ||
188 | 100884 | } | |
189 | |||
190 | 19299 | size_t kr_block_idx = 0; | |
191 |
2/2✓ Branch 0 taken 351904 times.
✓ Branch 1 taken 19299 times.
|
371203 | for (size_t dst_byte_idx = 0; dst_byte_idx < nr * num_bytes_per_block_k; |
192 | 351904 | dst_byte_idx += block_length_in_bytes) { | |
193 | 351904 | const size_t super_kr_block_idx = kr_block_idx / nr; | |
194 | 351904 | const size_t nr_idx = kr_block_idx % nr; | |
195 | 351904 | const size_t n0_idx = dst_row_idx * nr + nr_idx; | |
196 | |||
197 | // Clamp the index to avoid out-of-bound reads | ||
198 |
2/2✓ Branch 0 taken 333824 times.
✓ Branch 1 taken 18080 times.
|
351904 | const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); |
199 | |||
200 | 351904 | float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); | |
201 | |||
202 | 703808 | const size_t k_adjustment = | |
203 | 351904 | ((super_kr_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; | |
204 | 351904 | size_t k0_idx = dst_qblock_idx * bl + super_kr_block_idx * block_length_in_bytes + k_adjustment; | |
205 | 351904 | size_t k1_idx = k0_idx + k_interleaved_v; | |
206 | |||
207 | 351904 | float partial_sum = 0.0F; | |
208 | |||
209 |
2/2✓ Branch 0 taken 2152192 times.
✓ Branch 1 taken 351904 times.
|
2504096 | for (size_t block_byte_idx = 0; block_byte_idx < block_length_in_bytes; ++block_byte_idx) { |
210 | 2152192 | const size_t src_addr_byte0 = (n0_valid_idx / 2) + k0_idx * rhs_stride; | |
211 | 2152192 | const size_t src_addr_byte1 = (n0_valid_idx / 2) + k1_idx * rhs_stride; | |
212 | |||
213 | 2152192 | uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; | |
214 | 2152192 | uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; | |
215 | |||
216 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2152192 times.
|
2152192 | if (k0_idx < k) { |
217 | 2152192 | byte0 = rhs[src_addr_byte0]; | |
218 | 2152192 | } | |
219 | |||
220 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2152192 times.
|
2152192 | if (k1_idx < k) { |
221 | 2152192 | byte1 = rhs[src_addr_byte1]; | |
222 | 2152192 | } | |
223 | |||
224 |
2/2✓ Branch 0 taken 1076096 times.
✓ Branch 1 taken 1076096 times.
|
2152192 | if ((n0_idx % 2) == 0) { |
225 | 1076096 | const uint8_t src_x0_lo = (byte0 & 0x0F); | |
226 | 1076096 | const uint8_t src_x0_hi = (byte1 & 0x0F); | |
227 | |||
228 | 1076096 | partial_sum += (float)((int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * rhs_zero_point) * d; | |
229 | |||
230 | 1076096 | const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); | |
231 | |||
232 | 1076096 | dst_row[dst_byte_idx + block_byte_idx] = dst_qs0 ^ 0x88; | |
233 | 1076096 | } else { | |
234 | 1076096 | const uint8_t src_x1_lo = (byte0 >> 4); | |
235 | 1076096 | const uint8_t src_x1_hi = (byte1 >> 4); | |
236 | |||
237 | 1076096 | partial_sum += (float)((int32_t)src_x1_lo + (int32_t)src_x1_hi - 2 * rhs_zero_point) * d; | |
238 | |||
239 | 1076096 | const uint8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); | |
240 | |||
241 | 1076096 | dst_row[dst_byte_idx + block_byte_idx] = dst_qs1 ^ 0x88; | |
242 | 1076096 | } | |
243 | 2152192 | k0_idx++; | |
244 | 2152192 | k1_idx++; | |
245 | 2152192 | } | |
246 | 351904 | sums[nr_idx] += partial_sum; | |
247 | |||
248 | // Increment the Kr block index | ||
249 | 351904 | kr_block_idx++; | |
250 | 351904 | } | |
251 | // Move the pointer after K values | ||
252 | 19299 | dst_row += num_bytes_per_block * nr; | |
253 | 19299 | } | |
254 | |||
255 | // Move the pointer after the row sum | ||
256 | 4364 | dst_row += kai_num_bytes_sum_rhs * nr; | |
257 | |||
258 | // Set the bias | ||
259 |
1/2✓ Branch 0 taken 4364 times.
✗ Branch 1 not taken.
|
4364 | if (bias == NULL) { |
260 | ✗ | memset(dst_row, 0, nr * kai_num_bytes_bias); | |
261 | ✗ | } else { | |
262 |
2/2✓ Branch 0 taken 22928 times.
✓ Branch 1 taken 4364 times.
|
27292 | for (size_t i = 0; i < nr; ++i) { |
263 | // Clamp the row index to avoid out-of-bound reads | ||
264 |
2/2✓ Branch 0 taken 20780 times.
✓ Branch 1 taken 2148 times.
|
22928 | const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); |
265 | 22928 | ((float*)dst_row)[i] = bias[src_row_idx]; | |
266 | 22928 | } | |
267 | } | ||
268 | 4364 | } | |
269 | 702 | } | |
270 |