KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 94.9% 37 5 44
Functions: 90.9% 10 0 11
Branches: -% 0 10 10

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h"
12
13 #include <arm_neon.h>
14 #include <stddef.h>
15 #include <stdint.h>
16
17 #include "kai/kai_common.h"
18
19 typedef struct {
20 uint16_t maxval;
21 uint16_t minval;
22 const void* A_ptr;
23 const void* B_ptr;
24 size_t N;
25 size_t K;
26 void* output_ptr;
27 uint64_t flags;
28 } KernelArgs;
29
30 void kai_kernel_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(KernelArgs* args_ptr);
31 uint16_t kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(float value);
32
33 static const size_t kai_m_step = 1;
34 static const size_t kai_nr = 2;
35 static const size_t kai_n_step = 16;
36 static const size_t kai_kr = 2;
37 static const size_t kai_sr = 1;
38
39 52 size_t kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(void) {
40 52 return kai_m_step;
41 }
42
43 236 size_t kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(void) {
44 236 return kai_n_step * kai_get_sme_vector_length_u16() / kai_kr;
45 }
46
47 46 size_t kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(void) {
48 46 return kai_nr * kai_get_sme_vector_length_u16() / kai_kr;
49 }
50
51 46 size_t kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(void) {
52 46 return kai_kr;
53 }
54
55 46 size_t kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(void) {
56 46 return kai_sr;
57 }
58
59 size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(size_t m_idx, size_t k) {
60 KAI_ASSUME(m_idx == 0);
61
62 return m_idx * k;
63 }
64
65 46 static size_t kai_get_rhs_packed_stride_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(size_t k) {
66 92 return kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot() *
67 46 (kai_roundup(k, kai_kr) * sizeof(uint16_t) + sizeof(uint16_t));
68 }
69
70 46 size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(size_t n_idx, size_t k) {
71 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot() == 0);
72
73 46 const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot();
74 92 return block_idx * kai_get_rhs_packed_stride_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(k);
75 46 }
76
77 46 size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(
78 size_t m_idx, size_t n_idx, size_t dst_stride) {
79 KAI_ASSUME(m_idx == 0);
80 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot() == 0);
81
82 46 return (m_idx * dst_stride) + (n_idx * sizeof(uint16_t));
83 }
84
85 46 size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(size_t m, size_t n) {
86 46 return m * n * sizeof(uint16_t);
87 }
88
89 47 void kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(
90 size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst,
91 size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) {
92 47 KAI_UNUSED(dst_stride_row);
93 47 KAI_UNUSED(dst_stride_col);
94
95 47 KAI_UNUSED(lhs_stride);
96
97 KAI_ASSUME(m == 1);
98
99 47 uint64_t flags = 2;
100
101 47 KernelArgs args;
102
103 47 args.minval = kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(clamp_min);
104 47 args.maxval = kai_f16_from_float_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(clamp_max);
105
106 47 args.A_ptr = lhs;
107 47 args.B_ptr = rhs_packed;
108 47 args.N = n;
109 47 args.K = k;
110 47 args.output_ptr = dst;
111 47 args.flags = flags;
112
113 47 kai_kernel_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot(&args);
114 47 }
115
116 #endif // Architectural features check.
117