Line |
Branch |
Exec |
Source |
1 |
|
|
// |
2 |
|
|
// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> |
3 |
|
|
// |
4 |
|
|
// SPDX-License-Identifier: Apache-2.0 |
5 |
|
|
// |
6 |
|
|
|
7 |
|
|
#pragma once |
8 |
|
|
|
9 |
|
|
#include <cstddef> |
10 |
|
|
#include <functional> |
11 |
|
|
#include <string_view> |
12 |
|
|
|
13 |
|
|
// clang-format off |
14 |
|
|
#define UKERNEL_MATMUL_VARIANT(name) \ |
15 |
|
|
{kai_get_m_step_matmul_## name, \ |
16 |
|
|
kai_get_n_step_matmul_## name, \ |
17 |
|
|
kai_get_mr_matmul_## name, \ |
18 |
|
|
kai_get_nr_matmul_## name, \ |
19 |
|
|
kai_get_kr_matmul_## name, \ |
20 |
|
|
kai_get_sr_matmul_## name, \ |
21 |
|
|
kai_get_lhs_packed_offset_matmul_## name, \ |
22 |
|
|
kai_get_rhs_packed_offset_matmul_## name, \ |
23 |
|
|
kai_get_dst_offset_matmul_## name, \ |
24 |
|
|
kai_get_dst_size_matmul_## name, \ |
25 |
|
|
kai_run_matmul_## name} |
26 |
|
|
|
27 |
|
|
#define UKERNEL_RHS_PACK_VARIANT(rhs_pack) \ |
28 |
|
|
{ \ |
29 |
|
|
kai_get_rhs_packed_size_##rhs_pack, \ |
30 |
|
|
kai_get_rhs_packed_offset_##rhs_pack, \ |
31 |
|
|
kai_get_rhs_offset_##rhs_pack, \ |
32 |
|
|
kai_run_##rhs_pack \ |
33 |
|
|
} |
34 |
|
|
|
35 |
|
|
#define UKERNEL_LHS_PACK_VARIANT(lhs_pack) \ |
36 |
|
|
{ \ |
37 |
|
|
kai_get_lhs_packed_size_##lhs_pack, \ |
38 |
|
|
kai_get_lhs_packed_offset_##lhs_pack, \ |
39 |
|
|
kai_get_lhs_offset_##lhs_pack, \ |
40 |
|
|
kai_run_##lhs_pack \ |
41 |
|
|
} |
42 |
|
|
|
43 |
|
|
#define EXPAND(x) x |
44 |
|
|
|
45 |
|
|
#define UKERNEL_MATMUL_PACK_VARIANT_NAME(test_name, name, features_check, lhs_pack, rhs_pack, s0s1_input) \ |
46 |
|
|
{ \ |
47 |
|
|
{UKERNEL_MATMUL_VARIANT(name), "kai_matmul_" #test_name, features_check}, \ |
48 |
|
|
UKERNEL_LHS_PACK_VARIANT(lhs_pack), \ |
49 |
|
|
UKERNEL_RHS_PACK_VARIANT(rhs_pack), \ |
50 |
|
|
s0s1_input \ |
51 |
|
|
} |
52 |
|
|
|
53 |
|
|
#define UKERNEL_MATMUL_PACK_VARIANT_DEFAULT(name, features_check, lhs_pack, rhs_pack, s0s1_input) \ |
54 |
|
|
UKERNEL_MATMUL_PACK_VARIANT_NAME(name, name, features_check, lhs_pack, rhs_pack, s0s1_input) |
55 |
|
|
|
56 |
|
|
#define GET_MACRO(_1,_2,_3,_4,_5,_6,NAME,...) NAME |
57 |
|
|
#define UKERNEL_MATMUL_PACK_VARIANT(...) \ |
58 |
|
|
EXPAND(GET_MACRO(__VA_ARGS__, UKERNEL_MATMUL_PACK_VARIANT_NAME, UKERNEL_MATMUL_PACK_VARIANT_DEFAULT)(__VA_ARGS__)) |
59 |
|
|
// clang-format on |
60 |
|
|
|
61 |
|
|
namespace kai::test { |
62 |
|
|
|
63 |
|
|
template <typename T> |
64 |
|
|
struct UkernelVariant { |
65 |
|
|
/// Interface for testing variant. |
66 |
|
|
T interface; |
67 |
|
|
|
68 |
|
|
/// Name of the test variant. |
69 |
|
|
std::string_view name{}; |
70 |
|
|
|
71 |
|
|
/// Check if CPU supports required features. |
72 |
|
|
/// |
73 |
|
|
/// @return Supported (true) or not supported (false). |
74 |
|
|
std::function<bool(void)> fn_is_supported; |
75 |
|
|
|
76 |
|
130 |
UkernelVariant(T interface, const std::string_view name, const std::function<bool(void)>& fn_is_supported) : |
77 |
|
130 |
interface(interface), name(name), fn_is_supported(fn_is_supported) { |
78 |
|
130 |
} |
79 |
|
|
}; |
80 |
|
|
|
81 |
|
|
template <typename T, typename L, typename R> |
82 |
|
|
struct UkernelMatmulPackVariant { |
83 |
|
|
/// Interface for matmul variant. |
84 |
|
|
UkernelVariant<T> ukernel; |
85 |
|
|
|
86 |
|
|
L lhs_pack_interface; |
87 |
|
|
R rhs_pack_interface; |
88 |
|
|
|
89 |
|
|
bool rhs_s0s1_input; |
90 |
|
|
|
91 |
|
|
UkernelMatmulPackVariant() = delete; |
92 |
|
|
}; |
93 |
|
|
} // namespace kai::test |
94 |
|
|
|