KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.0% 98 11 111
Functions: 85.7% 6 0 7
Branches: 100.0% 28 24 52

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
15 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
16 static const size_t kai_num_bytes_bias = sizeof(float);
17
18 9352 inline static size_t kai_k_roundedup(size_t k) {
19 // Round up k to be a multiple of 32.
20 9352 size_t kai_k_multiple_of = 32;
21 18704 return kai_roundup(k, kai_k_multiple_of);
22 9352 }
23
24 size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t nr) {
25 return nr;
26 }
27
28 720 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride) {
29 720 return n_idx * rhs_stride;
30 }
31
32 6048 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) {
33 6048 KAI_UNUSED(kr);
34 6048 KAI_UNUSED(sr);
35 6048 const size_t k_internal = kai_k_roundedup(k);
36
37 KAI_ASSERT((k_internal % 2) == 0);
38
39 12096 return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
40 6048 }
41
42 2372 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(
43 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
44 KAI_ASSERT((n_idx % nr) == 0);
45
46 2372 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
47 }
48
49 2024 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
50 2024 const size_t num_rows = kai_roundup(n, nr) / nr;
51
52 4048 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
53 2024 }
54
55 1652 void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(
56 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias,
57 const float* scale, void* rhs_packed, size_t extra_bytes,
58 const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) {
59 KAI_ASSERT(num_groups == 1);
60 KAI_ASSERT(extra_bytes == 0);
61 KAI_ASSERT((kr % sr) == 0);
62 KAI_ASSERT(rhs != NULL);
63 KAI_ASSERT(scale != NULL);
64 KAI_ASSERT(rhs_packed != NULL);
65 KAI_ASSERT(params != NULL);
66 KAI_ASSERT(params->lhs_zero_point == 1);
67 KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8);
68
69 1652 const uint8_t rhs_zero_point = params->rhs_zero_point;
70 1652 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr);
71 1652 const size_t k_internal = kai_k_roundedup(k);
72 1652 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
73 1652 const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k) / 2);
74 1652 const size_t block_length_in_bytes = kr / sr;
75 1652 const size_t k_interleaved_v = 16U;
76 1652 const size_t rhs_stride = kai_roundup(k, 2) / 2;
77
78
2/2
✓ Branch 0 taken 1652 times.
✓ Branch 1 taken 13230 times.
14882 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
79 13230 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
80
81 13230 int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2));
82
83 // Initialize to zero the RHS reduction sums
84 13230 memset(sums, 0, nr * sizeof(int32_t));
85
86
2/2
✓ Branch 0 taken 1896960 times.
✓ Branch 1 taken 13230 times.
1910190 for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) {
87 1896960 const size_t block_idx = dst_byte_idx / block_length_in_bytes;
88 1896960 const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes;
89 1896960 const size_t super_block_idx = block_idx / nr;
90 1896960 const size_t nr_idx = block_idx % nr;
91
92 3793920 const size_t k_adjustment =
93 1896960 ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v;
94 1896960 const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment;
95 1896960 const size_t k1_idx = k0_idx + k_interleaved_v;
96 1896960 const size_t n0_idx = dst_row_idx * nr + nr_idx;
97
98 // Clamp the index to avoid out-of-bound reads
99
2/2
✓ Branch 0 taken 1790336 times.
✓ Branch 1 taken 106624 times.
1896960 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
100
101 1896960 const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride;
102 1896960 const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride;
103
104 1896960 const size_t shift_right_x0 = (k0_idx % 2) * 4;
105 1896960 const size_t shift_right_x1 = (k1_idx % 2) * 4;
106
107
2/2
✓ Branch 0 taken 476288 times.
✓ Branch 1 taken 1420672 times.
1896960 if (rhs_zero_point == 8) {
108 476288 uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4;
109 476288 uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4;
110
111
2/2
✓ Branch 0 taken 32316 times.
✓ Branch 1 taken 443972 times.
476288 if (k0_idx < k) {
112 443972 byte0 = rhs[src_addr_byte0];
113 443972 }
114
115
2/2
✓ Branch 0 taken 111516 times.
✓ Branch 1 taken 364772 times.
476288 if (k1_idx < k) {
116 364772 byte1 = rhs[src_addr_byte1];
117 364772 }
118
119 // The following operations where we extract the values from the bytes
120 // can be also written in the following and less efficient manner:
121 /*
122 uint8_t src_x0_lo = 0;
123 uint8_t src_x0_hi = 0;
124
125 if ((k0_idx % 2) == 0) {
126 src_x0_lo = (byte0 & 0x0F);
127 } else {
128 src_x0_lo = (byte0 >> 4);
129 }
130
131 if ((k1_idx % 2) == 0) {
132 src_x0_hi = (byte1 & 0x0F);
133 } else {
134 src_x0_hi = (byte1 >> 4);
135 }
136 */
137 476288 const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
138 476288 const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F;
139
140 476288 sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point;
141
142 476288 const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
143
144 476288 *dst_row = dst_qs0 ^ 0x88;
145 476288 dst_row += sizeof(uint8_t);
146 476288 } else {
147 1420672 int8_t byte0 = 0;
148 1420672 int8_t byte1 = 0;
149
150 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
151
2/2
✓ Branch 0 taken 127564 times.
✓ Branch 1 taken 1293108 times.
1420672 if (k0_idx < k) {
152 1293108 byte0 = rhs[src_addr_byte0];
153 1293108 }
154
155
2/2
✓ Branch 0 taken 234076 times.
✓ Branch 1 taken 1186596 times.
1420672 if (k1_idx < k) {
156 1186596 byte1 = rhs[src_addr_byte1];
157 1186596 }
158
159 // The logic behind the following operations where we extract the
160 // values from the bytes is same as unsigned
161
162 1420672 int8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
163 1420672 int8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F;
164
165 1420672 const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
166 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
167
168 1420672 *(int8_t*)dst_row = dst_qs0;
169 1420672 dst_row += sizeof(int8_t);
170
171 1420672 src_x0_lo = kai_ext_sign_i8_i4(src_x0_lo);
172 1420672 src_x0_hi = kai_ext_sign_i8_i4(src_x0_hi);
173 1420672 sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi;
174 1420672 }
175 1896960 }
176
177 // Adjust the reduction sums
178
2/2
✓ Branch 0 taken 67656 times.
✓ Branch 1 taken 13230 times.
80886 for (size_t i = 0; i < nr; ++i) {
179 67656 sums[i] = sums[i] * 16;
180 67656 dst_row += sizeof(int32_t);
181 67656 }
182
183 // Adjust the scales
184
2/2
✓ Branch 0 taken 67656 times.
✓ Branch 1 taken 13230 times.
80886 for (size_t i = 0; i < nr; ++i) {
185 // Clamp the row index to avoid out-of-bound reads
186
2/2
✓ Branch 0 taken 63712 times.
✓ Branch 1 taken 3944 times.
67656 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
187 67656 *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F;
188 67656 dst_row += sizeof(float);
189 67656 }
190
191 // Set the bias
192
2/2
✓ Branch 0 taken 3442 times.
✓ Branch 1 taken 9788 times.
13230 if (bias == NULL) {
193 3442 memset(dst_row, 0, nr * kai_num_bytes_bias);
194 3442 } else {
195
2/2
✓ Branch 0 taken 50232 times.
✓ Branch 1 taken 9788 times.
60020 for (size_t i = 0; i < nr; ++i) {
196 // Clamp the row index to avoid out-of-bound reads
197
2/2
✓ Branch 0 taken 47594 times.
✓ Branch 1 taken 2638 times.
50232 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
198 50232 ((float*)dst_row)[i] = bias[src_row_idx];
199 50232 }
200 }
201 13230 }
202 1652 }
203