KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 33 / 5 / 38
Functions: 100.0% 7 / 0 / 7
Branches: -% 0 / 10 / 10

kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 enum {
19 NR = 2,
20 KR = 1,
21 };
22
23 typedef struct {
24 const void* bias_ptr;
25 size_t width;
26 size_t height;
27 size_t k_chunk_count;
28 size_t in_stride;
29 size_t out_stride;
30 const void* in;
31 void* out;
32 } KernelArgs;
33
34 static const size_t kai_num_bytes_input = sizeof(uint32_t);
35 static const size_t kai_num_bytes_output = sizeof(uint32_t);
36 static const size_t kai_num_bytes_bias = sizeof(float);
37
38 void kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(const KernelArgs* args_ptr);
39
40 42228 size_t kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(void) {
41 42228 return NR * kai_get_sme_vector_length_u32() / KR;
42 }
43
44 4692 size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(size_t n_idx) {
45 KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme() == 0);
46
47 4692 return n_idx * kai_num_bytes_input;
48 }
49
50 4692 size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(size_t n_idx) {
51 4692 return n_idx * kai_num_bytes_bias;
52 }
53
54 14076 size_t kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
55 size_t k_chunk_count, size_t k_chunk_length) {
56 28152 return kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme() *
57 14076 (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, KR) * kai_num_bytes_output);
58 }
59
60 9384 size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
61 size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) {
62 KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme() == 0);
63
64 9384 const size_t block_idx = n_idx / kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme();
65 28152 return block_idx *
66 9384 kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(k_chunk_count, k_chunk_length);
67 9384 }
68
69 4692 size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
70 size_t n, size_t k_chunk_count, size_t k_chunk_length) {
71 4692 const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme());
72 9384 return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
73 4692 n_nr_blocks, k_chunk_count, k_chunk_length);
74 4692 }
75
76 4692 void kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
77 size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias,
78 void* rhs_packed) {
79 KAI_ASSUME(rhs != NULL);
80 KAI_ASSUME(bias != NULL);
81 KAI_ASSUME(rhs_packed != NULL);
82
83 4692 KernelArgs args;
84 4692 args.bias_ptr = bias;
85 4692 args.height = k_chunk_length;
86 4692 args.width = n;
87 4692 args.in = rhs;
88 4692 args.out = rhs_packed;
89 4692 args.k_chunk_count = k_chunk_count;
90 4692 args.in_stride = rhs_stride_row;
91 4692 args.out_stride =
92 4692 kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(k_chunk_count, k_chunk_length);
93
94 4692 kai_commit_za();
95
96 4692 kai_kernel_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(&args);
97 4692 }
98
99 #endif // Architectural features check.
100