kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || \ | ||
| 8 | !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) | ||
| 9 | #error This file must be compiled for AArch64, FEAT_FP16. | ||
| 10 | #else // Architectural features check. | ||
| 11 | |||
| 12 | #include "kai_lhs_quant_pack_qai8dxp_f16_neon.h" | ||
| 13 | |||
| 14 | #include <arm_fp16.h> | ||
| 15 | #include <arm_neon.h> | ||
| 16 | #include <float.h> | ||
| 17 | #include <math.h> | ||
| 18 | #include <stddef.h> | ||
| 19 | #include <stdint.h> | ||
| 20 | |||
| 21 | #include "kai/kai_common.h" | ||
| 22 | #define FLT16_MAX 65504.0 | ||
| 23 | #define FLT16_MIN (-65504.0F) | ||
| 24 | |||
| 25 | static const size_t kai_num_bytes_per_multiplier = sizeof(float); | ||
| 26 | static const size_t kai_num_bytes_per_offset = sizeof(int32_t); | ||
| 27 | |||
| 28 | 29568 | inline static size_t kai_k_roundedup(size_t k) { | |
| 29 | // Round up k to be a multiple of 32. | ||
| 30 | 29568 | size_t kai_k_multiple_of = 32; | |
| 31 | 59136 | return kai_roundup(k, kai_k_multiple_of); | |
| 32 | 29568 | } | |
| 33 | |||
| 34 | 22176 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { | |
| 35 | 22176 | KAI_UNUSED(kr); | |
| 36 | 22176 | KAI_UNUSED(sr); | |
| 37 | |||
| 38 | 22176 | const size_t k_internal = kai_k_roundedup(k); | |
| 39 | |||
| 40 | − | KAI_ASSERT((k_internal % 2) == 0); | |
| 41 | |||
| 42 | 44352 | return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); | |
| 43 | 22176 | } | |
| 44 | |||
| 45 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f16_neon(size_t mr) { | |
| 46 | ✗ | return mr; | |
| 47 | } | ||
| 48 | |||
| 49 | 7392 | size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(size_t m_idx, size_t lhs_stride) { | |
| 50 | 7392 | return m_idx * lhs_stride; | |
| 51 | } | ||
| 52 | |||
| 53 | 7392 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon( | |
| 54 | size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | ||
| 55 | // It always points to the beginning of the row | ||
| 56 | 7392 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); | |
| 57 | } | ||
| 58 | |||
| 59 | 7392 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
| 60 | 7392 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
| 61 | |||
| 62 | 14784 | return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); | |
| 63 | 7392 | } | |
| 64 | |||
| 65 | 7392 | void kai_run_lhs_quant_pack_qai8dxp_f16_neon( | |
| 66 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs, | ||
| 67 | size_t lhs_stride, void* restrict lhs_packed) { | ||
| 68 | − | KAI_ASSERT((kr % sr) == 0); | |
| 69 | − | KAI_ASSUME((kr / sr == 8) || (kr / sr == 4)); | |
| 70 | |||
| 71 |
1/2✓ Branch 0 taken 7392 times.
✗ Branch 1 not taken.
|
7392 | if (m == 0) { |
| 72 | ✗ | return; | |
| 73 | } | ||
| 74 | |||
| 75 | 7392 | const size_t num_rows = m; | |
| 76 | |||
| 77 | 7392 | float16_t const* src_ptr = (float16_t const*)lhs; | |
| 78 | |||
| 79 | 7392 | const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); | |
| 80 | 7392 | const size_t k_internal = kai_k_roundedup(k); | |
| 81 | 7392 | const int32_t k_block_len = (int32_t)(kr / sr); | |
| 82 | |||
| 83 | 7392 | const int32_t num_blocks_k = (int32_t)(k / k_block_len); | |
| 84 | 7392 | const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); | |
| 85 | 7392 | const size_t lhs_row_length = lhs_stride / sizeof(float16_t); | |
| 86 | |||
| 87 | 7392 | const float16x8_t vmax = vdupq_n_f16((float16_t)FLT16_MIN); | |
| 88 | 7392 | const float16x8_t vmin = vdupq_n_f16((float16_t)FLT16_MAX); | |
| 89 | |||
| 90 | // As we load 8-element vectors, limit vectorized loop to avoid reading out-of-bounds | ||
| 91 | 7392 | const int32_t blocks_lim_k = num_blocks_k - (8 / k_block_len); | |
| 92 | |||
| 93 | 7392 | size_t row_idx = 0; | |
| 94 | |||
| 95 | // Improved performance with 4x loop unrolling where packing parameters allow | ||
| 96 |
2/2✓ Branch 0 taken 2016 times.
✓ Branch 1 taken 5376 times.
|
7392 | if (mr == 4) { |
| 97 |
2/2✓ Branch 0 taken 14976 times.
✓ Branch 1 taken 5376 times.
|
20352 | for (; row_idx + 3 < m; row_idx += 4) { |
| 98 | // Find min/max for each channel | ||
| 99 | 14976 | int32_t k_idx = 0; | |
| 100 | 14976 | float16x8_t vmax0 = vmax; | |
| 101 | 14976 | float16x8_t vmin0 = vmin; | |
| 102 | 14976 | float16x8_t vmax1 = vmax; | |
| 103 | 14976 | float16x8_t vmin1 = vmin; | |
| 104 | 14976 | float16x8_t vmax2 = vmax; | |
| 105 | 14976 | float16x8_t vmin2 = vmin; | |
| 106 | 14976 | float16x8_t vmax3 = vmax; | |
| 107 | 14976 | float16x8_t vmin3 = vmin; | |
| 108 | |||
| 109 |
2/2✓ Branch 0 taken 92256 times.
✓ Branch 1 taken 14976 times.
|
107232 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
| 110 | 92256 | const float16x8_t src0 = vld1q_f16(src_ptr + k_idx); | |
| 111 | 92256 | const float16x8_t src1 = vld1q_f16(src_ptr + k_idx + lhs_row_length); | |
| 112 | 92256 | const float16x8_t src2 = vld1q_f16(src_ptr + k_idx + (2 * lhs_row_length)); | |
| 113 | 92256 | const float16x8_t src3 = vld1q_f16(src_ptr + k_idx + (3 * lhs_row_length)); | |
| 114 | |||
| 115 | 92256 | vmax0 = vmaxq_f16(src0, vmax0); | |
| 116 | 92256 | vmax1 = vmaxq_f16(src1, vmax1); | |
| 117 | 92256 | vmax2 = vmaxq_f16(src2, vmax2); | |
| 118 | 92256 | vmax3 = vmaxq_f16(src3, vmax3); | |
| 119 | 92256 | vmin0 = vminq_f16(src0, vmin0); | |
| 120 | 92256 | vmin1 = vminq_f16(src1, vmin1); | |
| 121 | 92256 | vmin2 = vminq_f16(src2, vmin2); | |
| 122 | 92256 | vmin3 = vminq_f16(src3, vmin3); | |
| 123 | 92256 | } | |
| 124 | |||
| 125 | 14976 | float16_t max0 = vmaxvq_f16(vmax0); | |
| 126 | 14976 | float16_t min0 = vminvq_f16(vmin0); | |
| 127 | 14976 | float16_t max1 = vmaxvq_f16(vmax1); | |
| 128 | 14976 | float16_t min1 = vminvq_f16(vmin1); | |
| 129 | 14976 | float16_t max2 = vmaxvq_f16(vmax2); | |
| 130 | 14976 | float16_t min2 = vminvq_f16(vmin2); | |
| 131 | 14976 | float16_t max3 = vmaxvq_f16(vmax3); | |
| 132 | 14976 | float16_t min3 = vminvq_f16(vmin3); | |
| 133 | // Process leftover elements with a scalar loop. | ||
| 134 |
2/2✓ Branch 0 taken 13776 times.
✓ Branch 1 taken 14976 times.
|
28752 | for (; k_idx < (int32_t)k; ++k_idx) { |
| 135 | 13776 | const float16_t src0 = *(src_ptr + (size_t)k_idx); | |
| 136 | 13776 | max0 = vmaxh_f16(src0, max0); | |
| 137 | 13776 | min0 = vminh_f16(src0, min0); | |
| 138 | 13776 | const float16_t src1 = *(src_ptr + (size_t)k_idx + lhs_row_length); | |
| 139 | 13776 | max1 = vmaxh_f16(src1, max1); | |
| 140 | 13776 | min1 = vminh_f16(src1, min1); | |
| 141 | 13776 | const float16_t src2 = *(src_ptr + (size_t)k_idx + (2 * lhs_row_length)); | |
| 142 | 13776 | max2 = vmaxh_f16(src2, max2); | |
| 143 | 13776 | min2 = vminh_f16(src2, min2); | |
| 144 | 13776 | const float16_t src3 = *(src_ptr + (size_t)k_idx + (3 * lhs_row_length)); | |
| 145 | 13776 | max3 = vmaxh_f16(src3, max3); | |
| 146 | 13776 | min3 = vminh_f16(src3, min3); | |
| 147 | 13776 | } | |
| 148 | |||
| 149 | // Maximum/minimum int8 values | ||
| 150 | 14976 | const float qmin = (float)INT8_MIN; | |
| 151 | 14976 | const float qmax = (float)INT8_MAX; | |
| 152 | |||
| 153 | 14976 | const float rmin0 = fminf(0.0F, min0); | |
| 154 | 14976 | const float rmax0 = fmaxf(0.0F, max0); | |
| 155 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 14976 times.
|
14976 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
| 156 | 14976 | const float rmin1 = fminf(0.0F, min1); | |
| 157 | 14976 | const float rmax1 = fmaxf(0.0F, max1); | |
| 158 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 14976 times.
|
14976 | const float scale1 = rmin1 == rmax1 ? 1.F : (qmax - qmin) / (rmax1 - rmin1); |
| 159 | 14976 | const float rmin2 = fminf(0.0F, min2); | |
| 160 | 14976 | const float rmax2 = fmaxf(0.0F, max2); | |
| 161 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 14976 times.
|
14976 | const float scale2 = rmin2 == rmax2 ? 1.F : (qmax - qmin) / (rmax2 - rmin2); |
| 162 | 14976 | const float rmin3 = fminf(0.0F, min3); | |
| 163 | 14976 | const float rmax3 = fmaxf(0.0F, max3); | |
| 164 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 14976 times.
|
14976 | const float scale3 = rmin3 == rmax3 ? 1.F : (qmax - qmin) / (rmax3 - rmin3); |
| 165 | |||
| 166 | // Reciprocal to quantize | ||
| 167 |
1/2✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
|
14976 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
| 168 |
1/2✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
|
14976 | const float recip_scale1 = scale1 ? 1.0F / scale1 : 0.0F; |
| 169 |
1/2✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
|
14976 | const float recip_scale2 = scale2 ? 1.0F / scale2 : 0.0F; |
| 170 |
1/2✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
|
14976 | const float recip_scale3 = scale3 ? 1.0F / scale3 : 0.0F; |
| 171 | |||
| 172 | 14976 | const float descaled_min0 = rmin0 * scale0; | |
| 173 | 14976 | const float descaled_max0 = rmax0 * scale0; | |
| 174 | 14976 | const float descaled_min1 = rmin1 * scale1; | |
| 175 | 14976 | const float descaled_max1 = rmax1 * scale1; | |
| 176 | 14976 | const float descaled_min2 = rmin2 * scale2; | |
| 177 | 14976 | const float descaled_max2 = rmax2 * scale2; | |
| 178 | 14976 | const float descaled_min3 = rmin3 * scale3; | |
| 179 | 14976 | const float descaled_max3 = rmax3 * scale3; | |
| 180 | |||
| 181 | 14976 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
| 182 | 14976 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
| 183 | 14976 | const float zero_point_from_min_error1 = qmin + descaled_min1; | |
| 184 | 14976 | const float zero_point_from_max_error1 = qmax + descaled_max1; | |
| 185 | 14976 | const float zero_point_from_min_error2 = qmin + descaled_min2; | |
| 186 | 14976 | const float zero_point_from_max_error2 = qmax + descaled_max2; | |
| 187 | 14976 | const float zero_point_from_min_error3 = qmin + descaled_min3; | |
| 188 | 14976 | const float zero_point_from_max_error3 = qmax + descaled_max3; | |
| 189 | |||
| 190 |
1/2✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
|
14976 | float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 |
| 191 | ✗ | : qmax - descaled_max0; | |
| 192 |
1/2✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
|
14976 | float zero_point1 = (zero_point_from_min_error1 + zero_point_from_max_error1 > 0) ? qmin - descaled_min1 |
| 193 | ✗ | : qmax - descaled_max1; | |
| 194 |
1/2✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
|
14976 | float zero_point2 = (zero_point_from_min_error2 + zero_point_from_max_error2 > 0) ? qmin - descaled_min2 |
| 195 | ✗ | : qmax - descaled_max2; | |
| 196 |
1/2✓ Branch 0 taken 14976 times.
✗ Branch 1 not taken.
|
14976 | float zero_point3 = (zero_point_from_min_error3 + zero_point_from_max_error3 > 0) ? qmin - descaled_min3 |
| 197 | ✗ | : qmax - descaled_max3; | |
| 198 | |||
| 199 | 14976 | zero_point0 = fmaxf(zero_point0, qmin); | |
| 200 | 14976 | zero_point0 = fminf(zero_point0, qmax); | |
| 201 | 14976 | zero_point1 = fmaxf(zero_point1, qmin); | |
| 202 | 14976 | zero_point1 = fminf(zero_point1, qmax); | |
| 203 | 14976 | zero_point2 = fmaxf(zero_point2, qmin); | |
| 204 | 14976 | zero_point2 = fminf(zero_point2, qmax); | |
| 205 | 14976 | zero_point3 = fmaxf(zero_point3, qmin); | |
| 206 | 14976 | zero_point3 = fminf(zero_point3, qmax); | |
| 207 | |||
| 208 | // Round to nearest integer | ||
| 209 | 14976 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
| 210 | 14976 | const int32_t nudged_zero_point1 = (int32_t)rintf(zero_point1); | |
| 211 | 14976 | const int32_t nudged_zero_point2 = (int32_t)rintf(zero_point2); | |
| 212 | 14976 | const int32_t nudged_zero_point3 = (int32_t)rintf(zero_point3); | |
| 213 | |||
| 214 | 14976 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
| 215 | |||
| 216 | 14976 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len); | |
| 217 | |||
| 218 | // Quantize the channels | ||
| 219 | 14976 | int32_t block_idx = 0; | |
| 220 | 14976 | const int32_t block_incr = 8 / k_block_len; | |
| 221 | |||
| 222 |
2/2✓ Branch 0 taken 92256 times.
✓ Branch 1 taken 14976 times.
|
107232 | for (; block_idx <= blocks_lim_k; block_idx += block_incr) { |
| 223 | // Clamp at the last valid k-index | ||
| 224 | 92256 | const int32_t k_idx_start = block_idx * k_block_len; | |
| 225 | |||
| 226 | 92256 | const float16x8_t src0 = vld1q_f16(src_ptr + k_idx_start); | |
| 227 | 92256 | const float16x8_t src1 = vld1q_f16(src_ptr + k_idx_start + lhs_row_length); | |
| 228 | 92256 | const float16x8_t src2 = vld1q_f16(src_ptr + k_idx_start + (2 * lhs_row_length)); | |
| 229 | 92256 | const float16x8_t src3 = vld1q_f16(src_ptr + k_idx_start + (3 * lhs_row_length)); | |
| 230 | |||
| 231 | // Scale the values. | ||
| 232 | 92256 | const int32x4_t v0_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src0)), scale0)); | |
| 233 | 92256 | const int32x4_t v0_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src0), scale0)); | |
| 234 | 92256 | const int32x4_t v1_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src1)), scale1)); | |
| 235 | 92256 | const int32x4_t v1_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src1), scale1)); | |
| 236 | 92256 | const int32x4_t v2_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src2)), scale2)); | |
| 237 | 92256 | const int32x4_t v2_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src2), scale2)); | |
| 238 | 92256 | const int32x4_t v3_0_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src3)), scale3)); | |
| 239 | 92256 | const int32x4_t v3_1_s32 = vcvtq_s32_f32(vmulq_n_f32(vcvt_high_f32_f16(src3), scale3)); | |
| 240 | |||
| 241 | 92256 | const int16x4_t v0_0_s16 = vqmovn_s32(v0_0_s32); | |
| 242 | 92256 | const int16x4_t v0_1_s16 = vqmovn_s32(v0_1_s32); | |
| 243 | 92256 | const int16x4_t v1_0_s16 = vqmovn_s32(v1_0_s32); | |
| 244 | 92256 | const int16x4_t v1_1_s16 = vqmovn_s32(v1_1_s32); | |
| 245 | 92256 | const int16x4_t v2_0_s16 = vqmovn_s32(v2_0_s32); | |
| 246 | 92256 | const int16x4_t v2_1_s16 = vqmovn_s32(v2_1_s32); | |
| 247 | 92256 | const int16x4_t v3_0_s16 = vqmovn_s32(v3_0_s32); | |
| 248 | 92256 | const int16x4_t v3_1_s16 = vqmovn_s32(v3_1_s32); | |
| 249 | |||
| 250 | 92256 | int16x8_t v0_s16; | |
| 251 | 92256 | int16x8_t v1_s16; | |
| 252 | 92256 | int16x8_t v2_s16; | |
| 253 | 92256 | int16x8_t v3_s16; | |
| 254 |
2/2✓ Branch 0 taken 46128 times.
✓ Branch 1 taken 46128 times.
|
92256 | if (k_block_len == 8) { |
| 255 | 46128 | v0_s16 = vcombine_s16(v0_0_s16, v0_1_s16); | |
| 256 | 46128 | v1_s16 = vcombine_s16(v1_0_s16, v1_1_s16); | |
| 257 | 46128 | v2_s16 = vcombine_s16(v2_0_s16, v2_1_s16); | |
| 258 | 46128 | v3_s16 = vcombine_s16(v3_0_s16, v3_1_s16); | |
| 259 | 46128 | } else { // k_block_len == 4 | |
| 260 | 46128 | v0_s16 = vcombine_s16(v0_0_s16, v1_0_s16); | |
| 261 | 46128 | v1_s16 = vcombine_s16(v2_0_s16, v3_0_s16); | |
| 262 | 46128 | v2_s16 = vcombine_s16(v0_1_s16, v1_1_s16); | |
| 263 | 46128 | v3_s16 = vcombine_s16(v2_1_s16, v3_1_s16); | |
| 264 | } | ||
| 265 | |||
| 266 | // Add zero points. | ||
| 267 | 92256 | const int16x8_t vnzp0 = vdupq_n_s16((int16_t)nudged_zero_point0); | |
| 268 | 92256 | const int16x8_t vnzp1 = vdupq_n_s16((int16_t)nudged_zero_point1); | |
| 269 | 92256 | const int16x8_t vnzp2 = vdupq_n_s16((int16_t)nudged_zero_point2); | |
| 270 | 92256 | const int16x8_t vnzp3 = vdupq_n_s16((int16_t)nudged_zero_point3); | |
| 271 | |||
| 272 | 92256 | v0_s16 = vaddq_s16(v0_s16, vnzp0); | |
| 273 | 92256 | v0_s16 = vmaxq_s16(v0_s16, vdupq_n_s16(INT8_MIN)); | |
| 274 | 92256 | v0_s16 = vminq_s16(v0_s16, vdupq_n_s16(INT8_MAX)); | |
| 275 | 92256 | v1_s16 = vaddq_s16(v1_s16, vnzp1); | |
| 276 | 92256 | v1_s16 = vmaxq_s16(v1_s16, vdupq_n_s16(INT8_MIN)); | |
| 277 | 92256 | v1_s16 = vminq_s16(v1_s16, vdupq_n_s16(INT8_MAX)); | |
| 278 | 92256 | v2_s16 = vaddq_s16(v2_s16, vnzp2); | |
| 279 | 92256 | v2_s16 = vmaxq_s16(v2_s16, vdupq_n_s16(INT8_MIN)); | |
| 280 | 92256 | v2_s16 = vminq_s16(v2_s16, vdupq_n_s16(INT8_MAX)); | |
| 281 | 92256 | v3_s16 = vaddq_s16(v3_s16, vnzp3); | |
| 282 | 92256 | v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN)); | |
| 283 | 92256 | v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX)); | |
| 284 | |||
| 285 | 92256 | int8x8_t v0_s8 = vqmovn_s16(v0_s16); | |
| 286 | 92256 | int8x8_t v1_s8 = vqmovn_s16(v1_s16); | |
| 287 | 92256 | int8x8_t v2_s8 = vqmovn_s16(v2_s16); | |
| 288 | 92256 | int8x8_t v3_s8 = vqmovn_s16(v3_s16); | |
| 289 | |||
| 290 | 92256 | vst1_s8((int8_t*)(dst_ptr), v0_s8); | |
| 291 | 92256 | vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8); | |
| 292 | 92256 | vst1_s8((int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)), v2_s8); | |
| 293 | 92256 | vst1_s8((int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)), v3_s8); | |
| 294 | 92256 | dst_ptr += block_incr * mr * k_block_len * sizeof(int8_t); | |
| 295 | 92256 | } | |
| 296 | |||
| 297 |
2/2✓ Branch 0 taken 19152 times.
✓ Branch 1 taken 14976 times.
|
34128 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
| 298 | // left over k | ||
| 299 |
2/2✓ Branch 0 taken 102144 times.
✓ Branch 1 taken 19152 times.
|
121296 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
| 300 | // Clamp at the last valid k-index. | ||
| 301 |
2/2✓ Branch 0 taken 10080 times.
✓ Branch 1 taken 92064 times.
|
102144 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
| 302 | |||
| 303 | 102144 | const float src0 = (float)(*(src_ptr + k_idx_start)); | |
| 304 | 102144 | const float src1 = (float)(*(src_ptr + k_idx_start + lhs_row_length)); | |
| 305 | 102144 | const float src2 = (float)(*(src_ptr + k_idx_start + (2 * lhs_row_length))); | |
| 306 | 102144 | const float src3 = (float)(*(src_ptr + k_idx_start + (3 * lhs_row_length))); | |
| 307 | |||
| 308 | // Scale the value. | ||
| 309 | 102144 | int32_t d0_s32 = (int32_t)(roundf(src0 * scale0)); | |
| 310 | 102144 | int32_t d1_s32 = (int32_t)(roundf(src1 * scale1)); | |
| 311 | 102144 | int32_t d2_s32 = (int32_t)(roundf(src2 * scale2)); | |
| 312 | 102144 | int32_t d3_s32 = (int32_t)(roundf(src3 * scale3)); | |
| 313 | |||
| 314 | 102144 | d0_s32 = d0_s32 + nudged_zero_point0; | |
| 315 |
1/2✓ Branch 0 taken 102144 times.
✗ Branch 1 not taken.
|
102144 | d0_s32 = KAI_MAX(d0_s32, INT8_MIN); |
| 316 |
2/2✓ Branch 0 taken 101724 times.
✓ Branch 1 taken 420 times.
|
102144 | d0_s32 = KAI_MIN(d0_s32, INT8_MAX); |
| 317 | |||
| 318 | 102144 | d1_s32 = d1_s32 + nudged_zero_point1; | |
| 319 |
2/2✓ Branch 0 taken 102088 times.
✓ Branch 1 taken 56 times.
|
102144 | d1_s32 = KAI_MAX(d1_s32, INT8_MIN); |
| 320 |
2/2✓ Branch 0 taken 101164 times.
✓ Branch 1 taken 980 times.
|
102144 | d1_s32 = KAI_MIN(d1_s32, INT8_MAX); |
| 321 | |||
| 322 | 102144 | d2_s32 = d2_s32 + nudged_zero_point2; | |
| 323 |
2/2✓ Branch 0 taken 102116 times.
✓ Branch 1 taken 28 times.
|
102144 | d2_s32 = KAI_MAX(d2_s32, INT8_MIN); |
| 324 |
2/2✓ Branch 0 taken 101360 times.
✓ Branch 1 taken 784 times.
|
102144 | d2_s32 = KAI_MIN(d2_s32, INT8_MAX); |
| 325 | |||
| 326 | 102144 | d3_s32 = d3_s32 + nudged_zero_point3; | |
| 327 |
2/2✓ Branch 0 taken 101248 times.
✓ Branch 1 taken 896 times.
|
102144 | d3_s32 = KAI_MAX(d3_s32, INT8_MIN); |
| 328 |
2/2✓ Branch 0 taken 100492 times.
✓ Branch 1 taken 1652 times.
|
102144 | d3_s32 = KAI_MIN(d3_s32, INT8_MAX); |
| 329 | |||
| 330 | 102144 | *(int8_t*)dst_ptr = (int8_t)d0_s32; | |
| 331 | 102144 | *(int8_t*)(dst_ptr + k_block_len * sizeof(int8_t)) = (int8_t)d1_s32; | |
| 332 | 102144 | *(int8_t*)(dst_ptr + 2 * (k_block_len * sizeof(int8_t))) = (int8_t)d2_s32; | |
| 333 | 102144 | *(int8_t*)(dst_ptr + 3 * (k_block_len * sizeof(int8_t))) = (int8_t)d3_s32; | |
| 334 | 102144 | dst_ptr += sizeof(int8_t); | |
| 335 | 102144 | } | |
| 336 | 19152 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 337 | 19152 | } | |
| 338 | |||
| 339 | 14976 | uint8_t* dst_base = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
| 340 | |||
| 341 | 14976 | dst_ptr = dst_base + dst_x * kai_num_bytes_per_offset; | |
| 342 | |||
| 343 | // LHS offset at the beginning of the row. | ||
| 344 | 14976 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
| 345 | 14976 | *((int32_t*)(dst_ptr + kai_num_bytes_per_offset)) = -nudged_zero_point1; | |
| 346 | 14976 | *((int32_t*)(dst_ptr + 2 * kai_num_bytes_per_offset)) = -nudged_zero_point2; | |
| 347 | 14976 | *((int32_t*)(dst_ptr + 3 * kai_num_bytes_per_offset)) = -nudged_zero_point3; | |
| 348 | |||
| 349 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. | ||
| 350 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
| 351 | |||
| 352 | 14976 | dst_ptr += mr * kai_num_bytes_per_offset; | |
| 353 | |||
| 354 | // Store the scale quantization params. | ||
| 355 | 14976 | *((float*)(dst_ptr)) = recip_scale0; | |
| 356 | 14976 | *((float*)(dst_ptr + kai_num_bytes_per_multiplier)) = recip_scale1; | |
| 357 | 14976 | *((float*)(dst_ptr + 2 * kai_num_bytes_per_multiplier)) = recip_scale2; | |
| 358 | 14976 | *((float*)(dst_ptr + 3 * kai_num_bytes_per_multiplier)) = recip_scale3; | |
| 359 | |||
| 360 | // Update src_ptr. Note: now lhs contains fp16 values (2 bytes each). | ||
| 361 | 14976 | src_ptr += (4 * lhs_row_length); | |
| 362 | |||
| 363 | // Move to the next row as we have interleaved all Mr rows. | ||
| 364 | 14976 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
| 365 | 14976 | } | |
| 366 | 5376 | } | |
| 367 | |||
| 368 |
2/2✓ Branch 0 taken 7680 times.
✓ Branch 1 taken 7392 times.
|
15072 | for (; row_idx < num_rows; ++row_idx) { |
| 369 | // Find min/max for each channel | ||
| 370 | 7680 | int32_t k_idx = 0; | |
| 371 | 7680 | float16x8_t vmax0 = vmax; | |
| 372 | 7680 | float16x8_t vmin0 = vmin; | |
| 373 | |||
| 374 |
2/2✓ Branch 0 taken 38592 times.
✓ Branch 1 taken 7680 times.
|
46272 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
| 375 | 38592 | const float16x8_t src0_0 = vld1q_f16(src_ptr + (size_t)k_idx); | |
| 376 | 38592 | vmax0 = vmaxq_f16(vmax0, src0_0); | |
| 377 | 38592 | vmin0 = vminq_f16(vmin0, src0_0); | |
| 378 | 38592 | } | |
| 379 | // Get the max/min | ||
| 380 | 7680 | float16_t max0 = vmaxvq_f16(vmax0); | |
| 381 | 7680 | float16_t min0 = vminvq_f16(vmin0); | |
| 382 | |||
| 383 |
2/2✓ Branch 0 taken 13776 times.
✓ Branch 1 taken 7680 times.
|
21456 | for (; k_idx < (int32_t)k; ++k_idx) { |
| 384 | 13776 | const float16_t src0 = *(src_ptr + (size_t)k_idx); | |
| 385 | 13776 | max0 = vmaxh_f16(src0, max0); | |
| 386 | 13776 | min0 = vminh_f16(src0, min0); | |
| 387 | 13776 | } | |
| 388 | |||
| 389 | // Maximum/minimum int8 values | ||
| 390 | 7680 | const float qmin = (float)INT8_MIN; | |
| 391 | 7680 | const float qmax = (float)INT8_MAX; | |
| 392 | |||
| 393 | 7680 | const float rmin0 = fminf(0.0F, min0); | |
| 394 | 7680 | const float rmax0 = fmaxf(0.0F, max0); | |
| 395 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 7680 times.
|
7680 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
| 396 | |||
| 397 | // Reciprocal to quantize | ||
| 398 |
1/2✓ Branch 0 taken 7680 times.
✗ Branch 1 not taken.
|
7680 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
| 399 | |||
| 400 | 7680 | const float descaled_min0 = rmin0 * scale0; | |
| 401 | 7680 | const float descaled_max0 = rmax0 * scale0; | |
| 402 | |||
| 403 | 7680 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
| 404 | 7680 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
| 405 | |||
| 406 | 15360 | float zero_point0 = | |
| 407 |
1/2✓ Branch 0 taken 7680 times.
✗ Branch 1 not taken.
|
7680 | zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; |
| 408 | |||
| 409 | 7680 | zero_point0 = fmaxf(zero_point0, qmin); | |
| 410 | 7680 | zero_point0 = fminf(zero_point0, qmax); | |
| 411 | |||
| 412 | // Round to nearest integer | ||
| 413 | 7680 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
| 414 | |||
| 415 | 7680 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
| 416 | |||
| 417 | 7680 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); | |
| 418 | |||
| 419 | // Quantize the channels | ||
| 420 | 7680 | int32_t block_idx = 0; | |
| 421 | |||
| 422 |
2/2✓ Branch 0 taken 55056 times.
✓ Branch 1 taken 7680 times.
|
62736 | for (; block_idx <= blocks_lim_k; ++block_idx) { |
| 423 | 55056 | const int32_t k_idx_start = block_idx * k_block_len; | |
| 424 | |||
| 425 | 55056 | const float16x8_t src_0 = vld1q_f16(src_ptr + k_idx_start); | |
| 426 | |||
| 427 | // Scale the values | ||
| 428 | 55056 | const float32x4_t v0_f32 = vmulq_n_f32(vcvt_f32_f16(vget_low_f16(src_0)), scale0); | |
| 429 | 55056 | const float32x4_t v1_f32 = vmulq_n_f32(vcvt_high_f32_f16(src_0), scale0); | |
| 430 | 55056 | const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); | |
| 431 | 55056 | const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); | |
| 432 | |||
| 433 | 55056 | const int16x4_t v0_s16 = vqmovn_s32(v0_s32); | |
| 434 | 55056 | const int16x4_t v1_s16 = vqmovn_s32(v1_s32); | |
| 435 | 55056 | int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); | |
| 436 | |||
| 437 | // Add zero points | ||
| 438 | 55056 | int16_t nzp_s16 = (int16_t)nudged_zero_point0; | |
| 439 | 55056 | int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); | |
| 440 | 55056 | v_s16 = vaddq_s16(v_s16, vnzp_s16); | |
| 441 | 55056 | v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); | |
| 442 | 55056 | v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); | |
| 443 | |||
| 444 | 55056 | int8x8_t v_s8 = vqmovn_s16(v_s16); | |
| 445 | 55056 | vst1_s8((int8_t*)(dst_ptr), v_s8); | |
| 446 | 55056 | dst_ptr += mr * k_block_len * sizeof(int8_t); | |
| 447 | 55056 | } | |
| 448 | |||
| 449 |
2/2✓ Branch 0 taken 18960 times.
✓ Branch 1 taken 7680 times.
|
26640 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
| 450 | // left over k | ||
| 451 |
2/2✓ Branch 0 taken 97344 times.
✓ Branch 1 taken 18960 times.
|
116304 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
| 452 | // Clamp at the last valid k-index | ||
| 453 |
2/2✓ Branch 0 taken 19416 times.
✓ Branch 1 taken 77928 times.
|
97344 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
| 454 | |||
| 455 | 97344 | const float src0 = (float)(*(src_ptr + k_idx_start)); | |
| 456 | |||
| 457 | // Scale the values | ||
| 458 | 97344 | int32_t d0_s32 = (int32_t)(roundf(src0 * scale0)); | |
| 459 | |||
| 460 | 97344 | d0_s32 = d0_s32 + nudged_zero_point0; | |
| 461 |
2/2✓ Branch 0 taken 97316 times.
✓ Branch 1 taken 28 times.
|
97344 | d0_s32 = KAI_MAX(d0_s32, INT8_MIN); |
| 462 |
2/2✓ Branch 0 taken 96490 times.
✓ Branch 1 taken 854 times.
|
97344 | d0_s32 = KAI_MIN(d0_s32, INT8_MAX); |
| 463 | |||
| 464 | 97344 | *((int8_t*)(dst_ptr)) = (int8_t)d0_s32; | |
| 465 | 97344 | dst_ptr += sizeof(int8_t); | |
| 466 | 97344 | } | |
| 467 | 18960 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 468 | 18960 | } | |
| 469 | |||
| 470 | 7680 | dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
| 471 | |||
| 472 | 7680 | dst_ptr += dst_x * kai_num_bytes_per_offset; | |
| 473 | |||
| 474 | // LHS offset at the beginning of the row | ||
| 475 | 7680 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
| 476 | |||
| 477 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier | ||
| 478 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
| 479 | |||
| 480 | 7680 | dst_ptr += mr * kai_num_bytes_per_offset; | |
| 481 | |||
| 482 | // Store the scale quantization params | ||
| 483 | 7680 | *((float*)(dst_ptr)) = recip_scale0; | |
| 484 | |||
| 485 | 7680 | src_ptr += lhs_row_length; | |
| 486 | |||
| 487 | // Move to the next row if we have interleaved all Mr rows | ||
| 488 |
2/2✓ Branch 0 taken 5664 times.
✓ Branch 1 taken 2016 times.
|
7680 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
| 489 | 2016 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
| 490 | 2016 | } | |
| 491 | 7680 | } | |
| 492 | 7392 | } | |
| 493 | |||
| 494 | #endif // Architectural features check. | ||
| 495 |