KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.5% 66 13 80
Functions: 100.0% 16 0 16
Branches: 50.0% 1 26 28

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
7 #error This file must be compiled for AArch64, FEAT_SVE2.
8 #else // Architectural features check.
9
10 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16 typedef struct {
17 float* dst; // 0
18 const void* lhs_packed; // 0x8
19 const void* rhs_packed; // 0x10
20 size_t rhs_packed_stride; // 0x18
21 size_t n; // 0x20
22 size_t k; // 0x28
23 size_t bl; // 0x30
24 const int32_t* lut; // 0x38
25 float min; // 0x40
26 float max; // 0x44
27 } KernelArgs;
28
29 void kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(KernelArgs* args_ptr);
30 // Compute args
31 static const size_t kai_m_step = 1;
32 static const size_t kai_n_step = 4; // Multiple of vector length
33 // Packing args
34 static const size_t kai_mr = 1;
35 static const size_t kai_nr = 4; // Multiple of vector length
36 static const size_t kai_kr = 8;
37 static const size_t kai_sr = 2;
38 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
39 static const size_t kai_num_bytes_qvalue_lhs = 1;
40 static const size_t kai_num_bytes_multiplier_lhs = 4;
41 static const size_t kai_num_bytes_sum_lhs = 4;
42 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
43 // asymmetric))
44 static const size_t kai_num_bytes_recip_qvalue_rhs = 2;
45 static const size_t kai_num_bytes_multiplier_rhs = 4;
46 static const size_t kai_num_bytes_offset_rhs = 4;
47
48 // DST format args
49 static const size_t kai_num_bytes_dst_value = 4;
50 // Extra args
51 static const size_t kai_num_bytes_bias = 4;
52 static const size_t kai_bl = 32;
53
54 // Look-up table used for int4->int8 convert
55 static const int32_t lut[16] = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
56
57 276 inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) {
58 276 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs;
59 }
60
61 553 inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) {
62 KAI_ASSUME((bl % kai_bl) == 0);
63 1106 size_t num_bytes_per_block_rhs =
64 553 (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs;
65 1106 return num_bytes_per_block_rhs;
66 553 }
67
68 829 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
69 KAI_ASSUME((bl % kai_bl) == 0);
70
71 829 return kai_roundup(k, bl) / bl;
72 }
73
74 276 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) {
75 276 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
76 552 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl);
77 276 }
78
79 553 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) {
80 KAI_ASSUME(bl % kai_bl == 0);
81 KAI_ASSUME((k % kai_bl) == 0);
82
83 553 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
84 553 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
85 553 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
86
87 553 size_t rhs_packed_stride = nr * (num_bytes_per_block * num_blocks_per_row);
88 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
89 553 rhs_packed_stride += nr * kai_num_bytes_bias;
90
91 1106 return rhs_packed_stride;
92 553 }
93
94 1980 size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
95 1980 return kai_m_step;
96 }
97
98 1980 size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
99 1980 return kai_n_step * kai_get_sme_vector_length_u32();
100 }
101
102 2624 size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
103 2624 return kai_mr;
104 }
105
106 2901 size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
107 2901 return kai_nr * kai_get_sme_vector_length_u32();
108 }
109
110 2072 size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
111 2072 return kai_kr;
112 }
113
114 2072 size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(void) {
115 2072 return kai_sr;
116 }
117
118 276 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(
119 size_t m_idx, size_t k, size_t bl) {
120 276 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
121 276 const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
122
123 KAI_ASSUME((m_idx % m_step) == 0);
124
125 552 return (m_idx / mr) * kai_get_lhs_packed_stride(k, bl);
126 276 }
127
128 276 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(
129 size_t n_idx, size_t k, size_t bl) {
130 KAI_ASSUME((k % bl) == 0);
131 276 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
132 276 const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
133
134 KAI_ASSUME((n_idx % n_step) == 0);
135
136 552 return (n_idx / nr) * kai_get_rhs_packed_stride(k, bl);
137 276 }
138
139 276 size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(
140 size_t m_idx, size_t n_idx, size_t dst_stride) {
141 276 const size_t m_step = kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
142 276 const size_t n_step = kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot();
143 KAI_ASSUME((m_idx % m_step) == 0);
144 KAI_ASSUME((n_idx % n_step) == 0);
145
146 552 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
147 276 }
148
149 276 size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(size_t m, size_t n) {
150 276 return m * n * kai_num_bytes_dst_value;
151 }
152
153 277 void kai_run_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(
154 size_t m, //
155 size_t n, //
156 size_t k, //
157 size_t bl, //
158 const void* restrict lhs_packed, //
159 const void* restrict rhs_packed, //
160 float* restrict dst, // NOLINT(readability-non-const-parameter)
161 size_t dst_stride_row, //
162 size_t dst_stride_col, //
163 float scalar_min, //
164 float scalar_max) {
165 KAI_ASSUME(dst_stride_col == sizeof(float));
166 KAI_ASSUME((k % bl) == 0);
167 KAI_ASSUME((bl % kai_bl) == 0);
168 KAI_ASSUME(m == 1);
169
170 277 KAI_UNUSED(dst_stride_row);
171
172
1/2
✓ Branch 0 taken 277 times.
✗ Branch 1 not taken.
277 if (m == 0) {
173 return;
174 }
175
176 277 KernelArgs args;
177 277 args.dst = dst;
178 277 args.lhs_packed = lhs_packed;
179 277 args.rhs_packed = rhs_packed;
180 277 args.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl);
181 277 args.n = n;
182 277 args.k = k;
183 277 args.bl = bl;
184 277 args.lut = lut;
185 277 args.min = scalar_min;
186 277 args.max = scalar_max;
187
188 277 kai_kernel_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot(&args);
189 277 }
190
191 #endif // Architectural features check.
192