KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 98.0% 98 / 11 / 111
Functions: 85.7% 6 / 0 / 7
Branches: 100.0% 28 / 24 / 52

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
15 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
16 static const size_t kai_num_bytes_bias = sizeof(float);
17
18 58512 inline static size_t kai_k_roundedup(size_t k) {
19 // Round up k to be a multiple of 32.
20 58512 size_t kai_k_multiple_of = 32;
21 117024 return kai_roundup(k, kai_k_multiple_of);
22 58512 }
23
24 size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t nr) {
25 return nr;
26 }
27
28 4320 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride) {
29 4320 return n_idx * rhs_stride;
30 }
31
32 37776 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) {
33 37776 KAI_UNUSED(kr);
34 37776 KAI_UNUSED(sr);
35 37776 const size_t k_internal = kai_k_roundedup(k);
36
37 KAI_ASSERT((k_internal % 2) == 0);
38
39 75552 return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
40 37776 }
41
42 14688 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(
43 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
44 KAI_ASSERT((n_idx % nr) == 0);
45
46 14688 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
47 }
48
49 12720 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
50 12720 const size_t num_rows = kai_roundup(n, nr) / nr;
51
52 25440 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
53 12720 }
54
55 10368 void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(
56 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias,
57 const float* scale, void* rhs_packed, size_t extra_bytes,
58 const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) {
59 KAI_ASSERT(num_groups == 1);
60 KAI_ASSERT(extra_bytes == 0);
61 KAI_ASSERT((kr % sr) == 0);
62 KAI_ASSERT(rhs != NULL);
63 KAI_ASSERT(scale != NULL);
64 KAI_ASSERT(rhs_packed != NULL);
65 KAI_ASSERT(params != NULL);
66 KAI_ASSERT(params->lhs_zero_point == 1);
67 KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8);
68
69 10368 const uint8_t rhs_zero_point = params->rhs_zero_point;
70 10368 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
71 10368 const size_t k_internal = kai_k_roundedup(k);
72 10368 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
73 10368 const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k) / 2);
74 10368 const size_t block_length_in_bytes = kr / sr;
75 10368 const size_t k_interleaved_v = 16U;
76 10368 const size_t rhs_stride = kai_roundup(k, 2) / 2;
77
78
2/2
✓ Branch 0 taken 10368 times.
✓ Branch 1 taken 81972 times.
92340 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
79 81972 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
80
81 81972 int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2));
82
83 // Initialize to zero the RHS reduction sums
84 81972 memset(sums, 0, nr * sizeof(int32_t));
85
86
2/2
✓ Branch 0 taken 11636736 times.
✓ Branch 1 taken 81972 times.
11718708 for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) {
87 11636736 const size_t block_idx = dst_byte_idx / block_length_in_bytes;
88 11636736 const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes;
89 11636736 const size_t super_block_idx = block_idx / nr;
90 11636736 const size_t nr_idx = block_idx % nr;
91
92 23273472 const size_t k_adjustment =
93 11636736 ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v;
94 11636736 const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment;
95 11636736 const size_t k1_idx = k0_idx + k_interleaved_v;
96 11636736 const size_t n0_idx = dst_row_idx * nr + nr_idx;
97
98 // Clamp the index to avoid out-of-bound reads
99
2/2
✓ Branch 0 taken 10970496 times.
✓ Branch 1 taken 666240 times.
11636736 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
100
101 11636736 const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride;
102 11636736 const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride;
103
104 11636736 const size_t shift_right_x0 = (k0_idx % 2) * 4;
105 11636736 const size_t shift_right_x1 = (k1_idx % 2) * 4;
106
107
2/2
✓ Branch 0 taken 2877696 times.
✓ Branch 1 taken 8759040 times.
11636736 if (rhs_zero_point == 8) {
108 2877696 uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4;
109 2877696 uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4;
110
111
2/2
✓ Branch 0 taken 194832 times.
✓ Branch 1 taken 2682864 times.
2877696 if (k0_idx < k) {
112 2682864 byte0 = rhs[src_addr_byte0];
113 2682864 }
114
115
2/2
✓ Branch 0 taken 676896 times.
✓ Branch 1 taken 2200800 times.
2877696 if (k1_idx < k) {
116 2200800 byte1 = rhs[src_addr_byte1];
117 2200800 }
118
119 // The following operations where we extract the values from the bytes
120 // can be also written in the following and less efficient manner:
121 /*
122 uint8_t src_x0_lo = 0;
123 uint8_t src_x0_hi = 0;
124
125 if ((k0_idx % 2) == 0) {
126 src_x0_lo = (byte0 & 0x0F);
127 } else {
128 src_x0_lo = (byte0 >> 4);
129 }
130
131 if ((k1_idx % 2) == 0) {
132 src_x0_hi = (byte1 & 0x0F);
133 } else {
134 src_x0_hi = (byte1 >> 4);
135 }
136 */
137 2877696 const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
138 2877696 const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F;
139
140 2877696 sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point;
141
142 2877696 const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
143
144 2877696 *dst_row = dst_qs0 ^ 0x88;
145 2877696 dst_row += sizeof(uint8_t);
146 2877696 } else {
147 8759040 int8_t byte0 = 0;
148 8759040 int8_t byte1 = 0;
149
150 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
151
2/2
✓ Branch 0 taken 792240 times.
✓ Branch 1 taken 7966800 times.
8759040 if (k0_idx < k) {
152 7966800 byte0 = rhs[src_addr_byte0];
153 7966800 }
154
155
2/2
✓ Branch 0 taken 1440288 times.
✓ Branch 1 taken 7318752 times.
8759040 if (k1_idx < k) {
156 7318752 byte1 = rhs[src_addr_byte1];
157 7318752 }
158
159 // The logic behind the following operations where we extract the
160 // values from the bytes is same as unsigned
161
162 8759040 int8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
163 8759040 int8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F;
164
165 8759040 const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
166 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
167
168 8759040 *(int8_t*)dst_row = dst_qs0;
169 8759040 dst_row += sizeof(int8_t);
170
171 8759040 src_x0_lo = kai_ext_sign_i8_i4(src_x0_lo);
172 8759040 src_x0_hi = kai_ext_sign_i8_i4(src_x0_hi);
173 8759040 sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi;
174 8759040 }
175 11636736 }
176
177 // Adjust the reduction sums
178
2/2
✓ Branch 0 taken 417168 times.
✓ Branch 1 taken 81972 times.
499140 for (size_t i = 0; i < nr; ++i) {
179 417168 sums[i] = sums[i] * 16;
180 417168 dst_row += sizeof(int32_t);
181 417168 }
182
183 // Adjust the scales
184
2/2
✓ Branch 0 taken 417168 times.
✓ Branch 1 taken 81972 times.
499140 for (size_t i = 0; i < nr; ++i) {
185 // Clamp the row index to avoid out-of-bound reads
186
2/2
✓ Branch 0 taken 392040 times.
✓ Branch 1 taken 25128 times.
417168 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
187 417168 *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F;
188 417168 dst_row += sizeof(float);
189 417168 }
190
191 // Set the bias
192
2/2
✓ Branch 0 taken 60132 times.
✓ Branch 1 taken 21840 times.
81972 if (bias == NULL) {
193 21840 memset(dst_row, 0, nr * kai_num_bytes_bias);
194 21840 } else {
195
2/2
✓ Branch 0 taken 307632 times.
✓ Branch 1 taken 60132 times.
367764 for (size_t i = 0; i < nr; ++i) {
196 // Clamp the row index to avoid out-of-bound reads
197
2/2
✓ Branch 0 taken 291072 times.
✓ Branch 1 taken 16560 times.
307632 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
198 307632 ((float*)dst_row)[i] = bias[src_row_idx];
199 307632 }
200 }
201 81972 }
202 10368 }
203