KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 87.0% 60 / 33 / 102
Functions: 87.5% 7 / 0 / 8
Branches: 85.0% 17 / 64 / 84

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_multiplier = sizeof(uint16_t);
15 static const size_t kai_bl = 32;
16
17 5520 inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
18 KAI_ASSUME((k % 2) == 0);
19 KAI_ASSUME(bl == kai_bl);
20 5520 return kai_roundup(k, bl) / bl;
21 }
22
23 4644 inline static size_t kai_num_bytes_per_block(size_t bl) {
24 KAI_ASSUME(bl == kai_bl);
25 4644 return (bl / 2) + kai_num_bytes_multiplier;
26 }
27
28 876 inline static size_t kai_rhs_stride(size_t k, size_t bl) {
29 KAI_ASSUME(bl == kai_bl);
30 KAI_ASSUME((k % 2) == 0);
31 KAI_ASSUME((k % bl) == 0);
32
33 876 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
34 876 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
35
36 1752 return num_bytes_per_block * num_blocks_per_row;
37 876 }
38
39 3768 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(size_t k, size_t nr, size_t kr, size_t bl) {
40 KAI_ASSUME(bl == kai_bl);
41 KAI_ASSUME((k % 2) == 0);
42 KAI_ASSUME((k % kr) == 0);
43 KAI_ASSUME((k % bl) == 0);
44
45 3768 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
46 3768 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
47
48 7536 return nr * (num_bytes_per_block * num_blocks_per_row);
49 3768 }
50
51 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(size_t n_idx, size_t rhs_stride) {
52 return n_idx * rhs_stride;
53 }
54
55 2016 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(
56 size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) {
57 KAI_ASSUME(bl == kai_bl);
58 KAI_ASSUME((k % 2) == 0);
59 KAI_ASSUME((k % kr) == 0);
60 KAI_ASSUME((k % bl) == 0);
61 KAI_ASSUME((n_idx % nr) == 0);
62
63 2016 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl);
64 }
65
66 876 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(
67 size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
68 KAI_ASSUME(bl == kai_bl);
69 KAI_ASSUME((k % 2) == 0);
70 KAI_ASSUME((k % kr) == 0);
71 KAI_ASSUME((k % bl) == 0);
72
73 876 const size_t num_rows = kai_roundup(n, nr) / nr;
74
75 1752 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl);
76 876 }
77
78 876 void kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(
79 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
80 const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params) {
81 KAI_ASSUME(bl == kai_bl);
82 KAI_ASSUME(num_groups == 1);
83 KAI_ASSUME((k % 2) == 0);
84 KAI_ASSUME((k % kr) == 0);
85 KAI_ASSUME((k % bl) == 0);
86 KAI_ASSUME(bias == NULL);
87 KAI_ASSUME(extra_bytes == 0);
88
89 KAI_ASSUME(sr == 2);
90 KAI_ASSUME(kr >= 1 && kr <= 16);
91 KAI_ASSUME(rhs != NULL);
92 KAI_ASSUME(rhs_packed != NULL);
93 KAI_ASSUME(params != NULL);
94 KAI_ASSUME(params->rhs_zero_point == 8);
95 KAI_ASSUME(params->lhs_zero_point == 1);
96
97 // Note: The input matrix (rhs) is expected with:
98 // "k" columns and "n" rows (NxK)
99
100 876 const size_t rhs_stride = kai_rhs_stride(k, bl);
101 1752 const size_t rhs_packed_stride =
102 876 kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl);
103 876 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
104 876 const size_t num_segments_per_block = bl / kr;
105 876 const size_t num_bytes_per_segment = kr / 2;
106
107
2/2
✓ Branch 0 taken 876 times.
✓ Branch 1 taken 8308 times.
9184 for (size_t y = 0; y < n; y += nr) {
108 8308 const uint8_t* src_row = rhs;
109 8308 uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride;
110
111
2/2
✓ Branch 0 taken 14076 times.
✓ Branch 1 taken 8308 times.
22384 for (size_t x = 0; x < num_blocks_per_row; ++x) {
112 // Store the scales at the end of the block
113 14076 uint8_t* scales = (dst_row);
114
115
2/2
✓ Branch 0 taken 61088 times.
✓ Branch 1 taken 14076 times.
75164 for (size_t i = 0; i < nr; ++i) {
116
2/2
✓ Branch 0 taken 58104 times.
✓ Branch 1 taken 2984 times.
61088 const size_t src_row_idx = KAI_MIN(y + i, n - 1);
117 61088 memcpy(
118 35328 scales + i * kai_num_bytes_multiplier, src_row + src_row_idx * rhs_stride,
119 kai_num_bytes_multiplier);
120 61088 }
121 14076 src_row += kai_num_bytes_multiplier;
122
123 14076 dst_row += (kai_num_bytes_multiplier * nr);
124
125 // Store the segments
126
2/2
✓ Branch 0 taken 30976 times.
✓ Branch 1 taken 14076 times.
45052 for (size_t s = 0; s < num_segments_per_block; ++s) {
127
2/2
✓ Branch 0 taken 134400 times.
✓ Branch 1 taken 30976 times.
165376 for (size_t i = 0; i < nr; ++i) {
128
2/2
✓ Branch 0 taken 127512 times.
✓ Branch 1 taken 6888 times.
134400 const size_t src_row_idx = KAI_MIN(y + i, n - 1);
129
130
2/2
✓ Branch 0 taken 24448 times.
✓ Branch 1 taken 109952 times.
134400 if (num_bytes_per_segment == sizeof(uint32_t)) {
131 24448 uint32_t tmp = 0;
132 24448 memcpy(&tmp, src_row + src_row_idx * rhs_stride, num_bytes_per_segment);
133 24448 tmp = tmp ^ 0x88888888;
134 24448 memcpy(dst_row + i * num_bytes_per_segment, &tmp, num_bytes_per_segment);
135
1/2
✓ Branch 0 taken 109952 times.
✗ Branch 1 not taken.
134400 } else if (num_bytes_per_segment == sizeof(uint64_t)) {
136 109952 uint64_t tmp = 0;
137 109952 memcpy(&tmp, src_row + src_row_idx * rhs_stride, num_bytes_per_segment);
138 109952 tmp = tmp ^ 0x8888888888888888ULL;
139 109952 memcpy(dst_row + i * num_bytes_per_segment, &tmp, num_bytes_per_segment);
140 109952 } else {
141 memcpy(
142 dst_row + i * num_bytes_per_segment, src_row + src_row_idx * rhs_stride,
143 num_bytes_per_segment);
144
145 for (size_t b = 0; b < num_bytes_per_segment; ++b) {
146 uint8_t qs = dst_row[i * num_bytes_per_segment + b];
147 // Add offset (0x88)
148 dst_row[i * num_bytes_per_segment + b] = qs ^ 0x88;
149 }
150 }
151 134400 }
152
153 30976 src_row += num_bytes_per_segment;
154 30976 dst_row += num_bytes_per_segment * nr;
155 30976 }
156 14076 }
157 8308 }
158 876 }
159