Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" | ||
7 | |||
8 | #include <stddef.h> | ||
9 | #include <stdint.h> | ||
10 | #include <string.h> | ||
11 | |||
12 | #include "kai/kai_common.h" | ||
13 | |||
14 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
15 | static const size_t kai_bl = 32; | ||
16 | |||
17 | 628 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
18 | − | KAI_ASSUME((k % 2) == 0); | |
19 | − | KAI_ASSUME(bl == kai_bl); | |
20 | 628 | return kai_roundup(k, bl) / bl; | |
21 | } | ||
22 | |||
23 | 530 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
24 | − | KAI_ASSUME(bl == kai_bl); | |
25 | 530 | return (bl / 2) + kai_num_bytes_multiplier; | |
26 | } | ||
27 | |||
28 | 98 | inline static size_t kai_rhs_stride(size_t k, size_t bl) { | |
29 | − | KAI_ASSUME(bl == kai_bl); | |
30 | − | KAI_ASSUME((k % 2) == 0); | |
31 | − | KAI_ASSUME((k % bl) == 0); | |
32 | |||
33 | 98 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
34 | 98 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
35 | |||
36 | 196 | return num_bytes_per_block * num_blocks_per_row; | |
37 | 98 | } | |
38 | |||
39 | 432 | size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(size_t k, size_t nr, size_t kr, size_t bl) { | |
40 | − | KAI_ASSUME(bl == kai_bl); | |
41 | − | KAI_ASSUME((k % 2) == 0); | |
42 | − | KAI_ASSUME((k % kr) == 0); | |
43 | − | KAI_ASSUME((k % bl) == 0); | |
44 | |||
45 | 432 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
46 | 432 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
47 | |||
48 | 864 | return nr * (num_bytes_per_block * num_blocks_per_row); | |
49 | 432 | } | |
50 | |||
51 | ✗ | size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(size_t n_idx, size_t rhs_stride) { | |
52 | ✗ | return n_idx * rhs_stride; | |
53 | } | ||
54 | |||
55 | 236 | size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( | |
56 | size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { | ||
57 | − | KAI_ASSUME(bl == kai_bl); | |
58 | − | KAI_ASSUME((k % 2) == 0); | |
59 | − | KAI_ASSUME((k % kr) == 0); | |
60 | − | KAI_ASSUME((k % bl) == 0); | |
61 | − | KAI_ASSUME((n_idx % nr) == 0); | |
62 | |||
63 | 236 | return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl); | |
64 | } | ||
65 | |||
66 | 98 | size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( | |
67 | size_t n, size_t k, size_t nr, size_t kr, size_t bl) { | ||
68 | − | KAI_ASSUME(bl == kai_bl); | |
69 | − | KAI_ASSUME((k % 2) == 0); | |
70 | − | KAI_ASSUME((k % kr) == 0); | |
71 | − | KAI_ASSUME((k % bl) == 0); | |
72 | |||
73 | 98 | const size_t num_rows = kai_roundup(n, nr) / nr; | |
74 | |||
75 | 196 | return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl); | |
76 | 98 | } | |
77 | |||
78 | 98 | void kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( | |
79 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, | ||
80 | const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params) { | ||
81 | − | KAI_ASSUME(bl == kai_bl); | |
82 | − | KAI_ASSUME(num_groups == 1); | |
83 | − | KAI_ASSUME((k % 2) == 0); | |
84 | − | KAI_ASSUME((k % kr) == 0); | |
85 | − | KAI_ASSUME((k % bl) == 0); | |
86 | − | KAI_ASSUME(bias == NULL); | |
87 | − | KAI_ASSUME(extra_bytes == 0); | |
88 | |||
89 | − | KAI_ASSUME(sr == 2); | |
90 | − | KAI_ASSUME(kr >= 1 && kr <= 16); | |
91 | − | KAI_ASSUME(rhs != NULL); | |
92 | − | KAI_ASSUME(rhs_packed != NULL); | |
93 | − | KAI_ASSUME(params != NULL); | |
94 | − | KAI_ASSUME(params->rhs_zero_point == 8); | |
95 | − | KAI_ASSUME(params->lhs_zero_point == 1); | |
96 | |||
97 | // Note: The input matrix (rhs) is expected with: | ||
98 | // "k" columns and "n" rows (NxK) | ||
99 | |||
100 | 98 | const size_t rhs_stride = kai_rhs_stride(k, bl); | |
101 | 196 | const size_t rhs_packed_stride = | |
102 | 98 | kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl); | |
103 | 98 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
104 | 98 | const size_t num_segments_per_block = bl / kr; | |
105 | 98 | const size_t num_bytes_per_segment = kr / 2; | |
106 | |||
107 |
2/2✓ Branch 0 taken 98 times.
✓ Branch 1 taken 1058 times.
|
1156 | for (size_t y = 0; y < n; y += nr) { |
108 | 1058 | const uint8_t* src_row = rhs; | |
109 | 1058 | uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; | |
110 | |||
111 |
2/2✓ Branch 0 taken 1970 times.
✓ Branch 1 taken 1058 times.
|
3028 | for (size_t x = 0; x < num_blocks_per_row; ++x) { |
112 | // Store the scales at the end of the block | ||
113 | 1970 | uint8_t* scales = (dst_row); | |
114 | |||
115 |
2/2✓ Branch 0 taken 7880 times.
✓ Branch 1 taken 1970 times.
|
9850 | for (size_t i = 0; i < nr; ++i) { |
116 |
2/2✓ Branch 0 taken 7650 times.
✓ Branch 1 taken 230 times.
|
7880 | const size_t src_row_idx = KAI_MIN(y + i, n - 1); |
117 | 7880 | memcpy( | |
118 | scales + i * kai_num_bytes_multiplier, src_row + src_row_idx * rhs_stride, | ||
119 | kai_num_bytes_multiplier); | ||
120 | 7880 | } | |
121 | 1970 | src_row += kai_num_bytes_multiplier; | |
122 | |||
123 | 1970 | dst_row += (kai_num_bytes_multiplier * nr); | |
124 | |||
125 | // Store the segments | ||
126 |
2/2✓ Branch 0 taken 4928 times.
✓ Branch 1 taken 1970 times.
|
6898 | for (size_t s = 0; s < num_segments_per_block; ++s) { |
127 |
2/2✓ Branch 0 taken 19712 times.
✓ Branch 1 taken 4928 times.
|
24640 | for (size_t i = 0; i < nr; ++i) { |
128 |
2/2✓ Branch 0 taken 19128 times.
✓ Branch 1 taken 584 times.
|
19712 | const size_t src_row_idx = KAI_MIN(y + i, n - 1); |
129 | |||
130 |
2/2✓ Branch 0 taken 7904 times.
✓ Branch 1 taken 11808 times.
|
19712 | if (num_bytes_per_segment == sizeof(uint32_t)) { |
131 | 7904 | uint32_t tmp = 0; | |
132 | 7904 | memcpy(&tmp, src_row + src_row_idx * rhs_stride, num_bytes_per_segment); | |
133 | 7904 | tmp = tmp ^ 0x88888888; | |
134 | 7904 | memcpy(dst_row + i * num_bytes_per_segment, &tmp, num_bytes_per_segment); | |
135 |
1/2✓ Branch 0 taken 11808 times.
✗ Branch 1 not taken.
|
19712 | } else if (num_bytes_per_segment == sizeof(uint64_t)) { |
136 | 11808 | uint64_t tmp = 0; | |
137 | 11808 | memcpy(&tmp, src_row + src_row_idx * rhs_stride, num_bytes_per_segment); | |
138 | 11808 | tmp = tmp ^ 0x8888888888888888ULL; | |
139 | 11808 | memcpy(dst_row + i * num_bytes_per_segment, &tmp, num_bytes_per_segment); | |
140 | 11808 | } else { | |
141 | ✗ | memcpy( | |
142 | dst_row + i * num_bytes_per_segment, src_row + src_row_idx * rhs_stride, | ||
143 | num_bytes_per_segment); | ||
144 | |||
145 | ✗ | for (size_t b = 0; b < num_bytes_per_segment; ++b) { | |
146 | ✗ | uint8_t qs = dst_row[i * num_bytes_per_segment + b]; | |
147 | // Add offset (0x88) | ||
148 | ✗ | dst_row[i * num_bytes_per_segment + b] = qs ^ 0x88; | |
149 | ✗ | } | |
150 | } | ||
151 | 19712 | } | |
152 | |||
153 | 4928 | src_row += num_bytes_per_segment; | |
154 | 4928 | dst_row += num_bytes_per_segment * nr; | |
155 | 4928 | } | |
156 | 1970 | } | |
157 | 1058 | } | |
158 | 98 | } | |
159 |