KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 40 / 5 / 45
Functions: 100.0% 11 / 0 / 11
Branches: -% 0 / 10 / 10

kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 typedef struct {
19 float maxval;
20 float minval;
21 const void* A_ptr;
22 const void* B_ptr;
23 size_t N;
24 size_t K;
25 void* output_ptr;
26 uint64_t flags;
27 } KernelArgs;
28
29 static const size_t kai_m_step = 1;
30 static const size_t kai_nr = 16;
31 static const size_t kai_n_step = 16;
32 static const size_t kai_kr = 1;
33 static const size_t kai_sr = 1;
34
35 void kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(KernelArgs* args_ptr);
36
37 156 size_t kai_get_m_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) {
38 156 return kai_m_step;
39 }
40
41 780 size_t kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) {
42 780 return kai_n_step * kai_get_sme_vector_length_u32() / kai_kr;
43 }
44
45 156 size_t kai_get_nr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) {
46 156 return kai_nr * kai_get_sme_vector_length_u32() / kai_kr;
47 }
48
49 156 size_t kai_get_kr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) {
50 156 return kai_kr;
51 }
52
53 156 size_t kai_get_sr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(void) {
54 156 return kai_sr;
55 }
56
57 156 size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t m_idx, size_t k) {
58 KAI_ASSUME(m_idx == 0);
59
60 156 return m_idx * k;
61 }
62
63 156 static size_t kai_get_rhs_packed_stride_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t k) {
64 312 return kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla() *
65 156 (kai_roundup(k, kai_kr) * sizeof(float) + sizeof(float));
66 }
67
68 156 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t n_idx, size_t k) {
69 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla() == 0);
70
71 156 const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla();
72 312 return block_idx * kai_get_rhs_packed_stride_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(k);
73 156 }
74
75 156 size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(
76 size_t m_idx, size_t n_idx, size_t dst_stride) {
77 KAI_ASSUME(m_idx == 0);
78 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla() == 0);
79
80 156 return (m_idx * dst_stride) + (n_idx * sizeof(float));
81 }
82
83 156 size_t kai_get_dst_size_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(size_t m, size_t n) {
84 156 return m * n * sizeof(float);
85 }
86
87 158 void kai_run_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(
88 size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst,
89 size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) {
90 158 KAI_UNUSED(dst_stride_row);
91 158 KAI_UNUSED(dst_stride_col);
92 158 KAI_UNUSED(lhs_stride);
93 KAI_ASSUME(m == 1);
94
95 158 uint64_t flags = 2;
96
97 158 KernelArgs args;
98
99 158 args.maxval = clamp_max;
100 158 args.minval = clamp_min;
101 158 args.A_ptr = lhs;
102 158 args.B_ptr = rhs_packed;
103 158 args.N = n;
104 158 args.K = k;
105 158 args.output_ptr = dst;
106 158 args.flags = flags;
107
108 158 kai_commit_za();
109
110 158 kai_kernel_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla(&args);
111 158 }
112
113 #endif // Architectural features check.
114