KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 34 / 6 / 40
Functions: 100.0% 7 / 0 / 7
Branches: -% 0 / 12 / 12

kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15 #include <string.h>
16
17 #include "kai/kai_common.h"
18
19 enum {
20 NR = 2,
21 KR = 2,
22 MAX_N_STEP = NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR),
23 };
24
25 typedef struct {
26 const void* bias_ptr;
27 size_t width;
28 size_t height;
29 size_t k_chunk_count;
30 size_t in_stride;
31 size_t out_stride;
32 const void* in;
33 void* out;
34 const void* pad_row;
35 } KernelArgs;
36
37 static const size_t kai_num_bytes_input = sizeof(uint16_t);
38 static const size_t kai_num_bytes_output = sizeof(uint16_t);
39 static const size_t kai_num_bytes_bias = sizeof(uint16_t);
40
41 void kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(const KernelArgs* args_ptr);
42
43 46920 size_t kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(void) {
44 46920 return NR * kai_get_sme_vector_length_u16() / KR;
45 }
46
47 4692 size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx) {
48 KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() == 0);
49
50 4692 return n_idx * kai_num_bytes_input;
51 }
52
53 4692 size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(size_t n_idx) {
54 4692 return n_idx * kai_num_bytes_bias;
55 }
56
57 14076 size_t kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
58 size_t k_chunk_count, size_t k_chunk_length) {
59 28152 return kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() *
60 14076 (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, KR) * kai_num_bytes_output);
61 }
62
63 9384 size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
64 size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) {
65 KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() == 0);
66
67 9384 const size_t block_idx = n_idx / kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme();
68 28152 return block_idx *
69 9384 kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(k_chunk_count, k_chunk_length);
70 9384 }
71
72 4692 size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
73 size_t n, size_t k_chunk_count, size_t k_chunk_length) {
74 4692 const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme());
75 9384 return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
76 4692 n_nr_blocks, k_chunk_count, k_chunk_length);
77 4692 }
78
79 4692 void kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
80 size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias,
81 void* rhs_packed) {
82 KAI_ASSUME(rhs != NULL);
83 KAI_ASSUME(bias != NULL);
84 KAI_ASSUME(rhs_packed != NULL);
85
86 KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme() <= MAX_N_STEP);
87 static const uint16_t pad_row[MAX_N_STEP] = {0};
88
89 4692 KernelArgs args;
90 4692 args.bias_ptr = bias;
91 4692 args.height = k_chunk_length;
92 4692 args.width = n;
93 4692 args.in = rhs;
94 4692 args.out = rhs_packed;
95 4692 args.k_chunk_count = k_chunk_count;
96 4692 args.in_stride = rhs_stride_row;
97 4692 args.out_stride =
98 4692 kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(k_chunk_count, k_chunk_length);
99 4692 args.pad_row = pad_row;
100
101 4692 kai_commit_za();
102
103 4692 kai_kernel_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(&args);
104 4692 }
105
106 #endif // Architectural features check.
107