benchmark/imatmul/imatmul_runner.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <cfloat> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <test/common/data_type.hpp> | ||
| 13 | |||
| 14 | #include "imatmul_interface.hpp" | ||
| 15 | #include "kai/kai_common.h" | ||
| 16 | |||
| 17 | namespace kai::benchmark { | ||
| 18 | |||
| 19 | using DataType = test::DataType; | ||
| 20 | |||
| 21 | /// Runner for the indirect matrix multiplication micro-kernel (imatmul). | ||
| 22 | /// | ||
| 23 | /// Prepares and executes the run method of the imatmul micro-kernel. | ||
| 24 | /// | ||
| 25 | /// @tparam IndirectMatMulInterface Interface of the indirect matrix multiplication micro-kernel. | ||
| 26 | template <typename IndirectMatMulInterface> | ||
| 27 | class ImatmulRunner { | ||
| 28 | public: | ||
| 29 | /// Constructs an ImatmulRunner object. | ||
| 30 | /// | ||
| 31 | /// @param imatmul_interface Abstraction containing the micro-kernel to run. | ||
| 32 | /// @param dst_type Output type of the micro-kernel. Required for the micro-kernel to make certain assumptions | ||
| 33 | /// internally about the stride of the data. | ||
| 34 | 12 | ImatmulRunner(const IndirectMatMulInterface& imatmul_interface, const DataType dst_type) : | |
| 35 | 12 | m_imatmul_interface(imatmul_interface), m_dst_type(dst_type) { | |
| 36 | 6 | set_mnk_chunked(m_m, m_n, m_k_chunk_count, m_k_chunk_length); | |
| 37 | 12 | } | |
| 38 | |||
| 39 | /// Sets the M, N and chunked K dimensions for imatmul micro-kernels. | ||
| 40 | /// | ||
| 41 | /// @param m Number of rows in the LHS and DST matrices. | ||
| 42 | /// @param n Number of columns in the RHS and DST matrices. | ||
| 43 | /// @param k_chunk_count Number of K chunks (for chunked K dimension). | ||
| 44 | /// @param k_chunk_length Length of each K chunk. | ||
| 45 | 12 | void set_mnk_chunked(size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length) { | |
| 46 | 12 | m_m = m; | |
| 47 | 12 | m_n = n; | |
| 48 | 12 | m_k_chunk_count = k_chunk_count; | |
| 49 | 12 | m_k_chunk_length = k_chunk_length; | |
| 50 | 12 | m_dst_stride_row = m_n * data_type_size_in_bits(m_dst_type) / 8; | |
| 51 | 12 | m_dst_stride_col = data_type_size_in_bits(m_dst_type) / 8; | |
| 52 | 12 | } | |
| 53 | |||
| 54 | /// Runs the indirect matrix multiplication micro-kernel. | ||
| 55 | /// | ||
| 56 | /// @param lhs Buffer containing LHS matrix data. | ||
| 57 | /// @param rhs Buffer containing RHS matrix data. | ||
| 58 | /// @param dst Destination buffer to write to. | ||
| 59 | void run(const void* lhs, const void* rhs, void* dst); | ||
| 60 | |||
| 61 | private: | ||
| 62 | IndirectMatMulInterface m_imatmul_interface = {}; | ||
| 63 | DataType m_dst_type = DataType::FP32; | ||
| 64 | 6 | size_t m_m = 1; | |
| 65 | 6 | size_t m_n = 1; | |
| 66 | 6 | size_t m_k_chunk_count = 1; | |
| 67 | 6 | size_t m_k_chunk_length = 1; | |
| 68 | 6 | size_t m_dst_stride_row = 1; | |
| 69 | 6 | size_t m_dst_stride_col = 1; | |
| 70 | }; | ||
| 71 | |||
| 72 | /// Default run method for imatmul micro-kernels (ImatmulBaseInterface) | ||
| 73 | template <typename IndirectMatMulInterface> | ||
| 74 | 4 | void ImatmulRunner<IndirectMatMulInterface>::run(const void* lhs, const void* rhs, void* dst) { | |
| 75 | 8 | m_imatmul_interface.run_imatmul( | |
| 76 | 4 | m_m, m_n, m_k_chunk_count, m_k_chunk_length, lhs, rhs, dst, m_dst_stride_row, -FLT_MAX, FLT_MAX); | |
| 77 | 4 | } | |
| 78 | |||
| 79 | /// Specialized run method for static quantization interface (ImatmulStaticQuantInterface) | ||
| 80 | template <> | ||
| 81 | 2 | inline void ImatmulRunner<ImatmulStaticQuantInterface>::run(const void* lhs, const void* rhs, void* dst) { | |
| 82 | 2 | constexpr kai_matmul_requantize32_params params = {INT8_MIN, INT8_MAX, 0}; | |
| 83 | 4 | m_imatmul_interface.run_imatmul( | |
| 84 | 2 | m_m, m_n, m_k_chunk_count, m_k_chunk_length, lhs, rhs, dst, m_dst_stride_row, ¶ms); | |
| 85 | 2 | } | |
| 86 | |||
| 87 | } // namespace kai::benchmark | ||
| 88 |