KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.7% 84 29 115
Functions: 85.7% 6 0 7
Branches: 90.9% 20 58 80

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10 #include "kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.h"
11
12 #include <stdint.h>
13 #include <string.h>
14
15 #include "kai/kai_common.h"
16
17 static const size_t kai_num_bytes_offset_rhs = sizeof(float);
18 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
19 static const size_t kai_num_bytes_bias = sizeof(float);
20 static const size_t kai_bl_multiple_of = 32;
21
22 7848 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
23 KAI_ASSUME((k % 2) == 0);
24 KAI_ASSUME((k % bl) == 0);
25 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
26 7848 return kai_roundup(k, bl) / bl;
27 }
28
29 10464 inline static size_t kai_get_num_bytes_per_block(size_t bl) {
30 10464 return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs;
31 }
32
33 7848 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) {
34 KAI_ASSUME((k % 2) == 0);
35 KAI_ASSUME((k % kr) == 0);
36 KAI_ASSUME((k % bl) == 0);
37 KAI_ASSUME((bl % kr) == 0);
38 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
39 7848 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
40 7848 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl);
41 15696 return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias);
42 7848 }
43
44 size_t kai_get_rhs_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) {
45 return n_idx * rhs_stride;
46 }
47
48 2616 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(
49 size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) {
50 KAI_ASSUME((k % 2) == 0);
51 KAI_ASSUME((k % kr) == 0);
52 KAI_ASSUME((k % bl) == 0);
53 KAI_ASSUME((n_idx % nr) == 0);
54 2616 KAI_UNUSED(kr);
55 2616 return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, bl);
56 }
57
58 2616 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(
59 size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
60 KAI_ASSUME((k % 2) == 0);
61 KAI_ASSUME((k % kr) == 0);
62 KAI_ASSUME((k % bl) == 0);
63 2616 KAI_UNUSED(kr);
64 2616 const size_t num_rows = kai_roundup(n, nr) / nr;
65 5232 return num_rows * kai_get_rhs_packed_stride(k, nr, kr, bl);
66 2616 }
67
68 2616 void kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon(
69 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
70 const void* zero, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes,
71 const struct kai_rhs_pack_nxk_qai4c32p_params* params) {
72 KAI_ASSUME(num_groups == 1);
73 KAI_ASSUME((k % 2) == 0);
74 KAI_ASSUME((k % kr) == 0);
75 KAI_ASSUME((k % bl) == 0);
76 KAI_ASSUME((bl % 32) == 0);
77 KAI_ASSUME(extra_bytes == 0);
78
79 KAI_ASSUME(sr == 2);
80 KAI_ASSUME(kr >= 1 && kr <= 16);
81 KAI_ASSUME(rhs != NULL);
82 KAI_ASSUME(zero != NULL);
83 KAI_ASSUME(rhs_packed != NULL);
84 KAI_ASSUME(params != NULL);
85 KAI_ASSUME(params->rhs_zero_point == 8);
86 KAI_ASSUME(params->lhs_zero_point == 1);
87
88 // Note: The input matrix (rhs) is expected with:
89 // "k" columns and "n" rows (NxK)
90
91 2616 const size_t num_blocks_per_row = k / bl;
92 2616 const size_t rhs_stride = k;
93 2616 const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, bl);
94
95 2616 const size_t dst_packed_block_size = kai_get_num_bytes_per_block(bl) * nr;
96 2616 const size_t dst_block_data_size = (bl / 2) * nr;
97 2616 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
98 2616 const size_t dst_bias_offset = num_blocks_per_row * dst_packed_block_size;
99 2616 const size_t k_block_length_in_bytes = kr / sr;
100 2616 const size_t k_interleaved_v = 16U;
101
102 2616 const size_t rhs_zero_point = params->rhs_zero_point;
103
104
2/2
✓ Branch 0 taken 2616 times.
✓ Branch 1 taken 38448 times.
41064 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
105 38448 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
106 38448 float* dst_row_bias = (float*)(dst_row + dst_bias_offset);
107
108
2/2
✓ Branch 0 taken 66144 times.
✓ Branch 1 taken 38448 times.
104592 for (size_t block_idx = 0; block_idx < num_blocks_per_row; block_idx++) {
109 66144 uint8_t* block_dst_row = dst_row + block_idx * dst_packed_block_size;
110 66144 float* block_dst_zp = (float*)(block_dst_row + dst_block_data_size);
111 66144 float* block_dst_scale = block_dst_zp + nr;
112
113
2/2
✓ Branch 0 taken 5415936 times.
✓ Branch 1 taken 66144 times.
5482080 for (size_t block_byte_idx = 0; block_byte_idx < dst_block_data_size; ++block_byte_idx) {
114 5415936 const size_t dst_byte_idx = block_byte_idx;
115 5415936 const size_t k_block_idx = dst_byte_idx / k_block_length_in_bytes;
116 5415936 const size_t k_block_byte_idx = dst_byte_idx % k_block_length_in_bytes;
117 5415936 const size_t super_k_block_idx = k_block_idx / nr;
118 5415936 const size_t nr_idx = k_block_idx % nr;
119
120 10831872 const size_t k_adjustment =
121 5415936 ((k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes) / k_interleaved_v) *
122 k_interleaved_v;
123 5415936 const size_t k0_idx = k_block_byte_idx + super_k_block_idx * k_block_length_in_bytes + k_adjustment;
124 5415936 const size_t k1_idx = k0_idx + k_interleaved_v;
125 5415936 const size_t n0_idx = dst_row_idx * nr + nr_idx;
126
127 // Clamp the index to avoid out-of-bound reads
128
2/2
✓ Branch 0 taken 5300096 times.
✓ Branch 1 taken 115840 times.
5415936 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
129
130 5415936 const size_t src_addr_byte0 = (k0_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2;
131 5415936 const size_t src_addr_byte1 = (k1_idx + n0_valid_idx * rhs_stride + block_idx * bl) / 2;
132
133 5415936 uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4;
134 5415936 uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4;
135
136
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 5415936 times.
5415936 if (k0_idx < k) {
137 5415936 byte0 = rhs[src_addr_byte0];
138 5415936 }
139
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 5415936 times.
5415936 if (k1_idx < k) {
140 5415936 byte1 = rhs[src_addr_byte1];
141 5415936 }
142
143 5415936 const size_t shift_right_x0 = (k0_idx % 2 == 0) ? 4 : 0;
144 5415936 const size_t shift_right_x1 = (k1_idx % 2 == 0) ? 4 : 0;
145
146 5415936 const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F;
147 5415936 const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F;
148
149 5415936 const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4);
150
151 5415936 *block_dst_row = dst_qs0 ^ 0x88;
152 5415936 block_dst_row += sizeof(uint8_t);
153 5415936 }
154
155 // Adjust the zero points and scales
156
2/2
✓ Branch 0 taken 264576 times.
✓ Branch 1 taken 66144 times.
330720 for (size_t i = 0; i < nr; ++i) {
157 // Clamp the row index to avoid out-of-bound reads
158
2/2
✓ Branch 0 taken 258600 times.
✓ Branch 1 taken 5976 times.
264576 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
159
160 264576 const float* block_zero = (const float*)zero + num_blocks_per_row * src_row_idx;
161 264576 const float* block_scale = (const float*)scale + num_blocks_per_row * src_row_idx;
162
163 264576 *block_dst_zp = block_zero[block_idx];
164 264576 *block_dst_scale = block_scale[block_idx] * 0.0625F;
165
166 264576 block_dst_zp++;
167 264576 block_dst_scale++;
168 264576 }
169 66144 }
170 // Set the bias
171
2/2
✓ Branch 0 taken 19224 times.
✓ Branch 1 taken 19224 times.
38448 if (bias == NULL) {
172 19224 memset(dst_row_bias, 0, nr * kai_num_bytes_bias);
173 19224 } else {
174
2/2
✓ Branch 0 taken 76896 times.
✓ Branch 1 taken 19224 times.
96120 for (size_t i = 0; i < nr; ++i) {
175 // Clamp the row index to avoid out-of-bound reads
176
2/2
✓ Branch 0 taken 74732 times.
✓ Branch 1 taken 2164 times.
76896 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
177
178 76896 dst_row_bias[i] = *((const float*)bias + src_row_idx);
179 76896 }
180 }
181 38448 }
182 2616 }
183 #endif
184