KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 46 / 4 / 50
Functions: 100.0% 13 / 0 / 13
Branches: -% 0 / 8 / 8

kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10 #include "kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16
17 typedef struct {
18 const void* A;
19 const void* B;
20 void* C;
21 uint64_t ldcb;
22 uint64_t M;
23 uint64_t N;
24 uint64_t K;
25 float min;
26 float max;
27 void* accumulator_buffer;
28 uint64_t flags;
29 } KernelArgs;
30
31 static const size_t kai_mr = 2;
32 static const size_t kai_nr = 2;
33 static const size_t kai_kr = 1;
34 static const size_t kai_sr = 1;
35
36 void kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(KernelArgs* args);
37
38 // Returns a constant value specific to this kernel that's relative to vector length
39 648 static size_t kai_get_kernel_vec_length_constant(void) {
40 648 const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u32() / kai_kr;
41 1296 return kernel_vec_length_constant;
42 648 }
43
44 216 size_t kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) {
45 216 return kai_mr * kai_get_kernel_vec_length_constant();
46 }
47
48 270 size_t kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) {
49 270 return kai_nr * kai_get_kernel_vec_length_constant();
50 }
51
52 54 size_t kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) {
53 54 return kai_mr * kai_get_kernel_vec_length_constant();
54 }
55
56 108 size_t kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) {
57 108 return kai_nr * kai_get_kernel_vec_length_constant();
58 }
59
60 162 size_t kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) {
61 162 return kai_kr;
62 }
63
64 162 size_t kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) {
65 162 return kai_sr;
66 }
67
68 108 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m_idx, size_t k) {
69 KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0);
70 108 return m_idx * kai_roundup(k, kai_kr) * sizeof(float);
71 }
72
73 54 static size_t kai_get_rhs_packed_stride_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t k) {
74 108 return kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() *
75 54 (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(float));
76 }
77
78 54 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t n_idx, size_t k) {
79 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0);
80 54 const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
81 108 return block_idx * kai_get_rhs_packed_stride_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(k);
82 54 }
83
84 54 size_t kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(
85 size_t m_idx, size_t n_idx, size_t dst_stride_row) {
86 KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0);
87 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0);
88
89 54 return m_idx * dst_stride_row + n_idx * sizeof(float);
90 }
91
92 54 size_t kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m, size_t n) {
93 54 return m * n * sizeof(float);
94 }
95
96 56 void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(
97 size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
98 size_t dst_stride_col, float clamp_min, float clamp_max) {
99 56 KAI_UNUSED(dst_stride_col);
100 56 KernelArgs args;
101
102 56 args.A = lhs_packed;
103 56 args.B = rhs_packed;
104 56 args.C = dst;
105 56 args.ldcb = dst_stride_row;
106 56 args.M = m;
107 56 args.N = n;
108 56 args.K = k;
109 56 args.min = clamp_min;
110 56 args.max = clamp_max;
111 56 args.accumulator_buffer = NULL;
112 56 args.flags = 0;
113
114 56 kai_commit_za();
115
116 56 kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(&args);
117 56 }
118
119 #endif // Architectural features check.
120