KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 65 / 14 / 79
Functions: 100.0% 16 / 0 / 16
Branches: -% 0 / 28 / 28

kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 #define KAI_LUT_NENTRIES 64
19
20 /// Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110).
21 static const int8_t lut[KAI_LUT_NENTRIES] = {
22 // clang-format off
23 0, 0, 0, 0,
24 1, 0, 0, 0,
25 2, 0, 0, 0,
26 3, 0, 0, 0,
27 4, 0, 0, 0,
28 5, 0, 0, 0,
29 6, 0, 0, 0,
30 7, 0, 0, 0,
31 -8, 0, 0, 0,
32 -7, 0, 0, 0,
33 -6, 0, 0, 0,
34 -5, 0, 0, 0,
35 -4, 0, 0, 0,
36 -3, 0, 0, 0,
37 -2, 0, 0, 0,
38 -1, 0, 0, 0
39 // clang-format on
40 };
41
42 typedef struct {
43 float* dst; // 0 (0x00)
44 const void* lhs_packed; // 8 (0x08)
45 const void* rhs_packed; // 16 (0x10)
46 size_t dst_stride_row; // 24 (0x18)
47 size_t lhs_stride; // 32 (0x20)
48 size_t rhs_stride; // 40 (0x28)
49 size_t m; // 48 (0x30)
50 size_t n; // 56 (0x38)
51 size_t k; // 64 (0x40)
52 size_t bl; // 72 (0x48)
53 const int8_t* lut; // 80 (0x50)
54 float scalar_max; // 88 (0x58)
55 float scalar_min; // 92 (0x5C)
56 } KernelArgs;
57
58 extern void kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(KernelArgs* args_ptr);
59
60 // Compute args
61 static const size_t kai_m_step = 1; // multiple of vector length
62 static const size_t kai_n_step = 4; // multiple of vector length
63 // Packing args
64 static const size_t kai_mr = 1; // multiple of vector length
65 static const size_t kai_nr = 4; // multiple of vector length
66 static const size_t kai_kr = 8;
67 static const size_t kai_sr = 2;
68 // LHS format args
69 static const size_t kai_num_bytes_qvalue_lhs = 1;
70 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
71 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
72 // RHS format args
73 static const size_t kai_num_bytes_recip_qvalue_rhs = 2;
74 static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t);
75 static const size_t kai_num_bytes_sum_rhs = sizeof(float);
76 static const size_t kai_num_bytes_bias_rhs = sizeof(float);
77 // DST format args
78 static const size_t kai_num_bytes_dst_value = 4;
79 // Extra args
80 static const size_t kai_k_multiple_of = 32;
81 static const size_t kai_bl = 32;
82
83 742 static size_t kai_k_roundedup(const size_t k) {
84 // Round up k to be a multiple of 32.
85 742 return kai_roundup(k, kai_k_multiple_of);
86 }
87
88 462 static size_t kai_get_num_bytes_per_block_rhs(const size_t bl) {
89 KAI_ASSUME((bl % kai_bl) == 0);
90 462 const size_t num_bytes_per_block_rhs = (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs;
91 924 return num_bytes_per_block_rhs;
92 462 }
93
94 462 static size_t kai_get_num_blocks_per_row(const size_t k, const size_t bl) {
95 KAI_ASSUME((bl % kai_bl) == 0);
96 462 return kai_roundup(k, bl) / bl;
97 }
98
99 186 static size_t kai_get_lhs_packed_stride(const size_t k) {
100 186 const size_t k_internal = kai_k_roundedup(k);
101 186 const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
102 372 return mr * (k_internal * kai_num_bytes_qvalue_lhs + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
103 186 }
104
105 462 static size_t kai_get_rhs_packed_stride(const size_t k, const size_t bl) {
106 KAI_ASSUME((bl % kai_bl) == 0);
107 KAI_ASSUME((k % kai_bl) == 0);
108
109 462 const size_t k_internal = kai_k_roundedup(k);
110 462 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k_internal, bl);
111 462 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
112 462 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
113
114 462 size_t rhs_packed_stride = nr * (num_blocks_per_row * num_bytes_per_block);
115 462 rhs_packed_stride += nr * kai_num_bytes_sum_rhs; // per-column rsum
116 462 rhs_packed_stride += nr * kai_num_bytes_bias_rhs; // per-column bias
117
118 924 return rhs_packed_stride;
119 462 }
120
121 1564 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
122 1564 return kai_m_step * kai_get_sme_vector_length_u32();
123 }
124
125 2024 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
126 2024 return kai_n_step * kai_get_sme_vector_length_u32();
127 }
128
129 738 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
130 738 return kai_mr * kai_get_sme_vector_length_u32();
131 }
132
133 1290 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
134 1290 return kai_nr * kai_get_sme_vector_length_u32();
135 }
136
137 552 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
138 552 return kai_kr;
139 }
140
141 552 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
142 552 return kai_sr;
143 }
144
145 92 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) {
146 KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa()) == 0);
147 92 const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
148 184 return (m_idx / mr) * kai_get_lhs_packed_stride(k);
149 92 }
150
151 368 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
152 const size_t n_idx, const size_t k, const size_t bl) {
153 KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa()) == 0);
154 368 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa();
155 736 return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl);
156 368 }
157
158 460 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
159 const size_t m_idx, const size_t n_idx, const size_t dst_stride) {
160 KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa()) == 0);
161 KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa()) == 0);
162 460 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
163 }
164
165 92 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(const size_t m, const size_t n) {
166 92 return m * n * kai_num_bytes_dst_value;
167 }
168
169 94 void kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(
170 const size_t m, //
171 const size_t n, //
172 const size_t k, //
173 const size_t bl, //
174 const void* restrict lhs_packed, //
175 const void* restrict rhs_packed, //
176 float* restrict dst, // NOLINT(readability-non-const-parameter)
177 const size_t dst_stride_row, //
178 const size_t dst_stride_col, //
179 const float scalar_min, //
180 const float scalar_max) {
181 KAI_ASSUME(dst_stride_col == sizeof(float));
182 KAI_ASSUME(m > 0);
183 KAI_ASSUME(n > 0);
184 KAI_ASSUME(k > 0);
185 KAI_ASSUME((bl % kai_k_multiple_of) == 0);
186 KAI_ASSUME((k % bl) == 0);
187
188 94 KernelArgs args;
189 94 args.dst = dst;
190 94 args.lhs_packed = lhs_packed;
191 94 args.rhs_packed = rhs_packed;
192 94 args.dst_stride_row = dst_stride_row;
193 94 args.m = m;
194 94 args.n = n;
195 94 args.k = kai_k_roundedup(k);
196 94 args.bl = bl;
197 94 args.lhs_stride = kai_get_lhs_packed_stride(k);
198 94 args.rhs_stride = kai_get_rhs_packed_stride(k, bl);
199 94 args.scalar_max = scalar_max;
200 94 args.scalar_min = scalar_min;
201 94 args.lut = lut;
202
203 94 kai_commit_za();
204
205 94 kai_kernel_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(&args);
206 94 }
207
208 #endif // Architectural feature check
209