KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 89.4% 59 33 99
Functions: 87.5% 7 0 8
Branches: 85.0% 17 66 86

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
7
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_multiplier = sizeof(uint16_t);
15 static const size_t kai_bl = 32;
16
17 628 inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
18 KAI_ASSUME((k % 2) == 0);
19 KAI_ASSUME(bl == kai_bl);
20 628 return kai_roundup(k, bl) / bl;
21 }
22
23 530 inline static size_t kai_num_bytes_per_block(size_t bl) {
24 KAI_ASSUME(bl == kai_bl);
25 530 return (bl / 2) + kai_num_bytes_multiplier;
26 }
27
28 98 inline static size_t kai_rhs_stride(size_t k, size_t bl) {
29 KAI_ASSUME(bl == kai_bl);
30 KAI_ASSUME((k % 2) == 0);
31 KAI_ASSUME((k % bl) == 0);
32
33 98 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
34 98 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
35
36 196 return num_bytes_per_block * num_blocks_per_row;
37 98 }
38
39 432 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(size_t k, size_t nr, size_t kr, size_t bl) {
40 KAI_ASSUME(bl == kai_bl);
41 KAI_ASSUME((k % 2) == 0);
42 KAI_ASSUME((k % kr) == 0);
43 KAI_ASSUME((k % bl) == 0);
44
45 432 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
46 432 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
47
48 864 return nr * (num_bytes_per_block * num_blocks_per_row);
49 432 }
50
51 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(size_t n_idx, size_t rhs_stride) {
52 return n_idx * rhs_stride;
53 }
54
55 236 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(
56 size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) {
57 KAI_ASSUME(bl == kai_bl);
58 KAI_ASSUME((k % 2) == 0);
59 KAI_ASSUME((k % kr) == 0);
60 KAI_ASSUME((k % bl) == 0);
61 KAI_ASSUME((n_idx % nr) == 0);
62
63 236 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl);
64 }
65
66 98 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(
67 size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
68 KAI_ASSUME(bl == kai_bl);
69 KAI_ASSUME((k % 2) == 0);
70 KAI_ASSUME((k % kr) == 0);
71 KAI_ASSUME((k % bl) == 0);
72
73 98 const size_t num_rows = kai_roundup(n, nr) / nr;
74
75 196 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl);
76 98 }
77
78 98 void kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(
79 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
80 const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params) {
81 KAI_ASSUME(bl == kai_bl);
82 KAI_ASSUME(num_groups == 1);
83 KAI_ASSUME((k % 2) == 0);
84 KAI_ASSUME((k % kr) == 0);
85 KAI_ASSUME((k % bl) == 0);
86 KAI_ASSUME(bias == NULL);
87 KAI_ASSUME(extra_bytes == 0);
88
89 KAI_ASSUME(sr == 2);
90 KAI_ASSUME(kr >= 1 && kr <= 16);
91 KAI_ASSUME(rhs != NULL);
92 KAI_ASSUME(rhs_packed != NULL);
93 KAI_ASSUME(params != NULL);
94 KAI_ASSUME(params->rhs_zero_point == 8);
95 KAI_ASSUME(params->lhs_zero_point == 1);
96
97 // Note: The input matrix (rhs) is expected with:
98 // "k" columns and "n" rows (NxK)
99
100 98 const size_t rhs_stride = kai_rhs_stride(k, bl);
101 196 const size_t rhs_packed_stride =
102 98 kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(k, nr, kr, bl);
103 98 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
104 98 const size_t num_segments_per_block = bl / kr;
105 98 const size_t num_bytes_per_segment = kr / 2;
106
107
2/2
✓ Branch 0 taken 98 times.
✓ Branch 1 taken 1058 times.
1156 for (size_t y = 0; y < n; y += nr) {
108 1058 const uint8_t* src_row = rhs;
109 1058 uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride;
110
111
2/2
✓ Branch 0 taken 1970 times.
✓ Branch 1 taken 1058 times.
3028 for (size_t x = 0; x < num_blocks_per_row; ++x) {
112 // Store the scales at the end of the block
113 1970 uint8_t* scales = (dst_row);
114
115
2/2
✓ Branch 0 taken 7880 times.
✓ Branch 1 taken 1970 times.
9850 for (size_t i = 0; i < nr; ++i) {
116
2/2
✓ Branch 0 taken 7650 times.
✓ Branch 1 taken 230 times.
7880 const size_t src_row_idx = KAI_MIN(y + i, n - 1);
117 7880 memcpy(
118 scales + i * kai_num_bytes_multiplier, src_row + src_row_idx * rhs_stride,
119 kai_num_bytes_multiplier);
120 7880 }
121 1970 src_row += kai_num_bytes_multiplier;
122
123 1970 dst_row += (kai_num_bytes_multiplier * nr);
124
125 // Store the segments
126
2/2
✓ Branch 0 taken 4928 times.
✓ Branch 1 taken 1970 times.
6898 for (size_t s = 0; s < num_segments_per_block; ++s) {
127
2/2
✓ Branch 0 taken 19712 times.
✓ Branch 1 taken 4928 times.
24640 for (size_t i = 0; i < nr; ++i) {
128
2/2
✓ Branch 0 taken 19128 times.
✓ Branch 1 taken 584 times.
19712 const size_t src_row_idx = KAI_MIN(y + i, n - 1);
129
130
2/2
✓ Branch 0 taken 7904 times.
✓ Branch 1 taken 11808 times.
19712 if (num_bytes_per_segment == sizeof(uint32_t)) {
131 7904 uint32_t tmp = 0;
132 7904 memcpy(&tmp, src_row + src_row_idx * rhs_stride, num_bytes_per_segment);
133 7904 tmp = tmp ^ 0x88888888;
134 7904 memcpy(dst_row + i * num_bytes_per_segment, &tmp, num_bytes_per_segment);
135
1/2
✓ Branch 0 taken 11808 times.
✗ Branch 1 not taken.
19712 } else if (num_bytes_per_segment == sizeof(uint64_t)) {
136 11808 uint64_t tmp = 0;
137 11808 memcpy(&tmp, src_row + src_row_idx * rhs_stride, num_bytes_per_segment);
138 11808 tmp = tmp ^ 0x8888888888888888ULL;
139 11808 memcpy(dst_row + i * num_bytes_per_segment, &tmp, num_bytes_per_segment);
140 11808 } else {
141 memcpy(
142 dst_row + i * num_bytes_per_segment, src_row + src_row_idx * rhs_stride,
143 num_bytes_per_segment);
144
145 for (size_t b = 0; b < num_bytes_per_segment; ++b) {
146 uint8_t qs = dst_row[i * num_bytes_per_segment + b];
147 // Add offset (0x88)
148 dst_row[i * num_bytes_per_segment + b] = qs ^ 0x88;
149 }
150 }
151 19712 }
152
153 4928 src_row += num_bytes_per_segment;
154 4928 dst_row += num_bytes_per_segment * nr;
155 4928 }
156 1970 }
157 1058 }
158 98 }
159