KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 96.6% 57 / 34 / 93
Functions: 88.9% 8 / 0 / 9
Branches: 100.0% 12 / 66 / 78

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10
11 #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_num_bytes_multiplier = sizeof(uint16_t);
19 static const size_t kai_bl = 32;
20
21 2544 static inline void convert_s1s0_s16s0(uint8_t* dst_blk, const uint8_t* src_blk) {
22 // First half
23
2/2
✓ Branch 0 taken 20352 times.
✓ Branch 1 taken 2544 times.
22896 for (size_t k = 0; k < kai_bl / 2; k += 2) {
24 20352 dst_blk[k / 2] = src_blk[k] & 0xF;
25 20352 dst_blk[k / 2] |= src_blk[k + 1] << 4;
26 20352 }
27
28 // Second half
29
2/2
✓ Branch 0 taken 2544 times.
✓ Branch 1 taken 20352 times.
22896 for (size_t k = kai_bl / 2; k < kai_bl; k += 2) {
30 20352 dst_blk[k / 2] = src_blk[k - kai_bl / 2] >> 4;
31 20352 dst_blk[k / 2] |= src_blk[k - kai_bl / 2 + 1] & 0xF0;
32 20352 }
33 2544 }
34
35 240 inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
36 KAI_ASSUME((k % 2) == 0);
37 KAI_ASSUME(bl == kai_bl);
38 240 return kai_roundup(k, bl) / bl;
39 }
40
41 284 inline static size_t kai_num_bytes_per_block(size_t bl) {
42 KAI_ASSUME(bl == kai_bl);
43
44 284 return (bl / 2) + kai_num_bytes_multiplier;
45 }
46
47 44 inline static size_t kai_rhs_stride(size_t k, size_t bl) {
48 KAI_ASSUME(bl == kai_bl);
49 KAI_ASSUME((k % 2) == 0);
50 KAI_ASSUME((k % bl) == 0);
51
52 44 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
53 44 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
54
55 88 return num_bytes_per_block * num_blocks_per_row;
56 44 }
57
58 196 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(
59 size_t k, size_t nr, size_t kr, size_t bl) {
60 KAI_ASSUME(bl == kai_bl);
61 KAI_ASSUME((k % 2) == 0);
62 KAI_ASSUME((k % kr) == 0);
63 KAI_ASSUME((k % bl) == 0);
64
65 196 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
66 196 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
67
68 392 return nr * (num_bytes_per_block * num_blocks_per_row);
69 196 }
70
71 size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(size_t n_idx, size_t rhs_stride) {
72 return n_idx * rhs_stride;
73 }
74
75 108 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(
76 size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) {
77 KAI_ASSUME(bl == kai_bl);
78 KAI_ASSUME((k % 2) == 0);
79 KAI_ASSUME((k % kr) == 0);
80 KAI_ASSUME((k % bl) == 0);
81 KAI_ASSUME((n_idx % nr) == 0);
82
83 // The scales are stored after all the nr packed quantized values
84 108 return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl);
85 }
86
87 44 size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(
88 size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
89 KAI_ASSUME(bl == kai_bl);
90 KAI_ASSUME((k % 2) == 0);
91 KAI_ASSUME((k % kr) == 0);
92 KAI_ASSUME((k % bl) == 0);
93
94 44 const size_t num_rows = kai_roundup(n, nr) / nr;
95
96 88 return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl);
97 44 }
98
99 44 void kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(
100 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
101 const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params) {
102 KAI_ASSUME(bl == kai_bl);
103 KAI_ASSUME(num_groups == 1);
104 KAI_ASSUME((k % 2) == 0);
105 KAI_ASSUME((k % kr) == 0);
106 KAI_ASSUME((k % bl) == 0);
107 KAI_ASSUME(bias == NULL);
108 KAI_ASSUME(extra_bytes == 0);
109
110 KAI_ASSUME(kr == 4);
111 KAI_ASSUME(sr == 2);
112 KAI_ASSUME(kr >= 1 && kr <= 16);
113 KAI_ASSUME(rhs != NULL);
114 KAI_ASSUME(rhs_packed != NULL);
115 KAI_ASSUME(params != NULL);
116 KAI_ASSUME(params->rhs_zero_point == 8);
117 KAI_ASSUME(params->lhs_zero_point == 1);
118
119 // Note: The input matrix (rhs) is expected with:
120 // "k" columns and "n" rows (NxK)
121
122 44 const size_t num_blocks = k / bl;
123 44 const size_t rhs_stride = kai_rhs_stride(k, bl);
124 88 const size_t rhs_packed_stride =
125 44 kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon(k, nr, kr, bl);
126 44 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
127
128 44 uint8_t* rhs_packed_ptr = rhs_packed;
129
130
2/2
✓ Branch 0 taken 44 times.
✓ Branch 1 taken 1636 times.
1680 for (uint64_t n_idx = 0; n_idx < n; n_idx++) {
131 3272 uint16_t* rhs_packed_scales =
132 1636 (uint16_t*)(rhs_packed_ptr + rhs_packed_stride - (nr * num_blocks * kai_num_bytes_multiplier));
133
134
2/2
✓ Branch 0 taken 2544 times.
✓ Branch 1 taken 1636 times.
4180 for (size_t block_idx = 0; block_idx < num_blocks; block_idx++) {
135 2544 uint8_t blk_s1s0[16];
136
137 5088 const uint16_t* blk_scale_ptr =
138 2544 (const uint16_t*)(rhs + (block_idx * num_bytes_per_block) + n_idx * rhs_stride);
139 2544 const uint8_t* blk_s16s0 = (const uint8_t*)blk_scale_ptr + kai_num_bytes_multiplier;
140
141 2544 convert_s1s0_s16s0(blk_s1s0, blk_s16s0);
142
143
2/2
✓ Branch 0 taken 20352 times.
✓ Branch 1 taken 2544 times.
22896 for (size_t bl4_idx = 0; bl4_idx < bl / 4; bl4_idx++) {
144 // Uint16 holds 4 int4 values
145 20352 ((uint16_t*)rhs_packed_ptr)[(block_idx * bl / 4 + bl4_idx) * nr + (n_idx % nr)] =
146 20352 ((int16_t*)blk_s1s0)[bl4_idx];
147 20352 }
148
149 // Num. block (rows) x Nr (cols)
150 2544 rhs_packed_scales[(n_idx % nr) + block_idx * nr] = *blk_scale_ptr;
151 2544 }
152
153
2/2
✓ Branch 0 taken 1628 times.
✓ Branch 1 taken 8 times.
1636 if (((n_idx + 1) % nr) == 0) {
154 8 rhs_packed_ptr += rhs_packed_stride;
155 8 }
156 1636 }
157 44 }
158 #endif // Architectural features check.
159