kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | #include "kai_lhs_quant_pack_qsi8d32p_f32.h" | ||
| 7 | |||
| 8 | #include <math.h> | ||
| 9 | #include <stddef.h> | ||
| 10 | #include <stdint.h> | ||
| 11 | |||
| 12 | #include "kai/kai_common.h" | ||
| 13 | |||
| 14 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
| 15 | |||
| 16 | 3724 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
| 17 | 3724 | return bl * sizeof(int8_t) + kai_num_bytes_multiplier; | |
| 18 | } | ||
| 19 | |||
| 20 | 3724 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
| 21 | − | KAI_ASSERT((k % bl) == 0); | |
| 22 | 3724 | return k / bl; | |
| 23 | } | ||
| 24 | |||
| 25 | 3032 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { | |
| 26 | 3032 | KAI_UNUSED(kr); | |
| 27 | 3032 | return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); | |
| 28 | } | ||
| 29 | |||
| 30 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qsi8d32p_f32(size_t mr) { | |
| 31 | ✗ | return mr; | |
| 32 | } | ||
| 33 | |||
| 34 | 692 | size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32(size_t m_idx, size_t lhs_stride) { | |
| 35 | 692 | return m_idx * lhs_stride; | |
| 36 | } | ||
| 37 | |||
| 38 | 1648 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32( | |
| 39 | size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
| 40 | − | KAI_ASSUME((k % 2) == 0); | |
| 41 | − | KAI_ASSUME((k % kr) == 0); | |
| 42 | − | KAI_ASSUME((k % bl) == 0); | |
| 43 | |||
| 44 | 1648 | KAI_UNUSED(sr); | |
| 45 | 1648 | KAI_UNUSED(kr); | |
| 46 | |||
| 47 | 1648 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); | |
| 48 | } | ||
| 49 | |||
| 50 | 692 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32( | |
| 51 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
| 52 | − | KAI_ASSUME((k % 2) == 0); | |
| 53 | − | KAI_ASSUME((k % kr) == 0); | |
| 54 | − | KAI_ASSUME((k % bl) == 0); | |
| 55 | |||
| 56 | 692 | KAI_UNUSED(sr); | |
| 57 | 692 | KAI_UNUSED(kr); | |
| 58 | |||
| 59 | 692 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
| 60 | |||
| 61 | 1384 | return num_rows * kai_lhs_packed_stride(k, mr, kr, bl); | |
| 62 | 692 | } | |
| 63 | |||
| 64 | 692 | void kai_run_lhs_quant_pack_qsi8d32p_f32( | |
| 65 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, | ||
| 66 | size_t lhs_stride, void* lhs_packed) { | ||
| 67 |
1/2✓ Branch 0 taken 692 times.
✗ Branch 1 not taken.
|
692 | if (m == 0) { |
| 68 | ✗ | return; | |
| 69 | } | ||
| 70 | |||
| 71 | 692 | const size_t num_rows = m; | |
| 72 | 692 | const size_t k_block_len = kr / sr; | |
| 73 | 692 | const size_t lhs_packed_stride = kai_lhs_packed_stride(k, mr, kr, bl); | |
| 74 | 692 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
| 75 | 692 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
| 76 | |||
| 77 |
2/2✓ Branch 0 taken 12076 times.
✓ Branch 1 taken 692 times.
|
12768 | for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { |
| 78 | 12076 | const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); | |
| 79 | |||
| 80 |
2/2✓ Branch 0 taken 23360 times.
✓ Branch 1 taken 12076 times.
|
35436 | for (size_t b = 0; b < num_blocks_per_row; ++b) { |
| 81 | 23360 | float abs_max = 0.0F; | |
| 82 | |||
| 83 | 23360 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
| 84 | 23360 | int8_t* dst_ptr = (int8_t*)lhs_packed + (b * mr) * num_bytes_per_block; | |
| 85 | |||
| 86 |
2/2✓ Branch 0 taken 747520 times.
✓ Branch 1 taken 23360 times.
|
770880 | for (size_t idx_v = 0; idx_v < bl; ++idx_v) { |
| 87 | 747520 | const float val = src_ptr[idx_v]; | |
| 88 |
2/2✓ Branch 0 taken 651584 times.
✓ Branch 1 taken 95936 times.
|
747520 | abs_max = KAI_MAX(abs_max, fabsf(val)); |
| 89 | 747520 | } | |
| 90 | |||
| 91 | // Calculate scale and reciprocal | ||
| 92 | 23360 | const float scale = abs_max / ((1 << 7) - 1); | |
| 93 |
1/2✓ Branch 0 taken 23360 times.
✗ Branch 1 not taken.
|
23360 | const float rep_scale = scale ? 1.0F / scale : 0.0F; |
| 94 | |||
| 95 | 23360 | *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale); | |
| 96 | 23360 | dst_ptr += mr * kai_num_bytes_multiplier; | |
| 97 | |||
| 98 | 23360 | dst_ptr += dst_x * k_block_len * sizeof(int8_t); | |
| 99 | |||
| 100 | // Quantize and pack the block | ||
| 101 |
2/2✓ Branch 0 taken 102768 times.
✓ Branch 1 taken 23360 times.
|
126128 | for (size_t k_idx = 0; k_idx < bl; k_idx += k_block_len) { |
| 102 |
2/2✓ Branch 0 taken 747520 times.
✓ Branch 1 taken 102768 times.
|
850288 | for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
| 103 | // Clamp at the last valid k-index | ||
| 104 |
2/2✓ Branch 0 taken 746728 times.
✓ Branch 1 taken 792 times.
|
747520 | const size_t k_idx_start = KAI_MIN(k_idx + k_block_idx, k - 1); |
| 105 | |||
| 106 | 747520 | const float src0_0 = *(src_ptr + k_idx_start); | |
| 107 | |||
| 108 | // Scale the values | ||
| 109 | 747520 | int32_t v0_s32 = (int32_t)(roundf(src0_0 * rep_scale)); | |
| 110 | |||
| 111 | 747520 | *dst_ptr = (int8_t)v0_s32; | |
| 112 | 747520 | dst_ptr += sizeof(int8_t); | |
| 113 | 747520 | } | |
| 114 | 102768 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
| 115 | 102768 | } | |
| 116 | |||
| 117 | 23360 | src_ptr += bl; | |
| 118 | 23360 | } | |
| 119 | // Move to the next row if we have interleaved all Mr rows | ||
| 120 |
2/2✓ Branch 0 taken 9040 times.
✓ Branch 1 taken 3036 times.
|
12076 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
| 121 | 3036 | lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); | |
| 122 | 3036 | } | |
| 123 | 12076 | } | |
| 124 | 692 | } | |
| 125 |