KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 100.0% 38 4 42
Functions: 100.0% 9 0 9
Branches: -% 0 8 8

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10 #include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16
17 typedef struct {
18 const void* A;
19 const void* B;
20 void* C;
21 uint64_t ldcb;
22 uint64_t M;
23 uint64_t N;
24 uint64_t K;
25 float min;
26 float max;
27 void* accumulator_buffer;
28 uint64_t flags;
29 } KernelArgs;
30
31 static const size_t kai_mr = 2;
32 static const size_t kai_nr = 2;
33 static const size_t kai_kr = 1;
34
35 void kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(KernelArgs* args);
36
37 // Returns a constant value specific to this kernel that's relative to vector length
38 34056 static size_t kai_get_kernel_vec_length_constant(void) {
39 34056 const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u32() / kai_kr;
40 68112 return kernel_vec_length_constant;
41 34056 }
42
43 11352 size_t kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(void) {
44 11352 return kai_mr * kai_get_kernel_vec_length_constant();
45 }
46
47 22704 size_t kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(void) {
48 22704 return kai_nr * kai_get_kernel_vec_length_constant();
49 }
50
51 5676 size_t kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
52 size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) {
53 KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa() == 0);
54 5676 return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(float);
55 }
56
57 5676 static size_t kai_get_rhs_packed_stride_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
58 size_t k_chunk_count, size_t k_chunk_length) {
59 11352 return kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa() *
60 5676 (sizeof(float) + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(float));
61 }
62
63 5676 size_t kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
64 size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) {
65 KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa() == 0);
66 5676 const size_t block_idx = n_idx / kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
67 17028 return block_idx *
68 5676 kai_get_rhs_packed_stride_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
69 5676 k_chunk_count, k_chunk_length);
70 5676 }
71
72 5676 size_t kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
73 size_t m_idx, size_t n_idx, size_t dst_stride_row) {
74 KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa() == 0);
75 KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa() == 0);
76
77 5676 return m_idx * dst_stride_row + n_idx * sizeof(float);
78 }
79
80 5676 size_t kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(size_t m, size_t n) {
81 5676 return m * n * sizeof(float);
82 }
83
84 5676 void kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
85 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed,
86 void* dst, size_t dst_stride_row, float clamp_min, float clamp_max) {
87 5676 KernelArgs args;
88
89 5676 args.A = lhs_packed;
90 5676 args.B = rhs_packed;
91 5676 args.C = dst;
92 5676 args.ldcb = dst_stride_row;
93 5676 args.M = m;
94 5676 args.N = n;
95 5676 args.K = k_chunk_count * kai_roundup(k_chunk_length, kai_kr);
96 5676 args.min = clamp_min;
97 5676 args.max = clamp_max;
98 5676 args.accumulator_buffer = NULL;
99 5676 args.flags = 0;
100
101 5676 kai_kernel_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(&args);
102 5676 }
103
104 #endif // Architectural features check.
105