kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | #include "kai_lhs_quant_pack_qai8dxp_f32.h" | ||
| 7 | |||
| 8 | #if defined(__aarch64__) | ||
| 9 | #include <arm_neon.h> | ||
| 10 | #endif | ||
| 11 | #include <float.h> | ||
| 12 | #include <math.h> | ||
| 13 | #include <stddef.h> | ||
| 14 | #include <stdint.h> | ||
| 15 | |||
| 16 | #include "kai/kai_common.h" | ||
| 17 | |||
| 18 | static const size_t kai_num_bytes_per_multiplier = sizeof(float); | ||
| 19 | static const size_t kai_num_bytes_per_offset = sizeof(int32_t); | ||
| 20 | |||
| 21 | 84244 | inline static size_t kai_k_roundedup(size_t k) { | |
| 22 | // Round up k to be a multiple of 32. | ||
| 23 | 84244 | size_t kai_k_multiple_of = 32; | |
| 24 | 168488 | return kai_roundup(k, kai_k_multiple_of); | |
| 25 | 84244 | } | |
| 26 | |||
| 27 | 65550 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { | |
| 28 | 65550 | KAI_UNUSED(kr); | |
| 29 | 65550 | KAI_UNUSED(sr); | |
| 30 | |||
| 31 | 65550 | const size_t k_internal = kai_k_roundedup(k); | |
| 32 | |||
| 33 | − | KAI_ASSERT((k_internal % 2) == 0); | |
| 34 | |||
| 35 | 131100 | return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); | |
| 36 | 65550 | } | |
| 37 | |||
| 38 | 600 | size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr) { | |
| 39 | 600 | return mr; | |
| 40 | } | ||
| 41 | |||
| 42 | 18694 | size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride) { | |
| 43 | 18694 | return m_idx * lhs_stride; | |
| 44 | } | ||
| 45 | |||
| 46 | 28162 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | |
| 47 | // It always points to the beginning of the row | ||
| 48 | 28162 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); | |
| 49 | } | ||
| 50 | |||
| 51 | 18694 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
| 52 | 18694 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
| 53 | |||
| 54 | 37388 | return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); | |
| 55 | 18694 | } | |
| 56 | |||
| 57 | 18694 | void kai_run_lhs_quant_pack_qai8dxp_f32( | |
| 58 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs, | ||
| 59 | size_t lhs_stride, void* restrict lhs_packed) { | ||
| 60 | − | KAI_ASSERT((kr % sr) == 0); | |
| 61 | |||
| 62 |
1/2✓ Branch 0 taken 18694 times.
✗ Branch 1 not taken.
|
18694 | if (m == 0) { |
| 63 | ✗ | return; | |
| 64 | } | ||
| 65 | |||
| 66 | 18694 | const size_t num_rows = m; | |
| 67 | |||
| 68 | 18694 | const float* src_ptr = lhs; | |
| 69 | |||
| 70 | 18694 | const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); | |
| 71 | 18694 | const size_t k_internal = kai_k_roundedup(k); | |
| 72 | 18694 | const int32_t k_block_len = (int32_t)(kr / sr); | |
| 73 | |||
| 74 | 18694 | const int32_t num_blocks_k = (int32_t)(k / k_block_len); | |
| 75 | 18694 | const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); | |
| 76 | |||
| 77 |
2/2✓ Branch 0 taken 404124 times.
✓ Branch 1 taken 18694 times.
|
422818 | for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { |
| 78 | 404124 | float max0 = -FLT_MAX; | |
| 79 | 404124 | float min0 = FLT_MAX; | |
| 80 | |||
| 81 | // Find min/max for each channel | ||
| 82 | 404124 | int32_t k_idx = 0; | |
| 83 | |||
| 84 | #if defined(__aarch64__) | ||
| 85 | 404124 | float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); | |
| 86 | 404124 | float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); | |
| 87 | |||
| 88 |
2/2✓ Branch 0 taken 5236988 times.
✓ Branch 1 taken 404124 times.
|
5641112 | for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { |
| 89 | 5236988 | const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); | |
| 90 | 5236988 | const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); | |
| 91 | |||
| 92 | // Calculate the max | ||
| 93 | 5236988 | vmax0 = vmaxq_f32(src0_0, vmax0); | |
| 94 | 5236988 | vmax0 = vmaxq_f32(vmax0, src0_1); | |
| 95 | |||
| 96 | // Calculate the min | ||
| 97 | 5236988 | vmin0 = vminq_f32(src0_0, vmin0); | |
| 98 | 5236988 | vmin0 = vminq_f32(vmin0, src0_1); | |
| 99 | 5236988 | } | |
| 100 | // Get the max/min | ||
| 101 | 404124 | max0 = vmaxvq_f32(vmax0); | |
| 102 | 404124 | min0 = vminvq_f32(vmin0); | |
| 103 | #endif | ||
| 104 |
2/2✓ Branch 0 taken 472972 times.
✓ Branch 1 taken 404124 times.
|
877096 | for (; k_idx < (int32_t)k; ++k_idx) { |
| 105 | 472972 | const float src0_0 = *(src_ptr + (size_t)k_idx); | |
| 106 | 472972 | max0 = fmaxf(src0_0, max0); | |
| 107 | 472972 | min0 = fminf(src0_0, min0); | |
| 108 | 472972 | } | |
| 109 | |||
| 110 | // Maximum/minimum int8 values | ||
| 111 | 404124 | const float qmin = (float)INT8_MIN; | |
| 112 | 404124 | const float qmax = (float)INT8_MAX; | |
| 113 | |||
| 114 | 404124 | const float rmin0 = fminf(0.0F, min0); | |
| 115 | 404124 | const float rmax0 = fmaxf(0.0F, max0); | |
| 116 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 404124 times.
|
404124 | const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); |
| 117 | |||
| 118 | // Reciprocal to quantize | ||
| 119 |
1/2✓ Branch 0 taken 404124 times.
✗ Branch 1 not taken.
|
404124 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
| 120 | |||
| 121 | 404124 | const float descaled_min0 = rmin0 * scale0; | |
| 122 | 404124 | const float descaled_max0 = rmax0 * scale0; | |
| 123 | |||
| 124 | 404124 | const float zero_point_from_min_error0 = qmin + descaled_min0; | |
| 125 | 404124 | const float zero_point_from_max_error0 = qmax + descaled_max0; | |
| 126 | |||
| 127 | 808248 | float zero_point0 = | |
| 128 |
1/2✓ Branch 0 taken 404124 times.
✗ Branch 1 not taken.
|
404124 | zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; |
| 129 | |||
| 130 | 404124 | zero_point0 = fmaxf(zero_point0, qmin); | |
| 131 | 404124 | zero_point0 = fminf(zero_point0, qmax); | |
| 132 | |||
| 133 | // Round to nearest integer | ||
| 134 | 404124 | const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); | |
| 135 | |||
| 136 | 404124 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
| 137 | |||
| 138 | 404124 | uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); | |
| 139 | |||
| 140 | // Quantize the channels | ||
| 141 | 404124 | int32_t block_idx = 0; | |
| 142 | |||
| 143 | #if defined(__aarch64__) | ||
| 144 |
2/2✓ Branch 0 taken 212602 times.
✓ Branch 1 taken 191522 times.
|
404124 | if (k_block_len == 8) { |
| 145 |
2/2✓ Branch 0 taken 2622004 times.
✓ Branch 1 taken 212602 times.
|
2834606 | for (; block_idx < num_blocks_k; ++block_idx) { |
| 146 | // Clamp at the last valid k-index | ||
| 147 | 2622004 | const int32_t k_idx_start = block_idx * k_block_len; | |
| 148 | |||
| 149 | 2622004 | const float32x4_t src_0 = vld1q_f32(src_ptr + k_idx_start); | |
| 150 | 2622004 | const float32x4_t src_1 = vld1q_f32(src_ptr + k_idx_start + 4); | |
| 151 | |||
| 152 | // Scale the values | ||
| 153 | 2622004 | float32x4_t v0_f32 = vmulq_n_f32(src_0, scale0); | |
| 154 | 2622004 | float32x4_t v1_f32 = vmulq_n_f32(src_1, scale0); | |
| 155 | 2622004 | int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); | |
| 156 | 2622004 | int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); | |
| 157 | |||
| 158 | 2622004 | int16x4_t v0_s16 = vqmovn_s32(v0_s32); | |
| 159 | 2622004 | int16x4_t v1_s16 = vqmovn_s32(v1_s32); | |
| 160 | 2622004 | int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); | |
| 161 | |||
| 162 | // Add zero points | ||
| 163 | 2622004 | int16_t nzp_s16 = (int16_t)nudged_zero_point0; | |
| 164 | 2622004 | int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16); | |
| 165 | 2622004 | v_s16 = vaddq_s16(v_s16, vnzp_s16); | |
| 166 | 2622004 | v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); | |
| 167 | 2622004 | v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); | |
| 168 | |||
| 169 | 2622004 | int8x8_t v0_s8 = vqmovn_s16(v_s16); | |
| 170 | 2622004 | vst1_s8((int8_t*)(dst_ptr), v0_s8); | |
| 171 | 2622004 | dst_ptr += 8 * sizeof(int8_t); | |
| 172 | 2622004 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 173 | 2622004 | } | |
| 174 | 212602 | } else | |
| 175 | #endif | ||
| 176 | { | ||
| 177 |
2/2✓ Branch 0 taken 5271502 times.
✓ Branch 1 taken 191522 times.
|
5463024 | for (; block_idx < num_blocks_k; ++block_idx) { |
| 178 |
2/2✓ Branch 0 taken 21086008 times.
✓ Branch 1 taken 5271502 times.
|
26357510 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
| 179 | 21086008 | const int32_t k_idx_start = (block_idx * k_block_len) + k_block_idx; | |
| 180 | |||
| 181 | 21086008 | const float src0_0 = *(src_ptr + k_idx_start); | |
| 182 | |||
| 183 | // Scale the values | ||
| 184 | 21086008 | int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); | |
| 185 | |||
| 186 | 21086008 | v0_s32 = v0_s32 + nudged_zero_point0; | |
| 187 |
2/2✓ Branch 0 taken 21045251 times.
✓ Branch 1 taken 40757 times.
|
21086008 | v0_s32 = KAI_MAX(v0_s32, INT8_MIN); |
| 188 |
2/2✓ Branch 0 taken 20859772 times.
✓ Branch 1 taken 226236 times.
|
21086008 | v0_s32 = KAI_MIN(v0_s32, INT8_MAX); |
| 189 | |||
| 190 | 21086008 | *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; | |
| 191 | 21086008 | dst_ptr += sizeof(int8_t); | |
| 192 | 21086008 | } | |
| 193 | 5271502 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 194 | 5271502 | } | |
| 195 | } | ||
| 196 | |||
| 197 |
2/2✓ Branch 0 taken 704038 times.
✓ Branch 1 taken 404124 times.
|
1108162 | for (; block_idx < num_blocks_k_internal; ++block_idx) { |
| 198 | // left over k | ||
| 199 |
2/2✓ Branch 0 taken 3704392 times.
✓ Branch 1 taken 704038 times.
|
4408430 | for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
| 200 | // Clamp at the last valid k-index | ||
| 201 |
2/2✓ Branch 0 taken 201798 times.
✓ Branch 1 taken 3502594 times.
|
3704392 | const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1); |
| 202 | |||
| 203 | 3704392 | const float src0_0 = *(src_ptr + k_idx_start); | |
| 204 | |||
| 205 | // Scale the values | ||
| 206 | 3704392 | int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); | |
| 207 | |||
| 208 | 3704392 | v0_s32 = v0_s32 + nudged_zero_point0; | |
| 209 |
2/2✓ Branch 0 taken 3703101 times.
✓ Branch 1 taken 1291 times.
|
3704392 | v0_s32 = KAI_MAX(v0_s32, INT8_MIN); |
| 210 |
2/2✓ Branch 0 taken 3595023 times.
✓ Branch 1 taken 109369 times.
|
3704392 | v0_s32 = KAI_MIN(v0_s32, INT8_MAX); |
| 211 | |||
| 212 | 3704392 | *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; | |
| 213 | 3704392 | dst_ptr += sizeof(int8_t); | |
| 214 | 3704392 | } | |
| 215 | 704038 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 216 | 704038 | } | |
| 217 | |||
| 218 | 404124 | dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); | |
| 219 | |||
| 220 | 404124 | dst_ptr += dst_x * kai_num_bytes_per_offset; | |
| 221 | |||
| 222 | // LHS offset at the beginning of the row | ||
| 223 | 404124 | *((int32_t*)(dst_ptr)) = -nudged_zero_point0; | |
| 224 | |||
| 225 | // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier | ||
| 226 | − | KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); | |
| 227 | |||
| 228 | 404124 | dst_ptr += mr * kai_num_bytes_per_offset; | |
| 229 | |||
| 230 | // Store the scale quantization params | ||
| 231 | 404124 | *((float*)(dst_ptr)) = recip_scale0; | |
| 232 | |||
| 233 | 404124 | src_ptr += (lhs_stride / sizeof(float)); | |
| 234 | |||
| 235 | // Move to the next row if we have interleaved all Mr rows | ||
| 236 |
2/2✓ Branch 0 taken 197228 times.
✓ Branch 1 taken 206896 times.
|
404124 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
| 237 | 206896 | lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); | |
| 238 | 206896 | } | |
| 239 | 404124 | } | |
| 240 | 18694 | } | |
| 241 |