test/common/matmul_test_common.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | #pragma once | ||
| 7 | |||
| 8 | #include <cstddef> | ||
| 9 | #include <cstdint> | ||
| 10 | #include <functional> | ||
| 11 | #include <iosfwd> | ||
| 12 | #include <string> | ||
| 13 | #include <string_view> | ||
| 14 | #include <tuple> | ||
| 15 | |||
| 16 | #include "kai/kai_common.h" | ||
| 17 | #include "test/common/buffer.hpp" | ||
| 18 | #include "test/common/data_format.hpp" | ||
| 19 | #include "test/common/float16.hpp" | ||
| 20 | #include "test/common/matrix_portion.hpp" | ||
| 21 | #include "test/common/seed.hpp" | ||
| 22 | |||
| 23 | namespace kai::test { | ||
| 24 | /// Matrix multiplication shape. | ||
| 25 | struct MatMulShape { | ||
| 26 | size_t m; ///< LHS height. | ||
| 27 | size_t n; ///< RHS width. | ||
| 28 | size_t k; ///< LHS width and RHS height. | ||
| 29 | |||
| 30 | struct Hash { | ||
| 31 | 59329 | size_t operator()(const MatMulShape& shape) const { | |
| 32 | 59329 | return // | |
| 33 | 118658 | (std::hash<size_t>{}(shape.m) << 0) ^ // | |
| 34 | 118658 | (std::hash<size_t>{}(shape.n) << 1) ^ // | |
| 35 | 59329 | (std::hash<size_t>{}(shape.k) << 2); // | |
| 36 | } | ||
| 37 | }; | ||
| 38 | |||
| 39 | private: | ||
| 40 | 50910 | friend bool operator==(const MatMulShape& lhs, const MatMulShape& rhs) { | |
| 41 | 50910 | return // | |
| 42 |
2/2✓ Branch 0 taken 44994 times.
✓ Branch 1 taken 5916 times.
|
50910 | lhs.m == rhs.m && // |
| 43 |
2/2✓ Branch 0 taken 1308 times.
✓ Branch 1 taken 43686 times.
|
44994 | lhs.n == rhs.n && // |
| 44 | 43686 | lhs.k == rhs.k; | |
| 45 | } | ||
| 46 | friend std::ostream& operator<<(std::ostream& os, const MatMulShape& shape); | ||
| 47 | }; | ||
| 48 | |||
| 49 | /// Value range | ||
| 50 | template <typename T> | ||
| 51 | struct Range { | ||
| 52 | T min; | ||
| 53 | T max; | ||
| 54 | |||
| 55 | [[nodiscard]] T range() const { | ||
| 56 | return max - min; | ||
| 57 | } | ||
| 58 | }; | ||
| 59 | |||
| 60 | // NOLINTBEGIN(misc-non-private-member-variables-in-classes) | ||
| 61 | |||
| 62 | /// Matrix multiplication method. | ||
| 63 |
34/68✓ Branch 0 taken 115458 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 115458 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 115458 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 115458 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 115458 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 115458 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 115458 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 115458 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 115458 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 115458 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 115458 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 115458 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 115458 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 115458 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 115458 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 115458 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 115458 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 115458 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 115458 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 115458 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 115458 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 115458 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 115458 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 115458 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 115458 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 115458 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 115458 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 115458 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 115458 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 115458 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 115458 times.
✗ Branch 61 not taken.
✓ Branch 62 taken 115458 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 115458 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 115458 times.
✗ Branch 67 not taken.
|
115458 | struct MatMulMethod { |
| 64 | std::string_view name{}; ///< Name of matmul method. | ||
| 65 | |||
| 66 | size_t m0{0}; ///< Block size in M dimension. | ||
| 67 | size_t n0{0}; ///< Block size in N dimension. | ||
| 68 | size_t k0{0}; ///< Block size in K dimension. | ||
| 69 | |||
| 70 | DataFormat dst_format{}; ///< Data format of the destination matrix. | ||
| 71 | DataFormat lhs_format{}; ///< Data format of the LHS matrix. | ||
| 72 | DataFormat packed_lhs_format{}; ///< Data format of the packed LHS matrix. | ||
| 73 | DataFormat rhs_format{}; ///< Data format of the RHS matrix. | ||
| 74 | DataFormat packed_rhs_format{}; ///< Data format of the packed RHS matrix. | ||
| 75 | DataFormat bias_format{}; ///< Data format of the bias vector. | ||
| 76 | bool nb_support{}; ///< Does the kernel support null_bias. | ||
| 77 | |||
| 78 | /// Generate LHS matrix. | ||
| 79 | /// | ||
| 80 | /// @param[in] m Number of rows in the LHS matrix. | ||
| 81 | /// @param[in] k Number of columns in the LHS matrix. | ||
| 82 | /// @param[in] seed_feed Seed feed for random number generation. | ||
| 83 | /// | ||
| 84 | /// @return LHS matrix data buffer. | ||
| 85 | std::function<Buffer(size_t, size_t, SeedFeed&)> fn_generate_lhs{nullptr}; | ||
| 86 | |||
| 87 | /// Generate RHS matrix. | ||
| 88 | /// | ||
| 89 | /// @param[in] k Number of rows in the RHS matrix. | ||
| 90 | /// @param[in] n Number of columns in the RHS matrix. | ||
| 91 | /// @param[in] seed_feed Seed feed for random number generation. | ||
| 92 | /// | ||
| 93 | /// @return RHS matrix data buffer. | ||
| 94 | std::function<Buffer(size_t, size_t, SeedFeed&)> fn_generate_rhs{nullptr}; | ||
| 95 | |||
| 96 | /// Generate bias. | ||
| 97 | /// | ||
| 98 | /// @param[in] n Number of rows in the bias. | ||
| 99 | /// @param[in] k Number of columns in the bias. | ||
| 100 | /// @param[in] seed_feed Seed feed for random number generation. | ||
| 101 | /// @param[in] null_bias_mode Whether to generate null bias (true) or real bias (false). | ||
| 102 | /// | ||
| 103 | /// @return Bias data buffer. | ||
| 104 | std::function<Buffer(size_t, size_t, SeedFeed&, bool)> fn_generate_bias{nullptr}; | ||
| 105 | |||
| 106 | /// Check if CPU supports required features. | ||
| 107 | /// | ||
| 108 | /// @return Supported (true) or not supported (false). | ||
| 109 | std::function<bool(void)> fn_is_supported{nullptr}; | ||
| 110 | |||
| 111 | /// Gets mr value. | ||
| 112 | /// | ||
| 113 | /// This is the packing parameter which must be used to pack the LHS matrix (if necessary). | ||
| 114 | /// | ||
| 115 | /// @return The mr value. | ||
| 116 | std::function<size_t(void)> fn_get_mr{nullptr}; | ||
| 117 | |||
| 118 | /// Gets nr value. | ||
| 119 | /// | ||
| 120 | /// This is the packing parameter which must be used to pack the RHS matrix (if necessary). | ||
| 121 | /// | ||
| 122 | /// @return The nr value. | ||
| 123 | std::function<size_t(void)> fn_get_nr{nullptr}; | ||
| 124 | |||
| 125 | /// Gets kr value. | ||
| 126 | /// | ||
| 127 | /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary). | ||
| 128 | /// | ||
| 129 | /// @return The kr value. | ||
| 130 | std::function<size_t(void)> fn_get_kr{nullptr}; | ||
| 131 | |||
| 132 | /// Gets sr value. | ||
| 133 | /// | ||
| 134 | /// This is the packing parameter which must be used to pack the RHS matrix. | ||
| 135 | /// | ||
| 136 | /// @return The sr value. | ||
| 137 | std::function<size_t(void)> fn_get_sr{nullptr}; | ||
| 138 | |||
| 139 | /// Gets m step value for main kernel. | ||
| 140 | /// | ||
| 141 | /// The starting row index must be divisible by `m_step`. | ||
| 142 | /// | ||
| 143 | /// @return The m step value. | ||
| 144 | std::function<size_t(void)> fn_get_main_m_step{nullptr}; | ||
| 145 | |||
| 146 | /// Gets n step value for RHS packing micro-kernel. | ||
| 147 | /// | ||
| 148 | /// The starting row index must be divisible by `n_step`. | ||
| 149 | /// | ||
| 150 | /// @return The n step value. | ||
| 151 | std::function<size_t(void)> fn_get_pack_rhs_n_step{nullptr}; | ||
| 152 | |||
| 153 | /// Gets n step value for main kernel. | ||
| 154 | /// | ||
| 155 | /// The starting column index must be divisible by `n_step`. | ||
| 156 | /// | ||
| 157 | /// @return The n step value. | ||
| 158 | std::function<size_t(void)> fn_get_main_n_step{nullptr}; | ||
| 159 | |||
| 160 | /// Gets the offset in bytes of the LHS matrix. | ||
| 161 | /// | ||
| 162 | /// @param[in] m_idx Coordinate of the matrix in M dimension. | ||
| 163 | /// @param[in] stride Row stride in bytes. | ||
| 164 | /// | ||
| 165 | /// @return The offset in bytes. | ||
| 166 | std::function<size_t(size_t m_idx, size_t stride)> fn_get_lhs_offset{nullptr}; | ||
| 167 | |||
| 168 | /// Gets the size in bytes of the packed LHS matrix. | ||
| 169 | /// | ||
| 170 | /// @param[in] m Number of rows in the unpacked LHS matrix. | ||
| 171 | /// @param[in] k Number of columns in the unpacked LHS matrix. | ||
| 172 | /// @param[in] mr Number of rows to be interleaved. | ||
| 173 | /// @param[in] kr Unused. Must be 1. | ||
| 174 | /// @param[in] sr Unused. Must be 1. | ||
| 175 | /// | ||
| 176 | /// @return The size in bytes. | ||
| 177 | std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)> fn_get_packed_lhs_size{nullptr}; | ||
| 178 | |||
| 179 | /// Gets the offset in bytes of the packed LHS matrix. | ||
| 180 | /// | ||
| 181 | /// @param[in] m_idx Coordinate of the matrix in M dimension. | ||
| 182 | /// @param[in] k Size of the matrix in K dimension. | ||
| 183 | /// | ||
| 184 | /// @return The offset in bytes. | ||
| 185 | std::function<size_t(size_t m_idx, size_t k)> fn_get_packed_lhs_offset{nullptr}; | ||
| 186 | |||
| 187 | /// Preprocesses the LHS matrix. | ||
| 188 | /// | ||
| 189 | /// @param[in] m Number of rows of the unpacked LHS matrix. | ||
| 190 | /// @param[in] k Common dimension between the LHS and RHS matrix. | ||
| 191 | /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL. | ||
| 192 | /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}. | ||
| 193 | /// @param[in] sr Number of kr splits. It must be 1. | ||
| 194 | /// @param[in] m_idx_start Unused. Must be 0. | ||
| 195 | /// @param[in] lhs LHS matrix data buffer. | ||
| 196 | /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. | ||
| 197 | /// @param[out] lhs_packed Packed RHS matrix. | ||
| 198 | std::function<void( | ||
| 199 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
| 200 | void* lhs_packed)> | ||
| 201 | fn_pack_lhs{nullptr}; | ||
| 202 | |||
| 203 | /// Gets a value indicating whether LHS packing is needed. | ||
| 204 | 4368 | [[nodiscard]] bool is_pack_lhs_needed() const { | |
| 205 | 4368 | return fn_pack_lhs != nullptr; | |
| 206 | } | ||
| 207 | |||
| 208 | /// Gets the offset in bytes of the RHS matrix. | ||
| 209 | /// | ||
| 210 | /// @param[in] n_idx Coordinate of the matrix in N dimension. | ||
| 211 | /// | ||
| 212 | /// @return The offset in bytes. | ||
| 213 | std::function<size_t(size_t n_idx)> fn_get_rhs_offset{nullptr}; | ||
| 214 | |||
| 215 | /// Gets the size in bytes of the packed RHS matrix. | ||
| 216 | /// | ||
| 217 | /// @param[in] n Size of the matrix in N dimension. | ||
| 218 | /// @param[in] k Size of the matrix in K dimension. | ||
| 219 | /// | ||
| 220 | /// @return The size in bytes. | ||
| 221 | std::function<size_t(size_t n, size_t k)> fn_get_packed_rhs_size{nullptr}; | ||
| 222 | |||
| 223 | /// Gets the size in bytes of the packed RHS matrix. | ||
| 224 | /// | ||
| 225 | /// @param[in] n Size of the matrix in N dimension. | ||
| 226 | /// @param[in] k Size of the matrix in K dimension. | ||
| 227 | /// @param[in] nr Block size in N dimension. | ||
| 228 | /// @param[in] kr Block size in K dimension. | ||
| 229 | /// | ||
| 230 | /// @return The size in bytes. | ||
| 231 | std::function<size_t(size_t n, size_t k, size_t nr, size_t kr)> fn_get_packed_rhs_size_generic_block_size = nullptr; | ||
| 232 | |||
| 233 | /// Gets the offset in bytes of the packed RHS matrix in the RHS packing micro-kernel | ||
| 234 | /// | ||
| 235 | /// @param[in] n_idx Coordinate of the matrix in N dimension. | ||
| 236 | /// @param[in] k Size of the matrix in K dimension. | ||
| 237 | /// | ||
| 238 | /// @return The offset in bytes. | ||
| 239 | std::function<size_t(size_t n_idx, size_t k)> fn_get_pack_rhs_packed_rhs_offset{nullptr}; | ||
| 240 | |||
| 241 | /// Gets the offset in bytes of the packed RHS matrix in the main kernel. | ||
| 242 | /// | ||
| 243 | /// @param[in] n_idx Coordinate of the matrix in N dimension. | ||
| 244 | /// @param[in] k Size of the matrix in K dimension. | ||
| 245 | /// | ||
| 246 | /// @return The offset in bytes. | ||
| 247 | std::function<size_t(size_t n_idx, size_t k)> fn_get_main_packed_rhs_offset{nullptr}; | ||
| 248 | |||
| 249 | std::function<void( | ||
| 250 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||
| 251 | const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)> | ||
| 252 | fn_pack_rhs{nullptr}; | ||
| 253 | |||
| 254 | /// Gets n step value. | ||
| 255 | /// | ||
| 256 | /// The starting row index must be divisible by `n_step`. | ||
| 257 | /// | ||
| 258 | /// @return The n step value. | ||
| 259 | std::function<size_t()> fn_pack_rhs_nxk_get_n_step{nullptr}; | ||
| 260 | |||
| 261 | /// Gets the offset in bytes to the data element in the RHS matrix buffer. | ||
| 262 | /// | ||
| 263 | /// @param[in] n_idx Column index. | ||
| 264 | /// @param[in] rhs_offset Row stride in bytes of the RHS matrix. | ||
| 265 | /// | ||
| 266 | /// @return The offset in bytes to the data element. | ||
| 267 | std::function<size_t(size_t n_idx, size_t rhs_stride)> fn_pack_rhs_nxk_get_rhs_offset{nullptr}; | ||
| 268 | |||
| 269 | /// Gets the offset in bytes to the data element in the bias buffer. | ||
| 270 | /// | ||
| 271 | /// @param[in] n_idx Column index. | ||
| 272 | /// | ||
| 273 | /// @return The offset in bytes to the data element. | ||
| 274 | std::function<size_t(size_t n_idx)> fn_pack_rhs_nxk_get_bias_offset{nullptr}; | ||
| 275 | |||
| 276 | /// Gets the offset in bytes to the data element in the packed RHS buffer. | ||
| 277 | /// | ||
| 278 | /// @param[in] n_idx Row index. | ||
| 279 | /// @param[in] k Number of columns. | ||
| 280 | /// | ||
| 281 | /// @return The offset in bytes to the data element. | ||
| 282 | std::function<size_t(size_t n_idx, size_t k)> fn_pack_rhs_nxk_get_packed_rhs_offset{nullptr}; | ||
| 283 | |||
| 284 | /// Gets the size in bytes of the packed RHS buffer. | ||
| 285 | /// | ||
| 286 | /// @param[in] n Number of rows. | ||
| 287 | /// @param[in] k Number of columns. | ||
| 288 | /// | ||
| 289 | /// @return The size in bytes of the packed RHS buffer. | ||
| 290 | std::function<size_t(size_t n, size_t k)> fn_pack_rhs_nxk_get_packed_rhs_size{nullptr}; | ||
| 291 | |||
| 292 | /// Runs the RHS packing micro-kernel for matrix multiplication. | ||
| 293 | /// | ||
| 294 | /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset | ||
| 295 | /// calculated using the following functions: | ||
| 296 | /// | ||
| 297 | /// * RHS: @ref kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme. | ||
| 298 | /// * Bias: @ref kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme. | ||
| 299 | /// * Output: @ref kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1b_f32_f32_sme. | ||
| 300 | /// | ||
| 301 | /// @param[in] num_groups Number of groups. It must be 1. | ||
| 302 | /// @param[in] n Number of columns of the output matrix. | ||
| 303 | /// @param[in] k Common dimension between the LHS and RHS matrix. | ||
| 304 | /// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u32(). | ||
| 305 | /// @param[in] kr Block size in K dimension. It must be 1. | ||
| 306 | /// @param[in] sr Number of kr splits. It must be 1. | ||
| 307 | /// @param[in] rhs_stride Row stride in bytes of the RHS matrix. | ||
| 308 | /// @param[in] rhs RHS matrix data buffer. | ||
| 309 | /// @param[in] bias Bias matrix data buffer. | ||
| 310 | /// @param[in] scale Scale data buffer. It must be NULL. | ||
| 311 | /// @param[out] rhs_packed Packed RHS matrix. | ||
| 312 | /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. | ||
| 313 | /// @param[in] params Extra packing parameters. It must be NULL. | ||
| 314 | std::function<void( | ||
| 315 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||
| 316 | const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)> | ||
| 317 | fn_pack_rhs_nxk{nullptr}; | ||
| 318 | |||
| 319 | /// Gets the offset in bytes to the data element in the bias buffer. | ||
| 320 | /// | ||
| 321 | /// @param[in] n_idx Column index. | ||
| 322 | /// | ||
| 323 | /// @return The offset in bytes to the data element. | ||
| 324 | std::function<size_t(size_t n_idx)> fn_get_bias_offset{nullptr}; | ||
| 325 | |||
| 326 | /// Gets the offset in bytes to the data element in the destination matrix buffer. | ||
| 327 | /// | ||
| 328 | /// @param[in] m_idx Row index. | ||
| 329 | /// @param[in] n_idx Column index. | ||
| 330 | /// @param[in] stride Row stride in bytes. | ||
| 331 | /// | ||
| 332 | /// @return The offset in bytes to the data element. | ||
| 333 | std::function<size_t(size_t m_idx, size_t n_idx, size_t stride)> fn_get_dst_offset{nullptr}; | ||
| 334 | |||
| 335 | /// Gets the size in bytes of the destination matrix buffer. | ||
| 336 | /// | ||
| 337 | /// @param[in] m Number of rows. | ||
| 338 | /// @param[in] n Number of columns. | ||
| 339 | /// | ||
| 340 | /// @return The size in bytes of the destination matrix buffer. | ||
| 341 | std::function<size_t(size_t m, size_t n)> fn_get_dst_size{nullptr}; | ||
| 342 | |||
| 343 | /// Performs F16 or F32 matrix multiplication with RHS packing | ||
| 344 | /// followed by clamp operation. | ||
| 345 | /// | ||
| 346 | /// @param[in] m Size of the matrix in M dimension. | ||
| 347 | /// @param[in] n Size of the matrix in N dimension. | ||
| 348 | /// @param[in] k Size of the matrix in K dimension. | ||
| 349 | /// @param[in] lhs LHS data buffer. | ||
| 350 | /// @param[in] packed_rhs Packed RHS data buffer. | ||
| 351 | /// @param[out] dst Output data buffer. | ||
| 352 | /// @param[in] lhs_stride LHS row stride. | ||
| 353 | /// @param[in] dst_stride_row Output row stride. | ||
| 354 | /// @param[in] dst_stride_col Output column stride. | ||
| 355 | /// @param[in] clamp_min Lower bound of the output data. | ||
| 356 | /// @param[in] clamp_max Upper bound of the output data. | ||
| 357 | std::function<void( | ||
| 358 | size_t m, size_t n, size_t k, // | ||
| 359 | const void* lhs, size_t lhs_stride, // | ||
| 360 | const void* packed_rhs, // | ||
| 361 | void* dst, size_t dst_stride_row, size_t dst_stride_col, // | ||
| 362 | float clamp_min, float clamp_max)> | ||
| 363 | fn_matmul_f16_f16_f16p = nullptr; | ||
| 364 | |||
| 365 | std::function<void( | ||
| 366 | size_t m, size_t n, size_t k, // | ||
| 367 | const void* lhs, size_t lhs_stride, // | ||
| 368 | const void* packed_rhs, // | ||
| 369 | void* dst, size_t dst_stride_row, size_t dst_stride_col, // | ||
| 370 | float clamp_min, float clamp_max)> | ||
| 371 | fn_matmul_f32_f32_f32p = nullptr; | ||
| 372 | |||
| 373 | /// Performs BF16 matrix multiplication with LHS and RHS packing | ||
| 374 | /// followed by clamp operation. | ||
| 375 | /// | ||
| 376 | /// @param[in] m Size of the matrix in M dimension. | ||
| 377 | /// @param[in] n Size of the matrix in N dimension. | ||
| 378 | /// @param[in] k Size of the matrix in K dimension. | ||
| 379 | /// @param[in] packed_lhs Packed LHS data buffer. | ||
| 380 | /// @param[in] packed_rhs Packed RHS data buffer. | ||
| 381 | /// @param[out] dst Output data buffer. | ||
| 382 | /// @param[in] dst_stride_row Output row stride. | ||
| 383 | /// @param[in] dst_stride_col Output column stride. | ||
| 384 | /// @param[in] clamp_min Lower bound of the output data. | ||
| 385 | /// @param[in] clamp_max Upper bound of the output data. | ||
| 386 | std::function<void( | ||
| 387 | size_t m, size_t n, size_t k, // | ||
| 388 | const void* packed_lhs, // | ||
| 389 | const void* packed_rhs, // | ||
| 390 | void* dst, size_t dst_stride_row, size_t dst_stride_col, // | ||
| 391 | float clamp_min, float clamp_max)> | ||
| 392 | fn_matmul_f32_bf16p_bf16p = nullptr; | ||
| 393 | |||
| 394 | std::function<void( | ||
| 395 | size_t m, size_t n, size_t k, // | ||
| 396 | const void* packed_lhs, // | ||
| 397 | const void* packed_rhs, // | ||
| 398 | void* dst, size_t dst_stride_row, size_t dst_stride_col, // | ||
| 399 | float clamp_min, float clamp_max)> | ||
| 400 | fn_matmul_f16_bf16p_bf16p = nullptr; | ||
| 401 | |||
| 402 | /// Performs F16 or F32 matrix multiplication with LHS & RHS packing | ||
| 403 | /// followed by clamp operation. | ||
| 404 | /// | ||
| 405 | /// @param[in] m Number of output rows to be computed. | ||
| 406 | /// @param[in] n Number of output columns to be computed. | ||
| 407 | /// @param[in] k Common dimension of the LHS and RHS operands. | ||
| 408 | /// @param[in] packed_lhs Packed LHS matrix buffer. | ||
| 409 | /// @param[in] packed_rhs Packed RHS matrix buffer. | ||
| 410 | /// @param[out] dst Output matrix buffer. | ||
| 411 | /// @param[in] dst_stride_row Row stride in bytes of the output matrix. | ||
| 412 | /// @param[in] dst_stride_col Column stride in bytes of the output matrix. | ||
| 413 | /// @param[in] clamp_min Minimum value to clamp the final result. | ||
| 414 | /// @param[in] clamp_max Maximum value to clamp the final result. | ||
| 415 | std::function<void( | ||
| 416 | size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, | ||
| 417 | size_t dst_stride_col, float clamp_min, float clamp_max)> | ||
| 418 | fn_matmul_f16_f16p_f16p = nullptr; | ||
| 419 | |||
| 420 | std::function<void( | ||
| 421 | size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, | ||
| 422 | size_t dst_stride_col, float clamp_min, float clamp_max)> | ||
| 423 | fn_matmul_f32_f32p_f32p = nullptr; | ||
| 424 | |||
| 425 | /// Gets a value indicating whether pre-processing the RHS matrix is needed. | ||
| 426 | 4368 | [[nodiscard]] bool is_pack_rhs_needed() const { | |
| 427 | 4368 | return fn_pack_rhs != nullptr; | |
| 428 | } | ||
| 429 | |||
| 430 | /// Gets a value indicating whether pre-processing the transposed RHS matrix is needed. | ||
| 431 | 2184 | [[nodiscard]] bool is_pack_rhs_nxk_needed() const { | |
| 432 | 2184 | return fn_pack_rhs_nxk != nullptr; | |
| 433 | } | ||
| 434 | |||
| 435 | /// Preprocesses the RHS matrix. | ||
| 436 | /// | ||
| 437 | /// @param[in] n Size of the matrix in N dimension. | ||
| 438 | /// @param[in] k Size of the matrix in K dimension. | ||
| 439 | /// @param[in] rhs RHS data buffer. | ||
| 440 | /// @param[in] rhs_row_stride RHS row stride. | ||
| 441 | /// @param[in] bias Bias data buffer. | ||
| 442 | /// @param[in] scale Quantization scales data buffer. | ||
| 443 | /// @param[out] packed_rhs Packed RHS data buffer. | ||
| 444 | 2706 | void pack_rhs( | |
| 445 | size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, | ||
| 446 | void* packed_rhs) const { | ||
| 447 | KAI_UNUSED(scale); | ||
| 448 | |||
| 449 |
1/2✓ Branch 0 taken 2706 times.
✗ Branch 1 not taken.
|
2706 | if (fn_pack_rhs != nullptr) { |
| 450 | 5412 | fn_pack_rhs( | |
| 451 | 2706 | 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, | |
| 452 | nullptr); | ||
| 453 | 2706 | } else { | |
| 454 | − | KAI_ERROR("RHS pre-processing is not supported!"); | |
| 455 | } | ||
| 456 | 2706 | } | |
| 457 | |||
| 458 | /// Preprocesses the transposed RHS matrix. | ||
| 459 | /// | ||
| 460 | /// @param[in] n Size of the matrix in N dimension. | ||
| 461 | /// @param[in] k Size of the matrix in K dimension. | ||
| 462 | /// @param[in] rhs RHS data buffer. | ||
| 463 | /// @param[in] rhs_row_stride RHS row stride. | ||
| 464 | /// @param[in] bias Bias data buffer. | ||
| 465 | /// @param[in] scale Quantization scales data buffer. | ||
| 466 | /// @param[out] packed_rhs Packed RHS data buffer. | ||
| 467 | 216 | void pack_rhs_nxk( | |
| 468 | size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale, | ||
| 469 | void* packed_rhs) const { | ||
| 470 | KAI_UNUSED(scale); | ||
| 471 | |||
| 472 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | if (fn_pack_rhs_nxk != nullptr) { |
| 473 | 432 | fn_pack_rhs_nxk( | |
| 474 | 216 | 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0, | |
| 475 | nullptr); | ||
| 476 | 216 | } else { | |
| 477 | − | KAI_ERROR("RHS pre-processing is not supported!"); | |
| 478 | } | ||
| 479 | 216 | } | |
| 480 | |||
| 481 | 4650 | [[nodiscard]] bool has_main_kernel() const { | |
| 482 |
2/2✓ Branch 0 taken 3798 times.
✓ Branch 1 taken 852 times.
|
8448 | return fn_matmul_f16_f16_f16p != nullptr || // |
| 483 |
2/2✓ Branch 0 taken 3690 times.
✓ Branch 1 taken 108 times.
|
3798 | fn_matmul_f16_f16p_f16p != nullptr || // |
| 484 |
2/2✓ Branch 0 taken 3582 times.
✓ Branch 1 taken 108 times.
|
3690 | fn_matmul_f32_f32p_f32p != nullptr || // |
| 485 |
2/2✓ Branch 0 taken 2466 times.
✓ Branch 1 taken 1116 times.
|
3582 | fn_matmul_f32_f32_f32p != nullptr || // |
| 486 |
2/2✓ Branch 0 taken 1926 times.
✓ Branch 1 taken 540 times.
|
2466 | fn_matmul_f32_bf16p_bf16p != nullptr || // |
| 487 | 540 | fn_matmul_f16_bf16p_bf16p != nullptr; | |
| 488 | } | ||
| 489 | |||
| 490 | 4650 | void main_kernel( | |
| 491 | size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride, | ||
| 492 | size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const { | ||
| 493 | KAI_UNUSED(bias); | ||
| 494 | KAI_UNUSED(rhs_stride); | ||
| 495 | |||
| 496 |
2/2✓ Branch 0 taken 852 times.
✓ Branch 1 taken 3798 times.
|
4650 | if (fn_matmul_f16_f16_f16p) { |
| 497 | 1704 | fn_matmul_f16_f16_f16p( | |
| 498 | 852 | m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(uint16_t), clamp_min, clamp_max); | |
| 499 |
2/2✓ Branch 0 taken 1116 times.
✓ Branch 1 taken 2682 times.
|
4650 | } else if (fn_matmul_f32_f32_f32p) { |
| 500 | 1116 | fn_matmul_f32_f32_f32p(m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); | |
| 501 |
2/2✓ Branch 0 taken 108 times.
✓ Branch 1 taken 2574 times.
|
3798 | } else if (fn_matmul_f16_f16p_f16p) { |
| 502 | 108 | fn_matmul_f16_f16p_f16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(Float16), clamp_min, clamp_max); | |
| 503 |
2/2✓ Branch 0 taken 108 times.
✓ Branch 1 taken 2466 times.
|
2682 | } else if (fn_matmul_f32_f32p_f32p) { |
| 504 | 108 | fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); | |
| 505 |
2/2✓ Branch 0 taken 1926 times.
✓ Branch 1 taken 540 times.
|
2574 | } else if (fn_matmul_f32_bf16p_bf16p) { |
| 506 | 3852 | fn_matmul_f32_bf16p_bf16p( | |
| 507 | 1926 | m, n, k, reinterpret_cast<const uint16_t*>(lhs), rhs, reinterpret_cast<float*>(dst), dst_stride, | |
| 508 | 1926 | sizeof(float), clamp_min, clamp_max); | |
| 509 |
1/2✓ Branch 0 taken 540 times.
✗ Branch 1 not taken.
|
2466 | } else if (fn_matmul_f16_bf16p_bf16p) { |
| 510 | 540 | fn_matmul_f16_bf16p_bf16p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(uint16_t), clamp_min, clamp_max); | |
| 511 | 540 | } else { | |
| 512 | − | KAI_ERROR("Main kernel is not available!"); | |
| 513 | } | ||
| 514 | 4650 | } | |
| 515 | }; | ||
| 516 | |||
| 517 | // NOLINTEND(misc-non-private-member-variables-in-classes) | ||
| 518 | |||
| 519 | /// Describes bias handling | ||
| 520 | enum class BiasMode { | ||
| 521 | INTERNAL, // Zero bias internally generated in kernel | ||
| 522 | PROVIDED, // Bias provided by kernel caller | ||
| 523 | }; | ||
| 524 | |||
| 525 | /// Matrix multiplication test information. | ||
| 526 | using MatMulClampTestParams = std::tuple<MatMulMethod, MatMulShape, MatrixPortion, BiasMode, float>; | ||
| 527 | using MatMulClampTestPortionedParams = std::tuple<size_t, MatMulShape, MatrixPortion, float>; | ||
| 528 | using MatMulClampTestPortionedParamsWithBias = std::tuple<size_t, MatMulShape, MatrixPortion, float, bool>; | ||
| 529 | using MatMulClampTestPortionedParamsWithBias_WithBL = | ||
| 530 | std::tuple<size_t, MatMulShape, size_t, MatrixPortion, float, bool>; | ||
| 531 | |||
| 532 | /// Prints the test information. | ||
| 533 | void PrintTo(const MatMulClampTestParams& param, std::ostream* os); | ||
| 534 | void PrintTo(const MatMulShape& shape, std::ostream* os); | ||
| 535 | void PrintTo(const MatrixPortion& portion, std::ostream* os); | ||
| 536 | void PrintTo(const BiasMode& bias_mode, std::ostream* os); | ||
| 537 | |||
| 538 | /// Generate test information | ||
| 539 | std::string test_description( | ||
| 540 | const std::string_view& name, const MatMulShape& shape, const MatrixPortion& portion, bool bias, | ||
| 541 | float clamp_keep_ratio); | ||
| 542 | |||
| 543 | } // namespace kai::test | ||
| 544 | |||
| 545 | template <> | ||
| 546 | struct std::hash<kai::test::MatMulShape> { | ||
| 547 | 27295 | size_t operator()(const kai::test::MatMulShape& ms) const { | |
| 548 | 27295 | return kai::test::MatMulShape::Hash{}(ms); | |
| 549 | } | ||
| 550 | }; | ||
| 551 |