benchmark/main.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <getopt.h> | ||
| 8 | #include <unistd.h> | ||
| 9 | |||
| 10 | #include <array> | ||
| 11 | #include <cerrno> | ||
| 12 | #include <cstdlib> | ||
| 13 | #include <cstring> | ||
| 14 | #include <iostream> | ||
| 15 | #include <optional> | ||
| 16 | #include <sstream> | ||
| 17 | #include <string> | ||
| 18 | #include <string_view> | ||
| 19 | #include <vector> | ||
| 20 | |||
| 21 | #include "benchmark/dwconv/dwconv_registry.hpp" | ||
| 22 | #include "benchmark/imatmul/imatmul_registry.hpp" | ||
| 23 | #include "benchmark/matmul/matmul_registry.hpp" | ||
| 24 | #include "kai/kai_common.h" | ||
| 25 | |||
| 26 | #ifdef __GNUC__ | ||
| 27 | #pragma GCC diagnostic push | ||
| 28 | #pragma GCC diagnostic ignored "-Wswitch-default" | ||
| 29 | #endif // __GNUC__ | ||
| 30 | |||
| 31 | #include <benchmark/benchmark.h> | ||
| 32 | |||
| 33 | #ifdef __GNUC__ | ||
| 34 | #pragma GCC diagnostic pop | ||
| 35 | #endif // __GNUC__ | ||
| 36 | |||
| 37 | namespace { | ||
| 38 | |||
| 39 | using namespace std::string_literals; | ||
| 40 | |||
| 41 | 6 | void print_matmul_usage(std::string_view name, bool defaulted = false) { | |
| 42 | 6 | std::ostringstream oss; | |
| 43 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
|
6 | if (defaulted) { |
| 44 | ✗ | oss << "Warning: No operation specified, defaulting to 'matmul' mode.\n"; | |
| 45 | ✗ | oss << "If you intended to run a different operation, specify it explicitly like so:\n"; | |
| 46 | ✗ | oss << '\t' << name << " imatmul [options]\n\n"; | |
| 47 | ✗ | } | |
| 48 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "Matmul usage:" << '\n'; |
| 49 |
4/8✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
|
6 | oss << '\t' << name << " matmul -m <M> -n <N> -k <K> [-b <block_size>]" << '\n'; |
| 50 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "Options:" << '\n'; |
| 51 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t-m,-n,-k\tMatrix dimensions (LHS MxK, RHS KxN)" << '\n'; |
| 52 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t-b\t\t(Optional) Block size for blockwise quantization" << '\n'; |
| 53 |
3/6✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
|
6 | std::cerr << oss.str() << '\n'; |
| 54 | 6 | } | |
| 55 | |||
| 56 | 6 | void print_imatmul_usage(std::string_view name) { | |
| 57 | 6 | std::ostringstream oss; | |
| 58 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "IndirectMatmul usage:" << '\n'; |
| 59 |
4/8✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
|
6 | oss << '\t' << name << " imatmul -m <M> -n <N> -c <k_chunk_count> -l <k_chunk_length>" << '\n'; |
| 60 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "Options:" << '\n'; |
| 61 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t-m\tNumber of rows (LHS)" << '\n'; |
| 62 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t-n\tNumber of columns (RHS)" << '\n'; |
| 63 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t-c\tK chunk count" << '\n'; |
| 64 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t-l\tK chunk length" << '\n'; |
| 65 |
3/6✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
|
6 | std::cerr << oss.str() << '\n'; |
| 66 | 6 | } | |
| 67 | |||
| 68 | 6 | void print_dwconv_usage(std::string_view name) { | |
| 69 | 6 | std::ostringstream oss; | |
| 70 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "DWConv usage:" << '\n'; |
| 71 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << '\t' << name |
| 72 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | << " dwconv --input_height <H> --input_width <W> --channels <C> [--stride <S_h,S_w>] " |
| 73 | "[--padding <P_top,P_bottom,P_left,P_right>] [--dilation <D_h,D_w>]" | ||
| 74 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | << '\n'; |
| 75 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "Options:" << '\n'; |
| 76 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t--input_height\tInput height (required)" << '\n'; |
| 77 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t--input_width\tInput width (required)" << '\n'; |
| 78 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t--channels\tNumber of channels (required)" << '\n'; |
| 79 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | oss << "\t--stride\t(Optional) Two positive comma-separated values for (row, col) stride (default: 1,1)" << '\n'; |
| 80 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | oss << "\t--padding\t(Optional) Four non-negative comma-separated values for (top, bottom, left, right) padding " |
| 81 | "(default: 0,0,0,0)" | ||
| 82 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | << '\n'; |
| 83 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | oss << "\t--dilation\t(Optional) Two positive comma-separated values for (row, col) dilation (default: 1,1)" |
| 84 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | << '\n'; |
| 85 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | oss << "\nCurrent DWConv micro-kernels only support stride=1 and dilation=1.\n"; |
| 86 |
3/6✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
|
6 | std::cerr << oss.str() << '\n'; |
| 87 | 6 | } | |
| 88 | |||
| 89 | 3 | void print_global_usage(std::string_view name) { | |
| 90 | 3 | std::ostringstream oss; | |
| 91 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | oss << "Usage:" << '\n'; |
| 92 |
4/8✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
|
3 | oss << '\t' << name << " <matmul|imatmul|dwconv> [<options>]" << '\n'; |
| 93 |
4/8✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
|
3 | oss << "\nIf no operation is provided, defaults to: " << name << " matmul [options]" << '\n'; |
| 94 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | oss << "\nBenchmark Framework options:" << '\n'; |
| 95 |
4/8✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
|
3 | oss << '\t' << name << " --help" << '\n'; |
| 96 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
|
3 | std::cerr << oss.str() << '\n'; |
| 97 | |||
| 98 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | print_matmul_usage(name); |
| 99 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | print_imatmul_usage(name); |
| 100 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | print_dwconv_usage(name); |
| 101 | 3 | } | |
| 102 | |||
| 103 | enum class DwConvValueKind { Positive, NonNegative }; | ||
| 104 | |||
| 105 | 21 | bool parse_size_t_arg(const char* arg, const char* name, DwConvValueKind kind, size_t& out, std::string& error) { | |
| 106 |
1/2✓ Branch 0 taken 21 times.
✗ Branch 1 not taken.
|
21 | if (!arg) { |
| 107 | ✗ | error = "Missing value for "s + name; | |
| 108 | ✗ | return false; | |
| 109 | } | ||
| 110 | |||
| 111 | 21 | errno = 0; | |
| 112 | 21 | char* end = nullptr; | |
| 113 | 21 | const unsigned long parsed = std::strtoul(arg, &end, 10); | |
| 114 |
3/6✓ Branch 0 taken 21 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 21 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 21 times.
|
21 | if (errno != 0 || end == arg || *end != '\0') { |
| 115 | ✗ | error = "Invalid value for "s + name + ": " + arg; | |
| 116 | ✗ | return false; | |
| 117 | } | ||
| 118 |
3/4✓ Branch 0 taken 9 times.
✓ Branch 1 taken 12 times.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
|
21 | if (kind == DwConvValueKind::Positive && parsed == 0) { |
| 119 | ✗ | error = "Value for "s + name + " must be greater than 0."; | |
| 120 | ✗ | return false; | |
| 121 | } | ||
| 122 | |||
| 123 | 21 | out = static_cast<size_t>(parsed); | |
| 124 | 21 | return true; | |
| 125 | 21 | } | |
| 126 | |||
| 127 | 12 | std::string_view trim_view(std::string_view sv) { | |
| 128 | 12 | const size_t start = sv.find_first_not_of(" \t"); | |
| 129 |
1/2✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
|
12 | if (start == std::string_view::npos) { |
| 130 | ✗ | return {}; | |
| 131 | } | ||
| 132 | 12 | const size_t end = sv.find_last_not_of(" \t"); | |
| 133 | 12 | return sv.substr(start, end - start + 1); | |
| 134 | 12 | } | |
| 135 | |||
| 136 | /// Parses a comma-separated list of `N` size_t values with optional whitespace. | ||
| 137 | template <size_t N> | ||
| 138 | 3 | bool parse_size_t_list( | |
| 139 | const char* arg, const char* name, DwConvValueKind kind, std::array<size_t, N>& out, std::string& error) { | ||
| 140 |
1/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | if (!arg) { |
| 141 | ✗ | error = "Missing value for "s + name; | |
| 142 | ✗ | return false; | |
| 143 | } | ||
| 144 | |||
| 145 | 3 | std::string_view values(arg); | |
| 146 | 3 | std::vector<std::string> tokens; | |
| 147 | 3 | size_t start = 0; | |
| 148 |
1/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
|
12 | while (start <= values.size()) { |
| 149 | 12 | const size_t pos = values.find(',', start); | |
| 150 |
2/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 3 times.
|
12 | const size_t len = (pos == std::string::npos) ? std::string::npos : pos - start; |
| 151 |
1/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
12 | std::string_view token = values.substr(start, len); |
| 152 |
1/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
|
12 | token = trim_view(token); |
| 153 |
1/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
12 | if (token.empty()) { |
| 154 | ✗ | error = "Invalid value for "s + name + ": " + arg; | |
| 155 | ✗ | return false; | |
| 156 | } | ||
| 157 |
1/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
12 | tokens.emplace_back(token); |
| 158 |
2/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 3 times.
|
12 | if (pos == std::string::npos) { |
| 159 | 3 | break; | |
| 160 | } | ||
| 161 | 9 | start = pos + 1; | |
| 162 | 12 | } | |
| 163 | |||
| 164 |
1/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
3 | if (tokens.size() != N) { |
| 165 | ✗ | error = std::string(name) + " expects " + std::to_string(N) + " comma-separated values."; | |
| 166 | ✗ | return false; | |
| 167 | } | ||
| 168 | |||
| 169 |
3/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✓ Branch 5 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
|
15 | for (size_t i = 0; i < N; ++i) { |
| 170 | 12 | size_t parsed = 0; | |
| 171 |
2/8✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 7 not taken.
|
12 | if (!parse_size_t_arg(tokens[i].c_str(), name, kind, parsed, error)) { |
| 172 | ✗ | return false; | |
| 173 | } | ||
| 174 | 12 | out[i] = parsed; | |
| 175 | 12 | } | |
| 176 | |||
| 177 | 3 | return true; | |
| 178 | 3 | } | |
| 179 | |||
| 180 | 6 | std::optional<kai::benchmark::DwConvShape> parse_dwconv_cli(int argc, char** argv, std::string& error) { | |
| 181 | enum : int { | ||
| 182 | OPT_CHANNELS = 1000, | ||
| 183 | OPT_INPUT_HEIGHT, | ||
| 184 | OPT_INPUT_WIDTH, | ||
| 185 | OPT_STRIDE, | ||
| 186 | OPT_PADDING, | ||
| 187 | OPT_DILATION, | ||
| 188 | }; | ||
| 189 | |||
| 190 | 18 | kai::benchmark::DwConvShape shape{}; | |
| 191 | shape.stride = {1, 1}; | ||
| 192 | shape.padding = {0, 0, 0, 0}; | ||
| 193 | shape.dilation = {1, 1}; | ||
| 194 | |||
| 195 | bool input_height_set = false; | ||
| 196 | bool input_width_set = false; | ||
| 197 | bool channels_set = false; | ||
| 198 | |||
| 199 | optind = 1; | ||
| 200 | static const struct option long_options[] = { | ||
| 201 | {"channels", required_argument, nullptr, OPT_CHANNELS}, | ||
| 202 | {"input_height", required_argument, nullptr, OPT_INPUT_HEIGHT}, | ||
| 203 | {"input_width", required_argument, nullptr, OPT_INPUT_WIDTH}, | ||
| 204 | {"stride", required_argument, nullptr, OPT_STRIDE}, | ||
| 205 | {"padding", required_argument, nullptr, OPT_PADDING}, | ||
| 206 | {"dilation", required_argument, nullptr, OPT_DILATION}, | ||
| 207 | {nullptr, 0, nullptr, 0}, | ||
| 208 | }; | ||
| 209 | |||
| 210 | int opt; | ||
| 211 | while ((opt = getopt_long(argc, argv, "", long_options, nullptr)) != -1) { | ||
| 212 | switch (opt) { | ||
| 213 | case OPT_CHANNELS: | ||
| 214 | if (!parse_size_t_arg(optarg, "--channels", DwConvValueKind::Positive, shape.num_channels, error)) { | ||
| 215 | return std::nullopt; | ||
| 216 | } | ||
| 217 | channels_set = true; | ||
| 218 | break; | ||
| 219 | case OPT_INPUT_HEIGHT: | ||
| 220 | if (!parse_size_t_arg(optarg, "--input_height", DwConvValueKind::Positive, shape.input_height, error)) { | ||
| 221 | return std::nullopt; | ||
| 222 | } | ||
| 223 | input_height_set = true; | ||
| 224 | break; | ||
| 225 | case OPT_INPUT_WIDTH: | ||
| 226 | if (!parse_size_t_arg(optarg, "--input_width", DwConvValueKind::Positive, shape.input_width, error)) { | ||
| 227 | return std::nullopt; | ||
| 228 | } | ||
| 229 | input_width_set = true; | ||
| 230 | break; | ||
| 231 | case OPT_STRIDE: | ||
| 232 | if (!parse_size_t_list(optarg, "--stride", DwConvValueKind::Positive, shape.stride, error)) { | ||
| 233 | return std::nullopt; | ||
| 234 | } | ||
| 235 | break; | ||
| 236 | case OPT_PADDING: | ||
| 237 | if (!parse_size_t_list(optarg, "--padding", DwConvValueKind::NonNegative, shape.padding, error)) { | ||
| 238 | return std::nullopt; | ||
| 239 | } | ||
| 240 | break; | ||
| 241 | case OPT_DILATION: | ||
| 242 | if (!parse_size_t_list(optarg, "--dilation", DwConvValueKind::Positive, shape.dilation, error)) { | ||
| 243 | return std::nullopt; | ||
| 244 | } | ||
| 245 | break; | ||
| 246 | case '?': | ||
| 247 | default: | ||
| 248 | error = "Unrecognized option for dwconv benchmark."; | ||
| 249 | return std::nullopt; | ||
| 250 | } | ||
| 251 | } | ||
| 252 | |||
| 253 | if (!input_height_set) { | ||
| 254 | error = "Missing required option --input_height"; | ||
| 255 | return std::nullopt; | ||
| 256 | } | ||
| 257 | if (!input_width_set) { | ||
| 258 | error = "Missing required option --input_width"; | ||
| 259 | return std::nullopt; | ||
| 260 | } | ||
| 261 | if (!channels_set) { | ||
| 262 | error = "Missing required option --channels"; | ||
| 263 | return std::nullopt; | ||
| 264 | } | ||
| 265 | |||
| 266 | return shape; | ||
| 267 | } | ||
| 268 | |||
| 269 | } // namespace | ||
| 270 | |||
| 271 | 36 | static std::optional<std::string> find_user_benchmark_filter(int argc, char** argv) { | |
| 272 | static constexpr std::string_view benchmark_filter_eq = "--benchmark_filter="; | ||
| 273 | static constexpr std::string_view benchmark_filter = "--benchmark_filter"; | ||
| 274 | |||
| 275 |
3/5✓ Branch 0 taken 135 times.
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 36 times.
|
171 | for (int i = 1; i < argc; ++i) { |
| 276 | 135 | const char* arg = argv[i]; | |
| 277 |
1/2✓ Branch 0 taken 135 times.
✗ Branch 1 not taken.
|
135 | if (!arg) { |
| 278 | ✗ | continue; | |
| 279 | } | ||
| 280 | |||
| 281 | // --benchmark_filter=REGEX | ||
| 282 | 135 | std::string_view arg_view(arg); | |
| 283 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 135 times.
|
135 | if (arg_view.substr(0, benchmark_filter_eq.length()) == benchmark_filter_eq) { |
| 284 | ✗ | auto val = arg_view.substr(benchmark_filter_eq.length()); | |
| 285 | ✗ | return std::string(val); | |
| 286 | ✗ | } | |
| 287 | |||
| 288 | // --benchmark_filter REGEX | ||
| 289 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 135 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
135 | if (arg_view == benchmark_filter && i + 1 < argc) { |
| 290 | ✗ | const char* val = argv[i + 1]; | |
| 291 | ✗ | return std::string(val ? val : ""); | |
| 292 | ✗ | } | |
| 293 | 135 | } | |
| 294 | 36 | return std::nullopt; | |
| 295 | 36 | } | |
| 296 | |||
| 297 | 9 | static int run_matmul( | |
| 298 | int argc, char** argv, bool default_to_matmul, const std::optional<std::string>& user_filter_opt) { | ||
| 299 | 9 | bool mflag = false, nflag = false, kflag = false, bflag = false; | |
| 300 | 9 | size_t m = 1, n = 1, k = 1, bl = 32; | |
| 301 | |||
| 302 | 9 | optind = 1; | |
| 303 | 9 | int opt; | |
| 304 |
2/2✓ Branch 0 taken 18 times.
✓ Branch 1 taken 9 times.
|
27 | while ((opt = getopt(argc, argv, "m:n:k:b:")) != -1) { |
| 305 |
3/5✓ Branch 0 taken 6 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | switch (opt) { |
| 306 | case 'm': | ||
| 307 | 6 | m = std::atoi(optarg); | |
| 308 | 6 | mflag = true; | |
| 309 | 6 | break; | |
| 310 | case 'n': | ||
| 311 | 6 | n = std::atoi(optarg); | |
| 312 | 6 | nflag = true; | |
| 313 | 6 | break; | |
| 314 | case 'k': | ||
| 315 | 6 | k = std::atoi(optarg); | |
| 316 | 6 | kflag = true; | |
| 317 | 6 | break; | |
| 318 | case 'b': | ||
| 319 | ✗ | bl = std::atoi(optarg); | |
| 320 | ✗ | bflag = true; | |
| 321 | ✗ | break; | |
| 322 | default: | ||
| 323 | ✗ | print_matmul_usage(argv[0], default_to_matmul); | |
| 324 | ✗ | return EXIT_FAILURE; | |
| 325 | } | ||
| 326 | } | ||
| 327 | |||
| 328 |
4/6✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
|
9 | if (!mflag || !nflag || !kflag) { |
| 329 | 3 | print_matmul_usage(argv[0]); | |
| 330 | 3 | return EXIT_FAILURE; | |
| 331 | } | ||
| 332 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
|
6 | if (!bflag) { |
| 333 | 6 | std::cerr << "Optional argument -b not specified. Defaulting to block size " << bl << "\n"; | |
| 334 | 6 | } | |
| 335 | |||
| 336 | 6 | kai::benchmark::RegisterMatMulBenchmarks({m, n, k}, bl); | |
| 337 | |||
| 338 | // Default filter if user didn’t supply one | ||
| 339 |
4/12✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
6 | std::string spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_matmul"); |
| 340 | |||
| 341 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | ::benchmark::RunSpecifiedBenchmarks(nullptr, nullptr, spec); |
| 342 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | ::benchmark::Shutdown(); |
| 343 | 6 | return 0; | |
| 344 | 9 | } | |
| 345 | |||
| 346 | 6 | static int run_imatmul(int argc, char** argv, const std::optional<std::string>& user_filter_opt) { | |
| 347 | 6 | bool mflag = false, nflag = false, cflag = false, lflag = false; | |
| 348 | 6 | size_t m = 1, n = 1, k_chunk_count = 1, k_chunk_length = 1; | |
| 349 | |||
| 350 | 6 | optind = 1; | |
| 351 | 6 | int opt; | |
| 352 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
|
18 | while ((opt = getopt(argc, argv, "m:n:c:l:")) != -1) { |
| 353 |
4/5✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
|
12 | switch (opt) { |
| 354 | case 'm': | ||
| 355 | 3 | m = std::atoi(optarg); | |
| 356 | 3 | mflag = true; | |
| 357 | 3 | break; | |
| 358 | case 'n': | ||
| 359 | 3 | n = std::atoi(optarg); | |
| 360 | 3 | nflag = true; | |
| 361 | 3 | break; | |
| 362 | case 'c': | ||
| 363 | 3 | k_chunk_count = std::atoi(optarg); | |
| 364 | 3 | cflag = true; | |
| 365 | 3 | break; | |
| 366 | case 'l': | ||
| 367 | 3 | k_chunk_length = std::atoi(optarg); | |
| 368 | 3 | lflag = true; | |
| 369 | 3 | break; | |
| 370 | default: | ||
| 371 | ✗ | print_imatmul_usage(argv[0]); | |
| 372 | ✗ | return EXIT_FAILURE; | |
| 373 | } | ||
| 374 | } | ||
| 375 | |||
| 376 |
5/8✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
|
6 | if (!mflag || !nflag || !cflag || !lflag) { |
| 377 | 3 | print_imatmul_usage(argv[0]); | |
| 378 | 3 | return EXIT_FAILURE; | |
| 379 | } | ||
| 380 | |||
| 381 | 3 | std::cerr << "Running imatmul benchmarks with m=" << m << ", n=" << n << ", k_chunk_count=" << k_chunk_count | |
| 382 | 3 | << ", k_chunk_length=" << k_chunk_length << "\n"; | |
| 383 | |||
| 384 | 3 | kai::benchmark::RegisteriMatMulBenchmarks(m, n, k_chunk_count, k_chunk_length); | |
| 385 | |||
| 386 | // Default filter if user didn’t supply one | ||
| 387 |
4/12✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
3 | std::string spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_imatmul"); |
| 388 | |||
| 389 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | ::benchmark::RunSpecifiedBenchmarks(nullptr, nullptr, spec); |
| 390 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | ::benchmark::Shutdown(); |
| 391 | 3 | return 0; | |
| 392 | 6 | } | |
| 393 | |||
| 394 | 6 | static int run_dwconv(int argc, char** argv, const std::optional<std::string>& user_filter_opt) { | |
| 395 | 6 | std::string parse_error; | |
| 396 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | auto shape_opt = parse_dwconv_cli(argc, argv, parse_error); |
| 397 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (!shape_opt) { |
| 398 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if (!parse_error.empty()) { |
| 399 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | std::cerr << parse_error << '\n'; |
| 400 | 3 | } | |
| 401 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | print_dwconv_usage(argv[0]); |
| 402 | 3 | return EXIT_FAILURE; | |
| 403 | } | ||
| 404 | |||
| 405 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | const auto inferred_dims = kai::benchmark::InferDwConvOutputDims(*shape_opt); |
| 406 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
|
3 | if (!inferred_dims) { |
| 407 | ✗ | std::cerr << "Invalid DWConv configuration: inferred output dimensions are non-positive. " | |
| 408 | ✗ | << "Check stride, padding, and dilation relative to the kernel size.\n"; | |
| 409 | ✗ | return EXIT_FAILURE; | |
| 410 | } | ||
| 411 | |||
| 412 | 3 | const size_t inferred_out_h = inferred_dims->height; | |
| 413 | 3 | const size_t inferred_out_w = inferred_dims->width; | |
| 414 | |||
| 415 | 3 | const auto& shape = *shape_opt; | |
| 416 | 12 | const auto format_array = [](const auto& values) { | |
| 417 | 9 | std::ostringstream oss; | |
| 418 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
9 | oss << '['; |
| 419 |
4/4✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 3 times.
|
33 | for (size_t i = 0; i < values.size(); ++i) { |
| 420 |
4/4✓ Branch 0 taken 6 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 3 times.
|
24 | if (i != 0) { |
| 421 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
|
15 | oss << ", "; |
| 422 | 15 | } | |
| 423 |
2/4✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
24 | oss << values[i]; |
| 424 | 24 | } | |
| 425 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
9 | oss << ']'; |
| 426 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
9 | return oss.str(); |
| 427 | 9 | }; | |
| 428 | |||
| 429 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
3 | if (!kai::benchmark::supports_unit_stride_and_dilation(shape)) { |
| 430 | ✗ | std::cerr << "Configured stride=" << format_array(shape.stride) << " dilation=" << format_array(shape.dilation) | |
| 431 | ✗ | << " is not supported by current DWConv micro-kernels. " | |
| 432 | ✗ | << "Only stride=1 and dilation=1 are available.\n"; | |
| 433 | ✗ | return EXIT_FAILURE; | |
| 434 | } | ||
| 435 | |||
| 436 |
4/8✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
|
6 | std::cerr << "Running dwconv benchmarks with input=" << shape.input_height << 'x' << shape.input_width |
| 437 |
6/12✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
|
3 | << ", output=" << inferred_out_h << 'x' << inferred_out_w << ", channels=" << shape.num_channels |
| 438 |
6/12✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
|
3 | << ", stride=" << format_array(shape.stride) << ", padding=" << format_array(shape.padding) |
| 439 |
4/8✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
|
3 | << ", dilation=" << format_array(shape.dilation) << "\n"; |
| 440 | |||
| 441 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | kai::benchmark::RegisterDwConvBenchmarks(shape); |
| 442 | |||
| 443 |
5/14✓ Branch 0 taken 1 time.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
|
3 | std::string spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_dwconv"); |
| 444 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | ::benchmark::RunSpecifiedBenchmarks(nullptr, nullptr, spec); |
| 445 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | ::benchmark::Shutdown(); |
| 446 | 3 | return 0; | |
| 447 | 6 | } | |
| 448 | |||
| 449 | 36 | int main(int argc, char** argv) { | |
| 450 | // Detect user-provided filter BEFORE Initialize() consumes the benchmark framework flags | ||
| 451 | 36 | const auto user_filter_opt = find_user_benchmark_filter(argc, argv); | |
| 452 | |||
| 453 | // Check for --benchmark_list_tests in argv | ||
| 454 | 36 | bool list_tests = false; | |
| 455 |
2/2✓ Branch 0 taken 135 times.
✓ Branch 1 taken 24 times.
|
171 | for (int i = 1; i < argc; ++i) { |
| 456 |
4/4✓ Branch 0 taken 53 times.
✓ Branch 1 taken 82 times.
✓ Branch 2 taken 4 times.
✓ Branch 3 taken 41 times.
|
135 | if (std::strstr(argv[i], "--benchmark_list_tests") == argv[i]) { |
| 457 | 12 | list_tests = true; | |
| 458 | 12 | break; | |
| 459 | } | ||
| 460 | 123 | } | |
| 461 | |||
| 462 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | ::benchmark::Initialize(&argc, argv); |
| 463 | |||
| 464 |
4/8✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 36 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 36 times.
✗ Branch 7 not taken.
|
36 | std::cerr << "KleidiAI version: v" << kai_get_version() << "\n"; |
| 465 | |||
| 466 | // Determine subcommand (mode): matmul or imatmul. | ||
| 467 | 36 | enum class Mode : uint8_t { COMPAT, MATMUL, IMATMUL, DWCONV } mode = Mode::COMPAT; | |
| 468 | |||
| 469 | static constexpr std::string_view MATMUL = "matmul"; | ||
| 470 | static constexpr std::string_view IMATMUL = "imatmul"; | ||
| 471 | static constexpr std::string_view DWCONV = "dwconv"; | ||
| 472 | |||
| 473 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | std::vector<std::string_view> args(argv, argv + argc); |
| 474 | |||
| 475 |
4/4✓ Branch 0 taken 24 times.
✓ Branch 1 taken 12 times.
✓ Branch 2 taken 21 times.
✓ Branch 3 taken 3 times.
|
36 | if (!list_tests && argc < 2) { |
| 476 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | print_global_usage(argv[0]); |
| 477 | 3 | return EXIT_FAILURE; | |
| 478 | } | ||
| 479 | |||
| 480 |
4/4✓ Branch 0 taken 30 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 21 times.
✓ Branch 3 taken 9 times.
|
33 | if (argc >= 2 && args[1] == MATMUL) { |
| 481 | 9 | mode = Mode::MATMUL; | |
| 482 | 9 | argv += 1; | |
| 483 | 9 | argc -= 1; | |
| 484 |
4/4✓ Branch 0 taken 21 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 9 times.
|
33 | } else if (argc >= 2 && args[1] == IMATMUL) { |
| 485 | 9 | mode = Mode::IMATMUL; | |
| 486 | 9 | argv += 1; | |
| 487 | 9 | argc -= 1; | |
| 488 |
4/4✓ Branch 0 taken 12 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 9 times.
|
24 | } else if (argc >= 2 && args[1] == DWCONV) { |
| 489 | 9 | mode = Mode::DWCONV; | |
| 490 | 9 | argv += 1; | |
| 491 | 9 | argc -= 1; | |
| 492 | 9 | } | |
| 493 | |||
| 494 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 21 times.
|
33 | if (list_tests) { |
| 495 | 12 | std::string spec; | |
| 496 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 3 times.
|
12 | if (mode == Mode::COMPAT) { |
| 497 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | kai::benchmark::RegisterMatMulBenchmarks({1, 1, 1}, 32); |
| 498 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | kai::benchmark::RegisteriMatMulBenchmarks(1, 1, 1, 1); |
| 499 |
5/10✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
|
12 | kai::benchmark::RegisterDwConvBenchmarks({3, 3, 1}); |
| 500 | spec = user_filter_opt.value_or(""); | ||
| 501 | } else if (mode == Mode::MATMUL) { | ||
| 502 | kai::benchmark::RegisterMatMulBenchmarks({1, 1, 1}, 32); | ||
| 503 | spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_matmul"); | ||
| 504 | } else if (mode == Mode::IMATMUL) { | ||
| 505 | kai::benchmark::RegisteriMatMulBenchmarks(1, 1, 1, 1); | ||
| 506 | spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_imatmul"); | ||
| 507 | } else if (mode == Mode::DWCONV) { | ||
| 508 | kai::benchmark::RegisterDwConvBenchmarks({3, 3, 1}); | ||
| 509 | spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_dwconv"); | ||
| 510 | } | ||
| 511 | ::benchmark::RunSpecifiedBenchmarks(nullptr, nullptr, spec); | ||
| 512 | ::benchmark::Shutdown(); | ||
| 513 | return 0; | ||
| 514 | } | ||
| 515 | |||
| 516 | switch (mode) { | ||
| 517 | case Mode::COMPAT: | ||
| 518 | return run_matmul(argc, argv, true, user_filter_opt); | ||
| 519 | case Mode::MATMUL: | ||
| 520 | return run_matmul(argc, argv, false, user_filter_opt); | ||
| 521 | case Mode::IMATMUL: | ||
| 522 | return run_imatmul(argc, argv, user_filter_opt); | ||
| 523 | case Mode::DWCONV: | ||
| 524 | return run_dwconv(argc, argv, user_filter_opt); | ||
| 525 | default: | ||
| 526 | print_global_usage(argv[0]); | ||
| 527 | return EXIT_FAILURE; | ||
| 528 | } | ||
| 529 | ✗ | } | |
| 530 |