KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 94.4% 34 / 13 / 49
Functions: 87.5% 7 / 0 / 8
Branches: 100.0% 6 / 26 / 32

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 enum {
19 NR = 2,
20 KR = 1,
21 SR = 1,
22 NUM_BYTES_DATA = 4,
23 NUM_BYTES_BIAS = 4,
24 MAX_BLOCK_HEIGHT = (NR * (KAI_SME_VEC_LENGTH_MAX_BYTES / NUM_BYTES_DATA) / KR),
25 };
26
27 void kai_kernel_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(
28 size_t height, size_t width, const void* in, void* out, const void* bias);
29
30 648 static size_t get_block_height(void) {
31 648 const size_t block_height = NR * kai_get_sme_vector_length_u32() / KR;
32 1296 return block_height;
33 648 }
34
35 108 size_t kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(void) {
36 108 return get_block_height();
37 }
38
39 108 size_t kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t rhs_stride) {
40 KAI_ASSUME(n_idx % get_block_height() == 0);
41
42 108 return n_idx * rhs_stride;
43 }
44
45 size_t kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx) {
46 KAI_ASSUME(n_idx % get_block_height() == 0);
47
48 return n_idx * NUM_BYTES_BIAS;
49 }
50
51 216 size_t kai_get_rhs_packed_stride_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t k) {
52 216 return NUM_BYTES_BIAS + kai_roundup(k, KR) * NUM_BYTES_DATA;
53 }
54
55 108 size_t kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k) {
56 KAI_ASSUME(n_idx % get_block_height() == 0);
57
58 108 return n_idx * kai_get_rhs_packed_stride_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(k);
59 }
60
61 108 size_t kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n, size_t k) {
62 108 return kai_roundup(n, get_block_height()) * kai_get_rhs_packed_stride_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(k);
63 }
64
65 108 void kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(
66 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
67 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) {
68 KAI_ASSUME(num_groups == 1);
69 KAI_ASSUME(nr == get_block_height());
70 KAI_ASSUME(kr == KR);
71 KAI_ASSUME(sr == SR);
72 KAI_ASSUME(rhs != NULL);
73 KAI_ASSUME(bias != NULL);
74 KAI_ASSUME(scale == NULL);
75 KAI_ASSUME(rhs_packed != NULL);
76 KAI_ASSUME(extra_bytes == 0);
77 KAI_ASSUME(params == NULL);
78
79 108 const size_t block_height = get_block_height();
80 108 const size_t width = k;
81
82 108 const uint8_t* in[MAX_BLOCK_HEIGHT];
83 108 uint8_t* rhs_packed_ptr = rhs_packed;
84 108 const uint8_t* rhs_ptr = rhs;
85 108 const uint8_t* bias_ptr = bias;
86
87 108 kai_commit_za();
88
89
2/2
✓ Branch 0 taken 108 times.
✓ Branch 1 taken 132 times.
240 for (size_t block_y = 0; block_y < n; block_y += block_height) {
90
2/2
✓ Branch 0 taken 78 times.
✓ Branch 1 taken 54 times.
132 const size_t height = KAI_MIN(n - block_y, block_height);
91 132 uint8_t* out = rhs_packed_ptr + block_y * (NUM_BYTES_BIAS + kai_roundup(k, KR) * NUM_BYTES_DATA);
92
93
2/2
✓ Branch 0 taken 2682 times.
✓ Branch 1 taken 132 times.
2814 for (size_t y = 0; y < height; y++) {
94 2682 in[y] = rhs_ptr + (block_y + y) * rhs_stride;
95 2682 }
96
97 132 kai_kernel_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(
98 132 height, width, in, out, bias); // NOLINT(bugprone-multi-level-implicit-pointer-conversion)
99
100 132 bias_ptr += height * NUM_BYTES_BIAS;
101 132 bias = bias_ptr;
102 132 }
103 108 }
104
105 #endif // Architectural features check.
106