kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" | ||
| 7 | |||
| 8 | #include <arm_neon.h> | ||
| 9 | #include <stddef.h> | ||
| 10 | #include <stdint.h> | ||
| 11 | |||
| 12 | #include "kai/kai_common.h" | ||
| 13 | |||
| 14 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
| 15 | |||
| 16 | 920 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
| 17 | 920 | return bl * sizeof(int8_t) + kai_num_bytes_multiplier; | |
| 18 | } | ||
| 19 | |||
| 20 | 920 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
| 21 | − | KAI_ASSERT((k % bl) == 0); | |
| 22 | 920 | return k / bl; | |
| 23 | } | ||
| 24 | |||
| 25 | 736 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { | |
| 26 | 736 | KAI_UNUSED(kr); | |
| 27 | 736 | return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); | |
| 28 | } | ||
| 29 | |||
| 30 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t mr) { | |
| 31 | ✗ | return mr; | |
| 32 | } | ||
| 33 | |||
| 34 | 184 | size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon(size_t m_idx, size_t lhs_stride) { | |
| 35 | 184 | return m_idx * lhs_stride; | |
| 36 | } | ||
| 37 | |||
| 38 | 368 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( | |
| 39 | size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
| 40 | − | KAI_ASSUME((k % 2) == 0); | |
| 41 | − | KAI_ASSUME((k % kr) == 0); | |
| 42 | − | KAI_ASSUME((k % bl) == 0); | |
| 43 | |||
| 44 | 368 | KAI_UNUSED(sr); | |
| 45 | 368 | KAI_UNUSED(kr); | |
| 46 | |||
| 47 | 368 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); | |
| 48 | } | ||
| 49 | |||
| 50 | 184 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( | |
| 51 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
| 52 | − | KAI_ASSUME((k % 2) == 0); | |
| 53 | − | KAI_ASSUME((k % kr) == 0); | |
| 54 | − | KAI_ASSUME((k % bl) == 0); | |
| 55 | |||
| 56 | 184 | KAI_UNUSED(sr); | |
| 57 | 184 | KAI_UNUSED(kr); | |
| 58 | |||
| 59 | 184 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
| 60 | |||
| 61 | 368 | return num_rows * kai_lhs_packed_stride(k, mr, kr, bl); | |
| 62 | 184 | } | |
| 63 | |||
| 64 | 184 | void kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon( | |
| 65 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, | ||
| 66 | size_t lhs_stride, void* lhs_packed) { | ||
| 67 |
1/2✓ Branch 0 taken 184 times.
✗ Branch 1 not taken.
|
184 | if (m == 0) { |
| 68 | ✗ | return; | |
| 69 | } | ||
| 70 | |||
| 71 | − | KAI_ASSUME(bl == 32); | |
| 72 | − | KAI_ASSUME(mr == 4); | |
| 73 | − | KAI_ASSUME(kr == 16); | |
| 74 | − | KAI_ASSUME(sr == 2); | |
| 75 | |||
| 76 | 184 | const size_t local_bl = 32; | |
| 77 | 184 | const size_t local_mr = 4; | |
| 78 | 184 | const size_t local_kr = 16; | |
| 79 | 184 | const size_t local_sr = 2; | |
| 80 | 184 | const size_t num_rows = m; | |
| 81 | 184 | const size_t k_block_len = local_kr / local_sr; | |
| 82 | 184 | const size_t lhs_packed_stride = kai_lhs_packed_stride(k, local_mr, local_kr, local_bl); | |
| 83 | 184 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, local_bl); | |
| 84 | 184 | const size_t num_bytes_per_block = kai_num_bytes_per_block(local_bl); | |
| 85 | |||
| 86 | 184 | size_t row_idx = 0; | |
| 87 | |||
| 88 | 184 | const size_t write_mem_increment = 2 * k_block_len * sizeof(int8_t); | |
| 89 | 184 | const size_t read_mem_increment = num_blocks_per_row * local_bl * sizeof(int8_t); | |
| 90 | |||
| 91 |
2/2✓ Branch 0 taken 72 times.
✓ Branch 1 taken 112 times.
|
184 | if (num_rows >= 4) { |
| 92 |
2/2✓ Branch 0 taken 1032 times.
✓ Branch 1 taken 112 times.
|
1144 | for (; row_idx + 4 <= num_rows; row_idx += 4) { |
| 93 | 1032 | const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); | |
| 94 | |||
| 95 |
2/2✓ Branch 0 taken 2040 times.
✓ Branch 1 taken 1032 times.
|
3072 | for (size_t b = 0; b < num_blocks_per_row; ++b) { |
| 96 | 2040 | const size_t dst_x = ((row_idx + m_idx_start) % local_mr); | |
| 97 | 2040 | int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block; | |
| 98 | |||
| 99 | 2040 | float abs_max_0 = 0.0F; | |
| 100 | 2040 | float abs_max_1 = 0.0F; | |
| 101 | 2040 | float abs_max_2 = 0.0F; | |
| 102 | 2040 | float abs_max_3 = 0.0F; | |
| 103 | |||
| 104 | 2040 | float32x4_t v_currentmax_0 = vdupq_n_f32(0); | |
| 105 | 2040 | float32x4_t v_currentmax_1 = vdupq_n_f32(0); | |
| 106 | 2040 | float32x4_t v_currentmax_2 = vdupq_n_f32(0); | |
| 107 | 2040 | float32x4_t v_currentmax_3 = vdupq_n_f32(0); | |
| 108 | |||
| 109 |
2/2✓ Branch 0 taken 16320 times.
✓ Branch 1 taken 2040 times.
|
18360 | for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { |
| 110 | 16320 | const float32x4_t v_f32_maxvals_0 = vld1q_f32(src_ptr + idx_v); | |
| 111 | 16320 | const float32x4_t v_f32_abs_values_0 = vabsq_f32(v_f32_maxvals_0); | |
| 112 | 16320 | v_currentmax_0 = vmaxq_f32(v_f32_abs_values_0, v_currentmax_0); | |
| 113 | 16320 | const float32x4_t v_f32_maxvals_1 = vld1q_f32(src_ptr + idx_v + read_mem_increment); | |
| 114 | 16320 | const float32x4_t v_f32_abs_values_1 = vabsq_f32(v_f32_maxvals_1); | |
| 115 | 16320 | v_currentmax_1 = vmaxq_f32(v_f32_abs_values_1, v_currentmax_1); | |
| 116 | 16320 | const float32x4_t v_f32_maxvals_2 = vld1q_f32(src_ptr + idx_v + 2 * read_mem_increment); | |
| 117 | 16320 | const float32x4_t v_f32_abs_values_2 = vabsq_f32(v_f32_maxvals_2); | |
| 118 | 16320 | v_currentmax_2 = vmaxq_f32(v_f32_abs_values_2, v_currentmax_2); | |
| 119 | 16320 | const float32x4_t v_f32_maxvals_3 = vld1q_f32(src_ptr + idx_v + 3 * read_mem_increment); | |
| 120 | 16320 | const float32x4_t v_f32_abs_values_3 = vabsq_f32(v_f32_maxvals_3); | |
| 121 | 16320 | v_currentmax_3 = vmaxq_f32(v_f32_abs_values_3, v_currentmax_3); | |
| 122 | 16320 | } | |
| 123 | |||
| 124 | 2040 | abs_max_0 = vmaxvq_f32(v_currentmax_0); | |
| 125 | 2040 | abs_max_1 = vmaxvq_f32(v_currentmax_1); | |
| 126 | 2040 | abs_max_2 = vmaxvq_f32(v_currentmax_2); | |
| 127 | 2040 | abs_max_3 = vmaxvq_f32(v_currentmax_3); | |
| 128 | |||
| 129 | 2040 | float32x4_t abs_max_vec = vdupq_n_f32(abs_max_0); | |
| 130 | 2040 | abs_max_vec = vsetq_lane_f32(abs_max_1, abs_max_vec, 1); | |
| 131 | 2040 | abs_max_vec = vsetq_lane_f32(abs_max_2, abs_max_vec, 2); | |
| 132 | 2040 | abs_max_vec = vsetq_lane_f32(abs_max_3, abs_max_vec, 3); | |
| 133 | |||
| 134 | // Calculate scale and reciprocals | ||
| 135 | 2040 | const float32x4_t scales = vdivq_f32(abs_max_vec, vdupq_n_f32((1 << 7) - 1)); | |
| 136 | 2040 | const uint32x4_t valid_scales = vmvnq_u32(vceqq_f32(scales, vdupq_n_f32(0.0F))); | |
| 137 | 2040 | const float32x4_t reciprocals = vdivq_f32(vdupq_n_f32(1.0F), scales); | |
| 138 | 2040 | const float32x4_t rep_scales = vbslq_f32(valid_scales, reciprocals, vdupq_n_f32(0.0F)); | |
| 139 | 2040 | const float16x4_t f16_scales = vcvt_f16_f32(scales); | |
| 140 | |||
| 141 | 2040 | vst1_u16((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier), vreinterpret_u16_f16(f16_scales)); | |
| 142 | |||
| 143 | 2040 | dst_ptr += local_mr * kai_num_bytes_multiplier; | |
| 144 | |||
| 145 | 2040 | dst_ptr += dst_x * k_block_len * sizeof(int8_t); | |
| 146 | |||
| 147 | // Quantize and pack the blocks | ||
| 148 |
2/2✓ Branch 0 taken 4080 times.
✓ Branch 1 taken 2040 times.
|
6120 | for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) { |
| 149 | // Row 1 blocks | ||
| 150 | 4080 | const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); | |
| 151 | 4080 | const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, vgetq_lane_f32(rep_scales, 0)); | |
| 152 | 4080 | const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); | |
| 153 | |||
| 154 | 4080 | const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); | |
| 155 | 4080 | const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, vgetq_lane_f32(rep_scales, 0)); | |
| 156 | 4080 | const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); | |
| 157 | |||
| 158 | 8160 | const int16x8_t v_full_i16_block1 = | |
| 159 | 4080 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); | |
| 160 | |||
| 161 | 4080 | const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); | |
| 162 | 4080 | const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, vgetq_lane_f32(rep_scales, 0)); | |
| 163 | 4080 | const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); | |
| 164 | |||
| 165 | 4080 | const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); | |
| 166 | 4080 | const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, vgetq_lane_f32(rep_scales, 0)); | |
| 167 | 4080 | const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); | |
| 168 | |||
| 169 | 8160 | const int16x8_t v_full_i16_block2 = | |
| 170 | 4080 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); | |
| 171 | |||
| 172 | // Row 2 blocks | ||
| 173 | 4080 | const float32x4_t v_f32_block5 = vld1q_f32(src_ptr + k_idx + read_mem_increment); | |
| 174 | 4080 | const float32x4_t v_f32_sblock5 = vmulq_n_f32(v_f32_block5, vgetq_lane_f32(rep_scales, 1)); | |
| 175 | 4080 | const int32x4_t v_i32_block5 = vcvtnq_s32_f32(v_f32_sblock5); | |
| 176 | |||
| 177 | 4080 | const float32x4_t v_f32_block6 = vld1q_f32(src_ptr + k_idx + 4 + read_mem_increment); | |
| 178 | 4080 | const float32x4_t v_f32_sblock6 = vmulq_n_f32(v_f32_block6, vgetq_lane_f32(rep_scales, 1)); | |
| 179 | 4080 | const int32x4_t v_i32_block6 = vcvtnq_s32_f32(v_f32_sblock6); | |
| 180 | |||
| 181 | 8160 | const int16x8_t v_full_i16_block3 = | |
| 182 | 4080 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block5), vreinterpretq_s16_s32(v_i32_block6)); | |
| 183 | |||
| 184 | 4080 | const float32x4_t v_f32_block7 = vld1q_f32(src_ptr + k_idx + 8 + read_mem_increment); | |
| 185 | 4080 | const float32x4_t v_f32_sblock7 = vmulq_n_f32(v_f32_block7, vgetq_lane_f32(rep_scales, 1)); | |
| 186 | 4080 | const int32x4_t v_i32_block7 = vcvtnq_s32_f32(v_f32_sblock7); | |
| 187 | |||
| 188 | 4080 | const float32x4_t v_f32_block8 = vld1q_f32(src_ptr + k_idx + 12 + read_mem_increment); | |
| 189 | 4080 | const float32x4_t v_f32_sblock8 = vmulq_n_f32(v_f32_block8, vgetq_lane_f32(rep_scales, 1)); | |
| 190 | 4080 | const int32x4_t v_i32_block8 = vcvtnq_s32_f32(v_f32_sblock8); | |
| 191 | |||
| 192 | 8160 | const int16x8_t v_full_i16_block4 = | |
| 193 | 4080 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block7), vreinterpretq_s16_s32(v_i32_block8)); | |
| 194 | |||
| 195 | // Row 3 blocks | ||
| 196 | 4080 | const float32x4_t v_f32_block9 = vld1q_f32(src_ptr + k_idx + 2 * read_mem_increment); | |
| 197 | 4080 | const float32x4_t v_f32_sblock9 = vmulq_n_f32(v_f32_block9, vgetq_lane_f32(rep_scales, 2)); | |
| 198 | 4080 | const int32x4_t v_i32_block9 = vcvtnq_s32_f32(v_f32_sblock9); | |
| 199 | |||
| 200 | 4080 | const float32x4_t v_f32_blockA = vld1q_f32(src_ptr + k_idx + 4 + 2 * read_mem_increment); | |
| 201 | 4080 | const float32x4_t v_f32_sblockA = vmulq_n_f32(v_f32_blockA, vgetq_lane_f32(rep_scales, 2)); | |
| 202 | 4080 | const int32x4_t v_i32_blockA = vcvtnq_s32_f32(v_f32_sblockA); | |
| 203 | |||
| 204 | 8160 | const int16x8_t v_full_i16_block5 = | |
| 205 | 4080 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block9), vreinterpretq_s16_s32(v_i32_blockA)); | |
| 206 | |||
| 207 | 4080 | const float32x4_t v_f32_blockB = vld1q_f32(src_ptr + k_idx + 8 + 2 * read_mem_increment); | |
| 208 | 4080 | const float32x4_t v_f32_sblockB = vmulq_n_f32(v_f32_blockB, vgetq_lane_f32(rep_scales, 2)); | |
| 209 | 4080 | const int32x4_t v_i32_blockB = vcvtnq_s32_f32(v_f32_sblockB); | |
| 210 | |||
| 211 | 4080 | const float32x4_t v_f32_blockC = vld1q_f32(src_ptr + k_idx + 12 + 2 * read_mem_increment); | |
| 212 | 4080 | const float32x4_t v_f32_sblockC = vmulq_n_f32(v_f32_blockC, vgetq_lane_f32(rep_scales, 2)); | |
| 213 | 4080 | const int32x4_t v_i32_blockC = vcvtnq_s32_f32(v_f32_sblockC); | |
| 214 | |||
| 215 | 8160 | const int16x8_t v_full_i16_block6 = | |
| 216 | 4080 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockB), vreinterpretq_s16_s32(v_i32_blockC)); | |
| 217 | |||
| 218 | // Row 4 blocks | ||
| 219 | 4080 | const float32x4_t v_f32_blockD = vld1q_f32(src_ptr + k_idx + 3 * read_mem_increment); | |
| 220 | 4080 | const float32x4_t v_f32_sblockD = vmulq_n_f32(v_f32_blockD, vgetq_lane_f32(rep_scales, 3)); | |
| 221 | 4080 | const int32x4_t v_i32_blockD = vcvtnq_s32_f32(v_f32_sblockD); | |
| 222 | |||
| 223 | 4080 | const float32x4_t v_f32_blockE = vld1q_f32(src_ptr + k_idx + 4 + 3 * read_mem_increment); | |
| 224 | 4080 | const float32x4_t v_f32_sblockE = vmulq_n_f32(v_f32_blockE, vgetq_lane_f32(rep_scales, 3)); | |
| 225 | 4080 | const int32x4_t v_i32_blockE = vcvtnq_s32_f32(v_f32_sblockE); | |
| 226 | |||
| 227 | 8160 | const int16x8_t v_full_i16_block7 = | |
| 228 | 4080 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockD), vreinterpretq_s16_s32(v_i32_blockE)); | |
| 229 | |||
| 230 | 4080 | const float32x4_t v_f32_blockF = vld1q_f32(src_ptr + k_idx + 8 + 3 * read_mem_increment); | |
| 231 | 4080 | const float32x4_t v_f32_sblockF = vmulq_n_f32(v_f32_blockF, vgetq_lane_f32(rep_scales, 3)); | |
| 232 | 4080 | const int32x4_t v_i32_blockF = vcvtnq_s32_f32(v_f32_sblockF); | |
| 233 | |||
| 234 | 4080 | const float32x4_t v_f32_block0 = vld1q_f32(src_ptr + k_idx + 12 + 3 * read_mem_increment); | |
| 235 | 4080 | const float32x4_t v_f32_sblock0 = vmulq_n_f32(v_f32_block0, vgetq_lane_f32(rep_scales, 3)); | |
| 236 | 4080 | const int32x4_t v_i32_block0 = vcvtnq_s32_f32(v_f32_sblock0); | |
| 237 | |||
| 238 | 8160 | const int16x8_t v_full_i16_block8 = | |
| 239 | 4080 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_blockF), vreinterpretq_s16_s32(v_i32_block0)); | |
| 240 | |||
| 241 | 8160 | const int8x16_t v_i8_block1_3 = | |
| 242 | 4080 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block3)); | |
| 243 | 4080 | vst1q_s8(dst_ptr, v_i8_block1_3); | |
| 244 | 4080 | dst_ptr += write_mem_increment; | |
| 245 | |||
| 246 | 8160 | const int8x16_t v_i8_block5_7 = | |
| 247 | 4080 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block5), vreinterpretq_s8_s16(v_full_i16_block7)); | |
| 248 | 4080 | vst1q_s8(dst_ptr, v_i8_block5_7); | |
| 249 | 4080 | dst_ptr += write_mem_increment; | |
| 250 | |||
| 251 | 8160 | const int8x16_t v_i8_block2_4 = | |
| 252 | 4080 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block2), vreinterpretq_s8_s16(v_full_i16_block4)); | |
| 253 | 4080 | vst1q_s8(dst_ptr, v_i8_block2_4); | |
| 254 | 4080 | dst_ptr += write_mem_increment; | |
| 255 | |||
| 256 | 8160 | const int8x16_t v_i8_block6_8 = | |
| 257 | 4080 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block6), vreinterpretq_s8_s16(v_full_i16_block8)); | |
| 258 | 4080 | vst1q_s8(dst_ptr, v_i8_block6_8); | |
| 259 | 4080 | dst_ptr += write_mem_increment; | |
| 260 | 4080 | } | |
| 261 | 2040 | src_ptr += local_bl; | |
| 262 | 2040 | } | |
| 263 | 1032 | lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); | |
| 264 | 1032 | } | |
| 265 | 112 | } | |
| 266 |
2/2✓ Branch 0 taken 48 times.
✓ Branch 1 taken 136 times.
|
184 | if (num_rows % 4 != 0) { |
| 267 |
2/2✓ Branch 0 taken 152 times.
✓ Branch 1 taken 136 times.
|
288 | for (; row_idx < num_rows; ++row_idx) { |
| 268 | 152 | const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); | |
| 269 | |||
| 270 |
2/2✓ Branch 0 taken 208 times.
✓ Branch 1 taken 152 times.
|
360 | for (size_t b = 0; b < num_blocks_per_row; ++b) { |
| 271 | 208 | float abs_max = 0.0F; | |
| 272 | |||
| 273 | 208 | const size_t dst_x = ((row_idx + m_idx_start) % local_mr); | |
| 274 | 208 | int8_t* dst_ptr = (int8_t*)lhs_packed + (b * local_mr) * num_bytes_per_block; | |
| 275 | |||
| 276 | 208 | float32x4_t v_f32_abs_values; | |
| 277 | 208 | float32x4_t v_f32_maxvals; | |
| 278 | 208 | float32x4_t v_currentmax = vdupq_n_f32(0); | |
| 279 | |||
| 280 |
2/2✓ Branch 0 taken 1664 times.
✓ Branch 1 taken 208 times.
|
1872 | for (size_t idx_v = 0; idx_v < local_bl; idx_v += 4) { |
| 281 | 1664 | v_f32_maxvals = vld1q_f32(src_ptr + idx_v); | |
| 282 | 1664 | v_f32_abs_values = vabsq_f32(v_f32_maxvals); | |
| 283 | 1664 | v_currentmax = vmaxq_f32(v_f32_abs_values, v_currentmax); | |
| 284 | 1664 | } | |
| 285 | 208 | abs_max = vmaxvq_f32(v_currentmax); | |
| 286 | |||
| 287 | // Calculate scale and reciprocal | ||
| 288 | 208 | const float scale = abs_max / ((1 << 7) - 1); | |
| 289 |
1/2✓ Branch 0 taken 208 times.
✗ Branch 1 not taken.
|
208 | const float rep_scale = scale ? 1.0F / scale : 0.0F; |
| 290 | |||
| 291 | 208 | *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale); | |
| 292 | 208 | dst_ptr += local_mr * kai_num_bytes_multiplier; | |
| 293 | |||
| 294 | 208 | dst_ptr += dst_x * k_block_len * sizeof(int8_t); | |
| 295 | |||
| 296 | // Quantize and pack the block | ||
| 297 |
2/2✓ Branch 0 taken 416 times.
✓ Branch 1 taken 208 times.
|
624 | for (size_t k_idx = 0; k_idx < local_bl; k_idx += k_block_len * 2) { |
| 298 | 416 | const float32x4_t v_f32_block1 = vld1q_f32(src_ptr + k_idx); | |
| 299 | 416 | const float32x4_t v_f32_sblock1 = vmulq_n_f32(v_f32_block1, rep_scale); | |
| 300 | 416 | const int32x4_t v_i32_block1 = vcvtnq_s32_f32(v_f32_sblock1); | |
| 301 | |||
| 302 | 416 | const float32x4_t v_f32_block2 = vld1q_f32(src_ptr + k_idx + 4); | |
| 303 | 416 | const float32x4_t v_f32_sblock2 = vmulq_n_f32(v_f32_block2, rep_scale); | |
| 304 | 416 | const int32x4_t v_i32_block2 = vcvtnq_s32_f32(v_f32_sblock2); | |
| 305 | |||
| 306 | 832 | const int16x8_t v_full_i16_block1 = | |
| 307 | 416 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block1), vreinterpretq_s16_s32(v_i32_block2)); | |
| 308 | |||
| 309 | 416 | const float32x4_t v_f32_block3 = vld1q_f32(src_ptr + k_idx + 8); | |
| 310 | 416 | const float32x4_t v_f32_sblock3 = vmulq_n_f32(v_f32_block3, rep_scale); | |
| 311 | 416 | const int32x4_t v_i32_block3 = vcvtnq_s32_f32(v_f32_sblock3); | |
| 312 | |||
| 313 | 416 | const float32x4_t v_f32_block4 = vld1q_f32(src_ptr + k_idx + 12); | |
| 314 | 416 | const float32x4_t v_f32_sblock4 = vmulq_n_f32(v_f32_block4, rep_scale); | |
| 315 | 416 | const int32x4_t v_i32_block4 = vcvtnq_s32_f32(v_f32_sblock4); | |
| 316 | |||
| 317 | 832 | const int16x8_t v_full_i16_block2 = | |
| 318 | 416 | vuzp1q_s16(vreinterpretq_s16_s32(v_i32_block3), vreinterpretq_s16_s32(v_i32_block4)); | |
| 319 | |||
| 320 | 832 | const int8x16_t v_full_i8_block = | |
| 321 | 416 | vuzp1q_s8(vreinterpretq_s8_s16(v_full_i16_block1), vreinterpretq_s8_s16(v_full_i16_block2)); | |
| 322 | |||
| 323 | 416 | vst1_s8(dst_ptr, vget_low_s8(v_full_i8_block)); | |
| 324 | 416 | dst_ptr += 8 * sizeof(int8_t); | |
| 325 | 416 | dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t); | |
| 326 | |||
| 327 | 416 | vst1_s8(dst_ptr, vget_high_s8(v_full_i8_block)); | |
| 328 | 416 | dst_ptr += 8 * sizeof(int8_t); | |
| 329 | 416 | dst_ptr += (local_mr - 1) * k_block_len * sizeof(int8_t); | |
| 330 | 416 | } | |
| 331 | 208 | src_ptr += local_bl; | |
| 332 | 208 | } | |
| 333 | // Move to the next row if we have interleaved all Mr rows | ||
| 334 |
1/2✓ Branch 0 taken 152 times.
✗ Branch 1 not taken.
|
152 | if ((((row_idx + 1) + m_idx_start) % local_mr) == 0) { |
| 335 | ✗ | lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); | |
| 336 | ✗ | } | |
| 337 | 152 | } | |
| 338 | 136 | } | |
| 339 | 184 | } | |
| 340 |