kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | #if (!defined(__aarch64__) && !defined(_M_ARM64)) | ||
| 7 | #error This file must be compiled for AArch64. | ||
| 8 | #else // Architectural features check. | ||
| 9 | |||
| 10 | #include "kai_lhs_quant_pack_qai8dxp_bf16_neon.h" | ||
| 11 | |||
| 12 | #include <arm_neon.h> | ||
| 13 | #endif | ||
| 14 | #include <float.h> | ||
| 15 | #include <math.h> | ||
| 16 | #include <stddef.h> | ||
| 17 | #include <stdint.h> | ||
| 18 | |||
| 19 | #include "kai/kai_common.h" | ||
| 20 | |||
| 21 | static const size_t kai_num_bytes_per_multiplier = sizeof(float); | ||
| 22 | static const size_t kai_num_bytes_per_offset = sizeof(int32_t); | ||
| 23 | |||
| 24 | 19776 | inline static size_t kai_k_roundedup(size_t k) { | |
| 25 | // Round up k to be a multiple of 32. | ||
| 26 | static const size_t kai_k_multiple_of = 32; | ||
| 27 | 19776 | return kai_roundup(k, kai_k_multiple_of); | |
| 28 | } | ||
| 29 | |||
| 30 | 14880 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr) { | |
| 31 | 14880 | const size_t k_internal = kai_k_roundedup(k); | |
| 32 | |||
| 33 | − | KAI_ASSERT((k_internal % 2) == 0); | |
| 34 | |||
| 35 | 29760 | return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); | |
| 36 | 14880 | } | |
| 37 | |||
| 38 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qai8dxp_bf16_neon(size_t mr) { | |
| 39 | ✗ | return mr; | |
| 40 | } | ||
| 41 | |||
| 42 | 4896 | size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(size_t m_idx, size_t lhs_stride) { | |
| 43 | 4896 | return m_idx * lhs_stride; | |
| 44 | } | ||
| 45 | |||
| 46 | 5088 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon( | |
| 47 | size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | ||
| 48 | 5088 | KAI_UNUSED(kr); | |
| 49 | 5088 | KAI_UNUSED(sr); | |
| 50 | // It always points to the beginning of the row | ||
| 51 | 5088 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr); | |
| 52 | } | ||
| 53 | |||
| 54 | 4896 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
| 55 | 4896 | KAI_UNUSED(kr); | |
| 56 | 4896 | KAI_UNUSED(sr); | |
| 57 | 4896 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
| 58 | |||
| 59 | 9792 | return num_rows * kai_lhs_packed_stride(k, mr); | |
| 60 | 4896 | } | |
| 61 | |||
| 62 | // Note: The lhs parameter type has been changed from float* to void*. | ||
| 63 | // The bfloat16 values (packed in 16 bits) will be converted to float32. | ||
| 64 | 4896 | void kai_run_lhs_quant_pack_qai8dxp_bf16_neon( | |
| 65 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* restrict lhs, | ||
| 66 | size_t lhs_stride, void* restrict lhs_packed) { | ||
| 67 | − | KAI_ASSERT((kr % sr) == 0); | |
| 68 | |||
| 69 |
1/2✓ Branch 0 taken 4896 times.
✗ Branch 1 not taken.
|
4896 | if (m == 0) { |
| 70 | ✗ | return; | |
| 71 | } | ||
| 72 | |||
| 73 | // Now lhs is assumed to contain bfloat16 values encoded in uint16_t. | ||
| 74 | 4896 | const uint16_t* src_ptr = (uint16_t const*)lhs; | |
| 75 | |||
| 76 | 4896 | const size_t dst_stride = kai_lhs_packed_stride(k, mr); | |
| 77 | 4896 | const size_t k_internal = kai_k_roundedup(k); | |
| 78 | 4896 | const int32_t k_block_len = (int32_t)(kr / sr); | |
| 79 | − | KAI_ASSERT(k_block_len == 8); | |
| 80 | |||
| 81 | 4896 | const int32_t num_blocks_k = (int32_t)(k / k_block_len); | |
| 82 | 4896 | const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); | |
| 83 | |||
| 84 | 4896 | size_t row_idx = 0; | |
| 85 | |||
| 86 |
2/2✓ Branch 0 taken 2448 times.
✓ Branch 1 taken 2448 times.
|
4896 | if (mr == 4) { |
| 87 |
2/2✓ Branch 0 taken 11352 times.
✓ Branch 1 taken 2448 times.
|
13800 | for (; row_idx + 3 < m; row_idx += 4) { |
| 88 | 11352 | float max0 = -FLT_MAX; | |
| 89 | 11352 | float min0 = FLT_MAX; | |
| 90 | 11352 | float max1 = -FLT_MAX; | |
| 91 | 11352 | float min1 = FLT_MAX; | |
| 92 | 11352 | float max2 = -FLT_MAX; | |
| 93 | 11352 | float min2 = FLT_MAX; | |
| 94 | 11352 | float max3 = -FLT_MAX; | |
| 95 | 11352 | float min3 = FLT_MAX; | |
| 96 | |||
| 97 | // Find min/max for each channel | ||
| 98 | 11352 | int32_t k_idx = 0; | |
| 99 | 11352 | float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); | |
| 100 | 11352 | float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); | |
| 101 | 11352 | float32x4_t vmax1 = vmax0; | |
| 102 | 11352 | float32x4_t vmin1 = vmin0; | |
| 103 | 11352 | float32x4_t vmax2 = vmax0; | |
| 104 | 11352 | float32x4_t vmin2 = vmin0; | |
| 105 | 11352 | float32x4_t vmax3 = vmax0; | |
| 106 | 11352 | float32x4_t vmin3 = vmin0; | |
| 107 | 11352 | const uint16x8_t zero = vdupq_n_u16(0); | |
| 108 | // Process 8 bfloat16 values per iteration. | ||
| 109 |
2/2✓ Branch 0 taken 105264 times.
✓ Branch 1 taken 11352 times.
|
116616 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
| 110 | // Load eight bfloat16 values. | ||
| 111 | 105264 | const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx); | |
| 112 | 105264 | const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t))); | |
| 113 | 105264 | const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t)))); | |
| 114 | 105264 | const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t)))); | |
| 115 | |||
| 116 | 105264 | const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0); | |
| 117 | 105264 | const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0); | |
| 118 | 105264 | const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1); | |
| 119 | 105264 | const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1); | |
| 120 | 105264 | const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2); | |
| 121 | 105264 | const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2); | |
| 122 | 105264 | const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3); | |
| 123 | 105264 | const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3); | |
| 124 | |||
| 125 | 105264 | const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0); | |
| 126 | 105264 | const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0); | |
| 127 | 105264 | const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1); | |
| 128 | 105264 | const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1); | |
| 129 | 105264 | const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2); | |
| 130 | 105264 | const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2); | |
| 131 | 105264 | const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3); | |
| 132 | 105264 | const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3); | |
| 133 | |||
| 134 | // Calculate the maximum | ||
| 135 | 105264 | vmax0 = vmaxq_f32(src0_0, vmax0); | |
| 136 | 105264 | vmax0 = vmaxq_f32(vmax0, src0_1); | |
| 137 | 105264 | vmax1 = vmaxq_f32(src1_0, vmax1); | |
| 138 | 105264 | vmax1 = vmaxq_f32(vmax1, src1_1); | |
| 139 | 105264 | vmax2 = vmaxq_f32(src2_0, vmax2); | |
| 140 | 105264 | vmax2 = vmaxq_f32(vmax2, src2_1); | |
| 141 | 105264 | vmax3 = vmaxq_f32(src3_0, vmax3); | |
| 142 | 105264 | vmax3 = vmaxq_f32(vmax3, src3_1); | |
| 143 | |||
| 144 | // Calculate the minimum | ||
| 145 | 105264 | vmin0 = vminq_f32(src0_0, vmin0); | |
| 146 | 105264 | vmin0 = vminq_f32(vmin0, src0_1); | |
| 147 | 105264 | vmin1 = vminq_f32(src1_0, vmin1); | |
| 148 | 105264 | vmin1 = vminq_f32(vmin1, src1_1); | |
| 149 | 105264 | vmin2 = vminq_f32(src2_0, vmin2); | |
| 150 | 105264 | vmin2 = vminq_f32(vmin2, src2_1); | |
| 151 | 105264 | vmin3 = vminq_f32(src3_0, vmin3); | |
| 152 | 105264 | vmin3 = vminq_f32(vmin3, src3_1); | |
| 153 | 105264 | } | |
| 154 | // Get the max/min scalar values. | ||
| 155 | 11352 | max0 = vmaxvq_f32(vmax0); | |
| 156 | 11352 | min0 = vminvq_f32(vmin0); | |
| 157 | 11352 | max1 = vmaxvq_f32(vmax1); | |
| 158 | 11352 | min1 = vminvq_f32(vmin1); | |
| 159 | 11352 | max2 = vmaxvq_f32(vmax2); | |
| 160 | 11352 | min2 = vminvq_f32(vmin2); | |
| 161 | 11352 | max3 = vmaxvq_f32(vmax3); | |
| 162 | 11352 | min3 = vminvq_f32(vmin3); | |
| 163 | // Process leftover elements with a scalar loop. | ||
| 164 |
2/2✓ Branch 0 taken 21168 times.
✓ Branch 1 taken 11352 times.
|
32520 | for (; k_idx < (int32_t)k; ++k_idx) { |
| 165 | 21168 | const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); | |
| 166 | 21168 | max0 = fmaxf(src0, max0); | |
| 167 | 21168 | min0 = fminf(src0, min0); | |
| 168 | 21168 | const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx + (lhs_stride / sizeof(uint16_t)))); | |
| 169 | 21168 | max1 = fmaxf(src1, max1); | |
| 170 | 21168 | min1 = fminf(src1, min1); | |
| 171 | 21168 | const float src2 = kai_cast_f32_bf16(*(src_ptr + k_idx + (2 * (lhs_stride / sizeof(uint16_t))))); | |
| 172 | 21168 | max2 = fmaxf(src2, max2); | |
| 173 | 21168 | min2 = fminf(src2, min2); | |
| 174 | 21168 | const float src3 = kai_cast_f32_bf16(*(src_ptr + k_idx + (3 * (lhs_stride / sizeof(uint16_t))))); | |
| 175 | 21168 | max3 = fmaxf(src3, max3); | |
| 176 | 21168 | min3 = fminf(src3, min3); | |
| 177 | 21168 | } | |
| 178 | |||
| 179 | // Maximum/minimum int8 values | ||
| 180 | 11352 | const float qmin = (float)INT8_MIN; | |
| 181 | 11352 | const float qmax = (float)INT8_MAX; | |
| 182 | |||
| 183 | 11352 | const float rmin0 = fminf(0.0F, min0); | |
| 184 | 11352 | const float rmax0 = fmaxf(0.0F, max0); | |
| 185 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 11352 times.
|
11352 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
| 186 | 11352 | const float rmin1 = fminf(0.0F, min1); | |
| 187 | 11352 | const float rmax1 = fmaxf(0.0F, max1); | |
| 188 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 11352 times.
|
11352 | const float scale1 = rmin1 == rmax1 ? 1.F : (qmax - qmin) / (rmax1 - rmin1); |
| 189 | 11352 | const float rmin2 = fminf(0.0F, min2); | |
| 190 | 11352 | const float rmax2 = fmaxf(0.0F, max2); | |
| 191 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 11352 times.
|
11352 | const float scale2 = rmin2 == rmax2 ? 1.F : (qmax - qmin) / (rmax2 - rmin2); |
| 192 | 11352 | const float rmin3 = fminf(0.0F, min3); | |
| 193 | 11352 | const float rmax3 = fmaxf(0.0F, max3); | |
| 194 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 11352 times.
|
11352 | const float scale3 = rmin3 == rmax3 ? 1.F : (qmax - qmin) / (rmax3 - rmin3); |
| 195 | |||
| 196 | // Reciprocal to quantize | ||
| 197 |
1/2✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
|
11352 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
| 198 |
1/2✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
|
11352 | const float recip_scale1 = scale1 ? 1.0F / scale1 : 0.0F; |
| 199 |
1/2✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
|
11352 | const float recip_scale2 = scale2 ? 1.0F / scale2 : 0.0F; |
| 200 |
1/2✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
|
11352 | const float recip_scale3 = scale3 ? 1.0F / scale3 : 0.0F; |
| 201 | |||
| 202 | 11352 | const float descaled_min0 = rmin0 * scale0; | |
| 203 | 11352 | const float descaled_max0 = rmax0 * scale0; | |
| 204 | 11352 | const float descaled_min1 = rmin1 * scale1; | |
| 205 | 11352 | const float descaled_max1 = rmax1 * scale1; | |
| 206 | 11352 | const float descaled_min2 = rmin2 * scale2; | |
| 207 | 11352 | const float descaled_max2 = rmax2 * scale2; | |
| 208 | 11352 | const float descaled_min3 = rmin3 * scale3; | |
| 209 | 11352 | const float descaled_max3 = rmax3 * scale3; | |
| 210 | |||
| 211 | 11352 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
| 212 | 11352 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
| 213 | 11352 | const float zero_point_from_min_error1 = qmin + descaled_min1; | |
| 214 | 11352 | const float zero_point_from_max_error1 = qmax + descaled_max1; | |
| 215 | 11352 | const float zero_point_from_min_error2 = qmin + descaled_min2; | |
| 216 | 11352 | const float zero_point_from_max_error2 = qmax + descaled_max2; | |
| 217 | 11352 | const float zero_point_from_min_error3 = qmin + descaled_min3; | |
| 218 | 11352 | const float zero_point_from_max_error3 = qmax + descaled_max3; | |
| 219 | |||
| 220 |
1/2✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
|
11352 | float zero_point0 = (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 |
| 221 | ✗ | : qmax - descaled_max0; | |
| 222 |
1/2✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
|
11352 | float zero_point1 = (zero_point_from_min_error1 + zero_point_from_max_error1 > 0) ? qmin - descaled_min1 |
| 223 | ✗ | : qmax - descaled_max1; | |
| 224 |
1/2✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
|
11352 | float zero_point2 = (zero_point_from_min_error2 + zero_point_from_max_error2 > 0) ? qmin - descaled_min2 |
| 225 | ✗ | : qmax - descaled_max2; | |
| 226 |
1/2✓ Branch 0 taken 11352 times.
✗ Branch 1 not taken.
|
11352 | float zero_point3 = (zero_point_from_min_error3 + zero_point_from_max_error3 > 0) ? qmin - descaled_min3 |
| 227 | ✗ | : qmax - descaled_max3; | |
| 228 | |||
| 229 | 11352 | zero_point0 = fmaxf(zero_point0, qmin); | |
| 230 | 11352 | zero_point0 = fminf(zero_point0, qmax); | |
| 231 | 11352 | zero_point1 = fmaxf(zero_point1, qmin); | |
| 232 | 11352 | zero_point1 = fminf(zero_point1, qmax); | |
| 233 | 11352 | zero_point2 = fmaxf(zero_point2, qmin); | |
| 234 | 11352 | zero_point2 = fminf(zero_point2, qmax); | |
| 235 | 11352 | zero_point3 = fmaxf(zero_point3, qmin); | |
| 236 | 11352 | zero_point3 = fminf(zero_point3, qmax); | |
| 237 | |||
| 238 | // Round to nearest integer | ||
| 239 | 11352 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
| 240 | 11352 | const int32_t nudged_zero_point1 = (int32_t)rintf(zero_point1); | |
| 241 | 11352 | const int32_t nudged_zero_point2 = (int32_t)rintf(zero_point2); | |
| 242 | 11352 | const int32_t nudged_zero_point3 = (int32_t)rintf(zero_point3); | |
| 243 | |||
| 244 | 11352 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
| 245 | |||
| 246 | 11352 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len); | |
| 247 | |||
| 248 | // Quantize the channels | ||
| 249 | 11352 | int32_t block_idx = 0; | |
| 250 | |||
| 251 |
2/2✓ Branch 0 taken 105264 times.
✓ Branch 1 taken 11352 times.
|
116616 | for (; block_idx < num_blocks_k; ++block_idx) { |
| 252 | // Clamp at the last valid k-index | ||
| 253 | 105264 | const int32_t k_idx_start = block_idx * k_block_len; | |
| 254 | |||
| 255 | // Load eight bfloat16 values and convert them to float32. | ||
| 256 | 105264 | const uint16x8_t bf16_vec_0 = vld1q_u16(src_ptr + k_idx_start); | |
| 257 | 105264 | const uint16x8_t bf16_vec_1 = vld1q_u16(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t))); | |
| 258 | 105264 | const uint16x8_t bf16_vec_2 = vld1q_u16(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t)))); | |
| 259 | 105264 | const uint16x8_t bf16_vec_3 = vld1q_u16(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t)))); | |
| 260 | 105264 | const uint16x8_t bf16_vec1_0 = vzip1q_u16(zero, bf16_vec_0); | |
| 261 | 105264 | const uint16x8_t bf16_vec2_0 = vzip2q_u16(zero, bf16_vec_0); | |
| 262 | 105264 | const uint16x8_t bf16_vec1_1 = vzip1q_u16(zero, bf16_vec_1); | |
| 263 | 105264 | const uint16x8_t bf16_vec2_1 = vzip2q_u16(zero, bf16_vec_1); | |
| 264 | 105264 | const uint16x8_t bf16_vec1_2 = vzip1q_u16(zero, bf16_vec_2); | |
| 265 | 105264 | const uint16x8_t bf16_vec2_2 = vzip2q_u16(zero, bf16_vec_2); | |
| 266 | 105264 | const uint16x8_t bf16_vec1_3 = vzip1q_u16(zero, bf16_vec_3); | |
| 267 | 105264 | const uint16x8_t bf16_vec2_3 = vzip2q_u16(zero, bf16_vec_3); | |
| 268 | 105264 | const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1_0); | |
| 269 | 105264 | const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2_0); | |
| 270 | 105264 | const float32x4_t src1_0 = vreinterpretq_f32_u16(bf16_vec1_1); | |
| 271 | 105264 | const float32x4_t src1_1 = vreinterpretq_f32_u16(bf16_vec2_1); | |
| 272 | 105264 | const float32x4_t src2_0 = vreinterpretq_f32_u16(bf16_vec1_2); | |
| 273 | 105264 | const float32x4_t src2_1 = vreinterpretq_f32_u16(bf16_vec2_2); | |
| 274 | 105264 | const float32x4_t src3_0 = vreinterpretq_f32_u16(bf16_vec1_3); | |
| 275 | 105264 | const float32x4_t src3_1 = vreinterpretq_f32_u16(bf16_vec2_3); | |
| 276 | |||
| 277 | // Scale the values. | ||
| 278 | 105264 | const int16x4_t v0_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_0, scale0))); | |
| 279 | 105264 | const int16x4_t v1_0 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src0_1, scale0))); | |
| 280 | 105264 | const int16x4_t v0_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_0, scale1))); | |
| 281 | 105264 | const int16x4_t v1_1 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src1_1, scale1))); | |
| 282 | 105264 | const int16x4_t v0_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_0, scale2))); | |
| 283 | 105264 | const int16x4_t v1_2 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src2_1, scale2))); | |
| 284 | 105264 | const int16x4_t v0_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_0, scale3))); | |
| 285 | 105264 | const int16x4_t v1_3 = vqmovn_s32(vcvtnq_s32_f32(vmulq_n_f32(src3_1, scale3))); | |
| 286 | |||
| 287 | 105264 | int16x8_t v0_s16 = vcombine_s16(v0_0, v1_0); | |
| 288 | 105264 | int16x8_t v1_s16 = vcombine_s16(v0_1, v1_1); | |
| 289 | 105264 | int16x8_t v2_s16 = vcombine_s16(v0_2, v1_2); | |
| 290 | 105264 | int16x8_t v3_s16 = vcombine_s16(v0_3, v1_3); | |
| 291 | |||
| 292 | // Add zero points. | ||
| 293 | 105264 | const int16x8_t vnzp0 = vdupq_n_s16((int16_t)nudged_zero_point0); | |
| 294 | 105264 | const int16x8_t vnzp1 = vdupq_n_s16((int16_t)nudged_zero_point1); | |
| 295 | 105264 | const int16x8_t vnzp2 = vdupq_n_s16((int16_t)nudged_zero_point2); | |
| 296 | 105264 | const int16x8_t vnzp3 = vdupq_n_s16((int16_t)nudged_zero_point3); | |
| 297 | |||
| 298 | 105264 | v0_s16 = vaddq_s16(v0_s16, vnzp0); | |
| 299 | 105264 | v0_s16 = vmaxq_s16(v0_s16, vdupq_n_s16(INT8_MIN)); | |
| 300 | 105264 | v0_s16 = vminq_s16(v0_s16, vdupq_n_s16(INT8_MAX)); | |
| 301 | 105264 | v1_s16 = vaddq_s16(v1_s16, vnzp1); | |
| 302 | 105264 | v1_s16 = vmaxq_s16(v1_s16, vdupq_n_s16(INT8_MIN)); | |
| 303 | 105264 | v1_s16 = vminq_s16(v1_s16, vdupq_n_s16(INT8_MAX)); | |
| 304 | 105264 | v2_s16 = vaddq_s16(v2_s16, vnzp2); | |
| 305 | 105264 | v2_s16 = vmaxq_s16(v2_s16, vdupq_n_s16(INT8_MIN)); | |
| 306 | 105264 | v2_s16 = vminq_s16(v2_s16, vdupq_n_s16(INT8_MAX)); | |
| 307 | 105264 | v3_s16 = vaddq_s16(v3_s16, vnzp3); | |
| 308 | 105264 | v3_s16 = vmaxq_s16(v3_s16, vdupq_n_s16(INT8_MIN)); | |
| 309 | 105264 | v3_s16 = vminq_s16(v3_s16, vdupq_n_s16(INT8_MAX)); | |
| 310 | |||
| 311 | 105264 | const int8x8_t v0_s8 = vqmovn_s16(v0_s16); | |
| 312 | 105264 | const int8x8_t v1_s8 = vqmovn_s16(v1_s16); | |
| 313 | 105264 | const int8x8_t v2_s8 = vqmovn_s16(v2_s16); | |
| 314 | 105264 | const int8x8_t v3_s8 = vqmovn_s16(v3_s16); | |
| 315 | |||
| 316 | 105264 | vst1_s8((int8_t*)(dst_ptr), v0_s8); | |
| 317 | 105264 | vst1_s8((int8_t*)(dst_ptr + sizeof(int8x8_t)), v1_s8); | |
| 318 | 105264 | vst1_s8((int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)), v2_s8); | |
| 319 | 105264 | vst1_s8((int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)), v3_s8); | |
| 320 | 105264 | dst_ptr += 4 * sizeof(int8x8_t); | |
| 321 | 105264 | } | |
| 322 | |||
| 323 |
2/2✓ Branch 0 taken 11760 times.
✓ Branch 1 taken 11352 times.
|
23112 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
| 324 | // Left over k | ||
| 325 |
2/2✓ Branch 0 taken 94080 times.
✓ Branch 1 taken 11760 times.
|
105840 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
| 326 | // Clamp at the last valid k-index. | ||
| 327 |
2/2✓ Branch 0 taken 16464 times.
✓ Branch 1 taken 77616 times.
|
94080 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
| 328 | |||
| 329 | 94080 | const float src0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); | |
| 330 | 94080 | const float src1 = kai_cast_f32_bf16(*(src_ptr + k_idx_start + (lhs_stride / sizeof(uint16_t)))); | |
| 331 | 188160 | const float src2 = | |
| 332 | 94080 | kai_cast_f32_bf16(*(src_ptr + k_idx_start + (2 * (lhs_stride / sizeof(uint16_t))))); | |
| 333 | 188160 | const float src3 = | |
| 334 | 94080 | kai_cast_f32_bf16(*(src_ptr + k_idx_start + (3 * (lhs_stride / sizeof(uint16_t))))); | |
| 335 | |||
| 336 | // Scale the value. | ||
| 337 | 94080 | int32_t v0_s32 = (int32_t)(roundf(src0 * scale0)); | |
| 338 | 94080 | int32_t v1_s32 = (int32_t)(roundf(src1 * scale1)); | |
| 339 | 94080 | int32_t v2_s32 = (int32_t)(roundf(src2 * scale2)); | |
| 340 | 94080 | int32_t v3_s32 = (int32_t)(roundf(src3 * scale3)); | |
| 341 | |||
| 342 | 94080 | v0_s32 = v0_s32 + nudged_zero_point0; | |
| 343 |
2/2✓ Branch 0 taken 94052 times.
✓ Branch 1 taken 28 times.
|
94080 | v0_s32 = KAI_MAX(v0_s32, INT8_MIN); |
| 344 |
2/2✓ Branch 0 taken 92864 times.
✓ Branch 1 taken 1216 times.
|
94080 | v0_s32 = KAI_MIN(v0_s32, INT8_MAX); |
| 345 | |||
| 346 | 94080 | v1_s32 = v1_s32 + nudged_zero_point1; | |
| 347 |
2/2✓ Branch 0 taken 93996 times.
✓ Branch 1 taken 84 times.
|
94080 | v1_s32 = KAI_MAX(v1_s32, INT8_MIN); |
| 348 |
2/2✓ Branch 0 taken 92244 times.
✓ Branch 1 taken 1836 times.
|
94080 | v1_s32 = KAI_MIN(v1_s32, INT8_MAX); |
| 349 | |||
| 350 | 94080 | v2_s32 = v2_s32 + nudged_zero_point2; | |
| 351 |
2/2✓ Branch 0 taken 93988 times.
✓ Branch 1 taken 92 times.
|
94080 | v2_s32 = KAI_MAX(v2_s32, INT8_MIN); |
| 352 |
2/2✓ Branch 0 taken 92816 times.
✓ Branch 1 taken 1264 times.
|
94080 | v2_s32 = KAI_MIN(v2_s32, INT8_MAX); |
| 353 | |||
| 354 | 94080 | v3_s32 = v3_s32 + nudged_zero_point3; | |
| 355 |
1/2✓ Branch 0 taken 94080 times.
✗ Branch 1 not taken.
|
94080 | v3_s32 = KAI_MAX(v3_s32, INT8_MIN); |
| 356 |
2/2✓ Branch 0 taken 92932 times.
✓ Branch 1 taken 1148 times.
|
94080 | v3_s32 = KAI_MIN(v3_s32, INT8_MAX); |
| 357 | |||
| 358 | 94080 | *(int8_t*)dst_ptr = (int8_t)v0_s32; | |
| 359 | 94080 | *(int8_t*)(dst_ptr + sizeof(int8x8_t)) = (int8_t)v1_s32; | |
| 360 | 94080 | *(int8_t*)(dst_ptr + 2 * sizeof(int8x8_t)) = (int8_t)v2_s32; | |
| 361 | 94080 | *(int8_t*)(dst_ptr + 3 * sizeof(int8x8_t)) = (int8_t)v3_s32; | |
| 362 | |||
| 363 | 94080 | dst_ptr += sizeof(int8_t); | |
| 364 | 94080 | } | |
| 365 | 11760 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 366 | 11760 | } | |
| 367 | |||
| 368 | 11352 | uint8_t* dst_base = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
| 369 | |||
| 370 | 11352 | dst_ptr = dst_base + dst_x * kai_num_bytes_per_offset; | |
| 371 | |||
| 372 | // LHS offset at the beginning of the row. | ||
| 373 | 11352 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
| 374 | 11352 | *((int32_t*)(dst_ptr + kai_num_bytes_per_offset)) = -nudged_zero_point1; | |
| 375 | 11352 | *((int32_t*)(dst_ptr + 2 * kai_num_bytes_per_offset)) = -nudged_zero_point2; | |
| 376 | 11352 | *((int32_t*)(dst_ptr + 3 * kai_num_bytes_per_offset)) = -nudged_zero_point3; | |
| 377 | |||
| 378 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. | ||
| 379 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
| 380 | |||
| 381 | 11352 | dst_ptr += mr * kai_num_bytes_per_offset; | |
| 382 | |||
| 383 | // Store the scale quantization params. | ||
| 384 | 11352 | *((float*)(dst_ptr)) = recip_scale0; | |
| 385 | 11352 | *((float*)(dst_ptr + kai_num_bytes_per_multiplier)) = recip_scale1; | |
| 386 | 11352 | *((float*)(dst_ptr + 2 * kai_num_bytes_per_multiplier)) = recip_scale2; | |
| 387 | 11352 | *((float*)(dst_ptr + 3 * kai_num_bytes_per_multiplier)) = recip_scale3; | |
| 388 | |||
| 389 | // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). | ||
| 390 | 11352 | src_ptr += (4 * lhs_stride / sizeof(uint16_t)); | |
| 391 | |||
| 392 | // Move to the next row as we have interleaved all Mr rows. | ||
| 393 | 11352 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
| 394 | 11352 | } | |
| 395 | 2448 | } | |
| 396 | |||
| 397 |
2/2✓ Branch 0 taken 48240 times.
✓ Branch 1 taken 4896 times.
|
53136 | for (; row_idx < m; ++row_idx) { |
| 398 | 48240 | float max0 = -FLT_MAX; | |
| 399 | 48240 | float min0 = FLT_MAX; | |
| 400 | |||
| 401 | // Find min/max for each channel | ||
| 402 | 48240 | int32_t k_idx = 0; | |
| 403 | 48240 | float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); | |
| 404 | 48240 | float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); | |
| 405 | 48240 | const uint16x8_t zero = vdupq_n_u16(0); | |
| 406 | // Process 8 bfloat16 values per iteration. | ||
| 407 |
2/2✓ Branch 0 taken 431304 times.
✓ Branch 1 taken 48240 times.
|
479544 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
| 408 | // Load eight bfloat16 values. | ||
| 409 | 431304 | const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx); | |
| 410 | 431304 | const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); | |
| 411 | 431304 | const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); | |
| 412 | 431304 | const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); | |
| 413 | 431304 | const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); | |
| 414 | |||
| 415 | // Calculate the maximum | ||
| 416 | 431304 | vmax0 = vmaxq_f32(src0_0, vmax0); | |
| 417 | 431304 | vmax0 = vmaxq_f32(vmax0, src0_1); | |
| 418 | |||
| 419 | // Calculate the minimum | ||
| 420 | 431304 | vmin0 = vminq_f32(src0_0, vmin0); | |
| 421 | 431304 | vmin0 = vminq_f32(vmin0, src0_1); | |
| 422 | 431304 | } | |
| 423 | // Get the max/min scalar values. | ||
| 424 | 48240 | max0 = vmaxvq_f32(vmax0); | |
| 425 | 48240 | min0 = vminvq_f32(vmin0); | |
| 426 | // Process leftover elements with a scalar loop. | ||
| 427 |
2/2✓ Branch 0 taken 86616 times.
✓ Branch 1 taken 48240 times.
|
134856 | for (; k_idx < (int32_t)k; ++k_idx) { |
| 428 | 86616 | const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx)); | |
| 429 | 86616 | max0 = fmaxf(src0_0, max0); | |
| 430 | 86616 | min0 = fminf(src0_0, min0); | |
| 431 | 86616 | } | |
| 432 | |||
| 433 | // Maximum/minimum int8 values | ||
| 434 | 48240 | const float qmin = (float)INT8_MIN; | |
| 435 | 48240 | const float qmax = (float)INT8_MAX; | |
| 436 | |||
| 437 | 48240 | const float rmin0 = fminf(0.0F, min0); | |
| 438 | 48240 | const float rmax0 = fmaxf(0.0F, max0); | |
| 439 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 48240 times.
|
48240 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
| 440 | |||
| 441 | // Reciprocal to quantize | ||
| 442 |
1/2✓ Branch 0 taken 48240 times.
✗ Branch 1 not taken.
|
48240 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
| 443 | |||
| 444 | 48240 | const float descaled_min0 = rmin0 * scale0; | |
| 445 | 48240 | const float descaled_max0 = rmax0 * scale0; | |
| 446 | |||
| 447 | 48240 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
| 448 | 48240 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
| 449 | |||
| 450 | 96480 | float zero_point0 = | |
| 451 |
1/2✓ Branch 0 taken 48240 times.
✗ Branch 1 not taken.
|
48240 | (zero_point_from_min_error0 + zero_point_from_max_error0 > 0) ? qmin - descaled_min0 : qmax - descaled_max0; |
| 452 | |||
| 453 | 48240 | zero_point0 = fmaxf(zero_point0, qmin); | |
| 454 | 48240 | zero_point0 = fminf(zero_point0, qmax); | |
| 455 | |||
| 456 | // Round to nearest integer | ||
| 457 | 48240 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
| 458 | |||
| 459 | 48240 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
| 460 | |||
| 461 | 48240 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); | |
| 462 | |||
| 463 | // Quantize the channels | ||
| 464 | 48240 | int32_t block_idx = 0; | |
| 465 | |||
| 466 |
2/2✓ Branch 0 taken 431304 times.
✓ Branch 1 taken 48240 times.
|
479544 | for (; block_idx < num_blocks_k; ++block_idx) { |
| 467 | // Clamp at the last valid k-index | ||
| 468 | 431304 | const int32_t k_idx_start = block_idx * k_block_len; | |
| 469 | |||
| 470 | // Load eight bfloat16 values and convert them to float32. | ||
| 471 | 431304 | const uint16x8_t bf16_vec = vld1q_u16(src_ptr + k_idx_start); | |
| 472 | 431304 | const uint16x8_t bf16_vec1 = vzip1q_u16(zero, bf16_vec); | |
| 473 | 431304 | const uint16x8_t bf16_vec2 = vzip2q_u16(zero, bf16_vec); | |
| 474 | 431304 | const float32x4_t src0_0 = vreinterpretq_f32_u16(bf16_vec1); | |
| 475 | 431304 | const float32x4_t src0_1 = vreinterpretq_f32_u16(bf16_vec2); | |
| 476 | |||
| 477 | // Scale the values. | ||
| 478 | 431304 | const float32x4_t v0_f32 = vmulq_n_f32(src0_0, scale0); | |
| 479 | 431304 | const float32x4_t v1_f32 = vmulq_n_f32(src0_1, scale0); | |
| 480 | 431304 | const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); | |
| 481 | 431304 | const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); | |
| 482 | |||
| 483 | 431304 | const int16x4_t v0_s16 = vqmovn_s32(v0_s32); | |
| 484 | 431304 | const int16x4_t v1_s16 = vqmovn_s32(v1_s32); | |
| 485 | 431304 | int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); | |
| 486 | |||
| 487 | // Add zero points. | ||
| 488 | 431304 | int16_t nzp_s16 = (int16_t)nudged_zero_point0; | |
| 489 | 431304 | int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); | |
| 490 | 431304 | v_s16 = vaddq_s16(v_s16, vnzp_s16); | |
| 491 | 431304 | v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); | |
| 492 | 431304 | v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); | |
| 493 | |||
| 494 | 431304 | const int8x8_t v0_s8 = vqmovn_s16(v_s16); | |
| 495 | 431304 | vst1_s8((int8_t*)(dst_ptr), v0_s8); | |
| 496 | 431304 | dst_ptr += 8 * sizeof(int8_t); | |
| 497 | 431304 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 498 | 431304 | } | |
| 499 | |||
| 500 |
2/2✓ Branch 0 taken 48120 times.
✓ Branch 1 taken 48240 times.
|
96360 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
| 501 | // Left over k | ||
| 502 |
2/2✓ Branch 0 taken 384960 times.
✓ Branch 1 taken 48120 times.
|
433080 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
| 503 | // Clamp at the last valid k-index. | ||
| 504 |
2/2✓ Branch 0 taken 67368 times.
✓ Branch 1 taken 317592 times.
|
384960 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
| 505 | |||
| 506 | 384960 | const float src0_0 = kai_cast_f32_bf16(*(src_ptr + k_idx_start)); | |
| 507 | |||
| 508 | // Scale the value. | ||
| 509 | 384960 | int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); | |
| 510 | |||
| 511 | 384960 | v0_s32 = v0_s32 + nudged_zero_point0; | |
| 512 |
2/2✓ Branch 0 taken 384756 times.
✓ Branch 1 taken 204 times.
|
384960 | v0_s32 = KAI_MAX(v0_s32, INT8_MIN); |
| 513 |
2/2✓ Branch 0 taken 379420 times.
✓ Branch 1 taken 5540 times.
|
384960 | v0_s32 = KAI_MIN(v0_s32, INT8_MAX); |
| 514 | |||
| 515 | 384960 | *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; | |
| 516 | 384960 | dst_ptr += sizeof(int8_t); | |
| 517 | 384960 | } | |
| 518 | 48120 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 519 | 48120 | } | |
| 520 | |||
| 521 | 48240 | dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
| 522 | |||
| 523 | 48240 | dst_ptr += dst_x * kai_num_bytes_per_offset; | |
| 524 | |||
| 525 | // LHS offset at the beginning of the row. | ||
| 526 | 48240 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
| 527 | |||
| 528 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier. | ||
| 529 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
| 530 | |||
| 531 | 48240 | dst_ptr += mr * kai_num_bytes_per_offset; | |
| 532 | |||
| 533 | // Store the scale quantization params. | ||
| 534 | 48240 | *((float*)(dst_ptr)) = recip_scale0; | |
| 535 | |||
| 536 | // Update src_ptr. Note: now lhs contains bfloat16 values (2 bytes each). | ||
| 537 | 48240 | src_ptr += (lhs_stride / sizeof(uint16_t)); | |
| 538 | |||
| 539 | // Move to the next row if we have interleaved all Mr rows. | ||
| 540 |
2/2✓ Branch 0 taken 2304 times.
✓ Branch 1 taken 45936 times.
|
48240 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
| 541 | 45936 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
| 542 | 45936 | } | |
| 543 | 48240 | } | |
| 544 | 4896 | } | |
| 545 |