kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
| 8 | #error This file must be compiled for AArch64. | ||
| 9 | #else // Architectural features check. | ||
| 10 | |||
| 11 | #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" | ||
| 12 | |||
| 13 | #include <math.h> | ||
| 14 | #include <stddef.h> | ||
| 15 | #include <stdint.h> | ||
| 16 | |||
| 17 | #include "kai/kai_common.h" | ||
| 18 | |||
| 19 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
| 20 | |||
| 21 | 196 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
| 22 | 196 | return bl * sizeof(int8_t) + kai_num_bytes_multiplier; | |
| 23 | } | ||
| 24 | |||
| 25 | 240 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
| 26 | − | KAI_ASSERT((k % bl) == 0); | |
| 27 | 240 | return k / bl; | |
| 28 | } | ||
| 29 | |||
| 30 | 196 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { | |
| 31 | 196 | KAI_UNUSED(kr); | |
| 32 | |||
| 33 | 196 | return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); | |
| 34 | } | ||
| 35 | |||
| 36 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qsi8d32p_f32_neon(size_t mr) { | |
| 37 | ✗ | return mr; | |
| 38 | } | ||
| 39 | |||
| 40 | 44 | size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon(size_t m_idx, size_t lhs_stride) { | |
| 41 | 44 | return m_idx * lhs_stride; | |
| 42 | } | ||
| 43 | |||
| 44 | 108 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon( | |
| 45 | size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
| 46 | − | KAI_ASSUME((k % 2) == 0); | |
| 47 | − | KAI_ASSUME((k % kr) == 0); | |
| 48 | − | KAI_ASSUME((k % bl) == 0); | |
| 49 | − | KAI_ASSUME((m_idx % mr) == 0); | |
| 50 | |||
| 51 | 108 | KAI_UNUSED(sr); | |
| 52 | 108 | KAI_UNUSED(kr); | |
| 53 | |||
| 54 | // The scales are stored after all the mr packed quantized values | ||
| 55 | 108 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); | |
| 56 | } | ||
| 57 | |||
| 58 | 44 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon( | |
| 59 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
| 60 | − | KAI_ASSUME((k % 2) == 0); | |
| 61 | − | KAI_ASSUME((k % kr) == 0); | |
| 62 | − | KAI_ASSUME((k % bl) == 0); | |
| 63 | |||
| 64 | 44 | KAI_UNUSED(sr); | |
| 65 | 44 | KAI_UNUSED(kr); | |
| 66 | |||
| 67 | 44 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
| 68 | |||
| 69 | 88 | return (num_rows * kai_lhs_packed_stride(k, mr, kr, bl)); | |
| 70 | 44 | } | |
| 71 | |||
| 72 | 44 | void kai_run_lhs_quant_pack_qsi8d32p_f32_neon( | |
| 73 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, | ||
| 74 | size_t lhs_stride, void* lhs_packed) { | ||
| 75 | − | KAI_ASSUME((bl % kr) == 0); | |
| 76 | − | KAI_ASSUME((k % bl) == 0); | |
| 77 | − | KAI_ASSUME(kr == 4); | |
| 78 | − | KAI_ASSUME(bl == 32); | |
| 79 | 44 | KAI_UNUSED(sr); | |
| 80 | 44 | KAI_UNUSED(m_idx_start); | |
| 81 | 44 | KAI_UNUSED(lhs_stride); | |
| 82 | |||
| 83 |
1/2✓ Branch 0 taken 44 times.
✗ Branch 1 not taken.
|
44 | if (m == 0) { |
| 84 | ✗ | return; | |
| 85 | } | ||
| 86 | |||
| 87 | 44 | const size_t num_blocks = kai_num_blocks_per_row(k, bl); | |
| 88 | 44 | const size_t lhs_packed_stride = kai_lhs_packed_stride(k, mr, kr, bl); | |
| 89 | |||
| 90 | 44 | const float* lhs_ptr = lhs; | |
| 91 | 44 | int8_t* lhs_packed_start_ptr = lhs_packed; | |
| 92 | |||
| 93 |
2/2✓ Branch 0 taken 616 times.
✓ Branch 1 taken 44 times.
|
660 | for (size_t m_idx = 0; m_idx < m; m_idx++) { |
| 94 | 616 | int8_t* lhs_packed_ptr = lhs_packed_start_ptr; | |
| 95 | 1232 | uint16_t* lhs_packed_scales = | |
| 96 | 616 | (uint16_t*)(lhs_packed_ptr + lhs_packed_stride - ((mr * num_blocks) * kai_num_bytes_multiplier)); | |
| 97 | |||
| 98 | 616 | lhs_packed_ptr += (m_idx % mr) * kr; | |
| 99 | 616 | lhs_packed_scales += (m_idx % mr); | |
| 100 | |||
| 101 |
2/2✓ Branch 0 taken 1148 times.
✓ Branch 1 taken 616 times.
|
1764 | for (size_t block_idx = 0; block_idx < num_blocks; block_idx++) { |
| 102 | // Maximum absolute value of the block elements | ||
| 103 | 1148 | float amax = 0.0F; | |
| 104 | |||
| 105 |
2/2✓ Branch 0 taken 36736 times.
✓ Branch 1 taken 1148 times.
|
37884 | for (size_t bl_idx = 0; bl_idx < bl; bl_idx++) { |
| 106 |
2/2✓ Branch 0 taken 31952 times.
✓ Branch 1 taken 4784 times.
|
36736 | amax = KAI_MAX(amax, fabsf(lhs_ptr[bl_idx])); |
| 107 | 36736 | } | |
| 108 | |||
| 109 | 1148 | const float sf = amax / ((1 << 7) - 1); | |
| 110 | |||
| 111 |
1/2✓ Branch 0 taken 1148 times.
✗ Branch 1 not taken.
|
1148 | const float sf_inv = sf ? 1.0F / sf : 0.0F; |
| 112 | |||
| 113 |
2/2✓ Branch 0 taken 9184 times.
✓ Branch 1 taken 1148 times.
|
10332 | for (size_t bl_idx = 0; bl_idx < bl; bl_idx += kr) { |
| 114 |
2/2✓ Branch 0 taken 36736 times.
✓ Branch 1 taken 9184 times.
|
45920 | for (size_t kr_idx = 0; kr_idx < kr; ++kr_idx) { |
| 115 | 36736 | int32_t v0_s32 = (int32_t)(roundf(lhs_ptr[kr_idx] * sf_inv)); | |
| 116 | 36736 | lhs_packed_ptr[kr_idx] = (int8_t)v0_s32; | |
| 117 | 36736 | } | |
| 118 | 9184 | lhs_ptr += kr; | |
| 119 | 9184 | lhs_packed_ptr += mr * kr; | |
| 120 | 9184 | } | |
| 121 | |||
| 122 | // Num_blocks (rows) x Mr (cols) | ||
| 123 | 1148 | lhs_packed_scales[0] = kai_cast_f16_f32(sf); | |
| 124 | |||
| 125 | 1148 | lhs_packed_scales += mr; | |
| 126 | 1148 | } | |
| 127 |
2/2✓ Branch 0 taken 576 times.
✓ Branch 1 taken 40 times.
|
616 | if (((m_idx + 1) % mr) == 0) { |
| 128 | 40 | lhs_packed_start_ptr += lhs_packed_stride; | |
| 129 | 40 | } | |
| 130 | 616 | } | |
| 131 | 44 | } | |
| 132 | #endif // Architectural features check. | ||
| 133 |