KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 83.7% 41 16 65
Functions: 66.7% 4 0 6
Branches: 100.0% 6 32 38

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_lhs_pack_x16p2vlx2_x16_sme.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 enum {
19 MR = 2,
20 KR = 2,
21 MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR),
22 SR = 1,
23 };
24
25 typedef struct {
26 size_t m;
27 size_t k;
28 size_t mr;
29 size_t kr;
30 size_t sr;
31 size_t m_idx_start;
32 const void* lhs;
33 size_t lhs_stride;
34 void* lhs_packed;
35 size_t height;
36 size_t width;
37 const void* const* in;
38 size_t row_offset;
39 void* out;
40 } KernelArgs;
41
42 void kai_kernel_lhs_pack_x16p2vlx2_x16_sme(const KernelArgs* args_ptr);
43
44 262 static size_t kai_get_mr_lhs_pack_x16p2vlx2_x16_sme(void) {
45 262 return MR * kai_get_sme_vector_length_u16() / KR;
46 }
47
48 size_t kai_get_m_step_lhs_pack_x16p2vlx2_x16_sme(size_t mr) {
49 KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme());
50 KAI_UNUSED(mr);
51
52 return kai_get_mr_lhs_pack_x16p2vlx2_x16_sme();
53 }
54
55 126 size_t kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme(size_t m_idx, size_t lhs_stride) {
56 KAI_ASSUME(m_idx % kai_get_mr_lhs_pack_x16p2vlx2_x16_sme() == 0);
57
58 126 return m_idx * lhs_stride;
59 }
60
61 size_t kai_get_lhs_packed_offset_lhs_pack_x16p2vlx2_x16_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
62 KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_x16p2vlx2_x16_sme(mr) == 0);
63 KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme());
64 KAI_ASSUME(kr == KR);
65 KAI_ASSUME(sr == SR);
66
67 KAI_UNUSED(mr);
68 KAI_UNUSED(kr);
69 KAI_UNUSED(sr);
70
71 return m_idx * kai_roundup(k, KR) * sizeof(uint16_t);
72 }
73
74 34 size_t kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
75 KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme());
76 KAI_ASSUME(kr == KR);
77 KAI_ASSUME(sr == SR);
78
79 34 KAI_UNUSED(mr);
80 34 KAI_UNUSED(kr);
81 34 KAI_UNUSED(sr);
82
83 34 return kai_roundup(m, kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()) * kai_roundup(k, KR) * sizeof(uint16_t);
84 }
85
86 34 void kai_run_lhs_pack_x16p2vlx2_x16_sme(
87 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
88 void* lhs_packed) {
89 KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme());
90 KAI_ASSUME(kr == KR);
91 KAI_ASSUME(sr == SR);
92 KAI_ASSUME(lhs != NULL);
93 KAI_ASSUME(lhs_packed != NULL);
94 KAI_ASSUME(m_idx_start == 0);
95
96 34 const size_t m_step = kai_get_mr_lhs_pack_x16p2vlx2_x16_sme();
97 34 const size_t block_height = mr;
98 34 const size_t width = k;
99 34 const size_t row_offset = 0;
100
101 KAI_ASSERT(m_step <= MAX_M_STEP);
102 34 const void* in[MAX_M_STEP];
103
104 34 uint8_t* lhs_packed_ptr = lhs_packed;
105 34 const uint8_t* lhs_ptr = lhs;
106
2/2
✓ Branch 0 taken 34 times.
✓ Branch 1 taken 38 times.
72 for (size_t block_y = 0; block_y < m; block_y += block_height) {
107
2/2
✓ Branch 0 taken 32 times.
✓ Branch 1 taken 6 times.
38 const size_t height = KAI_MIN(m - block_y, block_height);
108 38 void* out = lhs_packed_ptr + block_y * kai_roundup(k, KR) * sizeof(uint16_t);
109
110
2/2
✓ Branch 0 taken 594 times.
✓ Branch 1 taken 38 times.
632 for (size_t y = 0; y < height; y++) {
111 594 in[y] = lhs_ptr + (block_y + y) * lhs_stride;
112 594 }
113
114 38 KernelArgs args;
115 38 args.m = m;
116 38 args.k = k;
117 38 args.mr = MR;
118 38 args.kr = KR;
119 38 args.sr = SR;
120 38 args.m_idx_start = m_idx_start;
121 38 args.lhs = lhs;
122 38 args.lhs_stride = lhs_stride;
123 38 args.lhs_packed = lhs_packed;
124 38 args.height = height;
125 38 args.width = width;
126 38 args.in = in;
127 38 args.row_offset = row_offset;
128 38 args.out = out;
129
130 38 kai_kernel_lhs_pack_x16p2vlx2_x16_sme(&args);
131 38 }
132 34 }
133
134 #endif // Architectural features check.
135