KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.6% 70 12 83
Functions: 100.0% 16 0 16
Branches: 50.0% 1 24 26

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
7 #error This file must be compiled for AArch64, FEAT_SVE2.
8 #else // Architectural features check.
9
10 #include "kai_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16 typedef struct {
17 float* dst; // 0
18 const void* lhs_packed; // 0x8
19 const void* rhs_packed; // 0x10
20 size_t dst_stride_row; // 0x18
21 size_t lhs_packed_stride; // 0x20
22 size_t rhs_packed_stride; // 0x28
23 size_t bias; // 0x30
24 size_t m; // 0x38
25 size_t n; // 0x40
26 size_t k; // 0x48
27 size_t bl; // 0x50
28 const int32_t* lut; // 0x58
29 float min; // 0x60
30 float max; // 0x64
31 } KernelArgs;
32
33 void kai_kernel_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(KernelArgs* args_ptr);
34 // Compute args
35 static const size_t kai_m_step = 1; // Multiple of vector length
36 static const size_t kai_n_step = 4; // Multiple of vector length
37 // Packing args
38 static const size_t kai_mr = 1; // Multiple of vector length
39 static const size_t kai_nr = 4; // Multiple of vector length
40 static const size_t kai_kr = 8;
41 static const size_t kai_sr = 2;
42 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
43 static const size_t kai_num_bytes_qvalue_lhs = 1;
44 static const size_t kai_num_bytes_multiplier_lhs = 4;
45 static const size_t kai_num_bytes_sum_lhs = 4;
46 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
47 // asymmetric))
48 static const size_t kai_num_bytes_recip_qvalue_rhs = 2;
49 static const size_t kai_num_bytes_multiplier_rhs = 4;
50 static const size_t kai_num_bytes_offset_rhs = 4;
51
52 // DST format args
53 static const size_t kai_num_bytes_dst_value = 2;
54 // Extra args
55 static const size_t kai_num_bytes_bias = 4;
56 static const size_t kai_bl = 32;
57
58 // Look-up table used for int4->int8 convert
59 static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
60
61 2065 inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) {
62 2065 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs;
63 }
64
65 2065 inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) {
66 KAI_ASSUME((bl % kai_bl) == 0);
67 4130 size_t num_bytes_per_block_rhs =
68 2065 (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs;
69 4130 return num_bytes_per_block_rhs;
70 2065 }
71
72 4130 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
73 KAI_ASSUME((bl % kai_bl) == 0);
74
75 4130 return kai_roundup(k, bl) / bl;
76 }
77
78 2065 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) {
79 2065 const size_t mr = kai_get_mr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa();
80 4130 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl);
81 2065 }
82
83 2065 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) {
84 KAI_ASSUME(bl % kai_bl == 0);
85 KAI_ASSUME((k % kai_bl) == 0);
86
87 2065 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
88 2065 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
89 2065 const size_t nr = kai_get_nr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa();
90
91 2065 size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row);
92 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
93 2065 rhs_packed_stride += nr * kai_num_bytes_bias;
94
95 4130 return rhs_packed_stride;
96 2065 }
97
98 4136 size_t kai_get_m_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
99 4136 return kai_m_step * kai_get_sme_vector_length_u32();
100 }
101
102 4136 size_t kai_get_n_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
103 4136 return kai_n_step * kai_get_sme_vector_length_u32();
104 }
105
106 5169 size_t kai_get_mr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
107 5169 return kai_mr * kai_get_sme_vector_length_u32();
108 }
109
110 6202 size_t kai_get_nr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
111 6202 return kai_nr * kai_get_sme_vector_length_u32();
112 }
113
114 2072 size_t kai_get_kr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
115 2072 return kai_kr;
116 }
117
118 2072 size_t kai_get_sr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(void) {
119 2072 return kai_sr;
120 }
121
122 1032 size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(
123 size_t m_idx, size_t k, size_t bl) {
124 1032 const size_t m_step = kai_get_m_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa();
125 1032 const size_t mr = kai_get_mr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa();
126
127 KAI_ASSUME((m_idx % m_step) == 0);
128
129 2064 return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl);
130 1032 }
131
132 1032 size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(
133 size_t n_idx, size_t k, size_t bl) {
134 KAI_ASSUME((k % bl) == 0);
135 1032 const size_t n_step = kai_get_n_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa();
136 1032 const size_t nr = kai_get_nr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa();
137
138 KAI_ASSUME((n_idx % n_step) == 0);
139
140 2064 return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl);
141 1032 }
142
143 1032 size_t kai_get_dst_offset_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(
144 size_t m_idx, size_t n_idx, size_t dst_stride) {
145 1032 const size_t m_step = kai_get_m_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa();
146 1032 const size_t n_step = kai_get_n_step_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa();
147 KAI_ASSUME((m_idx % m_step) == 0);
148 KAI_ASSUME((n_idx % n_step) == 0);
149
150 2064 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
151 1032 }
152
153 1032 size_t kai_get_dst_size_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(size_t m, size_t n) {
154 1032 return m * n * kai_num_bytes_dst_value;
155 }
156
157 1033 void kai_run_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(
158 size_t m, //
159 size_t n, //
160 size_t k, //
161 size_t bl, //
162 const void* restrict lhs_packed, //
163 const void* restrict rhs_packed, //
164 void* restrict dst, // NOLINT(readability-non-const-parameter)
165 size_t dst_stride_row, //
166 size_t dst_stride_col, //
167 float scalar_min, //
168 float scalar_max) {
169 KAI_ASSUME(dst_stride_col == sizeof(uint16_t));
170 KAI_ASSUME((k % bl) == 0);
171 KAI_ASSUME((bl % kai_bl) == 0);
172
173
1/2
✓ Branch 0 taken 1033 times.
✗ Branch 1 not taken.
1033 if (m == 0) {
174 return;
175 }
176
177 1033 const size_t nr = kai_get_nr_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa();
178
179 1033 KernelArgs args;
180 1033 args.dst = dst;
181 1033 args.lhs_packed = lhs_packed;
182 1033 args.rhs_packed = rhs_packed;
183 1033 args.dst_stride_row = dst_stride_row;
184 1033 args.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl);
185 1033 args.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl);
186 1033 args.bias = args.rhs_packed_stride - nr * kai_num_bytes_bias;
187 1033 args.m = m;
188 1033 args.n = n;
189 1033 args.k = k;
190 1033 args.bl = bl;
191 1033 args.lut = lut;
192 1033 args.min = scalar_min;
193 1033 args.max = scalar_max;
194
195 1033 kai_kernel_matmul_clamp_f16_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa(&args);
196 1033 }
197
198 #endif // Architectural features check.
199