KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.5% 65 7 73
Functions: 100.0% 14 0 14
Branches: 50.0% 1 14 16

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error "This file must be compiled for AArch64, FEAT_SVE2"
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
12
13 #include <stddef.h>
14
15 #include "kai/kai_common.h"
16
17 typedef struct {
18 float* dst; // 0
19 const void* lhs_packed; // 0x8
20 const void* rhs_packed; // 0x10
21 size_t dst_stride_row; // 0x18
22 size_t m; // 0x20
23 size_t n; // 0x28
24 size_t k; // 0x30
25 size_t k_internal; // 0x38
26 size_t lhs_stride; // 0x40
27 size_t rhs_stride; // 0x48
28 size_t rhs_row_bytes; // 0x50
29 size_t lhs_end; // 0x58
30 float clamp_min; // 0x60
31 float clamp_max; // 0x64
32 } KernelArgs;
33
34 void kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(KernelArgs* args_ptr);
35
36 // Compute args
37 static const size_t kai_m_step = 1;
38 static const size_t kai_n_step = 4; // multiple of vector length
39 // Packing args
40 static const size_t kai_mr = 1;
41 static const size_t kai_nr = 4; // multiple of vector length
42 static const size_t kai_kr = 4;
43 static const size_t kai_sr = 1;
44 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
45 static const size_t kai_num_bytes_qvalue_lhs = 1;
46 static const size_t kai_num_bytes_multiplier_lhs = 4;
47 static const size_t kai_num_bytes_zp_lhs = 4;
48 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
49 // asymmetric))
50 static const size_t kai_num_bytes_qvalue_rhs = 1;
51 static const size_t kai_num_bytes_multiplier_rhs = 4;
52 static const size_t kai_num_bytes_rsum_rhs = 4;
53 // DST format args
54 static const size_t kai_num_bytes_dst_value = 4;
55 // Extra args
56 static const size_t kai_num_bytes_bias = 4;
57 static const size_t kai_k_multiple_of = 32;
58
59 827 inline static size_t kai_k_roundedup(size_t k) {
60 // Round up k to be a multiple of 32.
61 827 return kai_roundup(k, kai_k_multiple_of);
62 }
63
64 346 inline static size_t kai_get_lhs_packed_stride(size_t k) {
65 346 const size_t k_internal = kai_k_roundedup(k);
66 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
67 346 const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot();
68 346 size_t lhs_packed_stride = mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs);
69 // Since the LHS matrix is asymmetric with per-row quantization, we must include the
70 // the number of bytes to hold the zero point value
71 346 lhs_packed_stride += mr * kai_num_bytes_zp_lhs;
72
73 692 return lhs_packed_stride;
74 346 }
75
76 346 inline static size_t kai_get_rhs_packed_stride(size_t k) {
77 346 const size_t k_internal = kai_k_roundedup(k);
78 KAI_ASSERT((k_internal % kai_k_multiple_of) == 0);
79
80 346 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot();
81
82 346 size_t rhs_packed_stride = nr * (k_internal * kai_num_bytes_qvalue_rhs);
83 346 rhs_packed_stride += nr * kai_num_bytes_multiplier_rhs;
84 // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include
85 // the number of bytes for the reduction sum
86 346 rhs_packed_stride += nr * kai_num_bytes_rsum_rhs;
87 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
88 346 rhs_packed_stride += nr * kai_num_bytes_bias;
89
90 692 return rhs_packed_stride;
91 346 }
92
93 576 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) {
94 576 return kai_m_step;
95 }
96
97 576 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) {
98 576 return kai_n_step * kai_get_sme_vector_length_u8() / kai_kr;
99 }
100
101 577 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) {
102 577 return kai_mr;
103 }
104
105 923 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) {
106 923 return kai_nr * kai_get_sme_vector_length_u8() / kai_kr;
107 }
108
109 308 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) {
110 308 return kai_kr;
111 }
112
113 308 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(void) {
114 308 return kai_sr;
115 }
116
117 211 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(size_t m_idx, size_t k) {
118 KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot()) == 0);
119
120 211 return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k);
121 }
122
123 211 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(size_t n_idx, size_t k) {
124 KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot()) == 0);
125 211 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot();
126 422 return (n_idx / nr) * kai_get_rhs_packed_stride(k);
127 211 }
128
129 134 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(
130 size_t m_idx, size_t n_idx, size_t dst_stride) {
131 KAI_ASSUME((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot()) == 0);
132 KAI_ASSUME((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot()) == 0);
133
134 134 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
135 }
136
137 134 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(size_t m, size_t n) {
138 134 return m * n * kai_num_bytes_dst_value;
139 }
140
141 135 void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(
142 size_t m, //
143 size_t n, //
144 size_t k, //
145 const void* restrict lhs_packed, //
146 const void* restrict rhs_packed, //
147 float* restrict dst, // NOLINT(readability-non-const-parameter)
148 size_t dst_stride_row, //
149 size_t dst_stride_col, //
150 float scalar_min, //
151 float scalar_max) {
152 KAI_ASSUME(dst_stride_col == sizeof(float));
153
154
1/2
✓ Branch 0 taken 135 times.
✗ Branch 1 not taken.
135 if (m == 0) {
155 return;
156 }
157
158 135 const size_t k_internal = kai_k_roundedup(k);
159 135 const size_t lhs_stride = kai_get_lhs_packed_stride(k);
160 135 const size_t rhs_stride = kai_get_rhs_packed_stride(k);
161 135 const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot();
162
163 135 const size_t rhs_row_bytes = nr * k_internal;
164 135 const size_t lhs_end_ptr = ((size_t)lhs_packed) + (m * lhs_stride);
165
166 135 KernelArgs args;
167
168 135 args.dst = dst;
169 135 args.lhs_packed = lhs_packed;
170 135 args.rhs_packed = rhs_packed;
171 135 args.clamp_max = scalar_max;
172 135 args.clamp_min = scalar_min;
173 135 args.dst_stride_row = dst_stride_row;
174 135 args.m = m;
175 135 args.n = n;
176 135 args.k = k;
177 135 args.k_internal = k_internal;
178 135 args.lhs_stride = lhs_stride;
179 135 args.rhs_stride = rhs_stride;
180 135 args.rhs_row_bytes = rhs_row_bytes;
181 135 args.lhs_end = lhs_end_ptr;
182
183 135 kai_kernel_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot(&args);
184 135 }
185
186 #endif // Architectural features check.
187