KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.2% 107 31 140
Functions: 85.7% 6 0 7
Branches: 100.0% 22 62 84

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10 #include "kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.h"
11
12 #include <arm_neon.h>
13 #include <stdint.h>
14 #include <string.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_num_bytes_offset_rhs = sizeof(float);
19 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
20 static const size_t kai_num_bytes_bias = sizeof(float);
21 static const size_t kai_bl_multiple_of = 32;
22 static const size_t kai_nr_multiple_of = 4;
23
24 3924 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
25 KAI_ASSUME((k % 2) == 0);
26 KAI_ASSUME((k % bl) == 0);
27 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
28 3924 return kai_roundup(k, bl) / bl;
29 }
30
31 5232 inline static size_t kai_get_num_bytes_per_block(size_t bl) {
32 5232 return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs;
33 }
34
35 3924 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) {
36 KAI_ASSUME((k % 2) == 0);
37 KAI_ASSUME((k % kr) == 0);
38 KAI_ASSUME((k % bl) == 0);
39 KAI_ASSUME((bl % kr) == 0);
40 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
41 3924 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
42 3924 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl);
43 7848 return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias);
44 3924 }
45
46 size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) {
47 return n_idx * rhs_stride;
48 }
49
50 1308 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(
51 size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) {
52 KAI_ASSUME((k % 2) == 0);
53 KAI_ASSUME((k % kr) == 0);
54 KAI_ASSUME((k % bl) == 0);
55 KAI_ASSUME((n_idx % nr) == 0);
56 1308 KAI_UNUSED(kr);
57 1308 return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, bl);
58 }
59
60 1308 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(
61 size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
62 KAI_ASSUME((k % 2) == 0);
63 KAI_ASSUME((k % kr) == 0);
64 KAI_ASSUME((k % bl) == 0);
65 1308 KAI_UNUSED(kr);
66 1308 const size_t num_rows = kai_roundup(n, nr) / nr;
67 2616 return num_rows * kai_get_rhs_packed_stride(k, nr, kr, bl);
68 1308 }
69
70 1308 void kai_run_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon(
71 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
72 const void* zero, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes,
73 const struct kai_rhs_pack_nxk_qai4c32p_params* params) {
74 KAI_ASSUME(num_groups == 1);
75 KAI_ASSUME((k % 2) == 0);
76 KAI_ASSUME((k % kr) == 0);
77 KAI_ASSUME((k % bl) == 0);
78 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
79 KAI_ASSUME((nr % kai_nr_multiple_of) == 0);
80 KAI_ASSUME(extra_bytes == 0);
81
82 KAI_ASSUME(sr == 2);
83 KAI_ASSUME(kr / sr == 4);
84 KAI_ASSUME(rhs != NULL);
85 KAI_ASSUME(zero != NULL);
86 KAI_ASSUME(scale != NULL);
87 KAI_ASSUME(rhs_packed != NULL);
88 KAI_ASSUME(params != NULL);
89 KAI_ASSUME(params->rhs_zero_point == 8);
90 KAI_ASSUME(params->lhs_zero_point == 1);
91
92 // Note: The input matrix (rhs) is expected with:
93 // "k" columns and "n" rows (NxK)
94
95 1308 const size_t block_length = kr / sr;
96 1308 const size_t num_blocks_per_row = k / bl;
97 1308 const size_t rhs_stride = k / 2;
98 1308 const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, bl);
99
100 1308 const size_t dst_packed_block_size = kai_get_num_bytes_per_block(bl) * nr;
101 1308 const size_t dst_block_data_size = bl / 2;
102 1308 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
103 1308 const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size;
104 1308 const size_t k_block_length_in_bytes = (block_length * sizeof(uint8_t)) / 2;
105
106
2/2
✓ Branch 0 taken 1308 times.
✓ Branch 1 taken 1640 times.
2948 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
107 1640 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
108 1640 float* dst_row_bias = (float*)(dst_row + dst_bias_offset);
109 1640 size_t row_idx = dst_row_idx * nr;
110 1640 size_t rows_left = n - row_idx;
111
2/2
✓ Branch 0 taken 2572 times.
✓ Branch 1 taken 1640 times.
4212 for (size_t block_idx = 0; block_idx < num_blocks_per_row; block_idx++) {
112 2572 uint8_t* block_dst_row = dst_row + block_idx * dst_packed_block_size;
113 2572 float* block_dst_zp = (float*)(block_dst_row + nr * dst_block_data_size);
114 2572 float* block_dst_scale = block_dst_zp + nr;
115 2572 size_t k_idx = block_idx * bl;
116
2/2
✓ Branch 0 taken 6432 times.
✓ Branch 1 taken 2572 times.
9004 for (size_t dst_byte_idx = 0; dst_byte_idx < dst_block_data_size; dst_byte_idx += 8) {
117
2/2
✓ Branch 0 taken 102912 times.
✓ Branch 1 taken 6432 times.
109344 for (size_t nr_idx = 0; nr_idx <= nr - 4; nr_idx += 4) {
118
2/2
✓ Branch 0 taken 84152 times.
✓ Branch 1 taken 18760 times.
102912 const size_t n0_idx = KAI_MIN(dst_row_idx * nr + nr_idx, n - 1);
119
2/2
✓ Branch 0 taken 84072 times.
✓ Branch 1 taken 18840 times.
102912 const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1);
120
2/2
✓ Branch 0 taken 83376 times.
✓ Branch 1 taken 19536 times.
102912 const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1);
121
2/2
✓ Branch 0 taken 79656 times.
✓ Branch 1 taken 23256 times.
102912 const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1);
122 102912 const uint8_t* src_addr_byte = rhs + (k_idx / 2) + dst_byte_idx;
123
124 102912 const uint8x8_t vec0_u8 = vld1_u8(src_addr_byte + n0_idx * rhs_stride);
125 102912 const uint8x8_t vec1_u8 = vld1_u8(src_addr_byte + n1_idx * rhs_stride);
126 102912 const uint8x8_t vec2_u8 = vld1_u8(src_addr_byte + n2_idx * rhs_stride);
127 102912 const uint8x8_t vec3_u8 = vld1_u8(src_addr_byte + n3_idx * rhs_stride);
128
129 102912 const uint16x4_t vec0_u16 = vreinterpret_u16_u8(vec0_u8);
130 102912 const uint16x4_t vec1_u16 = vreinterpret_u16_u8(vec1_u8);
131 102912 const uint16x4_t vec2_u16 = vreinterpret_u16_u8(vec2_u8);
132 102912 const uint16x4_t vec3_u16 = vreinterpret_u16_u8(vec3_u8);
133
134 102912 const uint16x4_t vec01_lo_u16 = vzip1_u16(vec0_u16, vec1_u16);
135 102912 const uint16x4_t vec01_hi_u16 = vzip2_u16(vec0_u16, vec1_u16);
136 102912 const uint16x4_t vec23_lo_u16 = vzip1_u16(vec2_u16, vec3_u16);
137 102912 const uint16x4_t vec23_hi_u16 = vzip2_u16(vec2_u16, vec3_u16);
138
139 102912 const uint32x2_t vec01_lo_u32 = vreinterpret_u32_u16(vec01_lo_u16);
140 102912 const uint32x2_t vec01_hi_u32 = vreinterpret_u32_u16(vec01_hi_u16);
141 102912 const uint32x2_t vec23_lo_u32 = vreinterpret_u32_u16(vec23_lo_u16);
142 102912 const uint32x2_t vec23_hi_u32 = vreinterpret_u32_u16(vec23_hi_u16);
143
144 102912 const uint32x2_t vin0_u32 = vzip1_u32(vec01_lo_u32, vec23_lo_u32);
145 102912 const uint32x2_t vin1_u32 = vzip2_u32(vec01_lo_u32, vec23_lo_u32);
146 102912 const uint32x2_t vin2_u32 = vzip1_u32(vec01_hi_u32, vec23_hi_u32);
147 102912 const uint32x2_t vin3_u32 = vzip2_u32(vec01_hi_u32, vec23_hi_u32);
148
149 102912 uint8x8_t vin0_u8 = vreinterpret_u8_u32(vin0_u32);
150 102912 uint8x8_t vin1_u8 = vreinterpret_u8_u32(vin1_u32);
151 102912 uint8x8_t vin2_u8 = vreinterpret_u8_u32(vin2_u32);
152 102912 uint8x8_t vin3_u8 = vreinterpret_u8_u32(vin3_u32);
153
154 102912 const uint8x8_t vin0_s1s = vshr_n_u8(vin0_u8, 4);
155 102912 const uint8x8_t vin1_s1s = vshr_n_u8(vin1_u8, 4);
156 102912 const uint8x8_t vin2_s1s = vshr_n_u8(vin2_u8, 4);
157 102912 const uint8x8_t vin3_s1s = vshr_n_u8(vin3_u8, 4);
158
159 102912 vin0_u8 = vshl_n_u8(vin0_u8, 4);
160 102912 vin1_u8 = vshl_n_u8(vin1_u8, 4);
161 102912 vin2_u8 = vshl_n_u8(vin2_u8, 4);
162 102912 vin3_u8 = vshl_n_u8(vin3_u8, 4);
163
164 102912 vin0_u8 = vorr_u8(vin0_u8, vin0_s1s);
165 102912 vin1_u8 = vorr_u8(vin1_u8, vin1_s1s);
166 102912 vin2_u8 = vorr_u8(vin2_u8, vin2_s1s);
167 102912 vin3_u8 = vorr_u8(vin3_u8, vin3_s1s);
168
169 102912 uint8_t* dst_row_offset = block_dst_row + nr_idx * k_block_length_in_bytes;
170 102912 vst1_u8(dst_row_offset, vin0_u8);
171 102912 vst1_u8(dst_row_offset + nr * k_block_length_in_bytes, vin1_u8);
172 102912 vst1_u8(dst_row_offset + 2 * (nr * k_block_length_in_bytes), vin2_u8);
173 102912 vst1_u8(dst_row_offset + 3 * (nr * k_block_length_in_bytes), vin3_u8);
174 102912 }
175 6432 block_dst_row += nr * sizeof(uint8x8_t);
176 6432 }
177
178 // Adjust the zero points and scales
179
2/2
✓ Branch 0 taken 1760 times.
✓ Branch 1 taken 812 times.
2572 if (rows_left >= nr) {
180 1760 memcpy(block_dst_scale, &((const float*)scale)[row_idx], nr * kai_num_bytes_multiplier_rhs);
181 1760 memcpy(block_dst_zp, &((const float*)zero)[row_idx], nr * kai_num_bytes_offset_rhs);
182 1760 } else {
183 // Fill remaining values
184 812 memcpy(block_dst_scale, &((const float*)scale)[row_idx], rows_left * kai_num_bytes_multiplier_rhs);
185 812 memcpy(block_dst_zp, &((const float*)zero)[row_idx], rows_left * kai_num_bytes_offset_rhs);
186 // Set leftover to 0
187 812 memset(&block_dst_scale[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs);
188 812 memset(&block_dst_zp[rows_left], 0, (nr - rows_left) * kai_num_bytes_offset_rhs);
189 }
190 2572 }
191 // Set the bias
192
2/2
✓ Branch 0 taken 820 times.
✓ Branch 1 taken 820 times.
1640 if (bias == NULL) {
193 820 memset(dst_row_bias, 0, nr * kai_num_bytes_bias);
194 820 } else {
195
2/2
✓ Branch 0 taken 484 times.
✓ Branch 1 taken 336 times.
820 if (rows_left >= nr) {
196 484 memcpy(dst_row_bias, &((const float*)bias)[row_idx], nr * kai_num_bytes_bias);
197 484 } else {
198 // Fill remaining values
199 336 memcpy(dst_row_bias, &((const float*)bias)[row_idx], rows_left * kai_num_bytes_bias);
200 // Set leftover to 0
201 336 memset(&dst_row_bias[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias);
202 }
203 }
204 1640 }
205 1308 }
206 #endif
207