KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 88 / 12 / 100
Functions: 100.0% 7 / 0 / 7
Branches: 100.0% 32 / 26 / 58

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__aarch64__) && !defined(_M_ARM64)
7 #error This file must be compiled for AArch64.
8 #else // Architectural features check.
9
10 #include "kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h"
11
12 #include <stdint.h>
13 #include <string.h>
14
15 #include "kai/kai_common.h"
16
17 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
18 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
19 static const size_t kai_num_bytes_bias = sizeof(float);
20
21 4560 inline static size_t kai_k_roundedup(size_t k) {
22 // Round up k to be a multiple of 32.
23 4560 size_t kai_k_multiple_of = 32;
24 9120 return kai_roundup(k, kai_k_multiple_of);
25 4560 }
26
27 600 size_t kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t nr) {
28 600 return nr;
29 }
30
31 1080 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t n_idx, size_t rhs_stride) {
32 1080 return n_idx * rhs_stride;
33 }
34
35 3480 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t k, size_t nr, size_t kr, size_t sr) {
36 3480 KAI_UNUSED(kr);
37 3480 KAI_UNUSED(sr);
38
39 3480 const size_t k_internal = kai_k_roundedup(k);
40
41 // multiple of 2 because 2 elements in a byte
42 KAI_ASSERT((k_internal % 2) == 0);
43
44 6960 return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
45 3480 }
46
47 1320 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(
48 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
49 KAI_ASSERT((n_idx % nr) == 0);
50
51 1320 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr);
52 }
53
54 1080 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(
55 size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
56 1080 const size_t num_rows = kai_roundup(n, nr) / nr;
57
58 2160 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr);
59 1080 }
60
61 1080 void kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(
62 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias,
63 const float* scale, void* rhs_packed, size_t extra_bytes,
64 const struct kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon_params* params) {
65 1080 const size_t k_internal = kai_k_roundedup(k);
66
67 KAI_ASSERT((k_internal % kr) == 0);
68 KAI_ASSERT(num_groups == 1);
69 KAI_ASSERT(extra_bytes == 0);
70 KAI_ASSERT((kr % sr) == 0);
71 KAI_ASSERT(rhs != NULL);
72 KAI_ASSERT(scale != NULL);
73 KAI_ASSERT(rhs_packed != NULL);
74 KAI_ASSERT(params != NULL);
75 KAI_ASSERT(params->lhs_zero_point == 1);
76 KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8);
77
78 // Note: The input matrix (rhs) is expected with:
79 // "k" columns and "n" rows (NxK)
80
81 1080 const int32_t rhs_zero_point = params->rhs_zero_point;
82 1080 const size_t rhs_stride = kai_roundup(k, 2) / 2;
83 1080 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr);
84 1080 const size_t dst_nr_block_size = nr * kr * sizeof(uint8_t) / 2;
85
86 // Iterate over n src rows in blocks of nr rows
87
2/2
✓ Branch 0 taken 1080 times.
✓ Branch 1 taken 1474 times.
2554 for (size_t row_idx = 0; row_idx < n; row_idx += nr) {
88 1474 int8_t* const dst_row = (int8_t*)rhs_packed + ((row_idx / nr) * rhs_packed_stride);
89
90 1474 int32_t* const sums = (int32_t*)(dst_row + (nr * (k_internal / 2)));
91 1474 float* const scaling_factors = (float*)((uint8_t*)sums + (nr * kai_num_bytes_sum_rhs));
92 // Update destination row pointer
93 1474 float* const biases = (float*)((uint8_t*)scaling_factors + (nr * kai_num_bytes_multiplier_rhs));
94
95 // initialize sums to 0
96 1474 memset(sums, 0, nr * kai_num_bytes_sum_rhs);
97
98 // Copy the scaling factors and bias
99 1474 size_t rows_left = n - row_idx;
100 // Saving scales.
101
2/2
✓ Branch 0 taken 645 times.
✓ Branch 1 taken 829 times.
1474 if (rows_left >= nr) {
102 645 memcpy(scaling_factors, &scale[row_idx], nr * kai_num_bytes_multiplier_rhs);
103 645 } else {
104 // Fill remaining values
105 829 memcpy(scaling_factors, &scale[row_idx], rows_left * kai_num_bytes_multiplier_rhs);
106 // Set leftover to 0
107 829 memset(&scaling_factors[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs);
108 }
109
2/2
✓ Branch 0 taken 1245 times.
✓ Branch 1 taken 229 times.
1474 if (bias == NULL) {
110 // Set bias to 0
111 229 memset(biases, 0, nr * kai_num_bytes_bias);
112 229 } else {
113
2/2
✓ Branch 0 taken 541 times.
✓ Branch 1 taken 704 times.
1245 if (rows_left >= nr) {
114 541 memcpy(biases, &bias[row_idx], nr * kai_num_bytes_bias);
115 541 } else {
116 // Fill remaining values
117 704 memcpy(biases, &bias[row_idx], rows_left * kai_num_bytes_bias);
118 // Set leftover to 0
119 704 memset(&biases[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias);
120 }
121 }
122 // Iterate over rows in the nr row block
123
2/2
✓ Branch 0 taken 94336 times.
✓ Branch 1 taken 1474 times.
95810 for (size_t nr_block_idx = 0; nr_block_idx < nr; ++nr_block_idx) {
124 94336 const uint8_t* const src_row = rhs + ((row_idx + nr_block_idx) * rhs_stride);
125 // Go to the first kr block for this row in the nr block
126 94336 int8_t* dst_kr_block = dst_row + (nr_block_idx * kr / 2);
127
128 94336 int32_t sum = 0;
129
130 // Iterate over k src columns in blocks of kr columns
131
2/2
✓ Branch 0 taken 73600 times.
✓ Branch 1 taken 20736 times.
94336 if (rhs_zero_point == 8) {
132
2/2
✓ Branch 0 taken 1595904 times.
✓ Branch 1 taken 73600 times.
1669504 for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) {
133 // Iterate over columns in the kr block
134 // Kr checked to be multiple of 2 (because 2 values per byte)
135
2/2
✓ Branch 0 taken 3191808 times.
✓ Branch 1 taken 1595904 times.
4787712 for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) {
136 // We pad dst with 0s if the rounded k or n values have been exceeded
137
4/4
✓ Branch 0 taken 2143984 times.
✓ Branch 1 taken 1047824 times.
✓ Branch 2 taken 342311 times.
✓ Branch 3 taken 1801673 times.
3191808 if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) {
138 1390135 dst_kr_block[kr_block_idx / 2] = 0;
139 1390135 continue;
140 }
141
142 // Load the 2 u4 values from source
143 1801673 const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2];
144
145 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
146 // extract i8 values from the 2 u4 values
147 1801673 const int8_t first_value = (dst_byte & 0xF) - rhs_zero_point;
148 3603346 const int8_t second_value =
149
2/2
✓ Branch 0 taken 23674 times.
✓ Branch 1 taken 1777999 times.
1801673 col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point;
150
151 // Add the i4 value to the row sum
152 1801673 sum += (int32_t)first_value + (int32_t)second_value;
153
154 // Truncate i8 to i4 and write to dst
155 1801673 const uint8_t hi = second_value & 0x0F;
156 1801673 const uint8_t lo = first_value & 0x0F;
157 1801673 dst_kr_block[kr_block_idx / 2] = (hi << 4) | lo;
158 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
159 1801673 }
160
161 // Go to the next kr block for this row in the nr rows
162 1595904 dst_kr_block += dst_nr_block_size;
163 1595904 }
164 73600 } else {
165
2/2
✓ Branch 0 taken 325632 times.
✓ Branch 1 taken 20736 times.
346368 for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) {
166 // Iterate over columns in the kr block
167 // Kr checked to be multiple of 2 (because 2 values per byte)
168
2/2
✓ Branch 0 taken 651264 times.
✓ Branch 1 taken 325632 times.
976896 for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) {
169 // We pad dst with 0s if the rounded k or n values have been
170 // exceeded
171
4/4
✓ Branch 0 taken 435264 times.
✓ Branch 1 taken 216000 times.
✓ Branch 2 taken 67104 times.
✓ Branch 3 taken 368160 times.
651264 if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) {
172 283104 dst_kr_block[kr_block_idx / 2] = 0;
173 283104 continue;
174 }
175
176 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
177 // Load the 2 u4 values from source
178 368160 const int8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2];
179
180 // extract i8 values from the 2 u4 values, shift first value
181 // back and forth to get the sign right.
182 368160 const int8_t first_value = kai_ext_sign_i8_i4(dst_byte & 0xF);
183 736320 const int8_t second_value =
184
2/2
✓ Branch 0 taken 5706 times.
✓ Branch 1 taken 362454 times.
368160 col_idx + kr_block_idx + 1 >= k ? 0 : kai_ext_sign_i8_i4((dst_byte >> 4) & 0xF);
185
186 // Add the i4 value to the row sum
187 368160 sum += (int32_t)first_value + (int32_t)second_value;
188
189 // Truncate i8 to i4 and write to dst
190 368160 const uint8_t hi = second_value & 0x0F;
191 368160 const uint8_t lo = first_value & 0x0F;
192 368160 dst_kr_block[kr_block_idx / 2] = (hi << 4) | lo;
193 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
194 368160 }
195
196 // Go to the next kr block for this row in the nr rows
197 325632 dst_kr_block += dst_nr_block_size;
198 325632 }
199 }
200
201 // save sum
202 94336 sums[nr_block_idx] = sum;
203 94336 }
204 1474 }
205 1080 }
206 #endif // Architectural features check.
207