kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" | ||
| 7 | |||
| 8 | #include <stddef.h> | ||
| 9 | #include <stdint.h> | ||
| 10 | #include <string.h> | ||
| 11 | |||
| 12 | #include "kai/kai_common.h" | ||
| 13 | |||
| 14 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
| 15 | static const size_t kai_bl = 32; | ||
| 16 | |||
| 17 | 5520 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
| 18 | − | KAI_ASSUME((k % 2) == 0); | |
| 19 | − | KAI_ASSUME(bl == kai_bl); | |
| 20 | 5520 | return kai_roundup(k, bl) / bl; | |
| 21 | } | ||
| 22 | |||
| 23 | 4644 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
| 24 | − | KAI_ASSUME(bl == kai_bl); | |
| 25 | 4644 | return (bl / 2) + kai_num_bytes_multiplier; | |
| 26 | } | ||
| 27 | |||
| 28 | 876 | inline static size_t kai_rhs_stride(size_t k, size_t bl) { | |
| 29 | − | KAI_ASSUME(bl == kai_bl); | |
| 30 | − | KAI_ASSUME((k % 2) == 0); | |
| 31 | − | KAI_ASSUME((k % bl) == 0); | |
| 32 | |||
| 33 | 876 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
| 34 | 876 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
| 35 | |||
| 36 | 1752 | return num_bytes_per_block * num_blocks_per_row; | |
| 37 | 876 | } | |
| 38 | |||
| 39 | 3768 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(size_t k, size_t nr, size_t kr, size_t bl) { | |
| 40 | − | KAI_ASSUME(bl == kai_bl); | |
| 41 | − | KAI_ASSUME((k % 2) == 0); | |
| 42 | − | KAI_ASSUME((k % kr) == 0); | |
| 43 | − | KAI_ASSUME((k % bl) == 0); | |
| 44 | |||
| 45 | 3768 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
| 46 | 3768 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
| 47 | |||
| 48 | 7536 | return nr * (num_bytes_per_block * num_blocks_per_row); | |
| 49 | 3768 | } | |
| 50 | |||
| 51 | ✗ | size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(size_t n_idx, size_t rhs_stride) { | |
| 52 | ✗ | return n_idx * rhs_stride; | |
| 53 | } | ||
| 54 | |||
| 55 | 2016 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( | |
| 56 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { | ||
| 57 | − | KAI_ASSUME(bl == kai_bl); | |
| 58 | − | KAI_ASSUME((k % 2) == 0); | |
| 59 | − | KAI_ASSUME((k % kr) == 0); | |
| 60 | − | KAI_ASSUME((k % bl) == 0); | |
| 61 | − | KAI_ASSUME((n_idx % nr) == 0); | |
| 62 | |||
| 63 | 2016 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl); | |
| 64 | } | ||
| 65 | |||
| 66 | 876 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( | |
| 67 | size_t n, size_t k, size_t nr, size_t kr, size_t bl) { | ||
| 68 | − | KAI_ASSUME(bl == kai_bl); | |
| 69 | − | KAI_ASSUME((k % 2) == 0); | |
| 70 | − | KAI_ASSUME((k % kr) == 0); | |
| 71 | − | KAI_ASSUME((k % bl) == 0); | |
| 72 | |||
| 73 | 876 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
| 74 | |||
| 75 | 1752 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl); | |
| 76 | 876 | } | |
| 77 | |||
| 78 | 876 | void kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( | |
| 79 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, | ||
| 80 | const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params) { | ||
| 81 | − | KAI_ASSUME(bl == kai_bl); | |
| 82 | − | KAI_ASSUME(num_groups == 1); | |
| 83 | − | KAI_ASSUME((k % 2) == 0); | |
| 84 | − | KAI_ASSUME((k % kr) == 0); | |
| 85 | − | KAI_ASSUME((k % bl) == 0); | |
| 86 | − | KAI_ASSUME(bias == NULL); | |
| 87 | − | KAI_ASSUME(extra_bytes == 0); | |
| 88 | |||
| 89 | − | KAI_ASSUME(sr == 2); | |
| 90 | − | KAI_ASSUME(kr >= 1 && kr <= 16); | |
| 91 | − | KAI_ASSUME(rhs != NULL); | |
| 92 | − | KAI_ASSUME(rhs_packed != NULL); | |
| 93 | − | KAI_ASSUME(params != NULL); | |
| 94 | − | KAI_ASSUME(params->rhs_zero_point == 8); | |
| 95 | − | KAI_ASSUME(params->lhs_zero_point == 1); | |
| 96 | |||
| 97 | // Note: The input matrix (rhs) is expected with: | ||
| 98 | // "k" columns and "n" rows (NxK) | ||
| 99 | |||
| 100 | 876 | const size_t rhs_stride = kai_rhs_stride(k, bl); | |
| 101 | 1752 | const size_t rhs_packed_stride = | |
| 102 | 876 | kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl); | |
| 103 | 876 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
| 104 | 876 | const size_t num_segments_per_block = bl / kr; | |
| 105 | 876 | const size_t num_bytes_per_segment = kr / 2; | |
| 106 | |||
| 107 |
2/2✓ Branch 0 taken 876 times.
✓ Branch 1 taken 8308 times.
|
9184 | for (size_t y = 0; y < n; y += nr) { |
| 108 | 8308 | const uint8_t* src_row = rhs; | |
| 109 | 8308 | uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; | |
| 110 | |||
| 111 |
2/2✓ Branch 0 taken 14076 times.
✓ Branch 1 taken 8308 times.
|
22384 | for (size_t x = 0; x < num_blocks_per_row; ++x) { |
| 112 | // Store the scales at the end of the block | ||
| 113 | 14076 | uint8_t* scales = (dst_row); | |
| 114 | |||
| 115 |
2/2✓ Branch 0 taken 61088 times.
✓ Branch 1 taken 14076 times.
|
75164 | for (size_t i = 0; i < nr; ++i) { |
| 116 |
2/2✓ Branch 0 taken 58104 times.
✓ Branch 1 taken 2984 times.
|
61088 | const size_t src_row_idx = KAI_MIN(y + i, n - 1); |
| 117 | 61088 | memcpy( | |
| 118 | 35328 | scales + i * kai_num_bytes_multiplier, src_row + src_row_idx * rhs_stride, | |
| 119 | kai_num_bytes_multiplier); | ||
| 120 | 61088 | } | |
| 121 | 14076 | src_row += kai_num_bytes_multiplier; | |
| 122 | |||
| 123 | 14076 | dst_row += (kai_num_bytes_multiplier * nr); | |
| 124 | |||
| 125 | // Store the segments | ||
| 126 |
2/2✓ Branch 0 taken 30976 times.
✓ Branch 1 taken 14076 times.
|
45052 | for (size_t s = 0; s < num_segments_per_block; ++s) { |
| 127 |
2/2✓ Branch 0 taken 134400 times.
✓ Branch 1 taken 30976 times.
|
165376 | for (size_t i = 0; i < nr; ++i) { |
| 128 |
2/2✓ Branch 0 taken 127512 times.
✓ Branch 1 taken 6888 times.
|
134400 | const size_t src_row_idx = KAI_MIN(y + i, n - 1); |
| 129 | |||
| 130 |
2/2✓ Branch 0 taken 24448 times.
✓ Branch 1 taken 109952 times.
|
134400 | if (num_bytes_per_segment == sizeof(uint32_t)) { |
| 131 | 24448 | uint32_t tmp = 0; | |
| 132 | 24448 | memcpy(&tmp, src_row + src_row_idx * rhs_stride, num_bytes_per_segment); | |
| 133 | 24448 | tmp = tmp ^ 0x88888888; | |
| 134 | 24448 | memcpy(dst_row + i * num_bytes_per_segment, &tmp, num_bytes_per_segment); | |
| 135 |
1/2✓ Branch 0 taken 109952 times.
✗ Branch 1 not taken.
|
134400 | } else if (num_bytes_per_segment == sizeof(uint64_t)) { |
| 136 | 109952 | uint64_t tmp = 0; | |
| 137 | 109952 | memcpy(&tmp, src_row + src_row_idx * rhs_stride, num_bytes_per_segment); | |
| 138 | 109952 | tmp = tmp ^ 0x8888888888888888ULL; | |
| 139 | 109952 | memcpy(dst_row + i * num_bytes_per_segment, &tmp, num_bytes_per_segment); | |
| 140 | 109952 | } else { | |
| 141 | ✗ | memcpy( | |
| 142 | ✗ | dst_row + i * num_bytes_per_segment, src_row + src_row_idx * rhs_stride, | |
| 143 | ✗ | num_bytes_per_segment); | |
| 144 | |||
| 145 | ✗ | for (size_t b = 0; b < num_bytes_per_segment; ++b) { | |
| 146 | ✗ | uint8_t qs = dst_row[i * num_bytes_per_segment + b]; | |
| 147 | // Add offset (0x88) | ||
| 148 | ✗ | dst_row[i * num_bytes_per_segment + b] = qs ^ 0x88; | |
| 149 | ✗ | } | |
| 150 | } | ||
| 151 | 134400 | } | |
| 152 | |||
| 153 | 30976 | src_row += num_bytes_per_segment; | |
| 154 | 30976 | dst_row += num_bytes_per_segment * nr; | |
| 155 | 30976 | } | |
| 156 | 14076 | } | |
| 157 | 8308 | } | |
| 158 | 876 | } | |
| 159 |