Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | // nrx4 => this function can take in generic nr values but the input is expected to have a block depth of 4. | ||
8 | // Block depth is calculated as kr / sr. The values of these parameters are defined in the matmul ukernel. | ||
9 | |||
10 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
11 | #error This file must be compiled for AArch64. | ||
12 | #else // Architectural features check. | ||
13 | |||
14 | #include "kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h" | ||
15 | |||
16 | #include <arm_neon.h> | ||
17 | #include <stddef.h> | ||
18 | #include <stdint.h> | ||
19 | #include <string.h> | ||
20 | |||
21 | #include "kai/kai_common.h" | ||
22 | |||
23 | static const size_t kai_num_bytes_sum_rhs = sizeof(float); | ||
24 | static const size_t kai_num_bytes_bias = sizeof(float); | ||
25 | static const size_t kai_nr_multiple_of = 4; | ||
26 | static const size_t kai_bl_multiple_of = 32; | ||
27 | |||
28 | 864 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
29 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
30 | 864 | return kai_roundup(k, bl) / bl; | |
31 | } | ||
32 | |||
33 | 648 | inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { | |
34 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
35 | 648 | return (bl / 2) + num_bytes_multiplier_rhs; | |
36 | } | ||
37 | |||
38 | 216 | inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks( | |
39 | size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { | ||
40 | − | KAI_ASSERT((bl % kr) == 0); | |
41 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
42 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
43 | |||
44 | 216 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
45 | 216 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); | |
46 | |||
47 | 432 | return (nr * num_bytes_per_block * num_blocks_per_row); | |
48 | 216 | } | |
49 | |||
50 | ✗ | size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(size_t nr) { | |
51 | ✗ | return nr; | |
52 | } | ||
53 | |||
54 | 216 | size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( | |
55 | size_t n_idx, // | ||
56 | size_t rhs_stride) { | ||
57 | 216 | return n_idx * rhs_stride; | |
58 | } | ||
59 | |||
60 | 432 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( | |
61 | size_t k, // | ||
62 | size_t nr, // | ||
63 | size_t kr, // | ||
64 | size_t sr, // | ||
65 | size_t bl, // | ||
66 | enum kai_datatype scale_dt) { | ||
67 | − | KAI_ASSERT((k % bl) == 0); | |
68 | − | KAI_ASSERT((bl % kr) == 0); | |
69 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
70 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
71 | − | KAI_ASSERT(scale_dt == kai_dt_bf16); | |
72 | |||
73 | 432 | KAI_UNUSED(kr); | |
74 | 432 | KAI_UNUSED(sr); | |
75 | |||
76 | 432 | const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); | |
77 | 432 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
78 | 432 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); | |
79 | |||
80 | 864 | return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); | |
81 | 432 | } | |
82 | |||
83 | 216 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( | |
84 | size_t n_idx, // | ||
85 | size_t k, // | ||
86 | size_t nr, // | ||
87 | size_t kr, // | ||
88 | size_t sr, // | ||
89 | size_t bl, // | ||
90 | enum kai_datatype scale_dt) { | ||
91 | − | KAI_ASSERT((n_idx % nr) == 0); | |
92 | − | KAI_ASSERT((k % bl) == 0); | |
93 | − | KAI_ASSERT((bl % kr) == 0); | |
94 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
95 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
96 | − | KAI_ASSERT(scale_dt == kai_dt_bf16); | |
97 | |||
98 | 432 | return (n_idx / nr) * | |
99 | 216 | kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt); | |
100 | } | ||
101 | |||
102 | 216 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( | |
103 | size_t n, // | ||
104 | size_t k, // | ||
105 | size_t nr, // | ||
106 | size_t kr, // | ||
107 | size_t sr, // | ||
108 | size_t bl, // | ||
109 | enum kai_datatype scale_dt) { | ||
110 | − | KAI_ASSERT((k % bl) == 0); | |
111 | − | KAI_ASSERT((bl % kr) == 0); | |
112 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
113 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
114 | − | KAI_ASSERT(scale_dt == kai_dt_bf16); | |
115 | |||
116 | 216 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
117 | |||
118 | 432 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(k, nr, kr, sr, bl, scale_dt); | |
119 | 216 | } | |
120 | |||
121 | 216 | void kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( | |
122 | size_t num_groups, // | ||
123 | size_t n, // | ||
124 | size_t k, // | ||
125 | size_t nr, // | ||
126 | size_t kr, // | ||
127 | size_t sr, // | ||
128 | size_t bl, // | ||
129 | const uint8_t* rhs, // | ||
130 | size_t rhs_stride, // | ||
131 | const float* bias, // | ||
132 | const void* scale, // | ||
133 | size_t scale_stride, // | ||
134 | void* rhs_packed, // | ||
135 | size_t extra_bytes, // | ||
136 | const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) { | ||
137 | − | KAI_ASSERT(num_groups == 1); | |
138 | − | KAI_ASSERT(extra_bytes == 0); | |
139 | − | KAI_ASSERT(rhs != NULL); | |
140 | − | KAI_ASSERT(scale != NULL); | |
141 | − | KAI_ASSERT(rhs_packed != NULL); | |
142 | − | KAI_ASSERT(params != NULL); | |
143 | − | KAI_ASSERT(params->rhs_zero_point == 8); | |
144 | − | KAI_ASSERT(params->lhs_zero_point == 1); | |
145 | |||
146 | − | KAI_ASSERT((k % bl) == 0); | |
147 | − | KAI_ASSERT((bl % kr) == 0); | |
148 | − | KAI_ASSERT((kr % sr) == 0); | |
149 | − | KAI_ASSERT((nr % kai_nr_multiple_of) == 0); | |
150 | − | KAI_ASSERT((bl % kai_bl_multiple_of) == 0); | |
151 | − | KAI_ASSERT(params->scale_dt == kai_dt_bf16); | |
152 | |||
153 | // Note: The input matrix (rhs) is expected with: | ||
154 | // "k" columns and "n" rows (NxK) | ||
155 | 216 | const enum kai_datatype scale_dt = params->scale_dt; | |
156 | 216 | const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); | |
157 | 432 | const size_t rhs_packed_offset_end_of_all_blocks = | |
158 | 216 | kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); | |
159 | 216 | const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
160 | 216 | const size_t num_bytes_per_block_k = bl / 2; | |
161 | 216 | const size_t dst_num_rows = kai_roundup(n, nr); | |
162 | 216 | const size_t block_length_in_bytes = kr / sr; | |
163 | − | KAI_ASSERT(block_length_in_bytes == 4); | |
164 | |||
165 | 216 | uint8_t* dst_row = (uint8_t*)rhs_packed; | |
166 | |||
167 |
2/2✓ Branch 0 taken 216 times.
✓ Branch 1 taken 1312 times.
|
1528 | for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) { |
168 | 1312 | float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); | |
169 | |||
170 | // Initialize the RHS reduction sums to zero | ||
171 | 1312 | memset(sums, 0, nr * kai_num_bytes_sum_rhs); | |
172 | |||
173 | // Iterate over the quantized blocks | ||
174 |
2/2✓ Branch 0 taken 5796 times.
✓ Branch 1 taken 1312 times.
|
7108 | for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { |
175 | // Store the scales after packing all K values in the block | ||
176 | 5796 | uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; | |
177 | 5796 | const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * num_bytes_multiplier_rhs; | |
178 | |||
179 |
2/2✓ Branch 0 taken 31080 times.
✓ Branch 1 taken 5796 times.
|
36876 | for (size_t i = 0; i < nr; ++i) { |
180 |
2/2✓ Branch 0 taken 29472 times.
✓ Branch 1 taken 1608 times.
|
31080 | const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1); |
181 | 31080 | const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride; | |
182 | 31080 | void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; | |
183 | |||
184 | 31080 | memcpy( | |
185 | dst_scales_ptr, // | ||
186 | src_scales_ptr, // | ||
187 | num_bytes_multiplier_rhs); // | ||
188 | 31080 | } | |
189 | |||
190 | 5796 | size_t k0_idx_i = dst_qblock_idx * bl; | |
191 | 5796 | const uint8x8_t top_mask = vdup_n_u8(0xF0); | |
192 | 5796 | const uint8x8_t bottom_mask = vdup_n_u8(0x0F); | |
193 | 5796 | const uint32x2_t zero_point_conversion_mask = vdup_n_u32(0x88888888); | |
194 | |||
195 |
2/2✓ Branch 0 taken 7728 times.
✓ Branch 1 taken 5796 times.
|
13524 | for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { |
196 |
2/2✓ Branch 0 taken 10360 times.
✓ Branch 1 taken 7728 times.
|
18088 | for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) { |
197 | // Clamp the indices to avoid out-of-bound reads | ||
198 |
2/2✓ Branch 0 taken 9952 times.
✓ Branch 1 taken 408 times.
|
10360 | const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1); |
199 |
2/2✓ Branch 0 taken 9952 times.
✓ Branch 1 taken 408 times.
|
10360 | const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1); |
200 |
2/2✓ Branch 0 taken 9856 times.
✓ Branch 1 taken 504 times.
|
10360 | const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1); |
201 |
2/2✓ Branch 0 taken 9536 times.
✓ Branch 1 taken 824 times.
|
10360 | const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1); |
202 | |||
203 | 10360 | const float d0 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 0]); | |
204 | 10360 | const float d1 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 1]); | |
205 | 10360 | const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]); | |
206 | 10360 | const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]); | |
207 | |||
208 | // Initialize partial sum taking new zero-point (8) into account | ||
209 | 10360 | int32_t partial_sum0 = -(32 * 8); | |
210 | 10360 | int32_t partial_sum1 = -(32 * 8); | |
211 | 10360 | int32_t partial_sum2 = -(32 * 8); | |
212 | 10360 | int32_t partial_sum3 = -(32 * 8); | |
213 | |||
214 | 10360 | const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx); | |
215 | |||
216 | 10360 | const uint8x8_t vld0_0 = vld1_u8(src_block_base + n0_idx * rhs_stride); | |
217 | 10360 | const uint8x8_t vld0_1 = vld1_u8(src_block_base + n0_idx * rhs_stride + 8); | |
218 | 10360 | const uint8x8_t vld1_0 = vld1_u8(src_block_base + n1_idx * rhs_stride); | |
219 | 10360 | const uint8x8_t vld1_1 = vld1_u8(src_block_base + n1_idx * rhs_stride + 8); | |
220 | 10360 | const uint8x8_t vld2_0 = vld1_u8(src_block_base + n2_idx * rhs_stride); | |
221 | 10360 | const uint8x8_t vld2_1 = vld1_u8(src_block_base + n2_idx * rhs_stride + 8); | |
222 | 10360 | const uint8x8_t vld3_0 = vld1_u8(src_block_base + n3_idx * rhs_stride); | |
223 | 10360 | const uint8x8_t vld3_1 = vld1_u8(src_block_base + n3_idx * rhs_stride + 8); | |
224 | |||
225 | // Reorder blocks to give correct packing | ||
226 | 10360 | const uint8x8_t vld0_0_lower = vand_u8(vld0_0, bottom_mask); | |
227 | 10360 | const uint8x8_t vld0_1_lower = vshl_n_u8(vld0_1, 4); | |
228 | 10360 | const uint8x8_t vld0_0_upper = vshr_n_u8(vld0_0, 4); | |
229 | 10360 | const uint8x8_t vld0_1_upper = vand_u8(vld0_1, top_mask); | |
230 | 20720 | const uint8x8_t vstr0_04 = | |
231 | 10360 | vorr_u8(vzip1_u8(vld0_0_lower, vld0_0_upper), vzip1_u8(vld0_1_lower, vld0_1_upper)); | |
232 | 20720 | const uint8x8_t vstr0_46 = | |
233 | 10360 | vorr_u8(vzip2_u8(vld0_0_lower, vld0_0_upper), vzip2_u8(vld0_1_lower, vld0_1_upper)); | |
234 | |||
235 | 10360 | const uint8x8_t vld1_0_lower = vand_u8(vld1_0, bottom_mask); | |
236 | 10360 | const uint8x8_t vld1_1_lower = vshl_n_u8(vld1_1, 4); | |
237 | 10360 | const uint8x8_t vld1_0_upper = vshr_n_u8(vld1_0, 4); | |
238 | 10360 | const uint8x8_t vld1_1_upper = vand_u8(vld1_1, top_mask); | |
239 | 20720 | const uint8x8_t vstr0_04_1 = | |
240 | 10360 | vorr_u8(vzip1_u8(vld1_0_lower, vld1_0_upper), vzip1_u8(vld1_1_lower, vld1_1_upper)); | |
241 | 20720 | const uint8x8_t vstr0_46_1 = | |
242 | 10360 | vorr_u8(vzip2_u8(vld1_0_lower, vld1_0_upper), vzip2_u8(vld1_1_lower, vld1_1_upper)); | |
243 | |||
244 | 10360 | const uint8x8_t vld2_0_lower = vand_u8(vld2_0, bottom_mask); | |
245 | 10360 | const uint8x8_t vld2_1_lower = vshl_n_u8(vld2_1, 4); | |
246 | 10360 | const uint8x8_t vld2_0_upper = vshr_n_u8(vld2_0, 4); | |
247 | 10360 | const uint8x8_t vld2_1_upper = vand_u8(vld2_1, top_mask); | |
248 | 20720 | const uint8x8_t vstr0_15 = | |
249 | 10360 | vorr_u8(vzip1_u8(vld2_0_lower, vld2_0_upper), vzip1_u8(vld2_1_lower, vld2_1_upper)); | |
250 | 20720 | const uint8x8_t vstr0_57 = | |
251 | 10360 | vorr_u8(vzip2_u8(vld2_0_lower, vld2_0_upper), vzip2_u8(vld2_1_lower, vld2_1_upper)); | |
252 | |||
253 | 10360 | const uint8x8_t vld3_0_lower = vand_u8(vld3_0, bottom_mask); | |
254 | 10360 | const uint8x8_t vld3_1_lower = vshl_n_u8(vld3_1, 4); | |
255 | 10360 | const uint8x8_t vld3_0_upper = vshr_n_u8(vld3_0, 4); | |
256 | 10360 | const uint8x8_t vld3_1_upper = vand_u8(vld3_1, top_mask); | |
257 | 20720 | const uint8x8_t vstr0_15_1 = | |
258 | 10360 | vorr_u8(vzip1_u8(vld3_0_lower, vld3_0_upper), vzip1_u8(vld3_1_lower, vld3_1_upper)); | |
259 | 20720 | const uint8x8_t vstr0_57_1 = | |
260 | 10360 | vorr_u8(vzip2_u8(vld3_0_lower, vld3_0_upper), vzip2_u8(vld3_1_lower, vld3_1_upper)); | |
261 | |||
262 | 20720 | const uint32x2_t vstr0_0 = | |
263 | 10360 | vzip1_u32(vreinterpret_u32_u8(vstr0_04), vreinterpret_u32_u8(vstr0_04_1)); | |
264 | 20720 | const uint32x2_t vstr0_4 = | |
265 | 10360 | vzip1_u32(vreinterpret_u32_u8(vstr0_46), vreinterpret_u32_u8(vstr0_46_1)); | |
266 | 20720 | const uint32x2_t vstr0_2 = | |
267 | 10360 | vzip2_u32(vreinterpret_u32_u8(vstr0_04), vreinterpret_u32_u8(vstr0_04_1)); | |
268 | 20720 | const uint32x2_t vstr0_6 = | |
269 | 10360 | vzip2_u32(vreinterpret_u32_u8(vstr0_46), vreinterpret_u32_u8(vstr0_46_1)); | |
270 | 20720 | const uint32x2_t vstr0_1 = | |
271 | 10360 | vzip1_u32(vreinterpret_u32_u8(vstr0_15), vreinterpret_u32_u8(vstr0_15_1)); | |
272 | 20720 | const uint32x2_t vstr0_5 = | |
273 | 10360 | vzip1_u32(vreinterpret_u32_u8(vstr0_57), vreinterpret_u32_u8(vstr0_57_1)); | |
274 | 20720 | const uint32x2_t vstr0_3 = | |
275 | 10360 | vzip2_u32(vreinterpret_u32_u8(vstr0_15), vreinterpret_u32_u8(vstr0_15_1)); | |
276 | 20720 | const uint32x2_t vstr0_7 = | |
277 | 10360 | vzip2_u32(vreinterpret_u32_u8(vstr0_57), vreinterpret_u32_u8(vstr0_57_1)); | |
278 | |||
279 | // Convert to signed int4 and store repacked values | ||
280 | 10360 | vst1_u32((uint32_t*)dst_row + 0, veor_u32(vstr0_0, zero_point_conversion_mask)); | |
281 | 10360 | vst1_u32((uint32_t*)dst_row + 2, veor_u32(vstr0_1, zero_point_conversion_mask)); | |
282 | |||
283 | 10360 | vst1_u32( | |
284 | (uint32_t*)(dst_row + nr * block_length_in_bytes) + 0, | ||
285 | veor_u32(vstr0_2, zero_point_conversion_mask)); | ||
286 | 10360 | vst1_u32( | |
287 | (uint32_t*)(dst_row + nr * block_length_in_bytes) + 2, | ||
288 | veor_u32(vstr0_3, zero_point_conversion_mask)); | ||
289 | |||
290 | 10360 | vst1_u32( | |
291 | (uint32_t*)(dst_row + (2 * nr * block_length_in_bytes)) + 0, | ||
292 | veor_u32(vstr0_4, zero_point_conversion_mask)); | ||
293 | 10360 | vst1_u32( | |
294 | (uint32_t*)(dst_row + (2 * nr * block_length_in_bytes)) + 2, | ||
295 | veor_u32(vstr0_5, zero_point_conversion_mask)); | ||
296 | |||
297 | 10360 | vst1_u32( | |
298 | (uint32_t*)(dst_row + (3 * nr * block_length_in_bytes)) + 0, | ||
299 | veor_u32(vstr0_6, zero_point_conversion_mask)); | ||
300 | 10360 | vst1_u32( | |
301 | (uint32_t*)(dst_row + (3 * nr * block_length_in_bytes)) + 2, | ||
302 | veor_u32(vstr0_7, zero_point_conversion_mask)); | ||
303 | |||
304 | // Calculate and store row sums | ||
305 | 10360 | partial_sum0 += (int32_t)vaddlvq_u16(vaddl_u8( | |
306 | 10360 | vadd_u8(vld0_0_lower, vand_u8(vld0_1, bottom_mask)), | |
307 | 10360 | vadd_u8(vld0_0_upper, vshr_n_u8(vld0_1, 4)))); | |
308 | 10360 | partial_sum1 += (int32_t)vaddlvq_u16(vaddl_u8( | |
309 | 10360 | vadd_u8(vld1_0_lower, vand_u8(vld1_1, bottom_mask)), | |
310 | 10360 | vadd_u8(vld1_0_upper, vshr_n_u8(vld1_1, 4)))); | |
311 | 10360 | partial_sum2 += (int32_t)vaddlvq_u16(vaddl_u8( | |
312 | 10360 | vadd_u8(vld2_0_lower, vand_u8(vld2_1, bottom_mask)), | |
313 | 10360 | vadd_u8(vld2_0_upper, vshr_n_u8(vld2_1, 4)))); | |
314 | 10360 | partial_sum3 += (int32_t)vaddlvq_u16(vaddl_u8( | |
315 | 10360 | vadd_u8(vld3_0_lower, vand_u8(vld3_1, bottom_mask)), | |
316 | 10360 | vadd_u8(vld3_0_upper, vshr_n_u8(vld3_1, 4)))); | |
317 | |||
318 | // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
319 | 10360 | sums[nr_idx + 0] += (float)partial_sum0 * d0; | |
320 | 10360 | sums[nr_idx + 1] += (float)partial_sum1 * d1; | |
321 | 10360 | sums[nr_idx + 2] += (float)partial_sum2 * d2; | |
322 | 10360 | sums[nr_idx + 3] += (float)partial_sum3 * d3; | |
323 | // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) | ||
324 | |||
325 | 10360 | dst_row += (4 * block_length_in_bytes); | |
326 | 10360 | } | |
327 | // Skip to end of qblock | ||
328 | 7728 | dst_row += 3 * nr * block_length_in_bytes; | |
329 | 7728 | } | |
330 | |||
331 | // Move the pointer after scales | ||
332 | 5796 | dst_row += num_bytes_multiplier_rhs * nr; | |
333 | 5796 | } | |
334 | |||
335 | // Move the pointer after the row sum | ||
336 | 1312 | dst_row += kai_num_bytes_sum_rhs * nr; | |
337 | |||
338 | // Set the bias | ||
339 |
1/2✓ Branch 0 taken 1312 times.
✗ Branch 1 not taken.
|
1312 | if (bias == NULL) { |
340 | ✗ | memset(dst_row, 0, nr * kai_num_bytes_bias); | |
341 | ✗ | } else { | |
342 |
2/2✓ Branch 0 taken 7072 times.
✓ Branch 1 taken 1312 times.
|
8384 | for (size_t i = 0; i < nr; ++i) { |
343 | // Clamp the row index to avoid out-of-bound reads | ||
344 |
2/2✓ Branch 0 taken 6400 times.
✓ Branch 1 taken 672 times.
|
7072 | const size_t src_row_idx = KAI_MIN(dst_row_idx + i, n - 1); |
345 | 7072 | ((float*)dst_row)[i] = bias[src_row_idx]; | |
346 | 7072 | } | |
347 | } | ||
348 | |||
349 | // Move the pointer after the row sum | ||
350 | 1312 | dst_row += kai_num_bytes_bias * nr; | |
351 | 1312 | } | |
352 | 216 | } | |
353 | #endif // Architectural features check. | ||
354 |