KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 85.3% 232 / 0 / 272
Functions: 93.5% 29 / 0 / 31
Branches: 43.7% 246 / 0 / 563

benchmark/main.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <getopt.h>
8 #include <unistd.h>
9
10 #include <array>
11 #include <cerrno>
12 #include <cstdlib>
13 #include <cstring>
14 #include <iostream>
15 #include <optional>
16 #include <sstream>
17 #include <string>
18 #include <string_view>
19 #include <vector>
20
21 #include "benchmark/dwconv/dwconv_registry.hpp"
22 #include "benchmark/imatmul/imatmul_registry.hpp"
23 #include "benchmark/matmul/matmul_registry.hpp"
24 #include "kai/kai_common.h"
25
26 #ifdef __GNUC__
27 #pragma GCC diagnostic push
28 #pragma GCC diagnostic ignored "-Wswitch-default"
29 #endif // __GNUC__
30
31 #include <benchmark/benchmark.h>
32
33 #ifdef __GNUC__
34 #pragma GCC diagnostic pop
35 #endif // __GNUC__
36
37 namespace {
38
39 using namespace std::string_literals;
40
41 6 void print_matmul_usage(std::string_view name, bool defaulted = false) {
42 6 std::ostringstream oss;
43
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if (defaulted) {
44 oss << "Warning: No operation specified, defaulting to 'matmul' mode.\n";
45 oss << "If you intended to run a different operation, specify it explicitly like so:\n";
46 oss << '\t' << name << " imatmul [options]\n\n";
47 }
48
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "Matmul usage:" << '\n';
49
4/8
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
6 oss << '\t' << name << " matmul -m <M> -n <N> -k <K> [-b <block_size>]" << '\n';
50
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "Options:" << '\n';
51
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t-m,-n,-k\tMatrix dimensions (LHS MxK, RHS KxN)" << '\n';
52
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t-b\t\t(Optional) Block size for blockwise quantization" << '\n';
53
3/6
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
6 std::cerr << oss.str() << '\n';
54 6 }
55
56 6 void print_imatmul_usage(std::string_view name) {
57 6 std::ostringstream oss;
58
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "IndirectMatmul usage:" << '\n';
59
4/8
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
6 oss << '\t' << name << " imatmul -m <M> -n <N> -c <k_chunk_count> -l <k_chunk_length>" << '\n';
60
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "Options:" << '\n';
61
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t-m\tNumber of rows (LHS)" << '\n';
62
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t-n\tNumber of columns (RHS)" << '\n';
63
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t-c\tK chunk count" << '\n';
64
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t-l\tK chunk length" << '\n';
65
3/6
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
6 std::cerr << oss.str() << '\n';
66 6 }
67
68 6 void print_dwconv_usage(std::string_view name) {
69 6 std::ostringstream oss;
70
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "DWConv usage:" << '\n';
71
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << '\t' << name
72
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 << " dwconv --input_height <H> --input_width <W> --channels <C> [--stride <S_h,S_w>] "
73 "[--padding <P_top,P_bottom,P_left,P_right>] [--dilation <D_h,D_w>]"
74
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 << '\n';
75
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "Options:" << '\n';
76
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t--input_height\tInput height (required)" << '\n';
77
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t--input_width\tInput width (required)" << '\n';
78
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t--channels\tNumber of channels (required)" << '\n';
79
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 oss << "\t--stride\t(Optional) Two positive comma-separated values for (row, col) stride (default: 1,1)" << '\n';
80
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 oss << "\t--padding\t(Optional) Four non-negative comma-separated values for (top, bottom, left, right) padding "
81 "(default: 0,0,0,0)"
82
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 << '\n';
83
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 oss << "\t--dilation\t(Optional) Two positive comma-separated values for (row, col) dilation (default: 1,1)"
84
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 << '\n';
85
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 oss << "\nCurrent DWConv micro-kernels only support stride=1 and dilation=1.\n";
86
3/6
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
6 std::cerr << oss.str() << '\n';
87 6 }
88
89 3 void print_global_usage(std::string_view name) {
90 3 std::ostringstream oss;
91
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 oss << "Usage:" << '\n';
92
4/8
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
3 oss << '\t' << name << " <matmul|imatmul|dwconv> [<options>]" << '\n';
93
4/8
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
3 oss << "\nIf no operation is provided, defaults to: " << name << " matmul [options]" << '\n';
94
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 oss << "\nBenchmark Framework options:" << '\n';
95
4/8
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
3 oss << '\t' << name << " --help" << '\n';
96
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
3 std::cerr << oss.str() << '\n';
97
98
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 print_matmul_usage(name);
99
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 print_imatmul_usage(name);
100
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 print_dwconv_usage(name);
101 3 }
102
103 enum class DwConvValueKind { Positive, NonNegative };
104
105 21 bool parse_size_t_arg(const char* arg, const char* name, DwConvValueKind kind, size_t& out, std::string& error) {
106
1/2
✓ Branch 0 taken 21 times.
✗ Branch 1 not taken.
21 if (!arg) {
107 error = "Missing value for "s + name;
108 return false;
109 }
110
111 21 errno = 0;
112 21 char* end = nullptr;
113 21 const unsigned long parsed = std::strtoul(arg, &end, 10);
114
3/6
✓ Branch 0 taken 21 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 21 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 21 times.
21 if (errno != 0 || end == arg || *end != '\0') {
115 error = "Invalid value for "s + name + ": " + arg;
116 return false;
117 }
118
3/4
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 12 times.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
21 if (kind == DwConvValueKind::Positive && parsed == 0) {
119 error = "Value for "s + name + " must be greater than 0.";
120 return false;
121 }
122
123 21 out = static_cast<size_t>(parsed);
124 21 return true;
125 21 }
126
127 12 std::string_view trim_view(std::string_view sv) {
128 12 const size_t start = sv.find_first_not_of(" \t");
129
1/2
✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
12 if (start == std::string_view::npos) {
130 return {};
131 }
132 12 const size_t end = sv.find_last_not_of(" \t");
133 12 return sv.substr(start, end - start + 1);
134 12 }
135
136 /// Parses a comma-separated list of `N` size_t values with optional whitespace.
137 template <size_t N>
138 3 bool parse_size_t_list(
139 const char* arg, const char* name, DwConvValueKind kind, std::array<size_t, N>& out, std::string& error) {
140
1/4
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 if (!arg) {
141 error = "Missing value for "s + name;
142 return false;
143 }
144
145 3 std::string_view values(arg);
146 3 std::vector<std::string> tokens;
147 3 size_t start = 0;
148
1/4
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
12 while (start <= values.size()) {
149 12 const size_t pos = values.find(',', start);
150
2/4
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 3 times.
12 const size_t len = (pos == std::string::npos) ? std::string::npos : pos - start;
151
1/4
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
12 std::string_view token = values.substr(start, len);
152
1/4
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
12 token = trim_view(token);
153
1/4
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
12 if (token.empty()) {
154 error = "Invalid value for "s + name + ": " + arg;
155 return false;
156 }
157
1/4
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
12 tokens.emplace_back(token);
158
2/4
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 3 times.
12 if (pos == std::string::npos) {
159 3 break;
160 }
161 9 start = pos + 1;
162 12 }
163
164
1/4
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
3 if (tokens.size() != N) {
165 error = std::string(name) + " expects " + std::to_string(N) + " comma-separated values.";
166 return false;
167 }
168
169
3/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✓ Branch 5 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
15 for (size_t i = 0; i < N; ++i) {
170 12 size_t parsed = 0;
171
2/8
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 7 not taken.
12 if (!parse_size_t_arg(tokens[i].c_str(), name, kind, parsed, error)) {
172 return false;
173 }
174 12 out[i] = parsed;
175 12 }
176
177 3 return true;
178 3 }
179
180 6 std::optional<kai::benchmark::DwConvShape> parse_dwconv_cli(int argc, char** argv, std::string& error) {
181 enum : int {
182 OPT_CHANNELS = 1000,
183 OPT_INPUT_HEIGHT,
184 OPT_INPUT_WIDTH,
185 OPT_STRIDE,
186 OPT_PADDING,
187 OPT_DILATION,
188 };
189
190 18 kai::benchmark::DwConvShape shape{};
191 shape.stride = {1, 1};
192 shape.padding = {0, 0, 0, 0};
193 shape.dilation = {1, 1};
194
195 bool input_height_set = false;
196 bool input_width_set = false;
197 bool channels_set = false;
198
199 optind = 1;
200 static const struct option long_options[] = {
201 {"channels", required_argument, nullptr, OPT_CHANNELS},
202 {"input_height", required_argument, nullptr, OPT_INPUT_HEIGHT},
203 {"input_width", required_argument, nullptr, OPT_INPUT_WIDTH},
204 {"stride", required_argument, nullptr, OPT_STRIDE},
205 {"padding", required_argument, nullptr, OPT_PADDING},
206 {"dilation", required_argument, nullptr, OPT_DILATION},
207 {nullptr, 0, nullptr, 0},
208 };
209
210 int opt;
211 while ((opt = getopt_long(argc, argv, "", long_options, nullptr)) != -1) {
212 switch (opt) {
213 case OPT_CHANNELS:
214 if (!parse_size_t_arg(optarg, "--channels", DwConvValueKind::Positive, shape.num_channels, error)) {
215 return std::nullopt;
216 }
217 channels_set = true;
218 break;
219 case OPT_INPUT_HEIGHT:
220 if (!parse_size_t_arg(optarg, "--input_height", DwConvValueKind::Positive, shape.input_height, error)) {
221 return std::nullopt;
222 }
223 input_height_set = true;
224 break;
225 case OPT_INPUT_WIDTH:
226 if (!parse_size_t_arg(optarg, "--input_width", DwConvValueKind::Positive, shape.input_width, error)) {
227 return std::nullopt;
228 }
229 input_width_set = true;
230 break;
231 case OPT_STRIDE:
232 if (!parse_size_t_list(optarg, "--stride", DwConvValueKind::Positive, shape.stride, error)) {
233 return std::nullopt;
234 }
235 break;
236 case OPT_PADDING:
237 if (!parse_size_t_list(optarg, "--padding", DwConvValueKind::NonNegative, shape.padding, error)) {
238 return std::nullopt;
239 }
240 break;
241 case OPT_DILATION:
242 if (!parse_size_t_list(optarg, "--dilation", DwConvValueKind::Positive, shape.dilation, error)) {
243 return std::nullopt;
244 }
245 break;
246 case '?':
247 default:
248 error = "Unrecognized option for dwconv benchmark.";
249 return std::nullopt;
250 }
251 }
252
253 if (!input_height_set) {
254 error = "Missing required option --input_height";
255 return std::nullopt;
256 }
257 if (!input_width_set) {
258 error = "Missing required option --input_width";
259 return std::nullopt;
260 }
261 if (!channels_set) {
262 error = "Missing required option --channels";
263 return std::nullopt;
264 }
265
266 return shape;
267 }
268
269 } // namespace
270
271 36 static std::optional<std::string> find_user_benchmark_filter(int argc, char** argv) {
272 static constexpr std::string_view benchmark_filter_eq = "--benchmark_filter=";
273 static constexpr std::string_view benchmark_filter = "--benchmark_filter";
274
275
3/5
✓ Branch 0 taken 135 times.
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 36 times.
171 for (int i = 1; i < argc; ++i) {
276 135 const char* arg = argv[i];
277
1/2
✓ Branch 0 taken 135 times.
✗ Branch 1 not taken.
135 if (!arg) {
278 continue;
279 }
280
281 // --benchmark_filter=REGEX
282 135 std::string_view arg_view(arg);
283
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 135 times.
135 if (arg_view.substr(0, benchmark_filter_eq.length()) == benchmark_filter_eq) {
284 auto val = arg_view.substr(benchmark_filter_eq.length());
285 return std::string(val);
286 }
287
288 // --benchmark_filter REGEX
289
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 135 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
135 if (arg_view == benchmark_filter && i + 1 < argc) {
290 const char* val = argv[i + 1];
291 return std::string(val ? val : "");
292 }
293 135 }
294 36 return std::nullopt;
295 36 }
296
297 9 static int run_matmul(
298 int argc, char** argv, bool default_to_matmul, const std::optional<std::string>& user_filter_opt) {
299 9 bool mflag = false, nflag = false, kflag = false, bflag = false;
300 9 size_t m = 1, n = 1, k = 1, bl = 32;
301
302 9 optind = 1;
303 9 int opt;
304
2/2
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 9 times.
27 while ((opt = getopt(argc, argv, "m:n:k:b:")) != -1) {
305
3/5
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
18 switch (opt) {
306 case 'm':
307 6 m = std::atoi(optarg);
308 6 mflag = true;
309 6 break;
310 case 'n':
311 6 n = std::atoi(optarg);
312 6 nflag = true;
313 6 break;
314 case 'k':
315 6 k = std::atoi(optarg);
316 6 kflag = true;
317 6 break;
318 case 'b':
319 bl = std::atoi(optarg);
320 bflag = true;
321 break;
322 default:
323 print_matmul_usage(argv[0], default_to_matmul);
324 return EXIT_FAILURE;
325 }
326 }
327
328
4/6
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
9 if (!mflag || !nflag || !kflag) {
329 3 print_matmul_usage(argv[0]);
330 3 return EXIT_FAILURE;
331 }
332
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if (!bflag) {
333 6 std::cerr << "Optional argument -b not specified. Defaulting to block size " << bl << "\n";
334 6 }
335
336 6 kai::benchmark::RegisterMatMulBenchmarks({m, n, k}, bl);
337
338 // Default filter if user didn’t supply one
339
4/12
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
6 std::string spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_matmul");
340
341
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 ::benchmark::RunSpecifiedBenchmarks(nullptr, nullptr, spec);
342
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 ::benchmark::Shutdown();
343 6 return 0;
344 9 }
345
346 6 static int run_imatmul(int argc, char** argv, const std::optional<std::string>& user_filter_opt) {
347 6 bool mflag = false, nflag = false, cflag = false, lflag = false;
348 6 size_t m = 1, n = 1, k_chunk_count = 1, k_chunk_length = 1;
349
350 6 optind = 1;
351 6 int opt;
352
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
18 while ((opt = getopt(argc, argv, "m:n:c:l:")) != -1) {
353
4/5
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
12 switch (opt) {
354 case 'm':
355 3 m = std::atoi(optarg);
356 3 mflag = true;
357 3 break;
358 case 'n':
359 3 n = std::atoi(optarg);
360 3 nflag = true;
361 3 break;
362 case 'c':
363 3 k_chunk_count = std::atoi(optarg);
364 3 cflag = true;
365 3 break;
366 case 'l':
367 3 k_chunk_length = std::atoi(optarg);
368 3 lflag = true;
369 3 break;
370 default:
371 print_imatmul_usage(argv[0]);
372 return EXIT_FAILURE;
373 }
374 }
375
376
5/8
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
6 if (!mflag || !nflag || !cflag || !lflag) {
377 3 print_imatmul_usage(argv[0]);
378 3 return EXIT_FAILURE;
379 }
380
381 3 std::cerr << "Running imatmul benchmarks with m=" << m << ", n=" << n << ", k_chunk_count=" << k_chunk_count
382 3 << ", k_chunk_length=" << k_chunk_length << "\n";
383
384 3 kai::benchmark::RegisteriMatMulBenchmarks(m, n, k_chunk_count, k_chunk_length);
385
386 // Default filter if user didn’t supply one
387
4/12
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
3 std::string spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_imatmul");
388
389
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 ::benchmark::RunSpecifiedBenchmarks(nullptr, nullptr, spec);
390
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 ::benchmark::Shutdown();
391 3 return 0;
392 6 }
393
394 6 static int run_dwconv(int argc, char** argv, const std::optional<std::string>& user_filter_opt) {
395 6 std::string parse_error;
396
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 auto shape_opt = parse_dwconv_cli(argc, argv, parse_error);
397
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
6 if (!shape_opt) {
398
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if (!parse_error.empty()) {
399
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 std::cerr << parse_error << '\n';
400 3 }
401
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 print_dwconv_usage(argv[0]);
402 3 return EXIT_FAILURE;
403 }
404
405
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 const auto inferred_dims = kai::benchmark::InferDwConvOutputDims(*shape_opt);
406
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
3 if (!inferred_dims) {
407 std::cerr << "Invalid DWConv configuration: inferred output dimensions are non-positive. "
408 << "Check stride, padding, and dilation relative to the kernel size.\n";
409 return EXIT_FAILURE;
410 }
411
412 3 const size_t inferred_out_h = inferred_dims->height;
413 3 const size_t inferred_out_w = inferred_dims->width;
414
415 3 const auto& shape = *shape_opt;
416 12 const auto format_array = [](const auto& values) {
417 9 std::ostringstream oss;
418
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
9 oss << '[';
419
4/4
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 3 times.
33 for (size_t i = 0; i < values.size(); ++i) {
420
4/4
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 3 times.
24 if (i != 0) {
421
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
15 oss << ", ";
422 15 }
423
2/4
✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
24 oss << values[i];
424 24 }
425
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
9 oss << ']';
426
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
9 return oss.str();
427 9 };
428
429
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
3 if (!kai::benchmark::supports_unit_stride_and_dilation(shape)) {
430 std::cerr << "Configured stride=" << format_array(shape.stride) << " dilation=" << format_array(shape.dilation)
431 << " is not supported by current DWConv micro-kernels. "
432 << "Only stride=1 and dilation=1 are available.\n";
433 return EXIT_FAILURE;
434 }
435
436
4/8
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
6 std::cerr << "Running dwconv benchmarks with input=" << shape.input_height << 'x' << shape.input_width
437
6/12
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
3 << ", output=" << inferred_out_h << 'x' << inferred_out_w << ", channels=" << shape.num_channels
438
6/12
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
3 << ", stride=" << format_array(shape.stride) << ", padding=" << format_array(shape.padding)
439
4/8
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
3 << ", dilation=" << format_array(shape.dilation) << "\n";
440
441
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 kai::benchmark::RegisterDwConvBenchmarks(shape);
442
443
5/14
✓ Branch 0 taken 1 time.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
3 std::string spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_dwconv");
444
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 ::benchmark::RunSpecifiedBenchmarks(nullptr, nullptr, spec);
445
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 ::benchmark::Shutdown();
446 3 return 0;
447 6 }
448
449 36 int main(int argc, char** argv) {
450 // Detect user-provided filter BEFORE Initialize() consumes the benchmark framework flags
451 36 const auto user_filter_opt = find_user_benchmark_filter(argc, argv);
452
453 // Check for --benchmark_list_tests in argv
454 36 bool list_tests = false;
455
2/2
✓ Branch 0 taken 135 times.
✓ Branch 1 taken 24 times.
171 for (int i = 1; i < argc; ++i) {
456
4/4
✓ Branch 0 taken 53 times.
✓ Branch 1 taken 82 times.
✓ Branch 2 taken 4 times.
✓ Branch 3 taken 41 times.
135 if (std::strstr(argv[i], "--benchmark_list_tests") == argv[i]) {
457 12 list_tests = true;
458 12 break;
459 }
460 123 }
461
462
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 ::benchmark::Initialize(&argc, argv);
463
464
4/8
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 36 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 36 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 36 times.
✗ Branch 7 not taken.
36 std::cerr << "KleidiAI version: v" << kai_get_version() << "\n";
465
466 // Determine subcommand (mode): matmul or imatmul.
467 36 enum class Mode : uint8_t { COMPAT, MATMUL, IMATMUL, DWCONV } mode = Mode::COMPAT;
468
469 static constexpr std::string_view MATMUL = "matmul";
470 static constexpr std::string_view IMATMUL = "imatmul";
471 static constexpr std::string_view DWCONV = "dwconv";
472
473
1/2
✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
36 std::vector<std::string_view> args(argv, argv + argc);
474
475
4/4
✓ Branch 0 taken 24 times.
✓ Branch 1 taken 12 times.
✓ Branch 2 taken 21 times.
✓ Branch 3 taken 3 times.
36 if (!list_tests && argc < 2) {
476
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
3 print_global_usage(argv[0]);
477 3 return EXIT_FAILURE;
478 }
479
480
4/4
✓ Branch 0 taken 30 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 21 times.
✓ Branch 3 taken 9 times.
33 if (argc >= 2 && args[1] == MATMUL) {
481 9 mode = Mode::MATMUL;
482 9 argv += 1;
483 9 argc -= 1;
484
4/4
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 9 times.
33 } else if (argc >= 2 && args[1] == IMATMUL) {
485 9 mode = Mode::IMATMUL;
486 9 argv += 1;
487 9 argc -= 1;
488
4/4
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 9 times.
24 } else if (argc >= 2 && args[1] == DWCONV) {
489 9 mode = Mode::DWCONV;
490 9 argv += 1;
491 9 argc -= 1;
492 9 }
493
494
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 21 times.
33 if (list_tests) {
495 12 std::string spec;
496
2/2
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 3 times.
12 if (mode == Mode::COMPAT) {
497
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 kai::benchmark::RegisterMatMulBenchmarks({1, 1, 1}, 32);
498
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 kai::benchmark::RegisteriMatMulBenchmarks(1, 1, 1, 1);
499
5/10
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
12 kai::benchmark::RegisterDwConvBenchmarks({3, 3, 1});
500 spec = user_filter_opt.value_or("");
501 } else if (mode == Mode::MATMUL) {
502 kai::benchmark::RegisterMatMulBenchmarks({1, 1, 1}, 32);
503 spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_matmul");
504 } else if (mode == Mode::IMATMUL) {
505 kai::benchmark::RegisteriMatMulBenchmarks(1, 1, 1, 1);
506 spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_imatmul");
507 } else if (mode == Mode::DWCONV) {
508 kai::benchmark::RegisterDwConvBenchmarks({3, 3, 1});
509 spec = user_filter_opt.has_value() ? *user_filter_opt : std::string("^kai_dwconv");
510 }
511 ::benchmark::RunSpecifiedBenchmarks(nullptr, nullptr, spec);
512 ::benchmark::Shutdown();
513 return 0;
514 }
515
516 switch (mode) {
517 case Mode::COMPAT:
518 return run_matmul(argc, argv, true, user_filter_opt);
519 case Mode::MATMUL:
520 return run_matmul(argc, argv, false, user_filter_opt);
521 case Mode::IMATMUL:
522 return run_imatmul(argc, argv, user_filter_opt);
523 case Mode::DWCONV:
524 return run_dwconv(argc, argv, user_filter_opt);
525 default:
526 print_global_usage(argv[0]);
527 return EXIT_FAILURE;
528 }
529 }
530