KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 94.2% 65 8 77
Functions: 71.4% 5 0 7
Branches: 100.0% 22 16 38

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
15 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
16 static const size_t kai_num_bytes_bias = sizeof(float);
17
18 5160 inline static size_t kai_k_roundedup(size_t k) {
19 // Round up k to be a multiple of 32.
20 5160 size_t kai_k_multiple_of = 32;
21 10320 return kai_roundup(k, kai_k_multiple_of);
22 5160 }
23
24 size_t kai_get_n_step_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(size_t nr) {
25 return nr;
26 }
27
28 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(size_t n_idx, size_t rhs_stride) {
29 return n_idx * rhs_stride;
30 }
31
32 4024 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(size_t k, size_t nr, size_t kr, size_t sr) {
33 4024 KAI_UNUSED(kr);
34 4024 KAI_UNUSED(sr);
35 4024 const size_t k_internal = kai_k_roundedup(k);
36
37 8048 return nr * (k_internal + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
38 4024 }
39
40 1752 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(
41 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
42 KAI_ASSERT((n_idx % nr) == 0);
43
44 1752 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
45 }
46
47 1136 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
48 1136 const size_t num_rows = kai_roundup(n, nr) / nr;
49
50 2272 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
51 1136 }
52
53 1136 void kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(
54 size_t num_groups, //
55 size_t n, //
56 size_t k, //
57 size_t nr, //
58 size_t kr, //
59 size_t sr, //
60 const int8_t* rhs, //
61 const float* bias, //
62 const float* scale, //
63 void* rhs_packed, //
64 size_t extra_bytes, const struct kai_rhs_pack_qsi8cx_params* params) {
65 KAI_ASSERT(num_groups == 1);
66 KAI_ASSERT(extra_bytes == 0);
67 KAI_ASSERT(sr == 1);
68 KAI_ASSERT(rhs != NULL);
69 KAI_ASSERT(scale != NULL);
70 KAI_ASSERT(rhs_packed != NULL);
71 KAI_ASSERT(params != NULL);
72
73 1136 const int32_t lhs_zero_point = params->lhs_zero_point;
74 1136 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
75 1136 const size_t k_internal = kai_k_roundedup(k);
76 1136 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
77 1136 const size_t dst_num_bytes_per_row = nr * k_internal;
78 1136 const size_t rhs_stride = k;
79
80
2/2
✓ Branch 0 taken 1136 times.
✓ Branch 1 taken 10436 times.
11572 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
81 10436 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
82
83 10436 int32_t* sums = (int32_t*)(dst_row + nr * k_internal);
84
85 // Initialize to zero the RHS reduction sums
86 10436 memset(sums, 0, nr * sizeof(int32_t));
87
88
2/2
✓ Branch 0 taken 767072 times.
✓ Branch 1 taken 10436 times.
777508 for (size_t dst_offset = 0; dst_offset < dst_num_bytes_per_row; dst_offset += kr) {
89 767072 const size_t block_idx = dst_offset / kr;
90 767072 const size_t nr_idx = block_idx % nr;
91 767072 const size_t super_block_idx = block_idx / nr;
92
93 767072 const size_t k0_idx = super_block_idx * kr;
94 767072 const size_t n0_idx = dst_row_idx * nr + nr_idx;
95
96 // Clamp the index to avoid out-of-bound reads
97
2/2
✓ Branch 0 taken 625872 times.
✓ Branch 1 taken 141200 times.
767072 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
98
99 767072 const size_t src_offset = n0_valid_idx * rhs_stride;
100
101 767072 int32_t partial_sum = 0;
102
103 // Get the partial reduction sum
104
2/2
✓ Branch 0 taken 3588608 times.
✓ Branch 1 taken 767072 times.
4355680 for (size_t i = 0; i < kr; i++) {
105 3588608 const size_t k0_valid_idx = k0_idx + i;
106 3588608 int8_t v = 0;
107
2/2
✓ Branch 0 taken 1115216 times.
✓ Branch 1 taken 2473392 times.
3588608 if (k0_valid_idx < k) {
108 2473392 v = rhs[src_offset + k0_valid_idx];
109 2473392 }
110 3588608 ((int8_t*)dst_row)[i] = v;
111 3588608 partial_sum += v;
112 3588608 }
113
114 767072 sums[nr_idx] += partial_sum * lhs_zero_point;
115
116 767072 dst_row += kr;
117 767072 }
118
119 // Adjust the reduction sums
120
2/2
✓ Branch 0 taken 68384 times.
✓ Branch 1 taken 10436 times.
78820 for (size_t i = 0; i < nr; ++i) {
121 68384 dst_row += sizeof(int32_t);
122 68384 }
123
124 // Adjust the scales
125
2/2
✓ Branch 0 taken 68384 times.
✓ Branch 1 taken 10436 times.
78820 for (size_t i = 0; i < nr; ++i) {
126 // Clamp the row index to avoid out-of-bound reads
127
2/2
✓ Branch 0 taken 57164 times.
✓ Branch 1 taken 11220 times.
68384 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
128 68384 *((float*)(dst_row)) = scale[src_row_idx];
129 68384 dst_row += sizeof(float);
130 68384 }
131
132 // Set the bias
133
2/2
✓ Branch 0 taken 7908 times.
✓ Branch 1 taken 2528 times.
10436 if (bias == NULL) {
134 2528 memset(dst_row, 0, nr * kai_num_bytes_bias);
135 2528 } else {
136
2/2
✓ Branch 0 taken 58272 times.
✓ Branch 1 taken 7908 times.
66180 for (size_t i = 0; i < nr; ++i) {
137 // Clamp the row index to avoid out-of-bound reads
138
2/2
✓ Branch 0 taken 47590 times.
✓ Branch 1 taken 10682 times.
58272 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
139 58272 ((float*)dst_row)[i] = bias[src_row_idx];
140 58272 }
141 }
142 10436 }
143 1136 }
144