benchmark/imatmul/imatmul_benchmark_logic.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | #include <cstdint> | ||
| 11 | #include <functional> | ||
| 12 | #include <test/common/cpu_info.hpp> | ||
| 13 | #include <vector> | ||
| 14 | |||
| 15 | #include "imatmul_interface.hpp" | ||
| 16 | #include "imatmul_runner.hpp" | ||
| 17 | #include "kai/kai_common.h" | ||
| 18 | |||
| 19 | #ifdef __GNUC__ | ||
| 20 | #pragma GCC diagnostic push | ||
| 21 | #pragma GCC diagnostic ignored "-Wswitch-default" | ||
| 22 | #endif // __GNUC__ | ||
| 23 | |||
| 24 | #include <benchmark/benchmark.h> | ||
| 25 | |||
| 26 | #ifdef __GNUC__ | ||
| 27 | #pragma GCC diagnostic pop | ||
| 28 | #endif // __GNUC__ | ||
| 29 | |||
| 30 | namespace kai::benchmark { | ||
| 31 | using Buffer = std::vector<uint8_t>; | ||
| 32 | using CpuRequirement = std::function<bool()>; | ||
| 33 | |||
| 34 | /// Benchmarks an indirect matrix multiplication micro-kernel. | ||
| 35 | /// | ||
| 36 | /// @tparam ImatmulInterface Interface of the indirect matrix multiplication micro-kernel. | ||
| 37 | /// @param state State for the benchmark to use. | ||
| 38 | /// @param imatmul_interface Abstraction containing the micro-kernel to run. | ||
| 39 | /// @param dst_type Output type of the micro-kernel. | ||
| 40 | /// @param is_cpu_supported Function that checks the CPU feature requirement to run this benchmark. | ||
| 41 | template <typename ImatmulInterface> | ||
| 42 | 18 | void kai_benchmark_imatmul( | |
| 43 | ::benchmark::State& state, const ImatmulInterface imatmul_interface, const DataType dst_type, | ||
| 44 | const CpuRequirement& is_cpu_supported) { | ||
| 45 |
4/4✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
✓ Branch 3 taken 2 times.
|
18 | if (!is_cpu_supported()) { |
| 46 |
4/10✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 7 not taken.
|
12 | state.SkipWithMessage("Unsupported CPU feature"); |
| 47 | 12 | return; | |
| 48 | } | ||
| 49 | |||
| 50 | 6 | const size_t m = state.range(0); | |
| 51 | 6 | const size_t n = state.range(1); | |
| 52 | 6 | const size_t k_chunk_count = state.range(2); | |
| 53 | 6 | const size_t k_chunk_length = state.range(3); | |
| 54 | 6 | const size_t k = k_chunk_count * k_chunk_length; | |
| 55 | |||
| 56 | // Create sufficiently large buffers | ||
| 57 | 6 | size_t lhs_size = m * k * sizeof(uint64_t); | |
| 58 | 6 | size_t rhs_size = n * k * sizeof(uint64_t); | |
| 59 | 6 | size_t dst_size = m * n * sizeof(uint32_t); | |
| 60 | |||
| 61 |
2/8✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
6 | if (test::cpu_has_sme() || test::cpu_has_sme2()) { |
| 62 | 6 | lhs_size *= kai_get_sme_vector_length_u32(); | |
| 63 | 6 | rhs_size *= kai_get_sme_vector_length_u32(); | |
| 64 | 6 | dst_size *= kai_get_sme_vector_length_u32(); | |
| 65 | 6 | } | |
| 66 | |||
| 67 |
0/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
6 | const Buffer lhs(lhs_size); |
| 68 |
2/6✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 3 not taken.
|
6 | const Buffer rhs(rhs_size); |
| 69 |
2/6✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 3 not taken.
|
6 | Buffer dst(dst_size); |
| 70 | |||
| 71 |
2/4✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | ImatmulRunner imatmul_runner(imatmul_interface, dst_type); |
| 72 |
2/4✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | imatmul_runner.set_mnk_chunked(m, n, k_chunk_count, k_chunk_length); |
| 73 | |||
| 74 |
12/20✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✓ Branch 7 taken 4 times.
✓ Branch 8 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✓ Branch 19 taken 2 times.
|
12 | for (auto _ : state) { |
| 75 |
2/4✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | imatmul_runner.run(lhs.data(), rhs.data(), dst.data()); |
| 76 | 6 | } | |
| 77 | 18 | } | |
| 78 | |||
| 79 | } // namespace kai::benchmark | ||
| 80 |