KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 98.5% 67 / 13 / 81
Functions: 100.0% 16 / 0 / 16
Branches: 50.0% 1 / 26 / 28

kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
7 #error This file must be compiled for AArch64, FEAT_SVE2.
8 #else // Architectural features check.
9
10 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16 typedef struct {
17 float* dst; // 0
18 const void* lhs_packed; // 0x8
19 const void* rhs_packed; // 0x10
20 size_t rhs_packed_stride; // 0x18
21 size_t n; // 0x20
22 size_t k; // 0x28
23 size_t bl; // 0x30
24 const int32_t* lut; // 0x38
25 float min; // 0x40
26 float max; // 0x44
27 } KernelArgs;
28
29 void kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(KernelArgs* args_ptr);
30 // Compute args
31 static const size_t kai_m_step = 1;
32 static const size_t kai_n_step = 4; // Multiple of vector length
33 // Packing args
34 static const size_t kai_mr = 1;
35 static const size_t kai_nr = 4; // Multiple of vector length
36 static const size_t kai_kr = 8;
37 static const size_t kai_sr = 2;
38 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
39 static const size_t kai_num_bytes_qvalue_lhs = 1;
40 static const size_t kai_num_bytes_multiplier_lhs = 4;
41 static const size_t kai_num_bytes_sum_lhs = 4;
42 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
43 // asymmetric))
44 static const size_t kai_num_bytes_recip_qvalue_rhs = 2;
45 static const size_t kai_num_bytes_multiplier_rhs = 4;
46 static const size_t kai_num_bytes_offset_rhs = 4;
47
48 // DST format args
49 static const size_t kai_num_bytes_dst_value = 4;
50 // Extra args
51 static const size_t kai_num_bytes_bias = 4;
52 static const size_t kai_bl = 32;
53
54 // Look-up table used for int4->int8 convert
55 static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
56
57 1176 inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) {
58 1176 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs;
59 }
60
61 2354 inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) {
62 KAI_ASSUME((bl % kai_bl) == 0);
63 4708 size_t num_bytes_per_block_rhs =
64 2354 (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs;
65 4708 return num_bytes_per_block_rhs;
66 2354 }
67
68 3530 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
69 KAI_ASSUME((bl % kai_bl) == 0);
70
71 3530 return kai_roundup(k, bl) / bl;
72 }
73
74 1176 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) {
75 1176 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
76 2352 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl);
77 1176 }
78
79 2354 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) {
80 KAI_ASSUME(bl % kai_bl == 0);
81 KAI_ASSUME((k % kai_bl) == 0);
82
83 2354 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
84 2354 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
85 2354 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
86
87 2354 size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row);
88 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
89 2354 rhs_packed_stride += nr * kai_num_bytes_bias;
90
91 4708 return rhs_packed_stride;
92 2354 }
93
94 6636 size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
95 6636 return kai_m_step;
96 }
97
98 6636 size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
99 6636 return kai_n_step * kai_get_sme_vector_length_u32();
100 }
101
102 8568 size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
103 8568 return kai_mr;
104 }
105
106 9746 size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
107 9746 return kai_nr * kai_get_sme_vector_length_u32();
108 }
109
110 6216 size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
111 6216 return kai_kr;
112 }
113
114 6216 size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
115 6216 return kai_sr;
116 }
117
118 1176 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(
119 size_t m_idx, size_t k, size_t bl) {
120 1176 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
121 1176 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
122
123 KAI_ASSUME((m_idx % m_step) == 0);
124
125 2352 return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl);
126 1176 }
127
128 1176 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(
129 size_t n_idx, size_t k, size_t bl) {
130 KAI_ASSUME((k % bl) == 0);
131 1176 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
132 1176 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
133
134 KAI_ASSUME((n_idx % n_step) == 0);
135
136 2352 return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl);
137 1176 }
138
139 1176 size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(
140 size_t m_idx, size_t n_idx, size_t dst_stride) {
141 1176 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
142 1176 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
143 KAI_ASSUME((m_idx % m_step) == 0);
144 KAI_ASSUME((n_idx % n_step) == 0);
145
146 2352 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
147 1176 }
148
149 1176 size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(size_t m, size_t n) {
150 1176 return m * n * kai_num_bytes_dst_value;
151 }
152
153 1178 void kai_run_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(
154 size_t m, //
155 size_t n, //
156 size_t k, //
157 size_t bl, //
158 const void* restrict lhs_packed, //
159 const void* restrict rhs_packed, //
160 float* restrict dst, // NOLINT(readability-non-const-parameter)
161 size_t dst_stride_row, //
162 size_t dst_stride_col, //
163 float scalar_min, //
164 float scalar_max) {
165 KAI_ASSUME(dst_stride_col == sizeof(float));
166 KAI_ASSUME((k % bl) == 0);
167 KAI_ASSUME((bl % kai_bl) == 0);
168 KAI_ASSUME(m == 1);
169
170 1178 KAI_UNUSED(dst_stride_row);
171
172
1/2
✓ Branch 0 taken 1178 times.
✗ Branch 1 not taken.
1178 if (m == 0) {
173 return;
174 }
175
176 1178 KernelArgs args;
177 1178 args.dst = dst;
178 1178 args.lhs_packed = lhs_packed;
179 1178 args.rhs_packed = rhs_packed;
180 1178 args.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl);
181 1178 args.n = n;
182 1178 args.k = k;
183 1178 args.bl = bl;
184 1178 args.lut = lut;
185 1178 args.min = scalar_min;
186 1178 args.max = scalar_max;
187
188 1178 kai_commit_za();
189
190 1178 kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(&args);
191 1178 }
192
193 #endif // Architectural features check.
194