KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 95.0% 95 / 12 / 112
Functions: 71.4% 5 / 0 / 7
Branches: 100.0% 28 / 26 / 54

kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
15 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
16 static const size_t kai_num_bytes_bias = sizeof(float);
17
18 31008 inline static size_t kai_k_roundedup(size_t k) {
19 // Round up k to be a multiple of 32.
20 31008 size_t kai_k_multiple_of = 32;
21 62016 return kai_roundup(k, kai_k_multiple_of);
22 31008 }
23
24 size_t kai_get_n_step_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t nr) {
25 return nr;
26 }
27
28 size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride) {
29 KAI_UNUSED(rhs_stride);
30 KAI_ASSERT((n_idx % 2) == 0);
31 return (n_idx / 2) * sizeof(int8_t);
32 }
33
34 17664 size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) {
35 17664 KAI_UNUSED(kr);
36 17664 KAI_UNUSED(sr);
37
38 17664 const size_t k_internal = kai_k_roundedup(k);
39
40 KAI_ASSERT((k_internal % 2) == 0);
41
42 35328 return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
43 17664 }
44
45 6672 size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(
46 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
47 KAI_ASSERT((n_idx % nr) == 0);
48
49 6672 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
50 }
51
52 4320 size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
53 4320 const size_t num_rows = kai_roundup(n, nr) / nr;
54
55 8640 return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
56 4320 }
57
58 6672 void kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(
59 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias,
60 const float* scale, void* rhs_packed, size_t extra_bytes,
61 const struct kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params* params) {
62 KAI_ASSERT(num_groups == 1);
63 KAI_ASSERT(extra_bytes == 0);
64 KAI_ASSERT((kr % sr) == 0);
65 KAI_ASSERT(rhs != NULL);
66 KAI_ASSERT(scale != NULL);
67 KAI_ASSERT(rhs_packed != NULL);
68 KAI_ASSERT(params != NULL);
69 KAI_ASSERT(params->lhs_zero_point == 1);
70 KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8);
71
72 6672 const uint8_t rhs_zero_point = params->rhs_zero_point;
73 6672 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
74 6672 const size_t k_internal = kai_k_roundedup(k);
75 6672 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
76 6672 const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k) / 2);
77 6672 const size_t block_length_in_bytes = kr / sr;
78 6672 const size_t k_interleaved_v = 16U;
79 6672 const size_t rhs_stride = kai_roundup(n, 2) / 2;
80
81
2/2
✓ Branch 0 taken 6672 times.
✓ Branch 1 taken 85632 times.
92304 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
82 85632 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
83
84 85632 int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2));
85
86 // Initialize to zero the RHS reduction sums
87 85632 memset(sums, 0, nr * sizeof(int32_t));
88
89
2/2
✓ Branch 0 taken 13648896 times.
✓ Branch 1 taken 85632 times.
13734528 for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) {
90 13648896 const size_t block_idx = dst_byte_idx / block_length_in_bytes;
91 13648896 const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes;
92 13648896 const size_t super_block_idx = block_idx / nr;
93 13648896 const size_t nr_idx = block_idx % nr;
94
95 27297792 const size_t k_adjustment =
96 13648896 ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v;
97 13648896 const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment;
98 13648896 const size_t k1_idx = k0_idx + k_interleaved_v;
99 13648896 const size_t n0_idx = dst_row_idx * nr + nr_idx;
100
101 // Clamp the index to avoid out-of-bound reads
102
2/2
✓ Branch 0 taken 13100544 times.
✓ Branch 1 taken 548352 times.
13648896 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
103
104 13648896 const size_t src_addr_byte0 = (n0_valid_idx / 2) + k0_idx * rhs_stride;
105 13648896 const size_t src_addr_byte1 = (n0_valid_idx / 2) + k1_idx * rhs_stride;
106
107 13648896 const size_t shift_right_x0 = (n0_idx % 2) * 4;
108
109
2/2
✓ Branch 0 taken 5523456 times.
✓ Branch 1 taken 8125440 times.
13648896 if (rhs_zero_point == 8) {
110 5523456 uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4;
111 5523456 uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4;
112
113
2/2
✓ Branch 0 taken 348480 times.
✓ Branch 1 taken 5174976 times.
5523456 if (k0_idx < k) {
114 5174976 byte0 = rhs[src_addr_byte0];
115 5174976 }
116
117
2/2
✓ Branch 0 taken 1274592 times.
✓ Branch 1 taken 4248864 times.
5523456 if (k1_idx < k) {
118 4248864 byte1 = rhs[src_addr_byte1];
119 4248864 }
120
121 // The following operations where we extract the values from the bytes
122 // can be also written in the following and less efficient manner:
123 /*
124 uint8_t src_x0_lo = 0;
125 uint8_t src_x0_hi = 0;
126
127 if ((n0_idx % 2) == 0) {
128 src_x0_lo = (byte0 & 0x0F);
129 } else {
130 src_x0_lo = (byte0 >> 4);
131 }
132
133 if ((n0_idx % 2) == 0) {
134 src_x0_hi = (byte1 & 0x0F);
135 } else {
136 src_x0_hi = (byte1 >> 4);
137 }
138 */
139
140 5523456 const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
141 5523456 const uint8_t src_x0_hi = (byte1 >> shift_right_x0) & 0x0F;
142
143 5523456 sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point;
144
145 5523456 const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
146
147 5523456 *dst_row = dst_qs0 ^ 0x88;
148 5523456 dst_row += sizeof(uint8_t);
149 5523456 } else {
150 8125440 int8_t byte0 = 0;
151 8125440 int8_t byte1 = 0;
152
153 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
154
2/2
✓ Branch 0 taken 593088 times.
✓ Branch 1 taken 7532352 times.
8125440 if (k0_idx < k) {
155 7532352 byte0 = rhs[src_addr_byte0];
156 7532352 }
157
158
2/2
✓ Branch 0 taken 1571616 times.
✓ Branch 1 taken 6553824 times.
8125440 if (k1_idx < k) {
159 6553824 byte1 = rhs[src_addr_byte1];
160 6553824 }
161
162 // The logic behind the following operations where we extract the
163 // values from the bytes is same as unsigned
164
165 8125440 int8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
166 8125440 int8_t src_x0_hi = (byte1 >> shift_right_x0) & 0x0F;
167
168 8125440 const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
169 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
170
171 8125440 *(int8_t*)dst_row = dst_qs0;
172 8125440 dst_row += sizeof(int8_t);
173
174 8125440 src_x0_lo = kai_ext_sign_i8_i4(src_x0_lo);
175 8125440 src_x0_hi = kai_ext_sign_i8_i4(src_x0_hi);
176 8125440 sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi;
177 8125440 }
178 13648896 }
179
180 // Adjust the reduction sums
181
2/2
✓ Branch 0 taken 472896 times.
✓ Branch 1 taken 85632 times.
558528 for (size_t i = 0; i < nr; ++i) {
182 472896 sums[i] = sums[i] * 16;
183 472896 dst_row += sizeof(int32_t);
184 472896 }
185
186 // Adjust the scales
187
2/2
✓ Branch 0 taken 472896 times.
✓ Branch 1 taken 85632 times.
558528 for (size_t i = 0; i < nr; ++i) {
188 // Clamp the row index to avoid out-of-bound reads
189
2/2
✓ Branch 0 taken 453768 times.
✓ Branch 1 taken 19128 times.
472896 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
190 472896 *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F;
191 472896 dst_row += sizeof(float);
192 472896 }
193
194 // Set the bias
195
2/2
✓ Branch 0 taken 80088 times.
✓ Branch 1 taken 5544 times.
85632 if (bias == NULL) {
196 5544 memset(dst_row, 0, nr * kai_num_bytes_bias);
197 5544 } else {
198
2/2
✓ Branch 0 taken 428544 times.
✓ Branch 1 taken 80088 times.
508632 for (size_t i = 0; i < nr; ++i) {
199 // Clamp the row index to avoid out-of-bound reads
200
2/2
✓ Branch 0 taken 414372 times.
✓ Branch 1 taken 14172 times.
428544 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
201 428544 ((float*)dst_row)[i] = bias[src_row_idx];
202 428544 }
203 }
204 85632 }
205 6672 }
206