KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.6% 57 34 93
Functions: 88.9% 8 0 9
Branches: 100.0% 12 68 80

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10
11 #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_num_bytes_multiplier = sizeof(uint16_t);
19 static const size_t kai_bl = 32;
20
21 1956 static inline void convert_s1s0_s16s0(uint8_t* dst_blk, const uint8_t* src_blk) {
22 // First half
23
2/2
✓ Branch 0 taken 15648 times.
✓ Branch 1 taken 1956 times.
17604 for (size_t k = 0; k < kai_bl / 2; k += 2) {
24 15648 dst_blk[k / 2] = src_blk[k] & 0xF;
25 15648 dst_blk[k / 2] |= src_blk[k + 1] << 4;
26 15648 }
27
28 // Second half
29
2/2
✓ Branch 0 taken 1956 times.
✓ Branch 1 taken 15648 times.
17604 for (size_t k = kai_bl / 2; k < kai_bl; k += 2) {
30 15648 dst_blk[k / 2] = src_blk[k - kai_bl / 2] >> 4;
31 15648 dst_blk[k / 2] |= src_blk[k - kai_bl / 2 + 1] & 0xF0;
32 15648 }
33 1956 }
34
35 150 inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
36 KAI_ASSUME((k % 2) == 0);
37 KAI_ASSUME(bl == kai_bl);
38 150 return kai_roundup(k, bl) / bl;
39 }
40
41 176 inline static size_t kai_num_bytes_per_block(size_t bl) {
42 KAI_ASSUME(bl == kai_bl);
43
44 176 return (bl / 2) + kai_num_bytes_multiplier;
45 }
46
47 26 inline static size_t kai_rhs_stride(size_t k, size_t bl) {
48 KAI_ASSUME(bl == kai_bl);
49 KAI_ASSUME((k % 2) == 0);
50 KAI_ASSUME((k % bl) == 0);
51
52 26 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
53 26 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
54
55 52 return num_bytes_per_block * num_blocks_per_row;
56 26 }
57
58 124 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(
59 size_t k, size_t nr, size_t kr, size_t bl) {
60 KAI_ASSUME(bl == kai_bl);
61 KAI_ASSUME((k % 2) == 0);
62 KAI_ASSUME((k % kr) == 0);
63 KAI_ASSUME((k % bl) == 0);
64
65 124 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
66 124 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
67
68 248 return nr * (num_bytes_per_block * num_blocks_per_row);
69 124 }
70
71 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(size_t n_idx, size_t rhs_stride) {
72 return n_idx * rhs_stride;
73 }
74
75 72 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(
76 size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) {
77 KAI_ASSUME(bl == kai_bl);
78 KAI_ASSUME((k % 2) == 0);
79 KAI_ASSUME((k % kr) == 0);
80 KAI_ASSUME((k % bl) == 0);
81 KAI_ASSUME((n_idx % nr) == 0);
82
83 // The scales are stored after all the nr packed quantized values
84 72 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl);
85 }
86
87 26 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(
88 size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
89 KAI_ASSUME(bl == kai_bl);
90 KAI_ASSUME((k % 2) == 0);
91 KAI_ASSUME((k % kr) == 0);
92 KAI_ASSUME((k % bl) == 0);
93
94 26 const size_t num_rows = kai_roundup(n, nr) / nr;
95
96 52 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl);
97 26 }
98
99 26 void kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(
100 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
101 const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params) {
102 KAI_ASSUME(bl == kai_bl);
103 KAI_ASSUME(num_groups == 1);
104 KAI_ASSUME((k % 2) == 0);
105 KAI_ASSUME((k % kr) == 0);
106 KAI_ASSUME((k % bl) == 0);
107 KAI_ASSUME(bias == NULL);
108 KAI_ASSUME(extra_bytes == 0);
109
110 KAI_ASSUME(kr == 4);
111 KAI_ASSUME(sr == 2);
112 KAI_ASSUME(kr >= 1 && kr <= 16);
113 KAI_ASSUME(rhs != NULL);
114 KAI_ASSUME(rhs_packed != NULL);
115 KAI_ASSUME(params != NULL);
116 KAI_ASSUME(params->rhs_zero_point == 8);
117 KAI_ASSUME(params->lhs_zero_point == 1);
118
119 // Note: The input matrix (rhs) is expected with:
120 // "k" columns and "n" rows (NxK)
121
122 26 const size_t num_blocks = k / bl;
123 26 const size_t rhs_stride = kai_rhs_stride(k, bl);
124 52 const size_t rhs_packed_stride =
125 26 kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl);
126 26 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
127
128 26 uint8_t* rhs_packed_ptr = rhs_packed;
129
130
2/2
✓ Branch 0 taken 26 times.
✓ Branch 1 taken 1048 times.
1074 for (uint64_t n_idx = 0; n_idx < n; n_idx++) {
131 2096 uint16_t* rhs_packed_scales =
132 1048 (uint16_t*)(rhs_packed_ptr + rhs_packed_stride - (nr * num_blocks * kai_num_bytes_multiplier));
133
134
2/2
✓ Branch 0 taken 1956 times.
✓ Branch 1 taken 1048 times.
3004 for (size_t block_idx = 0; block_idx < num_blocks; block_idx++) {
135 1956 uint8_t blk_s1s0[16];
136
137 3912 const uint16_t* blk_scale_ptr =
138 1956 (const uint16_t*)(rhs + (block_idx * num_bytes_per_block) + n_idx * rhs_stride);
139 1956 const uint8_t* blk_s16s0 = (const uint8_t*)blk_scale_ptr + kai_num_bytes_multiplier;
140
141 1956 convert_s1s0_s16s0(blk_s1s0, blk_s16s0);
142
143
2/2
✓ Branch 0 taken 15648 times.
✓ Branch 1 taken 1956 times.
17604 for (size_t bl4_idx = 0; bl4_idx < bl / 4; bl4_idx++) {
144 // Uint16 holds 4 int4 values
145 15648 ((uint16_t*)rhs_packed_ptr)[(block_idx * bl / 4 + bl4_idx) * nr + (n_idx % nr)] =
146 15648 ((int16_t*)blk_s1s0)[bl4_idx];
147 15648 }
148
149 // Num. block (rows) x Nr (cols)
150 1956 rhs_packed_scales[(n_idx % nr) + block_idx * nr] = *blk_scale_ptr;
151 1956 }
152
153
2/2
✓ Branch 0 taken 1040 times.
✓ Branch 1 taken 8 times.
1048 if (((n_idx + 1) % nr) == 0) {
154 8 rhs_packed_ptr += rhs_packed_stride;
155 8 }
156 1048 }
157 26 }
158 #endif // Architectural features check.
159