Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" | ||
7 | |||
8 | #include <arm_neon.h> | ||
9 | #include <stddef.h> | ||
10 | #include <stdint.h> | ||
11 | |||
12 | #include "kai/kai_common.h" | ||
13 | |||
14 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
15 | |||
16 | 115 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
17 | 115 | return bl * sizeof(int8_t) + kai_num_bytes_multiplier; | |
18 | } | ||
19 | |||
20 | 115 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
21 | − | KAI_ASSERT((k % bl) == 0); | |
22 | 115 | return k / bl; | |
23 | } | ||
24 | |||
25 | 92 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { | |
26 | 92 | KAI_UNUSED(kr); | |
27 | 92 | return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); | |
28 | } | ||
29 | |||
30 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t mr) { | |
31 | ✗ | KAI_UNUSED(mr); | |
32 | ✗ | return 1; | |
33 | } | ||
34 | |||
35 | 23 | size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t m_idx, size_t lhs_stride) { | |
36 | 23 | return m_idx * lhs_stride; | |
37 | } | ||
38 | |||
39 | 46 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( | |
40 | size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
41 | − | KAI_ASSUME((k % 2) == 0); | |
42 | − | KAI_ASSUME((k % kr) == 0); | |
43 | − | KAI_ASSUME((k % bl) == 0); | |
44 | |||
45 | 46 | KAI_UNUSED(sr); | |
46 | 46 | KAI_UNUSED(kr); | |
47 | |||
48 | 46 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); | |
49 | } | ||
50 | |||
51 | 23 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( | |
52 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
53 | − | KAI_ASSUME((k % 2) == 0); | |
54 | − | KAI_ASSUME((k % kr) == 0); | |
55 | − | KAI_ASSUME((k % bl) == 0); | |
56 | |||
57 | 23 | KAI_UNUSED(sr); | |
58 | 23 | KAI_UNUSED(kr); | |
59 | |||
60 | 23 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
61 | |||
62 | 46 | return num_rows * kai_lhs_packed_stride(k, mr, kr, bl); | |
63 | 23 | } | |
64 | |||
65 | 23 | void kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( | |
66 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, | ||
67 | size_t lhs_stride, void* lhs_packed) { | ||
68 |
1/2✓ Branch 0 taken 23 times.
✗ Branch 1 not taken.
|
23 | if (m == 0) { |
69 | ✗ | return; | |
70 | } | ||
71 | |||
72 | − | KAI_ASSUME(bl == 32); | |
73 | − | KAI_ASSUME(mr == 4); | |
74 | − | KAI_ASSUME(kr == 16); | |
75 | − | KAI_ASSUME(sr == 2); | |
76 | |||
77 | 23 | const size_t local_bl = 32; | |
78 | 23 | const size_t local_mr = 4; | |
79 | 23 | const size_t local_kr = 16; | |
80 | 23 | const size_t local_sr = 2; | |
81 | 23 | const size_t num_rows = m; | |
82 | 23 | const size_t k_block_len = local_kr / local_sr; | |
83 | 23 | const size_t lhs_packed_stride = kai_lhs_packed_stride(k, local_mr, local_kr, local_bl); | |
84 | 23 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, local_bl); | |
85 | 23 | const size_t num_bytes_per_block = kai_num_bytes_per_block(local_bl); | |
86 | |||
87 | 23 | size_t row_idx = 0; | |
88 | |||
89 | 23 | const size_t write_mem_increment = 2 * k_block_len * sizeof(int8_t); | |
90 | 23 | const size_t read_mem_increment = num_blocks_per_row * local_bl * sizeof(int8_t); | |
91 | |||
92 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 20 times.
|
23 | if (num_rows >= 4) { |
93 |
2/2✓ Branch 0 taken 144 times.
✓ Branch 1 taken 20 times.
|
164 | for (; row_idx + 4 <= num_rows; row_idx += 4) { |
94 | 144 | const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); | |
95 | |||
96 |
2/2✓ Branch 0 taken 276 times.
✓ Branch 1 taken 144 times.
|
420 | for (size_t b = 0; b < num_blocks_per_row; ++b) { |
97 | 276 | const size_t dst_x = ((row_idx + m_idx_start) % local_mr); | |
98 | 276 | int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block; | |
99 | |||
100 | 276 | float abs_max_0 = 0.0F; | |
101 | 276 | float abs_max_1 = 0.0F; | |
102 | 276 | float abs_max_2 = 0.0F; | |
103 | 276 | float abs_max_3 = 0.0F; | |
104 | |||
105 | 276 | float32x4_t v_currentmax_0 = vdupq_n_f32(0); | |
106 | 276 | float32x4_t v_currentmax_1 = vdupq_n_f32(0); | |
107 | 276 | float32x4_t v_currentmax_2 = vdupq_n_f32(0); | |
108 | 276 | float32x4_t v_currentmax_3 = vdupq_n_f32(0); | |
109 | |||
110 |
2/2✓ Branch 0 taken 2208 times.
✓ Branch 1 taken 276 times.
|
2484 | for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { |
111 | 2208 | const float32x4_t v_f32_maxvals_0 = vld1q_f32(src_ptr + idx_v); | |
112 | 2208 | const float32x4_t v_f32_abs_values_0 = vabsq_f32(v_f32_maxvals_0); | |
113 | 2208 | v_currentmax_0 = vmaxq_f32(v_f32_abs_values_0, v_currentmax_0); | |
114 | 2208 | const float32x4_t v_f32_maxvals_1 = vld1q_f32(src_ptr + idx_v + read_mem_increment); | |
115 | 2208 | const float32x4_t v_f32_abs_values_1 = vabsq_f32(v_f32_maxvals_1); | |
116 | 2208 | v_currentmax_1 = vmaxq_f32(v_f32_abs_values_1, v_currentmax_1); | |
117 | 2208 | const float32x4_t v_f32_maxvals_2 = vld1q_f32(src_ptr + idx_v + 2 * read_mem_increment); | |
118 | 2208 | const float32x4_t v_f32_abs_values_2 = vabsq_f32(v_f32_maxvals_2); | |
119 | 2208 | v_currentmax_2 = vmaxq_f32(v_f32_abs_values_2, v_currentmax_2); | |
120 | 2208 | const float32x4_t v_f32_maxvals_3 = vld1q_f32(src_ptr + idx_v + 3 * read_mem_increment); | |
121 | 2208 | const float32x4_t v_f32_abs_values_3 = vabsq_f32(v_f32_maxvals_3); | |
122 | 2208 | v_currentmax_3 = vmaxq_f32(v_f32_abs_values_3, v_currentmax_3); | |
123 | 2208 | } | |
124 | |||
125 | 276 | abs_max_0 = vmaxvq_f32(v_currentmax_0); | |
126 | 276 | abs_max_1 = vmaxvq_f32(v_currentmax_1); | |
127 | 276 | abs_max_2 = vmaxvq_f32(v_currentmax_2); | |
128 | 276 | abs_max_3 = vmaxvq_f32(v_currentmax_3); | |
129 | |||
130 | 276 | float32x4_t abs_max_vec = vdupq_n_f32(abs_max_0); | |
131 | 276 | abs_max_vec = vsetq_lane_f32(abs_max_1, abs_max_vec, 1); | |
132 | 276 | abs_max_vec = vsetq_lane_f32(abs_max_2, abs_max_vec, 2); | |
133 | 276 | abs_max_vec = vsetq_lane_f32(abs_max_3, abs_max_vec, 3); | |
134 | |||
135 | // Calculate scale and reciprocals | ||
136 | 276 | const float32x4_t scales = vdivq_f32(abs_max_vec, vdupq_n_f32((1 << 7) - 1)); | |
137 | 276 | const uint32x4_t valid_scales = vmvnq_u32(vceqq_f32(scales, vdupq_n_f32(0.0F))); | |
138 | 276 | const float32x4_t reciprocals = vdivq_f32(vdupq_n_f32(1.0F), scales); | |
139 | 276 | const float32x4_t rep_scales = vbslq_f32(valid_scales, reciprocals, vdupq_n_f32(0.0F)); | |
140 | 276 | const float16x4_t f16_scales = vcvt_f16_f32(scales); | |
141 | |||
142 | 276 | vst1_u16((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier), vreinterpret_u16_f16(f16_scales)); | |
143 | |||
144 | 276 | dst_ptr += local_mr * kai_num_bytes_multiplier; | |
145 | |||
146 | 276 | dst_ptr += dst_x * k_block_len * sizeof(int8_t); | |
147 | |||
148 | // Quantize and pack the blocks | ||
149 |
2/2✓ Branch 0 taken 552 times.
✓ Branch 1 taken 276 times.
|
828 | for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) { |
150 | // Row 1 blocks | ||
151 | 552 | const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); | |
152 | 552 | const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, vgetq_lane_f32(rep_scales, 0)); | |
153 | 552 | const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); | |
154 | |||
155 | 552 | const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); | |
156 | 552 | const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, vgetq_lane_f32(rep_scales, 0)); | |
157 | 552 | const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); | |
158 | |||
159 | 1104 | const int16x8_t v_full_i16_block1 = | |
160 | 552 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); | |
161 | |||
162 | 552 | const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); | |
163 | 552 | const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, vgetq_lane_f32(rep_scales, 0)); | |
164 | 552 | const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); | |
165 | |||
166 | 552 | const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); | |
167 | 552 | const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, vgetq_lane_f32(rep_scales, 0)); | |
168 | 552 | const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); | |
169 | |||
170 | 1104 | const int16x8_t v_full_i16_block2 = | |
171 | 552 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); | |
172 | |||
173 | // Row 2 blocks | ||
174 | 552 | const float32x4_t v_f32_block5 = vld1q_f32(src_ptr + k_idx + read_mem_increment); | |
175 | 552 | const float32x4_t v_f32_sblock5 = vmulq_n_f32(v_f32_block5, vgetq_lane_f32(rep_scales, 1)); | |
176 | 552 | const int32x4_t v_i32_block5 = vcvtnq_s32_f32(v_f32_sblock5); | |
177 | |||
178 | 552 | const float32x4_t v_f32_block6 = vld1q_f32(src_ptr + k_idx + 4 + read_mem_increment); | |
179 | 552 | const float32x4_t v_f32_sblock6 = vmulq_n_f32(v_f32_block6, vgetq_lane_f32(rep_scales, 1)); | |
180 | 552 | const int32x4_t v_i32_block6 = vcvtnq_s32_f32(v_f32_sblock6); | |
181 | |||
182 | 1104 | const int16x8_t v_full_i16_block3 = | |
183 | 552 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block5), vreinterpretq_s16_s32(v_i32_block6)); | |
184 | |||
185 | 552 | const float32x4_t v_f32_block7 = vld1q_f32(src_ptr + k_idx + 8 + read_mem_increment); | |
186 | 552 | const float32x4_t v_f32_sblock7 = vmulq_n_f32(v_f32_block7, vgetq_lane_f32(rep_scales, 1)); | |
187 | 552 | const int32x4_t v_i32_block7 = vcvtnq_s32_f32(v_f32_sblock7); | |
188 | |||
189 | 552 | const float32x4_t v_f32_block8 = vld1q_f32(src_ptr + k_idx + 12 + read_mem_increment); | |
190 | 552 | const float32x4_t v_f32_sblock8 = vmulq_n_f32(v_f32_block8, vgetq_lane_f32(rep_scales, 1)); | |
191 | 552 | const int32x4_t v_i32_block8 = vcvtnq_s32_f32(v_f32_sblock8); | |
192 | |||
193 | 1104 | const int16x8_t v_full_i16_block4 = | |
194 | 552 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block7), vreinterpretq_s16_s32(v_i32_block8)); | |
195 | |||
196 | // Row 3 blocks | ||
197 | 552 | const float32x4_t v_f32_block9 = vld1q_f32(src_ptr + k_idx + 2 * read_mem_increment); | |
198 | 552 | const float32x4_t v_f32_sblock9 = vmulq_n_f32(v_f32_block9, vgetq_lane_f32(rep_scales, 2)); | |
199 | 552 | const int32x4_t v_i32_block9 = vcvtnq_s32_f32(v_f32_sblock9); | |
200 | |||
201 | 552 | const float32x4_t v_f32_blockA = vld1q_f32(src_ptr + k_idx + 4 + 2 * read_mem_increment); | |
202 | 552 | const float32x4_t v_f32_sblockA = vmulq_n_f32(v_f32_blockA, vgetq_lane_f32(rep_scales, 2)); | |
203 | 552 | const int32x4_t v_i32_blockA = vcvtnq_s32_f32(v_f32_sblockA); | |
204 | |||
205 | 1104 | const int16x8_t v_full_i16_block5 = | |
206 | 552 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block9), vreinterpretq_s16_s32(v_i32_blockA)); | |
207 | |||
208 | 552 | const float32x4_t v_f32_blockB = vld1q_f32(src_ptr + k_idx + 8 + 2 * read_mem_increment); | |
209 | 552 | const float32x4_t v_f32_sblockB = vmulq_n_f32(v_f32_blockB, vgetq_lane_f32(rep_scales, 2)); | |
210 | 552 | const int32x4_t v_i32_blockB = vcvtnq_s32_f32(v_f32_sblockB); | |
211 | |||
212 | 552 | const float32x4_t v_f32_blockC = vld1q_f32(src_ptr + k_idx + 12 + 2 * read_mem_increment); | |
213 | 552 | const float32x4_t v_f32_sblockC = vmulq_n_f32(v_f32_blockC, vgetq_lane_f32(rep_scales, 2)); | |
214 | 552 | const int32x4_t v_i32_blockC = vcvtnq_s32_f32(v_f32_sblockC); | |
215 | |||
216 | 1104 | const int16x8_t v_full_i16_block6 = | |
217 | 552 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockB), vreinterpretq_s16_s32(v_i32_blockC)); | |
218 | |||
219 | // Row 4 blocks | ||
220 | 552 | const float32x4_t v_f32_blockD = vld1q_f32(src_ptr + k_idx + 3 * read_mem_increment); | |
221 | 552 | const float32x4_t v_f32_sblockD = vmulq_n_f32(v_f32_blockD, vgetq_lane_f32(rep_scales, 3)); | |
222 | 552 | const int32x4_t v_i32_blockD = vcvtnq_s32_f32(v_f32_sblockD); | |
223 | |||
224 | 552 | const float32x4_t v_f32_blockE = vld1q_f32(src_ptr + k_idx + 4 + 3 * read_mem_increment); | |
225 | 552 | const float32x4_t v_f32_sblockE = vmulq_n_f32(v_f32_blockE, vgetq_lane_f32(rep_scales, 3)); | |
226 | 552 | const int32x4_t v_i32_blockE = vcvtnq_s32_f32(v_f32_sblockE); | |
227 | |||
228 | 1104 | const int16x8_t v_full_i16_block7 = | |
229 | 552 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockD), vreinterpretq_s16_s32(v_i32_blockE)); | |
230 | |||
231 | 552 | const float32x4_t v_f32_blockF = vld1q_f32(src_ptr + k_idx + 8 + 3 * read_mem_increment); | |
232 | 552 | const float32x4_t v_f32_sblockF = vmulq_n_f32(v_f32_blockF, vgetq_lane_f32(rep_scales, 3)); | |
233 | 552 | const int32x4_t v_i32_blockF = vcvtnq_s32_f32(v_f32_sblockF); | |
234 | |||
235 | 552 | const float32x4_t v_f32_block0 = vld1q_f32(src_ptr + k_idx + 12 + 3 * read_mem_increment); | |
236 | 552 | const float32x4_t v_f32_sblock0 = vmulq_n_f32(v_f32_block0, vgetq_lane_f32(rep_scales, 3)); | |
237 | 552 | const int32x4_t v_i32_block0 = vcvtnq_s32_f32(v_f32_sblock0); | |
238 | |||
239 | 1104 | const int16x8_t v_full_i16_block8 = | |
240 | 552 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockF), vreinterpretq_s16_s32(v_i32_block0)); | |
241 | |||
242 | 1104 | const int8x16_t v_i8_block1_3 = | |
243 | 552 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block3)); | |
244 | 552 | vst1q_s8(dst_ptr, v_i8_block1_3); | |
245 | 552 | dst_ptr += write_mem_increment; | |
246 | |||
247 | 1104 | const int8x16_t v_i8_block5_7 = | |
248 | 552 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block5), vreinterpretq_s8_s16(v_full_i16_block7)); | |
249 | 552 | vst1q_s8(dst_ptr, v_i8_block5_7); | |
250 | 552 | dst_ptr += write_mem_increment; | |
251 | |||
252 | 1104 | const int8x16_t v_i8_block2_4 = | |
253 | 552 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block2), vreinterpretq_s8_s16(v_full_i16_block4)); | |
254 | 552 | vst1q_s8(dst_ptr, v_i8_block2_4); | |
255 | 552 | dst_ptr += write_mem_increment; | |
256 | |||
257 | 1104 | const int8x16_t v_i8_block6_8 = | |
258 | 552 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block6), vreinterpretq_s8_s16(v_full_i16_block8)); | |
259 | 552 | vst1q_s8(dst_ptr, v_i8_block6_8); | |
260 | 552 | dst_ptr += write_mem_increment; | |
261 | 552 | } | |
262 | 276 | src_ptr += local_bl; | |
263 | 276 | } | |
264 | 144 | lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); | |
265 | 144 | } | |
266 | 20 | } | |
267 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 11 times.
|
23 | if (num_rows % 4 != 0) { |
268 |
2/2✓ Branch 0 taken 19 times.
✓ Branch 1 taken 11 times.
|
30 | for (; row_idx < num_rows; ++row_idx) { |
269 | 19 | const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); | |
270 | |||
271 |
2/2✓ Branch 0 taken 23 times.
✓ Branch 1 taken 19 times.
|
42 | for (size_t b = 0; b < num_blocks_per_row; ++b) { |
272 | 23 | float abs_max = 0.0F; | |
273 | |||
274 | 23 | const size_t dst_x = ((row_idx + m_idx_start) % local_mr); | |
275 | 23 | int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block; | |
276 | |||
277 | 23 | float32x4_t v_f32_abs_values; | |
278 | 23 | float32x4_t v_f32_maxvals; | |
279 | 23 | float32x4_t v_currentmax = vdupq_n_f32(0); | |
280 | |||
281 |
2/2✓ Branch 0 taken 184 times.
✓ Branch 1 taken 23 times.
|
207 | for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { |
282 | 184 | v_f32_maxvals = vld1q_f32(src_ptr + idx_v); | |
283 | 184 | v_f32_abs_values = vabsq_f32(v_f32_maxvals); | |
284 | 184 | v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); | |
285 | 184 | } | |
286 | 23 | abs_max = vmaxvq_f32(v_currentmax); | |
287 | |||
288 | // Calculate scale and reciprocal | ||
289 | 23 | const float scale = abs_max / ((1 << 7) - 1); | |
290 |
1/2✓ Branch 0 taken 23 times.
✗ Branch 1 not taken.
|
23 | const float rep_scale = scale ? 1.0F / scale : 0.0F; |
291 | |||
292 | 23 | *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale); | |
293 | 23 | dst_ptr += local_mr * kai_num_bytes_multiplier; | |
294 | |||
295 | 23 | dst_ptr += dst_x * k_block_len * sizeof(int8_t); | |
296 | |||
297 | // Quantize and pack the block | ||
298 |
2/2✓ Branch 0 taken 46 times.
✓ Branch 1 taken 23 times.
|
69 | for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) { |
299 | 46 | const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); | |
300 | 46 | const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale); | |
301 | 46 | const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); | |
302 | |||
303 | 46 | const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); | |
304 | 46 | const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale); | |
305 | 46 | const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); | |
306 | |||
307 | 92 | const int16x8_t v_full_i16_block1 = | |
308 | 46 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); | |
309 | |||
310 | 46 | const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); | |
311 | 46 | const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale); | |
312 | 46 | const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); | |
313 | |||
314 | 46 | const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); | |
315 | 46 | const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale); | |
316 | 46 | const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); | |
317 | |||
318 | 92 | const int16x8_t v_full_i16_block2 = | |
319 | 46 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); | |
320 | |||
321 | 92 | const int8x16_t v_full_i8_block = | |
322 | 46 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block2)); | |
323 | |||
324 | 46 | vst1_s8(dst_ptr, vget_low_s8(v_full_i8_block)); | |
325 | 46 | dst_ptr += 8 * sizeof(int8_t); | |
326 | 46 | dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t); | |
327 | |||
328 | 46 | vst1_s8(dst_ptr, vget_high_s8(v_full_i8_block)); | |
329 | 46 | dst_ptr += 8 * sizeof(int8_t); | |
330 | 46 | dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t); | |
331 | 46 | } | |
332 | 23 | src_ptr += local_bl; | |
333 | 23 | } | |
334 | // Move to the next row if we have interleaved all Mr rows | ||
335 |
1/2✓ Branch 0 taken 19 times.
✗ Branch 1 not taken.
|
19 | if ((((row_idx + 1) + m_idx_start) % local_mr) == 0) { |
336 | ✗ | lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); | |
337 | ✗ | } | |
338 | 19 | } | |
339 | 11 | } | |
340 | 23 | } | |
341 |