Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
8 | #error This file must be compiled for AArch64. | ||
9 | #else // Architectural features check. | ||
10 | |||
11 | #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" | ||
12 | |||
13 | #include <stddef.h> | ||
14 | #include <stdint.h> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | |||
18 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
19 | static const size_t kai_bl = 32; | ||
20 | |||
21 | 1956 | static inline void convert_s1s0_s16s0(uint8_t* dst_blk, const uint8_t* src_blk) { | |
22 | // First half | ||
23 |
2/2✓ Branch 0 taken 15648 times.
✓ Branch 1 taken 1956 times.
|
17604 | for (size_t k = 0; k < kai_bl / 2; k += 2) { |
24 | 15648 | dst_blk[k / 2] = src_blk[k] & 0xF; | |
25 | 15648 | dst_blk[k / 2] |= src_blk[k + 1] << 4; | |
26 | 15648 | } | |
27 | |||
28 | // Second half | ||
29 |
2/2✓ Branch 0 taken 1956 times.
✓ Branch 1 taken 15648 times.
|
17604 | for (size_t k = kai_bl / 2; k < kai_bl; k += 2) { |
30 | 15648 | dst_blk[k / 2] = src_blk[k - kai_bl / 2] >> 4; | |
31 | 15648 | dst_blk[k / 2] |= src_blk[k - kai_bl / 2 + 1] & 0xF0; | |
32 | 15648 | } | |
33 | 1956 | } | |
34 | |||
35 | 150 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
36 | − | KAI_ASSUME((k % 2) == 0); | |
37 | − | KAI_ASSUME(bl == kai_bl); | |
38 | 150 | return kai_roundup(k, bl) / bl; | |
39 | } | ||
40 | |||
41 | 176 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
42 | − | KAI_ASSUME(bl == kai_bl); | |
43 | |||
44 | 176 | return (bl / 2) + kai_num_bytes_multiplier; | |
45 | } | ||
46 | |||
47 | 26 | inline static size_t kai_rhs_stride(size_t k, size_t bl) { | |
48 | − | KAI_ASSUME(bl == kai_bl); | |
49 | − | KAI_ASSUME((k % 2) == 0); | |
50 | − | KAI_ASSUME((k % bl) == 0); | |
51 | |||
52 | 26 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
53 | 26 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
54 | |||
55 | 52 | return num_bytes_per_block * num_blocks_per_row; | |
56 | 26 | } | |
57 | |||
58 | 124 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon( | |
59 | size_t k, size_t nr, size_t kr, size_t bl) { | ||
60 | − | KAI_ASSUME(bl == kai_bl); | |
61 | − | KAI_ASSUME((k % 2) == 0); | |
62 | − | KAI_ASSUME((k % kr) == 0); | |
63 | − | KAI_ASSUME((k % bl) == 0); | |
64 | |||
65 | 124 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
66 | 124 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
67 | |||
68 | 248 | return nr * (num_bytes_per_block * num_blocks_per_row); | |
69 | 124 | } | |
70 | |||
71 | ✗ | size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(size_t n_idx, size_t rhs_stride) { | |
72 | ✗ | return n_idx * rhs_stride; | |
73 | } | ||
74 | |||
75 | 72 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon( | |
76 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { | ||
77 | − | KAI_ASSUME(bl == kai_bl); | |
78 | − | KAI_ASSUME((k % 2) == 0); | |
79 | − | KAI_ASSUME((k % kr) == 0); | |
80 | − | KAI_ASSUME((k % bl) == 0); | |
81 | − | KAI_ASSUME((n_idx % nr) == 0); | |
82 | |||
83 | // The scales are stored after all the nr packed quantized values | ||
84 | 72 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl); | |
85 | } | ||
86 | |||
87 | 26 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon( | |
88 | size_t n, size_t k, size_t nr, size_t kr, size_t bl) { | ||
89 | − | KAI_ASSUME(bl == kai_bl); | |
90 | − | KAI_ASSUME((k % 2) == 0); | |
91 | − | KAI_ASSUME((k % kr) == 0); | |
92 | − | KAI_ASSUME((k % bl) == 0); | |
93 | |||
94 | 26 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
95 | |||
96 | 52 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl); | |
97 | 26 | } | |
98 | |||
99 | 26 | void kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon( | |
100 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, | ||
101 | const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params) { | ||
102 | − | KAI_ASSUME(bl == kai_bl); | |
103 | − | KAI_ASSUME(num_groups == 1); | |
104 | − | KAI_ASSUME((k % 2) == 0); | |
105 | − | KAI_ASSUME((k % kr) == 0); | |
106 | − | KAI_ASSUME((k % bl) == 0); | |
107 | − | KAI_ASSUME(bias == NULL); | |
108 | − | KAI_ASSUME(extra_bytes == 0); | |
109 | |||
110 | − | KAI_ASSUME(kr == 4); | |
111 | − | KAI_ASSUME(sr == 2); | |
112 | − | KAI_ASSUME(kr >= 1 && kr <= 16); | |
113 | − | KAI_ASSUME(rhs != NULL); | |
114 | − | KAI_ASSUME(rhs_packed != NULL); | |
115 | − | KAI_ASSUME(params != NULL); | |
116 | − | KAI_ASSUME(params->rhs_zero_point == 8); | |
117 | − | KAI_ASSUME(params->lhs_zero_point == 1); | |
118 | |||
119 | // Note: The input matrix (rhs) is expected with: | ||
120 | // "k" columns and "n" rows (NxK) | ||
121 | |||
122 | 26 | const size_t num_blocks = k / bl; | |
123 | 26 | const size_t rhs_stride = kai_rhs_stride(k, bl); | |
124 | 52 | const size_t rhs_packed_stride = | |
125 | 26 | kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl); | |
126 | 26 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
127 | |||
128 | 26 | uint8_t* rhs_packed_ptr = rhs_packed; | |
129 | |||
130 |
2/2✓ Branch 0 taken 26 times.
✓ Branch 1 taken 1048 times.
|
1074 | for (uint64_t n_idx = 0; n_idx < n; n_idx++) { |
131 | 2096 | uint16_t* rhs_packed_scales = | |
132 | 1048 | (uint16_t*)(rhs_packed_ptr + rhs_packed_stride - (nr * num_blocks * kai_num_bytes_multiplier)); | |
133 | |||
134 |
2/2✓ Branch 0 taken 1956 times.
✓ Branch 1 taken 1048 times.
|
3004 | for (size_t block_idx = 0; block_idx < num_blocks; block_idx++) { |
135 | 1956 | uint8_t blk_s1s0[16]; | |
136 | |||
137 | 3912 | const uint16_t* blk_scale_ptr = | |
138 | 1956 | (const uint16_t*)(rhs + (block_idx * num_bytes_per_block) + n_idx * rhs_stride); | |
139 | 1956 | const uint8_t* blk_s16s0 = (const uint8_t*)blk_scale_ptr + kai_num_bytes_multiplier; | |
140 | |||
141 | 1956 | convert_s1s0_s16s0(blk_s1s0, blk_s16s0); | |
142 | |||
143 |
2/2✓ Branch 0 taken 15648 times.
✓ Branch 1 taken 1956 times.
|
17604 | for (size_t bl4_idx = 0; bl4_idx < bl / 4; bl4_idx++) { |
144 | // Uint16 holds 4 int4 values | ||
145 | 15648 | ((uint16_t*)rhs_packed_ptr)[(block_idx * bl / 4 + bl4_idx) * nr + (n_idx % nr)] = | |
146 | 15648 | ((int16_t*)blk_s1s0)[bl4_idx]; | |
147 | 15648 | } | |
148 | |||
149 | // Num. block (rows) x Nr (cols) | ||
150 | 1956 | rhs_packed_scales[(n_idx % nr) + block_idx * nr] = *blk_scale_ptr; | |
151 | 1956 | } | |
152 | |||
153 |
2/2✓ Branch 0 taken 1040 times.
✓ Branch 1 taken 8 times.
|
1048 | if (((n_idx + 1) % nr) == 0) { |
154 | 8 | rhs_packed_ptr += rhs_packed_stride; | |
155 | 8 | } | |
156 | 1048 | } | |
157 | 26 | } | |
158 | #endif // Architectural features check. | ||
159 |