kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
| 8 | #error This file must be compiled for AArch64. | ||
| 9 | #else // Architectural features check. | ||
| 10 | |||
| 11 | #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" | ||
| 12 | |||
| 13 | #include <stddef.h> | ||
| 14 | #include <stdint.h> | ||
| 15 | |||
| 16 | #include "kai/kai_common.h" | ||
| 17 | |||
| 18 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
| 19 | static const size_t kai_bl = 32; | ||
| 20 | |||
| 21 | 2544 | static inline void convert_s1s0_s16s0(uint8_t* dst_blk, const uint8_t* src_blk) { | |
| 22 | // First half | ||
| 23 |
2/2✓ Branch 0 taken 20352 times.
✓ Branch 1 taken 2544 times.
|
22896 | for (size_t k = 0; k < kai_bl / 2; k += 2) { |
| 24 | 20352 | dst_blk[k / 2] = src_blk[k] & 0xF; | |
| 25 | 20352 | dst_blk[k / 2] |= src_blk[k + 1] << 4; | |
| 26 | 20352 | } | |
| 27 | |||
| 28 | // Second half | ||
| 29 |
2/2✓ Branch 0 taken 2544 times.
✓ Branch 1 taken 20352 times.
|
22896 | for (size_t k = kai_bl / 2; k < kai_bl; k += 2) { |
| 30 | 20352 | dst_blk[k / 2] = src_blk[k - kai_bl / 2] >> 4; | |
| 31 | 20352 | dst_blk[k / 2] |= src_blk[k - kai_bl / 2 + 1] & 0xF0; | |
| 32 | 20352 | } | |
| 33 | 2544 | } | |
| 34 | |||
| 35 | 240 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
| 36 | − | KAI_ASSUME((k % 2) == 0); | |
| 37 | − | KAI_ASSUME(bl == kai_bl); | |
| 38 | 240 | return kai_roundup(k, bl) / bl; | |
| 39 | } | ||
| 40 | |||
| 41 | 284 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
| 42 | − | KAI_ASSUME(bl == kai_bl); | |
| 43 | |||
| 44 | 284 | return (bl / 2) + kai_num_bytes_multiplier; | |
| 45 | } | ||
| 46 | |||
| 47 | 44 | inline static size_t kai_rhs_stride(size_t k, size_t bl) { | |
| 48 | − | KAI_ASSUME(bl == kai_bl); | |
| 49 | − | KAI_ASSUME((k % 2) == 0); | |
| 50 | − | KAI_ASSUME((k % bl) == 0); | |
| 51 | |||
| 52 | 44 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
| 53 | 44 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
| 54 | |||
| 55 | 88 | return num_bytes_per_block * num_blocks_per_row; | |
| 56 | 44 | } | |
| 57 | |||
| 58 | 196 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon( | |
| 59 | size_t k, size_t nr, size_t kr, size_t bl) { | ||
| 60 | − | KAI_ASSUME(bl == kai_bl); | |
| 61 | − | KAI_ASSUME((k % 2) == 0); | |
| 62 | − | KAI_ASSUME((k % kr) == 0); | |
| 63 | − | KAI_ASSUME((k % bl) == 0); | |
| 64 | |||
| 65 | 196 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
| 66 | 196 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
| 67 | |||
| 68 | 392 | return nr * (num_bytes_per_block * num_blocks_per_row); | |
| 69 | 196 | } | |
| 70 | |||
| 71 | ✗ | size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(size_t n_idx, size_t rhs_stride) { | |
| 72 | ✗ | return n_idx * rhs_stride; | |
| 73 | } | ||
| 74 | |||
| 75 | 108 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon( | |
| 76 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { | ||
| 77 | − | KAI_ASSUME(bl == kai_bl); | |
| 78 | − | KAI_ASSUME((k % 2) == 0); | |
| 79 | − | KAI_ASSUME((k % kr) == 0); | |
| 80 | − | KAI_ASSUME((k % bl) == 0); | |
| 81 | − | KAI_ASSUME((n_idx % nr) == 0); | |
| 82 | |||
| 83 | // The scales are stored after all the nr packed quantized values | ||
| 84 | 108 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl); | |
| 85 | } | ||
| 86 | |||
| 87 | 44 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon( | |
| 88 | size_t n, size_t k, size_t nr, size_t kr, size_t bl) { | ||
| 89 | − | KAI_ASSUME(bl == kai_bl); | |
| 90 | − | KAI_ASSUME((k % 2) == 0); | |
| 91 | − | KAI_ASSUME((k % kr) == 0); | |
| 92 | − | KAI_ASSUME((k % bl) == 0); | |
| 93 | |||
| 94 | 44 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
| 95 | |||
| 96 | 88 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl); | |
| 97 | 44 | } | |
| 98 | |||
| 99 | 44 | void kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon( | |
| 100 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, | ||
| 101 | const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params) { | ||
| 102 | − | KAI_ASSUME(bl == kai_bl); | |
| 103 | − | KAI_ASSUME(num_groups == 1); | |
| 104 | − | KAI_ASSUME((k % 2) == 0); | |
| 105 | − | KAI_ASSUME((k % kr) == 0); | |
| 106 | − | KAI_ASSUME((k % bl) == 0); | |
| 107 | − | KAI_ASSUME(bias == NULL); | |
| 108 | − | KAI_ASSUME(extra_bytes == 0); | |
| 109 | |||
| 110 | − | KAI_ASSUME(kr == 4); | |
| 111 | − | KAI_ASSUME(sr == 2); | |
| 112 | − | KAI_ASSUME(kr >= 1 && kr <= 16); | |
| 113 | − | KAI_ASSUME(rhs != NULL); | |
| 114 | − | KAI_ASSUME(rhs_packed != NULL); | |
| 115 | − | KAI_ASSUME(params != NULL); | |
| 116 | − | KAI_ASSUME(params->rhs_zero_point == 8); | |
| 117 | − | KAI_ASSUME(params->lhs_zero_point == 1); | |
| 118 | |||
| 119 | // Note: The input matrix (rhs) is expected with: | ||
| 120 | // "k" columns and "n" rows (NxK) | ||
| 121 | |||
| 122 | 44 | const size_t num_blocks = k / bl; | |
| 123 | 44 | const size_t rhs_stride = kai_rhs_stride(k, bl); | |
| 124 | 88 | const size_t rhs_packed_stride = | |
| 125 | 44 | kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl); | |
| 126 | 44 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
| 127 | |||
| 128 | 44 | uint8_t* rhs_packed_ptr = rhs_packed; | |
| 129 | |||
| 130 |
2/2✓ Branch 0 taken 44 times.
✓ Branch 1 taken 1636 times.
|
1680 | for (uint64_t n_idx = 0; n_idx < n; n_idx++) { |
| 131 | 3272 | uint16_t* rhs_packed_scales = | |
| 132 | 1636 | (uint16_t*)(rhs_packed_ptr + rhs_packed_stride - (nr * num_blocks * kai_num_bytes_multiplier)); | |
| 133 | |||
| 134 |
2/2✓ Branch 0 taken 2544 times.
✓ Branch 1 taken 1636 times.
|
4180 | for (size_t block_idx = 0; block_idx < num_blocks; block_idx++) { |
| 135 | 2544 | uint8_t blk_s1s0[16]; | |
| 136 | |||
| 137 | 5088 | const uint16_t* blk_scale_ptr = | |
| 138 | 2544 | (const uint16_t*)(rhs + (block_idx * num_bytes_per_block) + n_idx * rhs_stride); | |
| 139 | 2544 | const uint8_t* blk_s16s0 = (const uint8_t*)blk_scale_ptr + kai_num_bytes_multiplier; | |
| 140 | |||
| 141 | 2544 | convert_s1s0_s16s0(blk_s1s0, blk_s16s0); | |
| 142 | |||
| 143 |
2/2✓ Branch 0 taken 20352 times.
✓ Branch 1 taken 2544 times.
|
22896 | for (size_t bl4_idx = 0; bl4_idx < bl / 4; bl4_idx++) { |
| 144 | // Uint16 holds 4 int4 values | ||
| 145 | 20352 | ((uint16_t*)rhs_packed_ptr)[(block_idx * bl / 4 + bl4_idx) * nr + (n_idx % nr)] = | |
| 146 | 20352 | ((int16_t*)blk_s1s0)[bl4_idx]; | |
| 147 | 20352 | } | |
| 148 | |||
| 149 | // Num. block (rows) x Nr (cols) | ||
| 150 | 2544 | rhs_packed_scales[(n_idx % nr) + block_idx * nr] = *blk_scale_ptr; | |
| 151 | 2544 | } | |
| 152 | |||
| 153 |
2/2✓ Branch 0 taken 1628 times.
✓ Branch 1 taken 8 times.
|
1636 | if (((n_idx + 1) % nr) == 0) { |
| 154 | 8 | rhs_packed_ptr += rhs_packed_stride; | |
| 155 | 8 | } | |
| 156 | 1636 | } | |
| 157 | 44 | } | |
| 158 | #endif // Architectural features check. | ||
| 159 |