KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.7% 84 / 29 / 115
Functions: 85.7% 6 / 0 / 7
Branches: 90.9% 20 / 56 / 78

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10 #include "kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.h"
11
12 #include <stdint.h>
13 #include <string.h>
14
15 #include "kai/kai_common.h"
16
17 static const size_t kai_num_bytes_offset_rhs = sizeof(float);
18 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
19 static const size_t kai_num_bytes_bias = sizeof(float);
20 static const size_t kai_bl_multiple_of = 32;
21
22 51408 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
23 KAI_ASSUME((k % 2) == 0);
24 KAI_ASSUME((k % bl) == 0);
25 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
26 51408 return kai_roundup(k, bl) / bl;
27 }
28
29 68544 inline static size_t kai_get_num_bytes_per_block(size_t bl) {
30 68544 return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs;
31 }
32
33 51408 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) {
34 KAI_ASSUME((k % 2) == 0);
35 KAI_ASSUME((k % kr) == 0);
36 KAI_ASSUME((k % bl) == 0);
37 KAI_ASSUME((bl % kr) == 0);
38 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
39 51408 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
40 51408 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl);
41 102816 return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias);
42 51408 }
43
44 size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) {
45 return n_idx * rhs_stride;
46 }
47
48 17136 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(
49 size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) {
50 KAI_ASSUME((k % 2) == 0);
51 KAI_ASSUME((k % kr) == 0);
52 KAI_ASSUME((k % bl) == 0);
53 KAI_ASSUME((n_idx % nr) == 0);
54 17136 KAI_UNUSED(kr);
55 17136 return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, bl);
56 }
57
58 17136 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(
59 size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
60 KAI_ASSUME((k % 2) == 0);
61 KAI_ASSUME((k % kr) == 0);
62 KAI_ASSUME((k % bl) == 0);
63 17136 KAI_UNUSED(kr);
64 17136 const size_t num_rows = kai_roundup(n, nr) / nr;
65 34272 return num_rows * kai_get_rhs_packed_stride(k, nr, kr, bl);
66 17136 }
67
68 17136 void kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(
69 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
70 const void* zero, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes,
71 const struct kai_rhs_pack_nxk_qai4c32p_params* params) {
72 KAI_ASSUME(num_groups == 1);
73 KAI_ASSUME((k % 2) == 0);
74 KAI_ASSUME((k % kr) == 0);
75 KAI_ASSUME((k % bl) == 0);
76 KAI_ASSUME((bl % 32) == 0);
77 KAI_ASSUME(extra_bytes == 0);
78
79 KAI_ASSUME(sr == 2);
80 KAI_ASSUME(kr >= 1 && kr <= 16);
81 KAI_ASSUME(rhs != NULL);
82 KAI_ASSUME(zero != NULL);
83 KAI_ASSUME(rhs_packed != NULL);
84 KAI_ASSUME(params != NULL);
85 KAI_ASSUME(params->rhs_zero_point == 8);
86 KAI_ASSUME(params->lhs_zero_point == 1);
87
88 // Note: The input matrix (rhs) is expected with:
89 // "k" columns and "n" rows (NxK)
90
91 17136 const size_t num_blocks_per_row = k / bl;
92 17136 const size_t rhs_stride = k;
93 17136 const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, bl);
94
95 17136 const size_t dst_packed_block_size = kai_get_num_bytes_per_block(bl) * nr;
96 17136 const size_t dst_block_data_size = (bl / 2) * nr;
97 17136 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
98 17136 const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size;
99 17136 const size_t k_block_length_in_bytes = kr / sr;
100 17136 const size_t k_interleaved_v = 16U;
101
102 17136 const size_t rhs_zero_point = params->rhs_zero_point;
103
104
2/2
✓ Branch 0 taken 17136 times.
✓ Branch 1 taken 251328 times.
268464 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
105 251328 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
106 251328 float* dst_row_bias = (float*)(dst_row + dst_bias_offset);
107
108
2/2
✓ Branch 0 taken 437472 times.
✓ Branch 1 taken 251328 times.
688800 for (size_t block_idx = 0; block_idx < num_blocks_per_row; block_idx++) {
109 437472 uint8_t* block_dst_row = dst_row + block_idx * dst_packed_block_size;
110 437472 float* block_dst_zp = (float*)(block_dst_row + dst_block_data_size);
111 437472 float* block_dst_scale = block_dst_zp + nr;
112
113
2/2
✓ Branch 0 taken 35782656 times.
✓ Branch 1 taken 437472 times.
36220128 for (size_t block_byte_idx = 0; block_byte_idx < dst_block_data_size; ++block_byte_idx) {
114 35782656 const size_t dst_byte_idx = block_byte_idx;
115 35782656 const size_t k_block_idx = dst_byte_idx / k_block_length_in_bytes;
116 35782656 const size_t k_block_byte_idx = dst_byte_idx % k_block_length_in_bytes;
117 35782656 const size_t super_k_block_idx = k_block_idx / nr;
118 35782656 const size_t nr_idx = k_block_idx % nr;
119
120 71565312 const size_t k_adjustment =
121 35782656 ((k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes) / k_interleaved_v) *
122 k_interleaved_v;
123 35782656 const size_t k0_idx = k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes + k_adjustment;
124 35782656 const size_t k1_idx = k0_idx + k_interleaved_v;
125 35782656 const size_t n0_idx = dst_row_idx * nr + nr_idx;
126
127 // Clamp the index to avoid out-of-bound reads
128
2/2
✓ Branch 0 taken 35024640 times.
✓ Branch 1 taken 758016 times.
35782656 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
129
130 35782656 const size_t src_addr_byte0 = (k0_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2;
131 35782656 const size_t src_addr_byte1 = (k1_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2;
132
133 35782656 uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4;
134 35782656 uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4;
135
136
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 35782656 times.
35782656 if (k0_idx < k) {
137 35782656 byte0 = rhs[src_addr_byte0];
138 35782656 }
139
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 35782656 times.
35782656 if (k1_idx < k) {
140 35782656 byte1 = rhs[src_addr_byte1];
141 35782656 }
142
143 35782656 const size_t shift_right_x0 = (k0_idx % 2 == 0) ? 4 : 0;
144 35782656 const size_t shift_right_x1 = (k1_idx % 2 == 0) ? 4 : 0;
145
146 35782656 const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
147 35782656 const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F;
148
149 35782656 const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
150
151 35782656 *block_dst_row = dst_qs0 ^ 0x88;
152 35782656 block_dst_row += sizeof(uint8_t);
153 35782656 }
154
155 // Adjust the zero points and scales
156
2/2
✓ Branch 0 taken 1749888 times.
✓ Branch 1 taken 437472 times.
2187360 for (size_t i = 0; i < nr; ++i) {
157 // Clamp the row index to avoid out-of-bound reads
158
2/2
✓ Branch 0 taken 1710576 times.
✓ Branch 1 taken 39312 times.
1749888 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
159
160 1749888 const float* block_zero = (const float*)zero + num_blocks_per_row * src_row_idx;
161 1749888 const float* block_scale = (const float*)scale + num_blocks_per_row * src_row_idx;
162
163 1749888 *block_dst_zp = block_zero[block_idx];
164 1749888 *block_dst_scale = block_scale[block_idx] * 0.0625F;
165
166 1749888 block_dst_zp++;
167 1749888 block_dst_scale++;
168 1749888 }
169 437472 }
170 // Set the bias
171
2/2
✓ Branch 0 taken 125664 times.
✓ Branch 1 taken 125664 times.
251328 if (bias == NULL) {
172 125664 memset(dst_row_bias, 0, nr * kai_num_bytes_bias);
173 125664 } else {
174
2/2
✓ Branch 0 taken 502656 times.
✓ Branch 1 taken 125664 times.
628320 for (size_t i = 0; i < nr; ++i) {
175 // Clamp the row index to avoid out-of-bound reads
176
2/2
✓ Branch 0 taken 488376 times.
✓ Branch 1 taken 14280 times.
502656 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
177
178 502656 dst_row_bias[i] = *((const float*)bias + src_row_idx);
179 502656 }
180 }
181 251328 }
182 17136 }
183 #endif
184