KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 94.2% 65 / 8 / 77
Functions: 71.4% 5 / 0 / 7
Branches: 100.0% 22 / 16 / 38

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
15 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
16 static const size_t kai_num_bytes_bias = sizeof(float);
17
18 28644 inline static size_t kai_k_roundedup(size_t k) {
19 // Round up k to be a multiple of 32.
20 28644 size_t kai_k_multiple_of = 32;
21 57288 return kai_roundup(k, kai_k_multiple_of);
22 28644 }
23
24 size_t kai_get_n_step_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(size_t nr) {
25 return nr;
26 }
27
28 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(size_t n_idx, size_t rhs_stride) {
29 return n_idx * rhs_stride;
30 }
31
32 22176 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(size_t k, size_t nr, size_t kr, size_t sr) {
33 22176 KAI_UNUSED(kr);
34 22176 KAI_UNUSED(sr);
35 22176 const size_t k_internal = kai_k_roundedup(k);
36
37 44352 return nr * (k_internal + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
38 22176 }
39
40 9240 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(
41 size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) {
42 KAI_ASSERT((n_idx % nr) == 0);
43
44 9240 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
45 }
46
47 6468 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(size_t n, size_t k, size_t nr, size_t kr, size_t sr) {
48 6468 const size_t num_rows = kai_roundup(n, nr) / nr;
49
50 12936 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
51 6468 }
52
53 6468 void kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(
54 size_t num_groups, //
55 size_t n, //
56 size_t k, //
57 size_t nr, //
58 size_t kr, //
59 size_t sr, //
60 const int8_t* rhs, //
61 const float* bias, //
62 const float* scale, //
63 void* rhs_packed, //
64 size_t extra_bytes, const struct kai_rhs_pack_qsi8cx_params* params) {
65 KAI_ASSERT(num_groups == 1);
66 KAI_ASSERT(extra_bytes == 0);
67 KAI_ASSERT(sr == 1);
68 KAI_ASSERT(rhs != NULL);
69 KAI_ASSERT(scale != NULL);
70 KAI_ASSERT(rhs_packed != NULL);
71 KAI_ASSERT(params != NULL);
72
73 6468 const int32_t lhs_zero_point = params->lhs_zero_point;
74 6468 const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(k, nr, kr, sr);
75 6468 const size_t k_internal = kai_k_roundedup(k);
76 6468 const size_t dst_num_rows = kai_roundup(n, nr) / nr;
77 6468 const size_t dst_num_bytes_per_row = nr * k_internal;
78 6468 const size_t rhs_stride = k;
79
80
2/2
✓ Branch 0 taken 6468 times.
✓ Branch 1 taken 65772 times.
72240 for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) {
81 65772 uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride;
82
83 65772 int32_t* sums = (int32_t*)(dst_row + nr * k_internal);
84
85 // Initialize to zero the RHS reduction sums
86 65772 memset(sums, 0, nr * sizeof(int32_t));
87
88
2/2
✓ Branch 0 taken 3700032 times.
✓ Branch 1 taken 65772 times.
3765804 for (size_t dst_offset = 0; dst_offset < dst_num_bytes_per_row; dst_offset += kr) {
89 3700032 const size_t block_idx = dst_offset / kr;
90 3700032 const size_t nr_idx = block_idx % nr;
91 3700032 const size_t super_block_idx = block_idx / nr;
92
93 3700032 const size_t k0_idx = super_block_idx * kr;
94 3700032 const size_t n0_idx = dst_row_idx * nr + nr_idx;
95
96 // Clamp the index to avoid out-of-bound reads
97
2/2
✓ Branch 0 taken 3193344 times.
✓ Branch 1 taken 506688 times.
3700032 const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1);
98
99 3700032 const size_t src_offset = n0_valid_idx * rhs_stride;
100
101 3700032 int32_t partial_sum = 0;
102
103 // Get the partial reduction sum
104
2/2
✓ Branch 0 taken 18127872 times.
✓ Branch 1 taken 3700032 times.
21827904 for (size_t i = 0; i < kr; i++) {
105 18127872 const size_t k0_valid_idx = k0_idx + i;
106 18127872 int8_t v = 0;
107
2/2
✓ Branch 0 taken 5308800 times.
✓ Branch 1 taken 12819072 times.
18127872 if (k0_valid_idx < k) {
108 12819072 v = rhs[src_offset + k0_valid_idx];
109 12819072 }
110 18127872 ((int8_t*)dst_row)[i] = v;
111 18127872 partial_sum += v;
112 18127872 }
113
114 3700032 sums[nr_idx] += partial_sum * lhs_zero_point;
115
116 3700032 dst_row += kr;
117 3700032 }
118
119 // Adjust the reduction sums
120
2/2
✓ Branch 0 taken 348768 times.
✓ Branch 1 taken 65772 times.
414540 for (size_t i = 0; i < nr; ++i) {
121 348768 dst_row += sizeof(int32_t);
122 348768 }
123
124 // Adjust the scales
125
2/2
✓ Branch 0 taken 348768 times.
✓ Branch 1 taken 65772 times.
414540 for (size_t i = 0; i < nr; ++i) {
126 // Clamp the row index to avoid out-of-bound reads
127
2/2
✓ Branch 0 taken 306600 times.
✓ Branch 1 taken 42168 times.
348768 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
128 348768 *((float*)(dst_row)) = scale[src_row_idx];
129 348768 dst_row += sizeof(float);
130 348768 }
131
132 // Set the bias
133
2/2
✓ Branch 0 taken 49476 times.
✓ Branch 1 taken 16296 times.
65772 if (bias == NULL) {
134 16296 memset(dst_row, 0, nr * kai_num_bytes_bias);
135 16296 } else {
136
2/2
✓ Branch 0 taken 283584 times.
✓ Branch 1 taken 49476 times.
333060 for (size_t i = 0; i < nr; ++i) {
137 // Clamp the row index to avoid out-of-bound reads
138
2/2
✓ Branch 0 taken 245028 times.
✓ Branch 1 taken 38556 times.
283584 const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1);
139 283584 ((float*)dst_row)[i] = bias[src_row_idx];
140 283584 }
141 }
142 65772 }
143 6468 }
144