KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 95.0% 95 12 112
Functions: 71.4% 5 0 7
Branches: 100.0% 28 26 54

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
15 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
16 static const size_t kai_num_bytes_bias = sizeof(float);
17
18 5088 inline static size_t kai_k_roundedup(size_t k) {
19 // Round up k to be a multiple of 32.
20 5088 size_t kai_k_multiple_of = 32;
21 10176 return kai_roundup(k, kai_k_multiple_of);
22 5088 }
23
24 size_t kai_get_n_step_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t nr) {
25 return nr;
26 }
27
28 size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride) {
29 KAI_UNUSED(rhs_stride);
30 KAI_ASSERT((n_idx % 2) == 0);
31 return (n_idx / 2) * sizeof(int8_t);
32 }
33
34 2904 size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) {
35 2904 KAI_UNUSED(kr);
36 2904 KAI_UNUSED(sr);
37
38 2904 const size_t k_internal = kai_k_roundedup(k);
39
40 KAI_ASSERT((k_internal % 2) == 0);
41
42 5808 return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
43 2904 }
44
45 1092 size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(
46 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
47 KAI_ASSERT((n_idx % nr) == 0);
48
49 1092 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
50 }
51
52 720 size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
53 720 const size_t num_rows = kai_roundup(n, nr) / nr;
54
55 1440 return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
56 720 }
57
58 1092 void kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(
59 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias,
60 const float* scale, void* rhs_packed, size_t extra_bytes,
61 const struct kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params* params) {
62 KAI_ASSERT(num_groups == 1);
63 KAI_ASSERT(extra_bytes == 0);
64 KAI_ASSERT((kr % sr) == 0);
65 KAI_ASSERT(rhs != NULL);
66 KAI_ASSERT(scale != NULL);
67 KAI_ASSERT(rhs_packed != NULL);
68 KAI_ASSERT(params != NULL);
69 KAI_ASSERT(params->lhs_zero_point == 1);
70 KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8);
71
72 1092 const uint8_t rhs_zero_point = params->rhs_zero_point;
73 1092 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
74 1092 const size_t k_internal = kai_k_roundedup(k);
75 1092 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
76 1092 const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k) / 2);
77 1092 const size_t block_length_in_bytes = kr / sr;
78 1092 const size_t k_interleaved_v = 16U;
79 1092 const size_t rhs_stride = kai_roundup(n, 2) / 2;
80
81
2/2
✓ Branch 0 taken 1092 times.
✓ Branch 1 taken 14252 times.
15344 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
82 14252 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
83
84 14252 int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2));
85
86 // Initialize to zero the RHS reduction sums
87 14252 memset(sums, 0, nr * sizeof(int32_t));
88
89
2/2
✓ Branch 0 taken 2272256 times.
✓ Branch 1 taken 14252 times.
2286508 for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) {
90 2272256 const size_t block_idx = dst_byte_idx / block_length_in_bytes;
91 2272256 const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes;
92 2272256 const size_t super_block_idx = block_idx / nr;
93 2272256 const size_t nr_idx = block_idx % nr;
94
95 4544512 const size_t k_adjustment =
96 2272256 ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v;
97 2272256 const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment;
98 2272256 const size_t k1_idx = k0_idx + k_interleaved_v;
99 2272256 const size_t n0_idx = dst_row_idx * nr + nr_idx;
100
101 // Clamp the index to avoid out-of-bound reads
102
2/2
✓ Branch 0 taken 2182720 times.
✓ Branch 1 taken 89536 times.
2272256 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
103
104 2272256 const size_t src_addr_byte0 = (n0_valid_idx / 2) + k0_idx * rhs_stride;
105 2272256 const size_t src_addr_byte1 = (n0_valid_idx / 2) + k1_idx * rhs_stride;
106
107 2272256 const size_t shift_right_x0 = (n0_idx % 2) * 4;
108
109
2/2
✓ Branch 0 taken 920576 times.
✓ Branch 1 taken 1351680 times.
2272256 if (rhs_zero_point == 8) {
110 920576 uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4;
111 920576 uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4;
112
113
2/2
✓ Branch 0 taken 58080 times.
✓ Branch 1 taken 862496 times.
920576 if (k0_idx < k) {
114 862496 byte0 = rhs[src_addr_byte0];
115 862496 }
116
117
2/2
✓ Branch 0 taken 212432 times.
✓ Branch 1 taken 708144 times.
920576 if (k1_idx < k) {
118 708144 byte1 = rhs[src_addr_byte1];
119 708144 }
120
121 // The following operations where we extract the values from the bytes
122 // can be also written in the following and less efficient manner:
123 /*
124 uint8_t src_x0_lo = 0;
125 uint8_t src_x0_hi = 0;
126
127 if ((n0_idx % 2) == 0) {
128 src_x0_lo = (byte0 & 0x0F);
129 } else {
130 src_x0_lo = (byte0 >> 4);
131 }
132
133 if ((n0_idx % 2) == 0) {
134 src_x0_hi = (byte1 & 0x0F);
135 } else {
136 src_x0_hi = (byte1 >> 4);
137 }
138 */
139
140 920576 const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
141 920576 const uint8_t src_x0_hi = (byte1 >> shift_right_x0) & 0x0F;
142
143 920576 sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point;
144
145 920576 const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
146
147 920576 *dst_row = dst_qs0 ^ 0x88;
148 920576 dst_row += sizeof(uint8_t);
149 920576 } else {
150 1351680 int8_t byte0 = 0;
151 1351680 int8_t byte1 = 0;
152
153 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
154
2/2
✓ Branch 0 taken 98848 times.
✓ Branch 1 taken 1252832 times.
1351680 if (k0_idx < k) {
155 1252832 byte0 = rhs[src_addr_byte0];
156 1252832 }
157
158
2/2
✓ Branch 0 taken 261936 times.
✓ Branch 1 taken 1089744 times.
1351680 if (k1_idx < k) {
159 1089744 byte1 = rhs[src_addr_byte1];
160 1089744 }
161
162 // The logic behind the following operations where we extract the
163 // values from the bytes is same as unsigned
164
165 1351680 int8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
166 1351680 int8_t src_x0_hi = (byte1 >> shift_right_x0) & 0x0F;
167
168 1351680 const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
169 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
170
171 1351680 *(int8_t*)dst_row = dst_qs0;
172 1351680 dst_row += sizeof(int8_t);
173
174 1351680 src_x0_lo = kai_ext_sign_i8_i4(src_x0_lo);
175 1351680 src_x0_hi = kai_ext_sign_i8_i4(src_x0_hi);
176 1351680 sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi;
177 1351680 }
178 2272256 }
179
180 // Adjust the reduction sums
181
2/2
✓ Branch 0 taken 78656 times.
✓ Branch 1 taken 14252 times.
92908 for (size_t i = 0; i < nr; ++i) {
182 78656 sums[i] = sums[i] * 16;
183 78656 dst_row += sizeof(int32_t);
184 78656 }
185
186 // Adjust the scales
187
2/2
✓ Branch 0 taken 78656 times.
✓ Branch 1 taken 14252 times.
92908 for (size_t i = 0; i < nr; ++i) {
188 // Clamp the row index to avoid out-of-bound reads
189
2/2
✓ Branch 0 taken 75584 times.
✓ Branch 1 taken 3072 times.
78656 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
190 78656 *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F;
191 78656 dst_row += sizeof(float);
192 78656 }
193
194 // Set the bias
195
2/2
✓ Branch 0 taken 914 times.
✓ Branch 1 taken 13338 times.
14252 if (bias == NULL) {
196 914 memset(dst_row, 0, nr * kai_num_bytes_bias);
197 914 } else {
198
2/2
✓ Branch 0 taken 71344 times.
✓ Branch 1 taken 13338 times.
84682 for (size_t i = 0; i < nr; ++i) {
199 // Clamp the row index to avoid out-of-bound reads
200
2/2
✓ Branch 0 taken 69040 times.
✓ Branch 1 taken 2304 times.
71344 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
201 71344 ((float*)dst_row)[i] = bias[src_row_idx];
202 71344 }
203 }
204 14252 }
205 1092 }
206