Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) | ||
8 | #error This file must be compiled for AArch64, FEAT_SVE2. | ||
9 | #else // Architectural features check. | ||
10 | |||
11 | #include "kai_lhs_pack_x16p2vlx2_x16_sme.h" | ||
12 | |||
13 | #include <stddef.h> | ||
14 | #include <stdint.h> | ||
15 | |||
16 | #include "kai/kai_common.h" | ||
17 | |||
18 | enum { | ||
19 | MR = 2, | ||
20 | KR = 2, | ||
21 | MAX_M_STEP = (MR * (KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint16_t)) / KR), | ||
22 | SR = 1, | ||
23 | }; | ||
24 | |||
25 | typedef struct { | ||
26 | size_t m; | ||
27 | size_t k; | ||
28 | size_t mr; | ||
29 | size_t kr; | ||
30 | size_t sr; | ||
31 | size_t m_idx_start; | ||
32 | const void* lhs; | ||
33 | size_t lhs_stride; | ||
34 | void* lhs_packed; | ||
35 | size_t height; | ||
36 | size_t width; | ||
37 | const void* const* in; | ||
38 | size_t row_offset; | ||
39 | void* out; | ||
40 | } KernelArgs; | ||
41 | |||
42 | void kai_kernel_lhs_pack_x16p2vlx2_x16_sme(const KernelArgs* args_ptr); | ||
43 | |||
44 | 262 | static size_t kai_get_mr_lhs_pack_x16p2vlx2_x16_sme(void) { | |
45 | 262 | return MR * kai_get_sme_vector_length_u16() / KR; | |
46 | } | ||
47 | |||
48 | ✗ | size_t kai_get_m_step_lhs_pack_x16p2vlx2_x16_sme(size_t mr) { | |
49 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()); | |
50 | ✗ | KAI_UNUSED(mr); | |
51 | |||
52 | ✗ | return kai_get_mr_lhs_pack_x16p2vlx2_x16_sme(); | |
53 | } | ||
54 | |||
55 | 126 | size_t kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme(size_t m_idx, size_t lhs_stride) { | |
56 | − | KAI_ASSUME(m_idx % kai_get_mr_lhs_pack_x16p2vlx2_x16_sme() == 0); | |
57 | |||
58 | 126 | return m_idx * lhs_stride; | |
59 | } | ||
60 | |||
61 | ✗ | size_t kai_get_lhs_packed_offset_lhs_pack_x16p2vlx2_x16_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | |
62 | − | KAI_ASSUME(m_idx % kai_get_m_step_lhs_pack_x16p2vlx2_x16_sme(mr) == 0); | |
63 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()); | |
64 | − | KAI_ASSUME(kr == KR); | |
65 | − | KAI_ASSUME(sr == SR); | |
66 | |||
67 | ✗ | KAI_UNUSED(mr); | |
68 | ✗ | KAI_UNUSED(kr); | |
69 | ✗ | KAI_UNUSED(sr); | |
70 | |||
71 | ✗ | return m_idx * kai_roundup(k, KR) * sizeof(uint16_t); | |
72 | } | ||
73 | |||
74 | 34 | size_t kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
75 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()); | |
76 | − | KAI_ASSUME(kr == KR); | |
77 | − | KAI_ASSUME(sr == SR); | |
78 | |||
79 | 34 | KAI_UNUSED(mr); | |
80 | 34 | KAI_UNUSED(kr); | |
81 | 34 | KAI_UNUSED(sr); | |
82 | |||
83 | 34 | return kai_roundup(m, kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()) * kai_roundup(k, KR) * sizeof(uint16_t); | |
84 | } | ||
85 | |||
86 | 34 | void kai_run_lhs_pack_x16p2vlx2_x16_sme( | |
87 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
88 | void* lhs_packed) { | ||
89 | − | KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()); | |
90 | − | KAI_ASSUME(kr == KR); | |
91 | − | KAI_ASSUME(sr == SR); | |
92 | − | KAI_ASSUME(lhs != NULL); | |
93 | − | KAI_ASSUME(lhs_packed != NULL); | |
94 | − | KAI_ASSUME(m_idx_start == 0); | |
95 | |||
96 | 34 | const size_t m_step = kai_get_mr_lhs_pack_x16p2vlx2_x16_sme(); | |
97 | 34 | const size_t block_height = mr; | |
98 | 34 | const size_t width = k; | |
99 | 34 | const size_t row_offset = 0; | |
100 | |||
101 | − | KAI_ASSERT(m_step <= MAX_M_STEP); | |
102 | 34 | const void* in[MAX_M_STEP]; | |
103 | |||
104 | 34 | uint8_t* lhs_packed_ptr = lhs_packed; | |
105 | 34 | const uint8_t* lhs_ptr = lhs; | |
106 |
2/2✓ Branch 0 taken 34 times.
✓ Branch 1 taken 38 times.
|
72 | for (size_t block_y = 0; block_y < m; block_y += block_height) { |
107 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 6 times.
|
38 | const size_t height = KAI_MIN(m - block_y, block_height); |
108 | 38 | void* out = lhs_packed_ptr + block_y * kai_roundup(k, KR) * sizeof(uint16_t); | |
109 | |||
110 |
2/2✓ Branch 0 taken 594 times.
✓ Branch 1 taken 38 times.
|
632 | for (size_t y = 0; y < height; y++) { |
111 | 594 | in[y] = lhs_ptr + (block_y + y) * lhs_stride; | |
112 | 594 | } | |
113 | |||
114 | 38 | KernelArgs args; | |
115 | 38 | args.m = m; | |
116 | 38 | args.k = k; | |
117 | 38 | args.mr = MR; | |
118 | 38 | args.kr = KR; | |
119 | 38 | args.sr = SR; | |
120 | 38 | args.m_idx_start = m_idx_start; | |
121 | 38 | args.lhs = lhs; | |
122 | 38 | args.lhs_stride = lhs_stride; | |
123 | 38 | args.lhs_packed = lhs_packed; | |
124 | 38 | args.height = height; | |
125 | 38 | args.width = width; | |
126 | 38 | args.in = in; | |
127 | 38 | args.row_offset = row_offset; | |
128 | 38 | args.out = out; | |
129 | |||
130 | 38 | kai_kernel_lhs_pack_x16p2vlx2_x16_sme(&args); | |
131 | 38 | } | |
132 | 34 | } | |
133 | |||
134 | #endif // Architectural features check. | ||
135 |