KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 95.5% 84 12 100
Functions: 85.7% 6 0 7
Branches: 96.9% 31 26 58

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__aarch64__) && !defined(_M_ARM64)
7 #error This file must be compiled for AArch64.
8 #else // Architectural features check.
9
10 #include "kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h"
11
12 #include <stdint.h>
13 #include <string.h>
14
15 #include "kai/kai_common.h"
16
17 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
18 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
19 static const size_t kai_num_bytes_bias = sizeof(float);
20
21 720 inline static size_t kai_k_roundedup(size_t k) {
22 // Round up k to be a multiple of 32.
23 720 size_t kai_k_multiple_of = 32;
24 1440 return kai_roundup(k, kai_k_multiple_of);
25 720 }
26
27 size_t kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t nr) {
28 return nr;
29 }
30
31 160 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t n_idx, size_t rhs_stride) {
32 160 return n_idx * rhs_stride;
33 }
34
35 560 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(size_t k, size_t nr, size_t kr, size_t sr) {
36 560 KAI_UNUSED(kr);
37 560 KAI_UNUSED(sr);
38
39 560 const size_t k_internal = kai_k_roundedup(k);
40
41 // multiple of 2 because 2 elements in a byte
42 KAI_ASSERT((k_internal % 2) == 0);
43
44 1120 return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
45 560 }
46
47 240 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(
48 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
49 KAI_ASSERT((n_idx % nr) == 0);
50
51 240 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr);
52 }
53
54 160 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(
55 size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
56 160 const size_t num_rows = kai_roundup(n, nr) / nr;
57
58 320 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr);
59 160 }
60
61 160 void kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(
62 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias,
63 const float* scale, void* rhs_packed, size_t extra_bytes,
64 const struct kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon_params* params) {
65 160 const size_t k_internal = kai_k_roundedup(k);
66
67 KAI_ASSERT((k_internal % kr) == 0);
68 KAI_ASSERT(num_groups == 1);
69 KAI_ASSERT(extra_bytes == 0);
70 KAI_ASSERT((kr % sr) == 0);
71 KAI_ASSERT(rhs != NULL);
72 KAI_ASSERT(scale != NULL);
73 KAI_ASSERT(rhs_packed != NULL);
74 KAI_ASSERT(params != NULL);
75 KAI_ASSERT(params->lhs_zero_point == 1);
76 KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8);
77
78 // Note: The input matrix (rhs) is expected with:
79 // "k" columns and "n" rows (NxK)
80
81 160 const int32_t rhs_zero_point = params->rhs_zero_point;
82 160 const size_t rhs_stride = kai_roundup(k, 2) / 2;
83 160 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(k, nr, kr, sr);
84 160 const size_t dst_nr_block_size = nr * kr * sizeof(uint8_t) / 2;
85
86 // Iterate over n src rows in blocks of nr rows
87
2/2
✓ Branch 0 taken 160 times.
✓ Branch 1 taken 216 times.
376 for (size_t row_idx = 0; row_idx < n; row_idx += nr) {
88 216 int8_t* const dst_row = (int8_t*)rhs_packed + ((row_idx / nr) * rhs_packed_stride);
89
90 216 int32_t* const sums = (int32_t*)(dst_row + (nr * (k_internal / 2)));
91 216 float* const scaling_factors = (float*)((uint8_t*)sums + (nr * kai_num_bytes_sum_rhs));
92 // Update destination row pointer
93 216 float* const biases = (float*)((uint8_t*)scaling_factors + (nr * kai_num_bytes_multiplier_rhs));
94
95 // initialize sums to 0
96 216 memset(sums, 0, nr * kai_num_bytes_sum_rhs);
97
98 // Copy the scaling factors and bias
99 216 size_t rows_left = n - row_idx;
100 // Saving scales.
101
2/2
✓ Branch 0 taken 96 times.
✓ Branch 1 taken 120 times.
216 if (rows_left >= nr) {
102 96 memcpy(scaling_factors, &scale[row_idx], nr * kai_num_bytes_multiplier_rhs);
103 96 } else {
104 // Fill remaining values
105 120 memcpy(scaling_factors, &scale[row_idx], rows_left * kai_num_bytes_multiplier_rhs);
106 // Set leftover to 0
107 120 memset(&scaling_factors[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs);
108 }
109
1/2
✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
216 if (bias == NULL) {
110 // Set bias to 0
111 memset(biases, 0, nr * kai_num_bytes_bias);
112 } else {
113
2/2
✓ Branch 0 taken 96 times.
✓ Branch 1 taken 120 times.
216 if (rows_left >= nr) {
114 96 memcpy(biases, &bias[row_idx], nr * kai_num_bytes_bias);
115 96 } else {
116 // Fill remaining values
117 120 memcpy(biases, &bias[row_idx], rows_left * kai_num_bytes_bias);
118 // Set leftover to 0
119 120 memset(&biases[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias);
120 }
121 }
122 // Iterate over rows in the nr row block
123
2/2
✓ Branch 0 taken 13824 times.
✓ Branch 1 taken 216 times.
14040 for (size_t nr_block_idx = 0; nr_block_idx < nr; ++nr_block_idx) {
124 13824 const uint8_t* const src_row = rhs + ((row_idx + nr_block_idx) * rhs_stride);
125 // Go to the first kr block for this row in the nr block
126 13824 int8_t* dst_kr_block = dst_row + (nr_block_idx * kr / 2);
127
128 13824 int32_t sum = 0;
129
130 // Iterate over k src columns in blocks of kr columns
131
2/2
✓ Branch 0 taken 6912 times.
✓ Branch 1 taken 6912 times.
13824 if (rhs_zero_point == 8) {
132
2/2
✓ Branch 0 taken 108544 times.
✓ Branch 1 taken 6912 times.
115456 for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) {
133 // Iterate over columns in the kr block
134 // Kr checked to be multiple of 2 (because 2 values per byte)
135
2/2
✓ Branch 0 taken 217088 times.
✓ Branch 1 taken 108544 times.
325632 for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) {
136 // We pad dst with 0s if the rounded k or n values have been exceeded
137
4/4
✓ Branch 0 taken 145088 times.
✓ Branch 1 taken 72000 times.
✓ Branch 2 taken 22368 times.
✓ Branch 3 taken 122720 times.
217088 if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) {
138 94368 dst_kr_block[kr_block_idx / 2] = 0;
139 94368 continue;
140 }
141
142 // Load the 2 u4 values from source
143 122720 const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2];
144
145 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
146 // extract i8 values from the 2 u4 values
147 122720 const int8_t first_value = (dst_byte & 0xF) - rhs_zero_point;
148 245440 const int8_t second_value =
149
2/2
✓ Branch 0 taken 1902 times.
✓ Branch 1 taken 120818 times.
122720 col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point;
150
151 // Add the i4 value to the row sum
152 122720 sum += (int32_t)first_value + (int32_t)second_value;
153
154 // Truncate i8 to i4 and write to dst
155 122720 const uint8_t hi = second_value & 0x0F;
156 122720 const uint8_t lo = first_value & 0x0F;
157 122720 dst_kr_block[kr_block_idx / 2] = (hi << 4) | lo;
158 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
159 122720 }
160
161 // Go to the next kr block for this row in the nr rows
162 108544 dst_kr_block += dst_nr_block_size;
163 108544 }
164 6912 } else {
165
2/2
✓ Branch 0 taken 108544 times.
✓ Branch 1 taken 6912 times.
115456 for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) {
166 // Iterate over columns in the kr block
167 // Kr checked to be multiple of 2 (because 2 values per byte)
168
2/2
✓ Branch 0 taken 217088 times.
✓ Branch 1 taken 108544 times.
325632 for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) {
169 // We pad dst with 0s if the rounded k or n values have been
170 // exceeded
171
4/4
✓ Branch 0 taken 145088 times.
✓ Branch 1 taken 72000 times.
✓ Branch 2 taken 22368 times.
✓ Branch 3 taken 122720 times.
217088 if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) {
172 94368 dst_kr_block[kr_block_idx / 2] = 0;
173 94368 continue;
174 }
175
176 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
177 // Load the 2 u4 values from source
178 122720 const int8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2];
179
180 // extract i8 values from the 2 u4 values, shift first value
181 // back and forth to get the sign right.
182 122720 const int8_t first_value = kai_ext_sign_i8_i4(dst_byte & 0xF);
183 245440 const int8_t second_value =
184
2/2
✓ Branch 0 taken 1902 times.
✓ Branch 1 taken 120818 times.
122720 col_idx + kr_block_idx + 1 >= k ? 0 : kai_ext_sign_i8_i4((dst_byte >> 4) & 0xF);
185
186 // Add the i4 value to the row sum
187 122720 sum += (int32_t)first_value + (int32_t)second_value;
188
189 // Truncate i8 to i4 and write to dst
190 122720 const uint8_t hi = second_value & 0x0F;
191 122720 const uint8_t lo = first_value & 0x0F;
192 122720 dst_kr_block[kr_block_idx / 2] = (hi << 4) | lo;
193 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
194 122720 }
195
196 // Go to the next kr block for this row in the nr rows
197 108544 dst_kr_block += dst_nr_block_size;
198 108544 }
199 }
200
201 // save sum
202 13824 sums[nr_block_idx] = sum;
203 13824 }
204 216 }
205 160 }
206 #endif // Architectural features check.
207