KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 100.0% 49 16 65
Functions: 100.0% 6 0 6
Branches: 100.0% 6 32 38

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_lhs_pack_x8p2vlx4_x8_sme.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 enum {
19 MR = 2,
20 KR = 4,
21 MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(int8_t)) / KR),
22 SR = 1,
23 };
24
25 typedef struct {
26 size_t m;
27 size_t k;
28 size_t mr;
29 size_t kr;
30 size_t sr;
31 size_t m_idx_start;
32 const void* lhs;
33 size_t lhs_stride;
34 void* lhs_packed;
35 size_t height;
36 size_t width;
37 const void* const* in;
38 size_t row_offset;
39 void* out;
40 } KernelArgs;
41
42 void kai_kernel_lhs_pack_x8p2vlx4_x8_sme(const KernelArgs* args_ptr);
43
44 5724 static size_t kai_get_mr_lhs_pack_x8p2vlx4_x8_sme(void) {
45 5724 return MR * kai_get_sme_vector_length_u8() / KR;
46 }
47
48 798 size_t kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme(size_t mr) {
49 KAI_ASSUME(mr == kai_get_mr_lhs_pack_x8p2vlx4_x8_sme());
50 798 KAI_UNUSED(mr);
51
52 798 return kai_get_mr_lhs_pack_x8p2vlx4_x8_sme();
53 }
54
55 666 size_t kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t lhs_stride) {
56 KAI_ASSUME(m_idx % kai_get_mr_lhs_pack_x8p2vlx4_x8_sme() == 0);
57
58 666 return m_idx * lhs_stride;
59 }
60
61 798 size_t kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
62 KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme(mr) == 0);
63 KAI_ASSUME(mr == kai_get_mr_lhs_pack_x8p2vlx4_x8_sme());
64 KAI_ASSUME(kr == KR);
65 KAI_ASSUME(sr == SR);
66
67 798 KAI_UNUSED(mr);
68 798 KAI_UNUSED(kr);
69 798 KAI_UNUSED(sr);
70
71 798 return m_idx * kai_roundup(k, KR) * sizeof(int8_t);
72 }
73
74 666 size_t kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
75 KAI_ASSUME(mr == kai_get_mr_lhs_pack_x8p2vlx4_x8_sme());
76 KAI_ASSUME(kr == KR);
77 KAI_ASSUME(sr == SR);
78
79 666 KAI_UNUSED(mr);
80 666 KAI_UNUSED(kr);
81 666 KAI_UNUSED(sr);
82
83 666 return kai_roundup(m, kai_get_mr_lhs_pack_x8p2vlx4_x8_sme()) * kai_roundup(k, KR) * sizeof(int8_t);
84 }
85
86 666 void kai_run_lhs_pack_x8p2vlx4_x8_sme(
87 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
88 void* lhs_packed) {
89 KAI_ASSUME(mr == kai_get_mr_lhs_pack_x8p2vlx4_x8_sme());
90 KAI_ASSUME(kr == KR);
91 KAI_ASSUME(sr == SR);
92 KAI_ASSUME(lhs != NULL);
93 KAI_ASSUME(lhs_packed != NULL);
94 KAI_ASSUME(m_idx_start == 0);
95
96 666 const size_t m_step = kai_get_mr_lhs_pack_x8p2vlx4_x8_sme();
97 666 const size_t block_height = mr;
98 666 const size_t width = k;
99 666 const size_t row_offset = 0;
100
101 KAI_ASSERT(m_step <= MAX_M_STEP);
102 666 const void* in[MAX_M_STEP];
103
104 666 uint8_t* lhs_packed_ptr = lhs_packed;
105 666 const uint8_t* lhs_ptr = lhs;
106
2/2
✓ Branch 0 taken 666 times.
✓ Branch 1 taken 684 times.
1350 for (size_t block_y = 0; block_y < m; block_y += block_height) {
107
2/2
✓ Branch 0 taken 414 times.
✓ Branch 1 taken 270 times.
684 const size_t height = KAI_MIN(m - block_y, block_height);
108 684 void* out = lhs_packed_ptr + block_y * kai_roundup(k, KR) * sizeof(int8_t);
109
110
2/2
✓ Branch 0 taken 10902 times.
✓ Branch 1 taken 684 times.
11586 for (size_t y = 0; y < height; y++) {
111 10902 in[y] = lhs_ptr + (block_y + y) * lhs_stride;
112 10902 }
113
114 684 KernelArgs args;
115 684 args.m = m;
116 684 args.k = k;
117 684 args.mr = MR;
118 684 args.kr = KR;
119 684 args.sr = SR;
120 684 args.m_idx_start = m_idx_start;
121 684 args.lhs = lhs;
122 684 args.lhs_stride = lhs_stride;
123 684 args.lhs_packed = lhs_packed;
124 684 args.height = height;
125 684 args.width = width;
126 684 args.in = in;
127 684 args.row_offset = row_offset;
128 684 args.out = out;
129
130 684 kai_kernel_lhs_pack_x8p2vlx4_x8_sme(&args);
131 684 }
132 666 }
133
134 #endif // Architectural features check.
135