kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
| 8 | #error This file must be compiled for AArch64. | ||
| 9 | #else // Architectural features check. | ||
| 10 | |||
| 11 | #include "kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h" | ||
| 12 | |||
| 13 | #include <arm_neon.h> | ||
| 14 | #include <float.h> | ||
| 15 | #include <stddef.h> | ||
| 16 | #include <stdint.h> | ||
| 17 | |||
| 18 | #include "kai/kai_common.h" | ||
| 19 | |||
| 20 | static const size_t kai_num_bytes_sum = sizeof(float); | ||
| 21 | static const size_t kai_num_bytes_multiplier = sizeof(float); | ||
| 22 | static const size_t kai_bl_multiple_of = 32; | ||
| 23 | |||
| 24 | 20857032 | inline static size_t kai_get_num_bytes_per_block(size_t bl) { | |
| 25 | 20857032 | return bl * sizeof(int8_t) + kai_num_bytes_multiplier + kai_num_bytes_sum; | |
| 26 | } | ||
| 27 | |||
| 28 | 20857032 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
| 29 | − | KAI_ASSERT((k % bl) == 0); | |
| 30 | 20857032 | return k / bl; | |
| 31 | } | ||
| 32 | |||
| 33 | 20806884 | inline static size_t kai_get_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { | |
| 34 | 20806884 | KAI_UNUSED(kr); | |
| 35 | |||
| 36 | 20806884 | return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block(bl); | |
| 37 | } | ||
| 38 | |||
| 39 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t mr) { | |
| 40 | ✗ | return mr; | |
| 41 | } | ||
| 42 | |||
| 43 | 50148 | size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t m_idx, size_t lhs_stride) { | |
| 44 | 50148 | return m_idx * lhs_stride; | |
| 45 | } | ||
| 46 | |||
| 47 | 50148 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon( | |
| 48 | size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
| 49 | − | KAI_ASSUME((k % 2) == 0); | |
| 50 | − | KAI_ASSUME((k % kr) == 0); | |
| 51 | − | KAI_ASSUME((k % bl) == 0); | |
| 52 | − | KAI_ASSUME((m_idx % mr) == 0); | |
| 53 | |||
| 54 | 50148 | KAI_UNUSED(sr); | |
| 55 | 50148 | return (m_idx / mr) * kai_get_lhs_packed_stride(k, mr, kr, bl); | |
| 56 | } | ||
| 57 | |||
| 58 | 20706588 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon( | |
| 59 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
| 60 | − | KAI_ASSUME((k % 2) == 0); | |
| 61 | − | KAI_ASSUME((k % kr) == 0); | |
| 62 | − | KAI_ASSUME((k % bl) == 0); | |
| 63 | |||
| 64 | 20706588 | KAI_UNUSED(sr); | |
| 65 | |||
| 66 | 20706588 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
| 67 | |||
| 68 | 41413176 | return (num_rows * kai_get_lhs_packed_stride(k, mr, kr, bl)); | |
| 69 | 20706588 | } | |
| 70 | 50148 | void kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon( | |
| 71 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, | ||
| 72 | size_t lhs_stride, void* lhs_packed) { | ||
| 73 | − | KAI_ASSERT((kr % sr) == 0); | |
| 74 | − | KAI_ASSUME((bl % kr) == 0); | |
| 75 | − | KAI_ASSUME((k % bl) == 0); | |
| 76 | − | KAI_ASSUME((bl % kai_bl_multiple_of) == 0); | |
| 77 | − | KAI_ASSERT(((kr / sr) == 8) || ((kr / sr) == 4)); | |
| 78 | |||
| 79 |
1/2✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
|
50148 | if (m == 0) { |
| 80 | ✗ | return; | |
| 81 | } | ||
| 82 | 50148 | const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, mr, kr, bl); | |
| 83 | 50148 | const size_t num_rows = m; | |
| 84 | 50148 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
| 85 | 50148 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl); | |
| 86 | 50148 | const size_t mr_block_size = mr * num_bytes_per_block; | |
| 87 | |||
| 88 | 50148 | const int32_t k_block_len = (int32_t)(kr / sr); | |
| 89 | |||
| 90 |
2/2✓ Branch 0 taken 507276 times.
✓ Branch 1 taken 50148 times.
|
557424 | for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { |
| 91 | 507276 | const float* row_src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); | |
| 92 | 507276 | const size_t dst_idx = ((row_idx + m_idx_start) % mr); | |
| 93 | |||
| 94 |
2/2✓ Branch 0 taken 733824 times.
✓ Branch 1 taken 507276 times.
|
1241100 | for (size_t blk_idx = 0; blk_idx < num_blocks_per_row; ++blk_idx) { |
| 95 | 733824 | const float* src_ptr = row_src_ptr + blk_idx * bl; | |
| 96 | 733824 | int8_t* dst_ptr = (int8_t*)lhs_packed + dst_idx * k_block_len * sizeof(int8_t) + blk_idx * mr_block_size; | |
| 97 | 733824 | int8_t* param_ptr = (int8_t*)lhs_packed + blk_idx * mr_block_size + bl * mr + dst_idx * kai_num_bytes_sum; | |
| 98 | |||
| 99 | // Find absmax for each block | ||
| 100 | 733824 | float absmax = -FLT_MAX; | |
| 101 | 733824 | int32_t k_idx = 0; | |
| 102 | 733824 | float32x4_t vabsmax = vdupq_n_f32(-FLT_MAX); | |
| 103 |
2/2✓ Branch 0 taken 3817296 times.
✓ Branch 1 taken 733824 times.
|
4551120 | for (; k_idx < ((int32_t)bl); k_idx += 8) { |
| 104 | 3817296 | const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); | |
| 105 | 3817296 | const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); | |
| 106 | // Calculate the max | ||
| 107 | 3817296 | vabsmax = vmaxq_f32(vabsq_f32(src0_0), vmaxq_f32(vabsmax, vabsq_f32(src0_1))); | |
| 108 | 3817296 | } | |
| 109 | // Get the absmax | ||
| 110 | 733824 | absmax = vmaxvq_f32(vabsmax); | |
| 111 | |||
| 112 | // Maximum/minimum int8 values | ||
| 113 | 733824 | const float qmax = (float)INT8_MAX; | |
| 114 | |||
| 115 | // Get the scale and reciprocal to quantize | ||
| 116 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 733824 times.
|
733824 | const float scale0 = absmax == 0.0F ? 0.0F : qmax / absmax; |
| 117 |
1/2✓ Branch 0 taken 733824 times.
✗ Branch 1 not taken.
|
733824 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
| 118 | |||
| 119 | 733824 | int32_t qsum = 0; | |
| 120 | // Quantize the blocks | ||
| 121 |
2/2✓ Branch 0 taken 6390192 times.
✓ Branch 1 taken 733824 times.
|
7124016 | for (k_idx = 0; k_idx <= (int32_t)bl - k_block_len; k_idx += k_block_len) { |
| 122 | // Clamp at the last valid k-index | ||
| 123 |
1/2✓ Branch 0 taken 6390192 times.
✗ Branch 1 not taken.
|
6390192 | const size_t k_idx_start = KAI_MIN((size_t)k_idx, k - 1); |
| 124 |
2/2✓ Branch 0 taken 1244400 times.
✓ Branch 1 taken 5145792 times.
|
6390192 | if (k_block_len == 8) { |
| 125 | 1244400 | const float32x4_t vsrc_0 = vld1q_f32(src_ptr + k_idx_start); | |
| 126 | 1244400 | const float32x4_t vsrc_1 = vld1q_f32(src_ptr + k_idx_start + 4); | |
| 127 | |||
| 128 | // Scale the values | ||
| 129 | 1244400 | float32x4_t v0_f32 = vmulq_n_f32(vsrc_0, scale0); | |
| 130 | 1244400 | float32x4_t v1_f32 = vmulq_n_f32(vsrc_1, scale0); | |
| 131 | |||
| 132 | 1244400 | int16x4_t v0_s16 = vqmovn_s32(vcvtnq_s32_f32(v0_f32)); | |
| 133 | 1244400 | int16x4_t v1_s16 = vqmovn_s32(vcvtnq_s32_f32(v1_f32)); | |
| 134 | 1244400 | int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); | |
| 135 | |||
| 136 | 1244400 | v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); | |
| 137 | 1244400 | v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); | |
| 138 | |||
| 139 | // Update the sum | ||
| 140 | 1244400 | qsum += vaddvq_s16(v_s16); | |
| 141 | |||
| 142 | 1244400 | int8x8_t v0_s8 = vqmovn_s16(v_s16); | |
| 143 | 1244400 | vst1_s8(dst_ptr, v0_s8); | |
| 144 | 1244400 | dst_ptr += 8 * sizeof(int8_t); | |
| 145 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 5145792 times.
|
6390192 | } else if (k_block_len == 4) { |
| 146 | 5145792 | const float32x2_t vsrc_0 = vld1_f32(src_ptr + k_idx_start); | |
| 147 | 5145792 | const float32x2_t vsrc_1 = vld1_f32(src_ptr + k_idx_start + 2); | |
| 148 | |||
| 149 | // Scale the values | ||
| 150 | 5145792 | float32x2_t v0_f32 = vmul_n_f32(vsrc_0, scale0); | |
| 151 | 5145792 | float32x2_t v1_f32 = vmul_n_f32(vsrc_1, scale0); | |
| 152 | |||
| 153 | 5145792 | int32x2_t v0_s32 = vcvtn_s32_f32(v0_f32); | |
| 154 | 5145792 | int32x2_t v1_s32 = vcvtn_s32_f32(v1_f32); | |
| 155 | 5145792 | int16x4_t v_s16 = vqmovn_s32(vcombine_s32(v0_s32, v1_s32)); | |
| 156 | |||
| 157 | 5145792 | v_s16 = vmax_s16(v_s16, vdup_n_s16(INT8_MIN)); | |
| 158 | 5145792 | v_s16 = vmin_s16(v_s16, vdup_n_s16(INT8_MAX)); | |
| 159 | |||
| 160 | // Update the sum | ||
| 161 | 5145792 | qsum += vaddv_s16(v_s16); | |
| 162 | |||
| 163 | 5145792 | dst_ptr[0] = vqmovnh_s16(vget_lane_s16(v_s16, 0)); | |
| 164 | 5145792 | dst_ptr[1] = vqmovnh_s16(vget_lane_s16(v_s16, 1)); | |
| 165 | 5145792 | dst_ptr[2] = vqmovnh_s16(vget_lane_s16(v_s16, 2)); | |
| 166 | 5145792 | dst_ptr[3] = vqmovnh_s16(vget_lane_s16(v_s16, 3)); | |
| 167 | 5145792 | dst_ptr += 4 * sizeof(int8_t); | |
| 168 | 5145792 | } | |
| 169 | 6390192 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 170 | 6390192 | } | |
| 171 | 733824 | *((float*)(param_ptr)) = ((float)qsum) * recip_scale0; | |
| 172 | 733824 | param_ptr += mr * kai_num_bytes_sum; | |
| 173 | 733824 | *((float*)(param_ptr)) = recip_scale0; | |
| 174 | 733824 | } | |
| 175 | // Move to the next row if we have interleaved all Mr rows | ||
| 176 |
2/2✓ Branch 0 taken 264708 times.
✓ Branch 1 taken 242568 times.
|
507276 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
| 177 | 242568 | lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); | |
| 178 | 242568 | } | |
| 179 | 507276 | } | |
| 180 | 50148 | } | |
| 181 | #endif // Architectural features check. | ||
| 182 |