KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 90.0% 63 / 8 / 78
Functions: 71.4% 5 / 0 / 7
Branches: 95.5% 21 / 16 / 38

kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
15 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
16 static const size_t kai_num_bytes_bias = sizeof(float);
17
18 13860 inline static size_t kai_k_roundedup(size_t k) {
19 // Round up k to be a multiple of 32.
20 13860 size_t kai_k_multiple_of = 32;
21 27720 return kai_roundup(k, kai_k_multiple_of);
22 13860 }
23
24 size_t kai_get_n_step_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(size_t nr) {
25 return nr;
26 }
27
28 size_t kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(size_t n_idx, size_t rhs_stride) {
29 KAI_UNUSED(rhs_stride);
30 return n_idx;
31 }
32
33 11088 size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(size_t k, size_t nr, size_t kr, size_t sr) {
34 11088 KAI_UNUSED(kr);
35 11088 KAI_UNUSED(sr);
36 11088 const size_t k_internal = kai_k_roundedup(k);
37
38 22176 return nr * (k_internal + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
39 11088 }
40
41 5544 size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(
42 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
43 KAI_ASSERT((n_idx % nr) == 0);
44
45 5544 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
46 }
47
48 2772 size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
49 2772 const size_t num_rows = kai_roundup(n, nr) / nr;
50
51 5544 return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
52 2772 }
53
54 2772 void kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(
55 size_t num_groups, //
56 size_t n, //
57 size_t k, //
58 size_t nr, //
59 size_t kr, //
60 size_t sr, //
61 const int8_t* rhs, //
62 const float* bias, //
63 const float* scale, //
64 void* rhs_packed, //
65 size_t extra_bytes, const struct kai_rhs_pack_qsi8cx_params* params) {
66 KAI_ASSERT(num_groups == 1);
67 KAI_ASSERT(extra_bytes == 0);
68 KAI_ASSERT(sr == 1);
69 KAI_ASSERT(rhs != NULL);
70 KAI_ASSERT(scale != NULL);
71 KAI_ASSERT(rhs_packed != NULL);
72 KAI_ASSERT(params != NULL);
73
74 2772 const int32_t lhs_zero_point = params->lhs_zero_point;
75 2772 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
76 2772 const size_t k_internal = kai_k_roundedup(k);
77 2772 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
78 2772 const size_t dst_num_bytes_per_row = nr * k_internal;
79 2772 const size_t rhs_stride = n;
80
81
2/2
✓ Branch 0 taken 2772 times.
✓ Branch 1 taken 33180 times.
35952 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
82 33180 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
83
84 33180 int32_t* sums = (int32_t*)(dst_row + nr * k_internal);
85
86 // Initialize to zero the RHS reduction sums
87 33180 memset(sums, 0, nr * sizeof(int32_t));
88
89
2/2
✓ Branch 0 taken 2470272 times.
✓ Branch 1 taken 33180 times.
2503452 for (size_t dst_offset = 0; dst_offset < dst_num_bytes_per_row; dst_offset += kr) {
90 2470272 const size_t block_idx = dst_offset / kr;
91 2470272 const size_t nr_idx = block_idx % nr;
92 2470272 const size_t super_block_idx = block_idx / nr;
93
94 2470272 const size_t k0_idx = super_block_idx * kr;
95 2470272 const size_t n0_idx = dst_row_idx * nr + nr_idx;
96
97 // Clamp the index to avoid out-of-bound reads
98
2/2
✓ Branch 0 taken 2026080 times.
✓ Branch 1 taken 444192 times.
2470272 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
99
100 2470272 const size_t src_offset = n0_valid_idx;
101
102 2470272 int32_t partial_sum = 0;
103
104 // Get the partial reduction sum
105
2/2
✓ Branch 0 taken 11569152 times.
✓ Branch 1 taken 2470272 times.
14039424 for (size_t i = 0; i < kr; i++) {
106 11569152 const size_t k0_valid_idx = k0_idx + i;
107 11569152 int8_t v = 0;
108
2/2
✓ Branch 0 taken 4489632 times.
✓ Branch 1 taken 7079520 times.
11569152 if (k0_valid_idx < k) {
109 7079520 v = rhs[src_offset + (k0_valid_idx * rhs_stride)];
110 7079520 }
111
112 11569152 ((int8_t*)dst_row)[i] = v;
113 11569152 partial_sum += v;
114 11569152 }
115
116 2470272 sums[nr_idx] += partial_sum * lhs_zero_point;
117
118 2470272 dst_row += kr;
119 2470272 }
120
121 // Adjust the reduction sums
122
2/2
✓ Branch 0 taken 218400 times.
✓ Branch 1 taken 33180 times.
251580 for (size_t i = 0; i < nr; ++i) {
123 218400 dst_row += sizeof(int32_t);
124 218400 }
125
126 // Adjust the scales
127
2/2
✓ Branch 0 taken 218400 times.
✓ Branch 1 taken 33180 times.
251580 for (size_t i = 0; i < nr; ++i) {
128 // Clamp the row index to avoid out-of-bound reads
129
2/2
✓ Branch 0 taken 183456 times.
✓ Branch 1 taken 34944 times.
218400 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
130 218400 *((float*)(dst_row)) = scale[src_row_idx];
131 218400 dst_row += sizeof(float);
132 218400 }
133
134 // Set the bias
135
1/2
✓ Branch 0 taken 33180 times.
✗ Branch 1 not taken.
33180 if (bias == NULL) {
136 memset(dst_row, 0, nr * kai_num_bytes_bias);
137 } else {
138
2/2
✓ Branch 0 taken 218400 times.
✓ Branch 1 taken 33180 times.
251580 for (size_t i = 0; i < nr; ++i) {
139 // Clamp the row index to avoid out-of-bound reads
140
2/2
✓ Branch 0 taken 183456 times.
✓ Branch 1 taken 34944 times.
218400 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
141 218400 ((float*)dst_row)[i] = bias[src_row_idx];
142 218400 }
143 }
144 33180 }
145 2772 }
146