KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 98.1% 105 / 31 / 138
Functions: 85.7% 6 / 0 / 7
Branches: 100.0% 24 / 62 / 86

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10 #include "kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.h"
11
12 #include <arm_neon.h>
13 #include <stdint.h>
14 #include <string.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_num_bytes_offset_rhs = sizeof(float);
19 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
20 static const size_t kai_num_bytes_bias = sizeof(float);
21 static const size_t kai_bl_multiple_of = 32;
22 static const size_t kai_nr_multiple_of = 4;
23
24 12852 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
25 KAI_ASSUME((k % 2) == 0);
26 KAI_ASSUME((k % bl) == 0);
27 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
28 12852 return kai_roundup(k, bl) / bl;
29 }
30
31 17136 inline static size_t kai_get_num_bytes_per_block(size_t bl) {
32 17136 return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs;
33 }
34
35 12852 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) {
36 KAI_ASSUME((k % 2) == 0);
37 KAI_ASSUME((k % kr) == 0);
38 KAI_ASSUME((k % bl) == 0);
39 KAI_ASSUME((bl % kr) == 0);
40 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
41 12852 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
42 12852 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl);
43 25704 return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias);
44 12852 }
45
46 size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) {
47 return n_idx * rhs_stride;
48 }
49
50 4284 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(
51 size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) {
52 KAI_ASSUME((k % 2) == 0);
53 KAI_ASSUME((k % kr) == 0);
54 KAI_ASSUME((k % bl) == 0);
55 KAI_ASSUME((n_idx % nr) == 0);
56 4284 KAI_UNUSED(kr);
57 4284 return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, bl);
58 }
59
60 4284 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(
61 size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
62 KAI_ASSUME((k % 2) == 0);
63 KAI_ASSUME((k % kr) == 0);
64 KAI_ASSUME((k % bl) == 0);
65 4284 KAI_UNUSED(kr);
66 4284 const size_t num_rows = kai_roundup(n, nr) / nr;
67 8568 return num_rows * kai_get_rhs_packed_stride(k, nr, kr, bl);
68 4284 }
69
70 4284 void kai_run_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(
71 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
72 const void* zero, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes,
73 const struct kai_rhs_pack_nxk_qai4c32p_params* params) {
74 KAI_ASSUME(num_groups == 1);
75 KAI_ASSUME((k % 2) == 0);
76 KAI_ASSUME((k % kr) == 0);
77 KAI_ASSUME((k % bl) == 0);
78 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
79 KAI_ASSUME((nr % kai_nr_multiple_of) == 0);
80 KAI_ASSUME(extra_bytes == 0);
81
82 KAI_ASSUME(sr == 2);
83 KAI_ASSUME(kr / sr == 4);
84 KAI_ASSUME(rhs != NULL);
85 KAI_ASSUME(zero != NULL);
86 KAI_ASSUME(scale != NULL);
87 KAI_ASSUME(rhs_packed != NULL);
88 KAI_ASSUME(params != NULL);
89 KAI_ASSUME(params->rhs_zero_point == 8);
90 KAI_ASSUME(params->lhs_zero_point == 1);
91
92 // Note: The input matrix (rhs) is expected with:
93 // "k" columns and "n" rows (NxK)
94
95 4284 const size_t block_length = kr / sr;
96 4284 const size_t num_blocks_per_row = k / bl;
97 4284 const size_t rhs_stride = k / 2;
98 4284 const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, bl);
99
100 4284 const size_t dst_packed_block_size = kai_get_num_bytes_per_block(bl) * nr;
101 4284 const size_t dst_block_data_size = bl / 2;
102 4284 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
103 4284 const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size;
104 4284 const size_t k_block_length_in_bytes = (block_length * sizeof(uint8_t)) / 2;
105
106
2/2
✓ Branch 0 taken 4284 times.
✓ Branch 1 taken 5376 times.
9660 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
107 5376 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
108 5376 float* dst_row_bias = (float*)(dst_row + dst_bias_offset);
109 5376 size_t row_idx = dst_row_idx * nr;
110 5376 size_t rows_left = n - row_idx;
111
2/2
✓ Branch 0 taken 8484 times.
✓ Branch 1 taken 5376 times.
13860 for (size_t block_idx = 0; block_idx < num_blocks_per_row; block_idx++) {
112 8484 uint8_t* block_dst_row = dst_row + block_idx * dst_packed_block_size;
113 8484 float* block_dst_zp = (float*)(block_dst_row + nr * dst_block_data_size);
114 8484 float* block_dst_scale = block_dst_zp + nr;
115 8484 size_t k_idx = block_idx * bl;
116
2/2
✓ Branch 0 taken 21168 times.
✓ Branch 1 taken 8484 times.
29652 for (size_t dst_byte_idx = 0; dst_byte_idx < dst_block_data_size; dst_byte_idx += 8) {
117
2/2
✓ Branch 0 taken 338688 times.
✓ Branch 1 taken 21168 times.
359856 for (size_t nr_idx = 0; nr_idx <= nr - 4; nr_idx += 4) {
118
2/2
✓ Branch 0 taken 278040 times.
✓ Branch 1 taken 60648 times.
338688 const size_t n0_idx = KAI_MIN(dst_row_idx * nr + nr_idx, n - 1);
119
2/2
✓ Branch 0 taken 277704 times.
✓ Branch 1 taken 60984 times.
338688 const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1);
120
2/2
✓ Branch 0 taken 275520 times.
✓ Branch 1 taken 63168 times.
338688 const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1);
121
2/2
✓ Branch 0 taken 263256 times.
✓ Branch 1 taken 75432 times.
338688 const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1);
122 338688 const uint8_t* src_addr_byte = rhs + (k_idx / 2) + dst_byte_idx;
123
124 338688 const uint8x8_t vec0_u8 = vld1_u8(src_addr_byte + n0_idx * rhs_stride);
125 338688 const uint8x8_t vec1_u8 = vld1_u8(src_addr_byte + n1_idx * rhs_stride);
126 338688 const uint8x8_t vec2_u8 = vld1_u8(src_addr_byte + n2_idx * rhs_stride);
127 338688 const uint8x8_t vec3_u8 = vld1_u8(src_addr_byte + n3_idx * rhs_stride);
128
129 338688 const uint16x4_t vec0_u16 = vreinterpret_u16_u8(vec0_u8);
130 338688 const uint16x4_t vec1_u16 = vreinterpret_u16_u8(vec1_u8);
131 338688 const uint16x4_t vec2_u16 = vreinterpret_u16_u8(vec2_u8);
132 338688 const uint16x4_t vec3_u16 = vreinterpret_u16_u8(vec3_u8);
133
134 338688 const uint16x4_t vec01_lo_u16 = vzip1_u16(vec0_u16, vec1_u16);
135 338688 const uint16x4_t vec01_hi_u16 = vzip2_u16(vec0_u16, vec1_u16);
136 338688 const uint16x4_t vec23_lo_u16 = vzip1_u16(vec2_u16, vec3_u16);
137 338688 const uint16x4_t vec23_hi_u16 = vzip2_u16(vec2_u16, vec3_u16);
138
139 338688 const uint32x2_t vec01_lo_u32 = vreinterpret_u32_u16(vec01_lo_u16);
140 338688 const uint32x2_t vec01_hi_u32 = vreinterpret_u32_u16(vec01_hi_u16);
141 338688 const uint32x2_t vec23_lo_u32 = vreinterpret_u32_u16(vec23_lo_u16);
142 338688 const uint32x2_t vec23_hi_u32 = vreinterpret_u32_u16(vec23_hi_u16);
143
144 338688 const uint32x2_t vin0_u32 = vzip1_u32(vec01_lo_u32, vec23_lo_u32);
145 338688 const uint32x2_t vin1_u32 = vzip2_u32(vec01_lo_u32, vec23_lo_u32);
146 338688 const uint32x2_t vin2_u32 = vzip1_u32(vec01_hi_u32, vec23_hi_u32);
147 338688 const uint32x2_t vin3_u32 = vzip2_u32(vec01_hi_u32, vec23_hi_u32);
148
149 338688 uint8x8_t vin0_u8 = vreinterpret_u8_u32(vin0_u32);
150 338688 uint8x8_t vin1_u8 = vreinterpret_u8_u32(vin1_u32);
151 338688 uint8x8_t vin2_u8 = vreinterpret_u8_u32(vin2_u32);
152 338688 uint8x8_t vin3_u8 = vreinterpret_u8_u32(vin3_u32);
153
154 338688 const uint8x8_t vin0_s1s = vshr_n_u8(vin0_u8, 4);
155 338688 const uint8x8_t vin1_s1s = vshr_n_u8(vin1_u8, 4);
156 338688 const uint8x8_t vin2_s1s = vshr_n_u8(vin2_u8, 4);
157 338688 const uint8x8_t vin3_s1s = vshr_n_u8(vin3_u8, 4);
158
159 338688 vin0_u8 = vshl_n_u8(vin0_u8, 4);
160 338688 vin1_u8 = vshl_n_u8(vin1_u8, 4);
161 338688 vin2_u8 = vshl_n_u8(vin2_u8, 4);
162 338688 vin3_u8 = vshl_n_u8(vin3_u8, 4);
163
164 338688 vin0_u8 = vorr_u8(vin0_u8, vin0_s1s);
165 338688 vin1_u8 = vorr_u8(vin1_u8, vin1_s1s);
166 338688 vin2_u8 = vorr_u8(vin2_u8, vin2_s1s);
167 338688 vin3_u8 = vorr_u8(vin3_u8, vin3_s1s);
168
169 338688 uint8_t* dst_row_offset = block_dst_row + nr_idx * k_block_length_in_bytes;
170 338688 vst1_u8(dst_row_offset, vin0_u8);
171 338688 vst1_u8(dst_row_offset + nr * k_block_length_in_bytes, vin1_u8);
172 338688 vst1_u8(dst_row_offset + 2 * (nr * k_block_length_in_bytes), vin2_u8);
173 338688 vst1_u8(dst_row_offset + 3 * (nr * k_block_length_in_bytes), vin3_u8);
174 338688 }
175 21168 block_dst_row += nr * sizeof(uint8x8_t);
176 21168 }
177
178 // Adjust the zero points and scales
179
2/2
✓ Branch 0 taken 542976 times.
✓ Branch 1 taken 8484 times.
551460 for (size_t i = 0; i < nr; ++i) {
180
2/2
✓ Branch 0 taken 427644 times.
✓ Branch 1 taken 115332 times.
542976 const size_t src_row_idx = KAI_MIN(row_idx + i, n - 1);
181 542976 const size_t src_idx = src_row_idx * num_blocks_per_row + block_idx;
182
183 542976 block_dst_scale[i] = ((const float*)scale)[src_idx];
184 542976 block_dst_zp[i] = ((const float*)zero)[src_idx];
185 542976 }
186 8484 }
187 // Set the bias
188
2/2
✓ Branch 0 taken 2688 times.
✓ Branch 1 taken 2688 times.
5376 if (bias == NULL) {
189 2688 memset(dst_row_bias, 0, nr * kai_num_bytes_bias);
190 2688 } else {
191
2/2
✓ Branch 0 taken 1596 times.
✓ Branch 1 taken 1092 times.
2688 if (rows_left >= nr) {
192 1596 memcpy(dst_row_bias, &((const float*)bias)[row_idx], nr * kai_num_bytes_bias);
193 1596 } else {
194 // Fill remaining values
195 1092 memcpy(dst_row_bias, &((const float*)bias)[row_idx], rows_left * kai_num_bytes_bias);
196 // Set leftover to 0
197 1092 memset(&dst_row_bias[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias);
198 }
199 }
200 5376 }
201 4284 }
202 #endif
203