KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 73 / 13 / 86
Functions: 100.0% 16 / 0 / 16
Branches: -% 0 / 26 / 26

kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 // Do not flag up inline assembly blocks
8
9 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
10 #error This file must be compiled for AArch64, FEAT_SVE2.
11 #else // Architectural features check
12
13 #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot.h"
14
15 #include <stddef.h>
16 #include <stdint.h>
17
18 #include "kai/kai_common.h"
19
20 #define KAI_LUT_NENTRIES 64
21
22 // Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110).
23 static const int8_t lut[KAI_LUT_NENTRIES] = {
24 // clang-format off
25 0, 0, 0, 0,
26 1, 0, 0, 0,
27 2, 0, 0, 0,
28 3, 0, 0, 0,
29 4, 0, 0, 0,
30 5, 0, 0, 0,
31 6, 0, 0, 0,
32 7, 0, 0, 0,
33 -8, 0, 0, 0,
34 -7, 0, 0, 0,
35 -6, 0, 0, 0,
36 -5, 0, 0, 0,
37 -4, 0, 0, 0,
38 -3, 0, 0, 0,
39 -2, 0, 0, 0,
40 -1, 0, 0, 0
41 // clang-format on
42 };
43
44 typedef struct {
45 float* dst; // 0 ( 0x00 )
46 size_t dst_stride_row; // 8 ( 0x08 )
47 const int8_t* lut; // 16 ( 0x10 )
48 size_t m; // 24 ( 0x18 )
49 size_t n; // 32 ( 0x20 )
50 size_t k; // 40 ( 0x28 )
51 const void* lhs_packed; // 48 ( 0x30 )
52 const void* rhs_packed; // 56 ( 0x38 )
53 float scalar_max; // 64 ( 0x40 )
54 float scalar_min; // 68 ( 0x44 )
55 size_t k_internal; // 72 ( 0x48 )
56 size_t lhs_stride; // 80 ( 0x50 )
57 size_t rhs_stride; // 88 ( 0x58 )
58 size_t nr; // 96 ( 0x60 )
59 size_t rhs_row_bytes; // 104 ( 0x68 )
60 size_t lhs_end_ptr; // 112 ( 0x70 )
61 size_t bl; // 120 ( 0x78 )
62 } KernelArgs;
63
64 void kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(KernelArgs* args_ptr);
65
66 // Compute args
67 static const size_t kai_m_step = 1;
68 static const size_t kai_n_step = 4; // multiple of vector length
69 // Packing args
70 static const size_t kai_mr = 1;
71 static const size_t kai_nr = 4; // multiple of vector length
72 static const size_t kai_kr = 8;
73 static const size_t kai_sr = 2;
74 // LHS format args
75 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
76 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
77 // RHS format args
78 static const size_t kai_num_bytes_recip_qvalue_rhs = 2; // int4: 2 values per byte
79 static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); // BF16 scale per column
80 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); // rsum per column
81 static const size_t kai_num_bytes_bias_rhs = sizeof(float); // bias per column
82 // DST format args
83 static const size_t kai_num_bytes_dst_value = 4;
84 // Extra args
85 static const size_t kai_k_multiple_of = 32;
86 static const size_t kai_bl = 32;
87
88 166 static size_t kai_k_roundedup(const size_t k) {
89 // Round up k to be a multiple of 32.
90 166 return kai_roundup(k, kai_k_multiple_of);
91 }
92
93 42 static size_t kai_get_lhs_packed_stride(const size_t k) {
94 42 const size_t k_internal = kai_k_roundedup(k);
95 126 return kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot() *
96 42 (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
97 42 }
98
99 124 static size_t kai_get_num_bytes_per_block_rhs(const size_t bl) {
100 KAI_ASSUME((bl % kai_bl) == 0);
101 124 const size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs;
102 248 return num_bytes_per_block_rhs;
103 124 }
104
105 22 static size_t kai_get_num_blocks_per_row(const size_t k, const size_t bl) {
106 KAI_ASSUME((bl % kai_bl) == 0);
107 22 return kai_roundup(k, bl) / bl;
108 }
109
110 102 static size_t kai_get_rhs_packed_stride(const size_t k, const size_t bl) {
111 KAI_ASSUME((bl % kai_bl) == 0);
112 102 const size_t k_internal = kai_k_roundedup(k);
113 102 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot();
114 102 const size_t num_blocks_per_row = kai_roundup(k_internal, bl) / bl;
115
116 // bytes_per_block: int4 packed weights (bl/2 bytes) + per-block scale bytes
117 102 const size_t bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
118 102 size_t rhs_packed_stride = nr * (num_blocks_per_row * bytes_per_block);
119 102 rhs_packed_stride += nr * kai_num_bytes_sum_rhs; // per-column rsum
120 102 rhs_packed_stride += nr * kai_num_bytes_bias_rhs; // per-column bias
121
122 204 return rhs_packed_stride;
123 102 }
124
125 320 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(void) {
126 320 return kai_m_step;
127 }
128
129 440 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(void) {
130 440 return kai_n_step * kai_get_sme_vector_length_u32();
131 }
132
133 142 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(void) {
134 // For gemv mr must be 1 to consecutively read the data
135 142 return kai_mr;
136 }
137
138 224 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(void) {
139 224 return kai_nr * kai_get_sme_vector_length_u32();
140 }
141
142 120 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(void) {
143 120 return kai_kr;
144 }
145
146 120 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(void) {
147 120 return kai_sr;
148 }
149
150 20 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(
151 const size_t m_idx, const size_t k) {
152 KAI_ASSUME((m_idx % kai_m_step) == 0);
153 20 return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k);
154 }
155
156 80 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(
157 const size_t n_idx, const size_t k, const size_t bl) {
158 80 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot();
159 KAI_ASSUME((n_idx % n_step) == 0);
160 80 const size_t row = n_idx / n_step;
161 160 return row * kai_get_rhs_packed_stride(k, bl);
162 80 }
163
164 100 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(
165 const size_t m_idx, const size_t n_idx, const size_t dst_stride) {
166 100 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot();
167 100 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot();
168 KAI_ASSERT((m_idx % m_step) == 0);
169 KAI_ASSERT((n_idx % n_step) == 0);
170 200 return (n_idx * kai_num_bytes_dst_value) + (m_idx * dst_stride);
171 100 }
172
173 20 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(const size_t m, const size_t n) {
174 20 return m * n * kai_num_bytes_dst_value;
175 }
176
177 22 void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(
178 const size_t m, //
179 const size_t n, //
180 const size_t k, //
181 const size_t bl, //
182 const void* restrict lhs_packed, //
183 const void* restrict rhs_packed, //
184 float* restrict dst, // NOLINT(readability-non-const-parameter)
185 const size_t dst_stride_row, //
186 const size_t dst_stride_col, //
187 const float scalar_min, //
188 const float scalar_max) {
189 KAI_ASSUME(dst_stride_col == sizeof(float));
190 KAI_ASSUME(n > 0);
191 KAI_ASSUME(m == 1);
192 KAI_ASSUME(k > 0);
193 KAI_ASSUME((bl % kai_k_multiple_of) == 0);
194 KAI_ASSUME((k % bl) == 0);
195
196 22 KernelArgs args;
197
198 22 args.dst = dst;
199 22 args.lhs_packed = lhs_packed;
200 22 args.rhs_packed = rhs_packed;
201 22 args.dst_stride_row = dst_stride_row;
202 22 args.m = m;
203 22 args.n = n;
204 22 args.k = k;
205 22 args.bl = bl;
206 22 args.k_internal = kai_k_roundedup(k);
207 22 args.lhs_stride = kai_get_lhs_packed_stride(k);
208 22 args.rhs_stride = kai_get_rhs_packed_stride(k, bl);
209 22 args.nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot();
210 22 const size_t bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
211 22 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
212 22 args.rhs_row_bytes = args.nr * num_blocks_per_row * bytes_per_block;
213 22 args.lhs_end_ptr = ((uint64_t)lhs_packed) + (m * args.lhs_stride);
214 22 args.scalar_max = scalar_max;
215 22 args.scalar_min = scalar_min;
216 22 args.lut = lut;
217
218 22 kai_commit_za();
219
220 22 kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot(&args);
221 22 }
222 #endif // Architectural features check.
223