KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 41 / 4 / 45
Functions: 100.0% 10 / 0 / 10
Branches: -% 0 / 8 / 8

kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #elif (!defined(__ARM_FEATURE_SVE) && !defined(_M_ARM64))
10 #error This file must be compiled for for AArch64, FEAT_SVE.
11 #else // Architectural features check.
12
13 #include "kai_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla.h"
14
15 #include <stddef.h>
16 #include <stdint.h>
17
18 #include "kai/kai_common.h"
19
20 typedef struct {
21 float maxval;
22 float minval;
23 unsigned int num_strings;
24 const unsigned int* string_lengths;
25 size_t N;
26 const void* B_ptr;
27 size_t output_offset;
28 size_t input_initial_col;
29 size_t input_offset;
30 void* output_ptr;
31 const void* bias;
32 } KernelArgs;
33
34 void kai_kernel_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(
35 const void* input_ptr, size_t m, KernelArgs* args_ptr, unsigned long flags);
36
37 static const size_t kai_nr = 4;
38 static const size_t kai_kr = 1;
39 static const size_t kai_sr = 1;
40
41 static const size_t kai_m_step = 1;
42
43 324 size_t kai_get_m_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(void) {
44 324 return kai_m_step;
45 }
46
47 540 size_t kai_get_n_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(void) {
48 540 return kai_nr * kai_get_sve_vector_length_u32() / kai_kr;
49 }
50
51 108 size_t kai_get_nr_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(void) {
52 108 return kai_nr * kai_get_sve_vector_length_u32() / kai_kr;
53 }
54
55 108 size_t kai_get_kr_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(void) {
56 108 return kai_kr;
57 }
58
59 108 size_t kai_get_sr_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(void) {
60 108 return kai_sr;
61 }
62
63 108 size_t kai_get_lhs_offset_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(size_t m_idx, size_t stride) {
64 KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla() == 0);
65
66 108 return m_idx * stride;
67 }
68
69 108 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(size_t n_idx, size_t k) {
70 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla() == 0);
71
72 108 const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla();
73 324 return block_idx * kai_get_n_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla() *
74 108 (kai_roundup(k, kai_kr) * sizeof(float) + sizeof(float));
75 108 }
76
77 108 size_t kai_get_dst_offset_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(size_t m_idx, size_t n_idx, size_t stride) {
78 KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla() == 0);
79 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla() == 0);
80
81 108 return m_idx * stride + n_idx * sizeof(float);
82 }
83
84 108 size_t kai_get_dst_size_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(size_t m, size_t n) {
85 108 return m * n * sizeof(float);
86 }
87
88 112 void kai_run_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(
89 size_t m, size_t n, size_t k, //
90 const void* lhs, size_t lhs_stride, //
91 const void* rhs_packed, //
92 void* dst, size_t dst_stride_row, size_t dst_stride_col, //
93 float clamp_min, float clamp_max) {
94 112 KAI_UNUSED(dst_stride_col);
95
96 112 KernelArgs ka;
97
98 112 unsigned long flags = 0;
99
100 112 unsigned int string_length = k;
101 112 ka.num_strings = 1;
102 112 ka.string_lengths = &string_length;
103 112 ka.N = n;
104 112 ka.B_ptr = rhs_packed;
105 112 ka.bias = NULL;
106
107 // Direct input.
108 112 const void* input_ptr = lhs;
109 112 ka.input_offset = lhs_stride / sizeof(float);
110 112 ka.input_initial_col = 0;
111
112 // Direct output.
113 112 ka.output_ptr = dst;
114 112 ka.output_offset = dst_stride_row / sizeof(float);
115
116 // Clamping output.
117 112 flags |= 0x2;
118 112 ka.maxval = clamp_max;
119 112 ka.minval = clamp_min;
120
121 112 kai_kernel_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla(input_ptr, m, &ka, flags);
122 112 }
123
124 #endif // Architectural features check.
125