test/tests/matmul_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/reference/matmul.hpp" | ||
| 8 | |||
| 9 | #include <gtest/gtest.h> | ||
| 10 | |||
| 11 | #include <array> | ||
| 12 | #include <cstddef> | ||
| 13 | #include <cstdint> | ||
| 14 | #include <functional> | ||
| 15 | #include <map> | ||
| 16 | #include <string_view> | ||
| 17 | #include <tuple> | ||
| 18 | #include <utility> | ||
| 19 | |||
| 20 | #include "kai/kai_common.h" | ||
| 21 | #include "test/common/abi_checker.hpp" | ||
| 22 | #include "test/common/buffer.hpp" | ||
| 23 | #include "test/common/compare.hpp" | ||
| 24 | #include "test/common/cpu_info.hpp" | ||
| 25 | #include "test/common/data_format.hpp" | ||
| 26 | #include "test/common/data_type.hpp" | ||
| 27 | #include "test/common/matmul_test_common.hpp" | ||
| 28 | #include "test/common/matrix_portion.hpp" | ||
| 29 | #include "test/common/seed.hpp" | ||
| 30 | #include "test/common/sme.hpp" | ||
| 31 | #include "test/common/sve.hpp" | ||
| 32 | #include "test/reference/clamp.hpp" | ||
| 33 | #include "test/reference/fill.hpp" | ||
| 34 | #include "test/reference/generators.hpp" | ||
| 35 | #include "test/reference/pack.hpp" | ||
| 36 | #include "test/reference/transpose.hpp" | ||
| 37 | |||
| 38 | // matmul_clamp_f16_f16_f16p | ||
| 39 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" | ||
| 40 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h" | ||
| 41 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla.h" | ||
| 42 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla.h" | ||
| 43 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55.h" | ||
| 44 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" | ||
| 45 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p32x1b_x16_x16_neon.h" | ||
| 46 | |||
| 47 | // matmul_clamp_f16_f16p_f16p | ||
| 48 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" | ||
| 49 | #include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h" | ||
| 50 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h" | ||
| 51 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h" | ||
| 52 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h" | ||
| 53 | |||
| 54 | // matmul_clamp_f32_f32_f32p | ||
| 55 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" | ||
| 56 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla.h" | ||
| 57 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55.h" | ||
| 58 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" | ||
| 59 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h" | ||
| 60 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla.h" | ||
| 61 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" | ||
| 62 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme.h" | ||
| 63 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" | ||
| 64 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x32p16x1b_x32_x32_neon.h" | ||
| 65 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve.h" | ||
| 66 | |||
| 67 | // matmul_clamp_f32_f32p_f32p | ||
| 68 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" | ||
| 69 | #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" | ||
| 70 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" | ||
| 71 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" | ||
| 72 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" | ||
| 73 | |||
| 74 | namespace kai::test { | ||
| 75 | |||
| 76 | 12 | static const auto& get_matmul_methods() { | |
| 77 | // List of supported matrix multiplication methods. | ||
| 78 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
12 | static std::array<MatMulMethod, 7> matmul_methods{}; |
| 79 | |||
| 80 | matmul_methods[0].name = "matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla"; | ||
| 81 | matmul_methods[0].m0 = 6; | ||
| 82 | matmul_methods[0].n0 = 16; | ||
| 83 | matmul_methods[0].dst_format = DataFormat(DataType::FP16); | ||
| 84 | matmul_methods[0].lhs_format = DataFormat(DataType::FP16); | ||
| 85 | matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 86 | matmul_methods[0].rhs_format = DataFormat(DataType::FP16); | ||
| 87 | matmul_methods[0].packed_rhs_format = DataFormat( | ||
| 88 | DataType::FP16, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 16, 1); | ||
| 89 | matmul_methods[0].bias_format = DataFormat(DataType::FP16); | ||
| 90 | 36 | matmul_methods[0].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 91 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return NormalRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 92 | ✗ | }; | |
| 93 | 36 | matmul_methods[0].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 94 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return NormalRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 95 | ✗ | }; | |
| 96 | 36 | matmul_methods[0].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 97 | KAI_UNUSED(null_bias_mode); | ||
| 98 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return NormalRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 99 | ✗ | }; | |
| 100 | matmul_methods[0].fn_is_supported = cpu_has_fp16; | ||
| 101 | matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 102 | matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 103 | matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 104 | matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 105 | matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
| 106 | matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 107 | matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 108 | matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
| 109 | matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
| 110 | matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = | ||
| 111 | kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
| 112 | matmul_methods[0].fn_get_main_packed_rhs_offset = | ||
| 113 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 114 | matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
| 115 | matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon; | ||
| 116 | matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 117 | matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 118 | matmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla; | ||
| 119 | |||
| 120 | matmul_methods[1].name = "matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa"; | ||
| 121 | matmul_methods[1].m0 = 2 * get_sme_vector_length<float>(); | ||
| 122 | matmul_methods[1].n0 = 2 * get_sme_vector_length<float>(); | ||
| 123 | matmul_methods[1].dst_format = DataFormat(DataType::FP16); | ||
| 124 | matmul_methods[1].lhs_format = DataFormat(DataType::FP16); | ||
| 125 | matmul_methods[1].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length<float>(), 2); | ||
| 126 | matmul_methods[1].rhs_format = DataFormat(DataType::FP16); | ||
| 127 | matmul_methods[1].packed_rhs_format = DataFormat( | ||
| 128 | DataType::FP16, // Output type | ||
| 129 | 2 * get_sme_vector_length<float>(), 2, // Block size | ||
| 130 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
| 131 | DataType::FP16, // Bias format | ||
| 132 | DataType::UNKNOWN, // Scaling type | ||
| 133 | 2 * get_sme_vector_length<float>(), 2); // Sub-block | ||
| 134 | matmul_methods[1].bias_format = DataFormat(DataType::FP16); | ||
| 135 | 18 | matmul_methods[1].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 136 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 137 | ✗ | }; | |
| 138 | 18 | matmul_methods[1].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 139 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 140 | ✗ | }; | |
| 141 | 18 | matmul_methods[1].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 142 | KAI_UNUSED(null_bias_mode); | ||
| 143 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 144 | ✗ | }; | |
| 145 | matmul_methods[1].fn_is_supported = cpu_has_sme2; | ||
| 146 | matmul_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 147 | matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 148 | matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 149 | matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 150 | matmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 151 | matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 152 | matmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 153 | matmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; | ||
| 154 | matmul_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme; | ||
| 155 | matmul_methods[1].fn_get_packed_lhs_offset = | ||
| 156 | kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 157 | matmul_methods[1].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme; | ||
| 158 | matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 159 | matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 160 | matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 161 | matmul_methods[1].fn_get_main_packed_rhs_offset = | ||
| 162 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 163 | matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 164 | matmul_methods[1].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 165 | matmul_methods[1].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 166 | matmul_methods[1].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 167 | matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_offset = | ||
| 168 | kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 169 | matmul_methods[1].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 170 | matmul_methods[1].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 171 | matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 172 | matmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 173 | matmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 174 | matmul_methods[1].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 175 | |||
| 176 | matmul_methods[2].name = "matmul_nt_nt_fp32_fp32_fp32_6x8_neon_mla"; | ||
| 177 | matmul_methods[2].m0 = 6; | ||
| 178 | matmul_methods[2].n0 = 8; | ||
| 179 | matmul_methods[2].dst_format = DataFormat(DataType::FP32); | ||
| 180 | matmul_methods[2].lhs_format = DataFormat(DataType::FP32); | ||
| 181 | matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 182 | matmul_methods[2].rhs_format = DataFormat(DataType::FP32); | ||
| 183 | matmul_methods[2].packed_rhs_format = | ||
| 184 | DataFormat(DataType::FP32, 8, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 8, 1); | ||
| 185 | matmul_methods[2].bias_format = DataFormat(DataType::FP32); | ||
| 186 | 36 | matmul_methods[2].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 187 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 188 | ✗ | }; | |
| 189 | 36 | matmul_methods[2].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 190 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 191 | ✗ | }; | |
| 192 | 36 | matmul_methods[2].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 193 | KAI_UNUSED(null_bias_mode); | ||
| 194 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 195 | ✗ | }; | |
| 196 | matmul_methods[2].fn_is_supported = cpu_has_advsimd; | ||
| 197 | matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 198 | matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 199 | matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 200 | matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 201 | matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
| 202 | matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 203 | matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 204 | matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
| 205 | matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
| 206 | matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = | ||
| 207 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
| 208 | matmul_methods[2].fn_get_main_packed_rhs_offset = | ||
| 209 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 210 | matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
| 211 | matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon; | ||
| 212 | matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 213 | matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 214 | matmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla; | ||
| 215 | |||
| 216 | matmul_methods[3].name = "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"; | ||
| 217 | matmul_methods[3].m0 = 2 * get_sme_vector_length<float>(); | ||
| 218 | matmul_methods[3].n0 = 2 * get_sme_vector_length<float>(); | ||
| 219 | matmul_methods[3].dst_format = DataFormat(DataType::FP32); | ||
| 220 | matmul_methods[3].lhs_format = DataFormat(DataType::FP32); | ||
| 221 | matmul_methods[3].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length<float>(), 1); | ||
| 222 | matmul_methods[3].rhs_format = DataFormat(DataType::FP32); | ||
| 223 | matmul_methods[3].packed_rhs_format = DataFormat( | ||
| 224 | DataType::FP32, 2 * get_sme_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, | ||
| 225 | DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 1); | ||
| 226 | matmul_methods[3].bias_format = DataFormat(DataType::FP32); | ||
| 227 | 18 | matmul_methods[3].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 228 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 229 | ✗ | }; | |
| 230 | 18 | matmul_methods[3].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 231 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 232 | ✗ | }; | |
| 233 | 18 | matmul_methods[3].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 234 | KAI_UNUSED(null_bias_mode); | ||
| 235 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 236 | ✗ | }; | |
| 237 | matmul_methods[3].fn_is_supported = cpu_has_sme2; | ||
| 238 | matmul_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 239 | matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 240 | matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 241 | matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 242 | matmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 243 | matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 244 | matmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 245 | matmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme; | ||
| 246 | matmul_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme; | ||
| 247 | matmul_methods[3].fn_get_packed_lhs_offset = | ||
| 248 | kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 249 | matmul_methods[3].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme; | ||
| 250 | matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 251 | matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 252 | matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset = | ||
| 253 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 254 | matmul_methods[3].fn_get_main_packed_rhs_offset = | ||
| 255 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 256 | matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 257 | matmul_methods[3].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 258 | matmul_methods[3].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 259 | matmul_methods[3].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 260 | matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_offset = | ||
| 261 | kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 262 | matmul_methods[3].fn_pack_rhs_nxk_get_packed_rhs_size = | ||
| 263 | kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 264 | matmul_methods[3].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 265 | matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 266 | matmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 267 | matmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 268 | matmul_methods[3].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa; | ||
| 269 | |||
| 270 | matmul_methods[4].name = "matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa"; | ||
| 271 | matmul_methods[4].m0 = 2 * get_sme_vector_length<float>(); | ||
| 272 | matmul_methods[4].n0 = 2 * get_sme_vector_length<float>(); | ||
| 273 | matmul_methods[4].dst_format = DataFormat(DataType::FP32); | ||
| 274 | matmul_methods[4].lhs_format = DataFormat(DataType::FP32); | ||
| 275 | matmul_methods[4].packed_lhs_format = DataFormat(DataType::FP32, 2 * get_sme_vector_length<float>(), 1); | ||
| 276 | matmul_methods[4].rhs_format = DataFormat(DataType::FP32); | ||
| 277 | matmul_methods[4].packed_rhs_format = DataFormat( | ||
| 278 | DataType::FP32, 2 * get_sme_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, | ||
| 279 | DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 1); | ||
| 280 | matmul_methods[4].bias_format = DataFormat(DataType::FP32); | ||
| 281 | 18 | matmul_methods[4].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 282 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 283 | ✗ | }; | |
| 284 | 18 | matmul_methods[4].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 285 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 286 | ✗ | }; | |
| 287 | 18 | matmul_methods[4].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 288 | KAI_UNUSED(null_bias_mode); | ||
| 289 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 290 | ✗ | }; | |
| 291 | matmul_methods[4].fn_is_supported = cpu_has_sme; | ||
| 292 | matmul_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 293 | matmul_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 294 | matmul_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 295 | matmul_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 296 | matmul_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 297 | matmul_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 298 | matmul_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 299 | matmul_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme; | ||
| 300 | matmul_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme; | ||
| 301 | matmul_methods[4].fn_get_packed_lhs_offset = | ||
| 302 | kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 303 | matmul_methods[4].fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme; | ||
| 304 | matmul_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 305 | matmul_methods[4].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 306 | matmul_methods[4].fn_get_pack_rhs_packed_rhs_offset = | ||
| 307 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 308 | matmul_methods[4].fn_get_main_packed_rhs_offset = | ||
| 309 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 310 | matmul_methods[4].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 311 | matmul_methods[4].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 312 | matmul_methods[4].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 313 | matmul_methods[4].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 314 | matmul_methods[4].fn_pack_rhs_nxk_get_packed_rhs_offset = | ||
| 315 | kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 316 | matmul_methods[4].fn_pack_rhs_nxk_get_packed_rhs_size = | ||
| 317 | kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 318 | matmul_methods[4].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme; | ||
| 319 | matmul_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 320 | matmul_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 321 | matmul_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 322 | matmul_methods[4].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | ||
| 323 | |||
| 324 | matmul_methods[5].name = "matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa"; | ||
| 325 | matmul_methods[5].m0 = 2 * get_sme_vector_length<float>(); | ||
| 326 | matmul_methods[5].n0 = 2 * get_sme_vector_length<float>(); | ||
| 327 | matmul_methods[5].dst_format = DataFormat(DataType::FP16); | ||
| 328 | matmul_methods[5].lhs_format = DataFormat(DataType::FP16); | ||
| 329 | matmul_methods[5].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length<float>(), 2); | ||
| 330 | matmul_methods[5].rhs_format = DataFormat(DataType::FP16); | ||
| 331 | matmul_methods[5].packed_rhs_format = DataFormat( | ||
| 332 | DataType::FP16, // Output type | ||
| 333 | 2 * get_sme_vector_length<float>(), 2, // Block size | ||
| 334 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
| 335 | DataType::FP16, // Bias format | ||
| 336 | DataType::UNKNOWN, // Scaling type | ||
| 337 | 2 * get_sme_vector_length<float>(), 2); // Sub-block | ||
| 338 | matmul_methods[5].bias_format = DataFormat(DataType::FP16); | ||
| 339 | 18 | matmul_methods[5].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 340 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 341 | ✗ | }; | |
| 342 | 18 | matmul_methods[5].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 343 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 344 | ✗ | }; | |
| 345 | 18 | matmul_methods[5].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 346 | KAI_UNUSED(null_bias_mode); | ||
| 347 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 348 | ✗ | }; | |
| 349 | matmul_methods[5].fn_is_supported = cpu_has_sme; | ||
| 350 | matmul_methods[5].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 351 | matmul_methods[5].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 352 | matmul_methods[5].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 353 | matmul_methods[5].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 354 | matmul_methods[5].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 355 | matmul_methods[5].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 356 | matmul_methods[5].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 357 | matmul_methods[5].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; | ||
| 358 | matmul_methods[5].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme; | ||
| 359 | matmul_methods[5].fn_get_packed_lhs_offset = | ||
| 360 | kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 361 | matmul_methods[5].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme; | ||
| 362 | matmul_methods[5].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 363 | matmul_methods[5].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 364 | matmul_methods[5].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 365 | matmul_methods[5].fn_get_main_packed_rhs_offset = | ||
| 366 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 367 | matmul_methods[5].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 368 | matmul_methods[5].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 369 | matmul_methods[5].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 370 | matmul_methods[5].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 371 | matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_offset = | ||
| 372 | kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 373 | matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 374 | matmul_methods[5].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; | ||
| 375 | matmul_methods[5].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 376 | matmul_methods[5].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 377 | matmul_methods[5].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 378 | matmul_methods[5].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | ||
| 379 | |||
| 380 | matmul_methods[6].name = "matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla"; | ||
| 381 | matmul_methods[6].m0 = 1; | ||
| 382 | matmul_methods[6].n0 = 4 * get_sve_vector_length<float>(); | ||
| 383 | matmul_methods[6].dst_format = DataFormat(DataType::FP32); | ||
| 384 | matmul_methods[6].lhs_format = DataFormat(DataType::FP32); | ||
| 385 | matmul_methods[6].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 386 | matmul_methods[6].rhs_format = DataFormat(DataType::FP32); | ||
| 387 | matmul_methods[6].packed_rhs_format = DataFormat( | ||
| 388 | DataType::FP32, 4 * get_sve_vector_length<float>(), 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, | ||
| 389 | DataType::UNKNOWN, 4 * get_sve_vector_length<float>(), 1); | ||
| 390 | matmul_methods[6].bias_format = DataFormat(DataType::FP32); | ||
| 391 | 36 | matmul_methods[6].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 392 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 393 | ✗ | }; | |
| 394 | 36 | matmul_methods[6].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 395 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 396 | ✗ | }; | |
| 397 | 36 | matmul_methods[6].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 398 | KAI_UNUSED(null_bias_mode); | ||
| 399 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 400 | ✗ | }; | |
| 401 | matmul_methods[6].fn_is_supported = cpu_has_sve; | ||
| 402 | matmul_methods[6].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 403 | matmul_methods[6].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 404 | matmul_methods[6].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 405 | matmul_methods[6].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 406 | matmul_methods[6].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve; | ||
| 407 | matmul_methods[6].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 408 | matmul_methods[6].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 409 | matmul_methods[6].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve; | ||
| 410 | matmul_methods[6].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve; | ||
| 411 | matmul_methods[6].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve; | ||
| 412 | matmul_methods[6].fn_get_main_packed_rhs_offset = | ||
| 413 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 414 | matmul_methods[6].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve; | ||
| 415 | matmul_methods[6].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p4vlx1b_x32_x32_sve; | ||
| 416 | matmul_methods[6].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 417 | matmul_methods[6].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 418 | matmul_methods[6].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p4vlx1b_6x4vl_sve_mla; | ||
| 419 | |||
| 420 | return matmul_methods; | ||
| 421 | } | ||
| 422 | |||
| 423 | 12 | static const auto& get_vecmul_methods() { | |
| 424 | // List of supported vector by matrix multiplication methods | ||
| 425 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
12 | static std::array<MatMulMethod, 5> vecmul_methods{}; |
| 426 | |||
| 427 | vecmul_methods[0].name = "matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot"; | ||
| 428 | vecmul_methods[0].m0 = 1; | ||
| 429 | vecmul_methods[0].n0 = 16 * get_sme_vector_length<float>(); | ||
| 430 | vecmul_methods[0].dst_format = DataFormat(DataType::FP16); | ||
| 431 | vecmul_methods[0].lhs_format = DataFormat(DataType::FP16); | ||
| 432 | vecmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 433 | vecmul_methods[0].rhs_format = DataFormat(DataType::FP16); | ||
| 434 | vecmul_methods[0].packed_rhs_format = DataFormat( | ||
| 435 | DataType::FP16, // Output type | ||
| 436 | 2 * get_sme_vector_length<float>(), 2, // Block size | ||
| 437 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
| 438 | DataType::FP16, // Bias format | ||
| 439 | DataType::UNKNOWN, // Scaling type | ||
| 440 | 2 * get_sme_vector_length<float>(), 2); // Sub-block | ||
| 441 | vecmul_methods[0].bias_format = DataFormat(DataType::FP16); | ||
| 442 | 39 | vecmul_methods[0].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 443 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 444 | ✗ | }; | |
| 445 | 39 | vecmul_methods[0].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 446 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 447 | ✗ | }; | |
| 448 | 39 | vecmul_methods[0].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 449 | KAI_UNUSED(null_bias_mode); | ||
| 450 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 451 | ✗ | }; | |
| 452 | vecmul_methods[0].fn_is_supported = cpu_has_sme2; | ||
| 453 | vecmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
| 454 | vecmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
| 455 | vecmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
| 456 | vecmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
| 457 | vecmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 458 | vecmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
| 459 | vecmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; | ||
| 460 | vecmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 461 | vecmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 462 | vecmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 463 | vecmul_methods[0].fn_get_main_packed_rhs_offset = | ||
| 464 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
| 465 | vecmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 466 | vecmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 467 | vecmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
| 468 | vecmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
| 469 | vecmul_methods[0].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot; | ||
| 470 | |||
| 471 | vecmul_methods[1].name = "matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla"; | ||
| 472 | vecmul_methods[1].m0 = 1; | ||
| 473 | vecmul_methods[1].n0 = 8 * get_sme_vector_length<float>(); | ||
| 474 | vecmul_methods[1].dst_format = DataFormat(DataType::FP16); | ||
| 475 | vecmul_methods[1].lhs_format = DataFormat(DataType::FP16); | ||
| 476 | vecmul_methods[1].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 477 | vecmul_methods[1].rhs_format = DataFormat(DataType::FP16); | ||
| 478 | vecmul_methods[1].packed_rhs_format = DataFormat( | ||
| 479 | DataType::FP16, // Output type | ||
| 480 | 2 * get_sme_vector_length<float>(), 2, // Block size | ||
| 481 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
| 482 | DataType::FP16, // Bias format | ||
| 483 | DataType::UNKNOWN, // Scaling type | ||
| 484 | 2 * get_sme_vector_length<float>(), 2); // Sub-block | ||
| 485 | vecmul_methods[1].bias_format = DataFormat(DataType::FP16); | ||
| 486 | 39 | vecmul_methods[1].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 487 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 488 | ✗ | }; | |
| 489 | 39 | vecmul_methods[1].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 490 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 491 | ✗ | }; | |
| 492 | 39 | vecmul_methods[1].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 493 | KAI_UNUSED(null_bias_mode); | ||
| 494 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 495 | ✗ | }; | |
| 496 | vecmul_methods[1].fn_is_supported = cpu_has_sme; | ||
| 497 | vecmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
| 498 | vecmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
| 499 | vecmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
| 500 | vecmul_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
| 501 | vecmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 502 | vecmul_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
| 503 | vecmul_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; | ||
| 504 | vecmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 505 | vecmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 506 | vecmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 507 | vecmul_methods[1].fn_get_main_packed_rhs_offset = | ||
| 508 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
| 509 | vecmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 510 | vecmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 511 | vecmul_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
| 512 | vecmul_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
| 513 | vecmul_methods[1].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla; | ||
| 514 | |||
| 515 | vecmul_methods[2].name = "matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla"; | ||
| 516 | vecmul_methods[2].m0 = 1; | ||
| 517 | vecmul_methods[2].n0 = 8 * get_sme_vector_length<float>(); | ||
| 518 | vecmul_methods[2].dst_format = DataFormat(DataType::FP32); | ||
| 519 | vecmul_methods[2].lhs_format = DataFormat(DataType::FP32); | ||
| 520 | vecmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 521 | vecmul_methods[2].rhs_format = DataFormat(DataType::FP32); | ||
| 522 | vecmul_methods[2].packed_rhs_format = DataFormat( | ||
| 523 | DataType::FP32, // Output type | ||
| 524 | 2 * get_sme_vector_length<float>(), 1, // Block size | ||
| 525 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
| 526 | DataType::FP32, // Bias format | ||
| 527 | DataType::UNKNOWN, // Scaling type | ||
| 528 | 2 * get_sme_vector_length<float>(), 1); // Sub-block | ||
| 529 | vecmul_methods[2].bias_format = DataFormat(DataType::FP32); | ||
| 530 | 39 | vecmul_methods[2].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 531 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 532 | ✗ | }; | |
| 533 | 39 | vecmul_methods[2].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 534 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 535 | ✗ | }; | |
| 536 | 39 | vecmul_methods[2].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 537 | KAI_UNUSED(null_bias_mode); | ||
| 538 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 539 | ✗ | }; | |
| 540 | vecmul_methods[2].fn_is_supported = cpu_has_sme; | ||
| 541 | vecmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 542 | vecmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 543 | vecmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 544 | vecmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 545 | vecmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 546 | vecmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 547 | vecmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 548 | vecmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 549 | vecmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 550 | vecmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = | ||
| 551 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 552 | vecmul_methods[2].fn_get_main_packed_rhs_offset = | ||
| 553 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 554 | vecmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 555 | vecmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 556 | vecmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 557 | vecmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 558 | vecmul_methods[2].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla; | ||
| 559 | |||
| 560 | vecmul_methods[3].name = "matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla"; | ||
| 561 | vecmul_methods[3].m0 = 1; | ||
| 562 | vecmul_methods[3].n0 = 16 * get_sme_vector_length<float>(); | ||
| 563 | vecmul_methods[3].dst_format = DataFormat(DataType::FP32); | ||
| 564 | vecmul_methods[3].lhs_format = DataFormat(DataType::FP32); | ||
| 565 | vecmul_methods[3].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 566 | vecmul_methods[3].rhs_format = DataFormat(DataType::FP32); | ||
| 567 | vecmul_methods[3].packed_rhs_format = DataFormat( | ||
| 568 | DataType::FP32, // Output type | ||
| 569 | 2 * get_sme_vector_length<float>(), 1, // Block size | ||
| 570 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
| 571 | DataType::FP32, // Bias format | ||
| 572 | DataType::UNKNOWN, // Scaling type | ||
| 573 | 2 * get_sme_vector_length<float>(), 1); // Sub-block | ||
| 574 | vecmul_methods[3].bias_format = DataFormat(DataType::FP32); | ||
| 575 | 39 | vecmul_methods[3].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 576 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 577 | ✗ | }; | |
| 578 | 39 | vecmul_methods[3].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 579 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 580 | ✗ | }; | |
| 581 | 39 | vecmul_methods[3].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 582 | KAI_UNUSED(null_bias_mode); | ||
| 583 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 584 | ✗ | }; | |
| 585 | vecmul_methods[3].fn_is_supported = cpu_has_sme2; | ||
| 586 | vecmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 587 | vecmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 588 | vecmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 589 | vecmul_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 590 | vecmul_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 591 | vecmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 592 | vecmul_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 593 | vecmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 594 | vecmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 595 | vecmul_methods[3].fn_get_pack_rhs_packed_rhs_offset = | ||
| 596 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 597 | vecmul_methods[3].fn_get_main_packed_rhs_offset = | ||
| 598 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 599 | vecmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 600 | vecmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme; | ||
| 601 | vecmul_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 602 | vecmul_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 603 | vecmul_methods[3].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla; | ||
| 604 | |||
| 605 | vecmul_methods[4].name = "matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla"; | ||
| 606 | vecmul_methods[4].m0 = 1; | ||
| 607 | vecmul_methods[4].n0 = 16 * get_sme_vector_length<float>(); | ||
| 608 | vecmul_methods[4].dst_format = DataFormat(DataType::FP32); | ||
| 609 | vecmul_methods[4].lhs_format = DataFormat(DataType::FP32); | ||
| 610 | vecmul_methods[4].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 611 | vecmul_methods[4].rhs_format = DataFormat(DataType::FP32); | ||
| 612 | vecmul_methods[4].packed_rhs_format = DataFormat( | ||
| 613 | DataType::FP32, // Output type | ||
| 614 | 16 * get_sme_vector_length<float>(), 1, // Block size | ||
| 615 | DataFormat::PackFormat::BIAS_PER_ROW, // Data layout | ||
| 616 | DataType::FP32, // Bias format | ||
| 617 | DataType::UNKNOWN, // Scaling type | ||
| 618 | 16 * get_sme_vector_length<float>(), 1); // Sub-block | ||
| 619 | vecmul_methods[4].bias_format = DataFormat(DataType::FP32); | ||
| 620 | 39 | vecmul_methods[4].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 621 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 622 | ✗ | }; | |
| 623 | 39 | vecmul_methods[4].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 624 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 625 | ✗ | }; | |
| 626 | 39 | vecmul_methods[4].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 627 | KAI_UNUSED(null_bias_mode); | ||
| 628 |
1/2✓ Branch 0 taken 39 times.
✗ Branch 1 not taken.
|
39 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 629 | ✗ | }; | |
| 630 | vecmul_methods[4].fn_is_supported = cpu_has_sme2; | ||
| 631 | vecmul_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 632 | vecmul_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 633 | vecmul_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 634 | vecmul_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 635 | vecmul_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 636 | vecmul_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
| 637 | vecmul_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 638 | vecmul_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
| 639 | vecmul_methods[4].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
| 640 | vecmul_methods[4].fn_get_pack_rhs_packed_rhs_offset = | ||
| 641 | kai_get_rhs_packed_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
| 642 | vecmul_methods[4].fn_get_main_packed_rhs_offset = | ||
| 643 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 644 | vecmul_methods[4].fn_pack_rhs = kai_run_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
| 645 | vecmul_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme; | ||
| 646 | vecmul_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 647 | vecmul_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 648 | vecmul_methods[4].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla; | ||
| 649 | |||
| 650 | return vecmul_methods; | ||
| 651 | } | ||
| 652 | |||
| 653 | 12 | static const auto& get_nullbias_matmul_methods() { | |
| 654 | // List of supported vector by matrix multiplication methods | ||
| 655 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
12 | static std::array<MatMulMethod, 4> nullbias_matmul_methods{}; |
| 656 | |||
| 657 | nullbias_matmul_methods[0].name = "matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla"; | ||
| 658 | nullbias_matmul_methods[0].m0 = 6; | ||
| 659 | nullbias_matmul_methods[0].n0 = 16; | ||
| 660 | nullbias_matmul_methods[0].dst_format = DataFormat(DataType::FP32); | ||
| 661 | nullbias_matmul_methods[0].lhs_format = DataFormat(DataType::FP32); | ||
| 662 | nullbias_matmul_methods[0].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 663 | nullbias_matmul_methods[0].rhs_format = DataFormat(DataType::FP32); | ||
| 664 | nullbias_matmul_methods[0].packed_rhs_format = DataFormat( | ||
| 665 | DataType::FP32, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 16, 1); | ||
| 666 | nullbias_matmul_methods[0].bias_format = DataFormat(DataType::FP32); | ||
| 667 | 72 | nullbias_matmul_methods[0].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 668 |
1/2✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
|
72 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 669 | ✗ | }; | |
| 670 | 72 | nullbias_matmul_methods[0].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 671 |
1/2✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
|
72 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 672 | ✗ | }; | |
| 673 | 72 | nullbias_matmul_methods[0].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 674 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 36 times.
|
72 | if (null_bias_mode) { |
| 675 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return ConstantGenerator<float>(0.0)(rows, cols); |
| 676 | } else { | ||
| 677 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 678 | } | ||
| 679 | 72 | }; | |
| 680 | nullbias_matmul_methods[0].fn_is_supported = cpu_has_advsimd; | ||
| 681 | nullbias_matmul_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 682 | nullbias_matmul_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 683 | nullbias_matmul_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 684 | nullbias_matmul_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 685 | nullbias_matmul_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 686 | nullbias_matmul_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 687 | nullbias_matmul_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 688 | nullbias_matmul_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 689 | nullbias_matmul_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 690 | nullbias_matmul_methods[0].fn_get_pack_rhs_packed_rhs_offset = | ||
| 691 | kai_get_rhs_packed_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 692 | nullbias_matmul_methods[0].fn_get_main_packed_rhs_offset = | ||
| 693 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 694 | nullbias_matmul_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 695 | nullbias_matmul_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 696 | nullbias_matmul_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 697 | nullbias_matmul_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 698 | nullbias_matmul_methods[0].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla; | ||
| 699 | |||
| 700 | nullbias_matmul_methods[1].name = "matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55"; | ||
| 701 | nullbias_matmul_methods[1].m0 = 6; | ||
| 702 | nullbias_matmul_methods[1].n0 = 16; | ||
| 703 | nullbias_matmul_methods[1].dst_format = DataFormat(DataType::FP32); | ||
| 704 | nullbias_matmul_methods[1].lhs_format = DataFormat(DataType::FP32); | ||
| 705 | nullbias_matmul_methods[1].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 706 | nullbias_matmul_methods[1].rhs_format = DataFormat(DataType::FP32); | ||
| 707 | nullbias_matmul_methods[1].packed_rhs_format = DataFormat( | ||
| 708 | DataType::FP32, 16, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 16, 1); | ||
| 709 | nullbias_matmul_methods[1].bias_format = DataFormat(DataType::FP32); | ||
| 710 | 72 | nullbias_matmul_methods[1].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 711 |
1/2✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
|
72 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 712 | ✗ | }; | |
| 713 | 72 | nullbias_matmul_methods[1].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 714 |
1/2✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
|
72 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 715 | ✗ | }; | |
| 716 | 72 | nullbias_matmul_methods[1].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 717 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 36 times.
|
72 | if (null_bias_mode) { |
| 718 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return ConstantGenerator<float>(0.0)(rows, cols); |
| 719 | } else { | ||
| 720 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<float>(-1.0, 1.0, feed())(rows, cols); |
| 721 | } | ||
| 722 | 72 | }; | |
| 723 | nullbias_matmul_methods[1].fn_is_supported = cpu_has_advsimd; | ||
| 724 | nullbias_matmul_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 725 | nullbias_matmul_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 726 | nullbias_matmul_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 727 | nullbias_matmul_methods[1].fn_get_main_m_step = | ||
| 728 | kai_get_m_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 729 | nullbias_matmul_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 730 | nullbias_matmul_methods[1].fn_get_main_n_step = | ||
| 731 | kai_get_n_step_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 732 | nullbias_matmul_methods[1].fn_get_lhs_offset = | ||
| 733 | kai_get_lhs_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 734 | nullbias_matmul_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 735 | nullbias_matmul_methods[1].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 736 | nullbias_matmul_methods[1].fn_get_pack_rhs_packed_rhs_offset = | ||
| 737 | kai_get_rhs_packed_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 738 | nullbias_matmul_methods[1].fn_get_main_packed_rhs_offset = | ||
| 739 | kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 740 | nullbias_matmul_methods[1].fn_pack_rhs = kai_run_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 741 | nullbias_matmul_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x32p16x1b_x32_x32_neon; | ||
| 742 | nullbias_matmul_methods[1].fn_get_dst_offset = | ||
| 743 | kai_get_dst_offset_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 744 | nullbias_matmul_methods[1].fn_get_dst_size = | ||
| 745 | kai_get_dst_size_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 746 | nullbias_matmul_methods[1].fn_matmul_f32_f32_f32p = kai_run_matmul_clamp_f32_f32_f32p16x1b_6x16_neon_mla_cortexa55; | ||
| 747 | |||
| 748 | nullbias_matmul_methods[2].name = "matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla"; | ||
| 749 | nullbias_matmul_methods[2].m0 = 6; | ||
| 750 | nullbias_matmul_methods[2].n0 = 32; | ||
| 751 | nullbias_matmul_methods[2].dst_format = DataFormat(DataType::FP16); | ||
| 752 | nullbias_matmul_methods[2].lhs_format = DataFormat(DataType::FP16); | ||
| 753 | nullbias_matmul_methods[2].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 754 | nullbias_matmul_methods[2].rhs_format = DataFormat(DataType::FP16); | ||
| 755 | nullbias_matmul_methods[2].packed_rhs_format = DataFormat( | ||
| 756 | DataType::FP16, 32, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 32, 1); | ||
| 757 | nullbias_matmul_methods[2].bias_format = DataFormat(DataType::FP16); | ||
| 758 | 72 | nullbias_matmul_methods[2].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 759 |
1/2✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
|
72 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 760 | ✗ | }; | |
| 761 | 72 | nullbias_matmul_methods[2].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 762 |
1/2✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
|
72 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 763 | ✗ | }; | |
| 764 | 72 | nullbias_matmul_methods[2].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 765 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 36 times.
|
72 | if (null_bias_mode) { |
| 766 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return ConstantGenerator<Float16>(0.0)(rows, cols); |
| 767 | } else { | ||
| 768 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 769 | } | ||
| 770 | 72 | }; | |
| 771 | nullbias_matmul_methods[2].fn_is_supported = cpu_has_fp16; | ||
| 772 | nullbias_matmul_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 773 | nullbias_matmul_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 774 | nullbias_matmul_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 775 | nullbias_matmul_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 776 | nullbias_matmul_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 777 | nullbias_matmul_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 778 | nullbias_matmul_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 779 | nullbias_matmul_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 780 | nullbias_matmul_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 781 | nullbias_matmul_methods[2].fn_get_pack_rhs_packed_rhs_offset = | ||
| 782 | kai_get_rhs_packed_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 783 | nullbias_matmul_methods[2].fn_get_main_packed_rhs_offset = | ||
| 784 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 785 | nullbias_matmul_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 786 | nullbias_matmul_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 787 | nullbias_matmul_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 788 | nullbias_matmul_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 789 | nullbias_matmul_methods[2].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla; | ||
| 790 | |||
| 791 | nullbias_matmul_methods[3].name = "matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55"; | ||
| 792 | nullbias_matmul_methods[3].m0 = 6; | ||
| 793 | nullbias_matmul_methods[3].n0 = 32; | ||
| 794 | nullbias_matmul_methods[3].dst_format = DataFormat(DataType::FP16); | ||
| 795 | nullbias_matmul_methods[3].lhs_format = DataFormat(DataType::FP16); | ||
| 796 | nullbias_matmul_methods[3].packed_lhs_format = DataFormat(DataType::UNKNOWN); | ||
| 797 | nullbias_matmul_methods[3].rhs_format = DataFormat(DataType::FP16); | ||
| 798 | nullbias_matmul_methods[3].packed_rhs_format = DataFormat( | ||
| 799 | DataType::FP16, 32, 0, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP16, DataType::UNKNOWN, 32, 1); | ||
| 800 | nullbias_matmul_methods[3].bias_format = DataFormat(DataType::FP16); | ||
| 801 | 72 | nullbias_matmul_methods[3].fn_generate_lhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 802 |
1/2✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
|
72 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 803 | ✗ | }; | |
| 804 | 72 | nullbias_matmul_methods[3].fn_generate_rhs = [](size_t rows, size_t cols, SeedFeed& feed) { | |
| 805 |
1/2✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
|
72 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 806 | ✗ | }; | |
| 807 | 72 | nullbias_matmul_methods[3].fn_generate_bias = [](size_t rows, size_t cols, SeedFeed& feed, bool null_bias_mode) { | |
| 808 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 36 times.
|
72 | if (null_bias_mode) { |
| 809 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return ConstantGenerator<Float16>(0.0)(rows, cols); |
| 810 | } else { | ||
| 811 |
1/2✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
|
36 | return UniformRandomGenerator<Float16>(-1.0, 1.0, feed())(rows, cols); |
| 812 | } | ||
| 813 | 72 | }; | |
| 814 | nullbias_matmul_methods[3].fn_is_supported = cpu_has_fp16; | ||
| 815 | nullbias_matmul_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 816 | nullbias_matmul_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 817 | nullbias_matmul_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 818 | nullbias_matmul_methods[3].fn_get_main_m_step = | ||
| 819 | kai_get_m_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 820 | nullbias_matmul_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 821 | nullbias_matmul_methods[3].fn_get_main_n_step = | ||
| 822 | kai_get_n_step_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 823 | nullbias_matmul_methods[3].fn_get_lhs_offset = | ||
| 824 | kai_get_lhs_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 825 | nullbias_matmul_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 826 | nullbias_matmul_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 827 | nullbias_matmul_methods[3].fn_get_pack_rhs_packed_rhs_offset = | ||
| 828 | kai_get_rhs_packed_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 829 | nullbias_matmul_methods[3].fn_get_main_packed_rhs_offset = | ||
| 830 | kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 831 | nullbias_matmul_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 832 | nullbias_matmul_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p32x1b_x16_x16_neon; | ||
| 833 | nullbias_matmul_methods[3].fn_get_dst_offset = | ||
| 834 | kai_get_dst_offset_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 835 | nullbias_matmul_methods[3].fn_get_dst_size = | ||
| 836 | kai_get_dst_size_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 837 | nullbias_matmul_methods[3].fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p32x1b_6x32_neon_mla_cortexa55; | ||
| 838 | |||
| 839 | return nullbias_matmul_methods; | ||
| 840 | } | ||
| 841 | |||
| 842 | using MatMulClampTestParams = std::tuple<MatMulMethod, MatMulShape, MatrixPortion, BiasMode, float>; | ||
| 843 | |||
| 844 | /// Matrix multiplication test fixture. | ||
| 845 | class MatMulTest : public testing::TestWithParam<MatMulClampTestParams> { | ||
| 846 | private: | ||
| 847 | /// Unique ID: m, n, k, method_id. | ||
| 848 | using TestDataId = std::tuple<size_t, size_t, size_t, std::string_view, BiasMode, float>; | ||
| 849 | |||
| 850 | protected: | ||
| 851 | /// Cached test data that is shared between multiple test case. | ||
| 852 | 663 | struct TestData { | |
| 853 | 1326 | Buffer lhs{}; ///< LHS operand. | |
| 854 | 1326 | Buffer ref_packed_lhs{}; ///< Reference packed LHS. | |
| 855 | 1326 | Buffer rhs{}; ///< RHS operand. | |
| 856 | 1326 | Buffer rhs_scales{}; ///< RHS per-row quantization scales. | |
| 857 | 1326 | Buffer bias{}; ///< Bias. | |
| 858 | 1326 | Buffer rhs_t{}; ///< Transposed RHS matrix. | |
| 859 | 1326 | Buffer ref_packed_rhs{}; ///< Reference packed RHS. | |
| 860 | 1326 | Buffer ref_dst{}; ///< Reference output. | |
| 861 | 663 | float clamp_min{}; ///< Minimum output value. | |
| 862 | 663 | float clamp_max{}; ///< Maximum output value. | |
| 863 | }; | ||
| 864 | |||
| 865 | /// Gets the test data for the current test case. | ||
| 866 | 4800 | static const TestData& test_data() { | |
| 867 | 69417 | const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam(); | |
| 868 | 28800 | const TestDataId data_id{info.m, info.n, info.k, method.name, bias_mode, clamp_keep_ratio}; | |
| 869 | |||
| 870 | // Creates a unique seed for the test data. | ||
| 871 |
12/24✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4800 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4800 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4800 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4800 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4800 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4800 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4800 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4800 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 4800 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4800 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1296 times.
✗ Branch 23 not taken.
|
19200 | const auto key = std::string(method.name) + "_" + std::to_string(info.m) + "x" + std::to_string(info.n) + "x" + |
| 872 |
8/14✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4800 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4800 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3504 times.
✓ Branch 7 taken 1296 times.
✓ Branch 8 taken 4800 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4800 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4800 times.
✗ Branch 13 not taken.
|
19200 | std::to_string(info.k) + "_" + (bias_mode == BiasMode::INTERNAL ? "internal" : "provided") + "_" + |
| 873 |
2/4✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4800 times.
✗ Branch 3 not taken.
|
9600 | std::to_string(clamp_keep_ratio); |
| 874 |
1/2✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
|
4800 | auto& feed = seed_stream(key); |
| 875 | |||
| 876 | // If the test data is already available, returns it. | ||
| 877 |
1/2✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
|
4800 | const auto data_it = _data.find(data_id); |
| 878 | |||
| 879 |
4/4✓ Branch 0 taken 4584 times.
✓ Branch 1 taken 216 times.
✓ Branch 2 taken 3057 times.
✓ Branch 3 taken 447 times.
|
4800 | if (data_it != _data.end()) { |
| 880 |
1/2✓ Branch 0 taken 3057 times.
✗ Branch 1 not taken.
|
4137 | return data_it->second; |
| 881 | } | ||
| 882 | |||
| 883 | // Generates the test data. | ||
| 884 |
2/4✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
|
1326 | const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; |
| 885 |
2/4✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
|
1326 | const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; |
| 886 |
2/4✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
|
1326 | const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; |
| 887 | 1326 | const bool null_bias_mode = bias_mode == BiasMode::INTERNAL; | |
| 888 | |||
| 889 | 1326 | const auto lhs_h = info.m; | |
| 890 | 1326 | const auto lhs_w = info.k; | |
| 891 |
2/4✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
|
1326 | auto lhs = method.fn_generate_lhs(lhs_h, lhs_w, feed); |
| 892 | |||
| 893 | 663 | Buffer ref_packed_lhs; | |
| 894 |
2/2✓ Branch 0 taken 591 times.
✓ Branch 1 taken 72 times.
|
663 | if (has_lhs_pack) { |
| 895 | 72 | ref_packed_lhs = | |
| 896 |
3/6✓ Branch 0 taken 72 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 72 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 72 times.
✗ Branch 5 not taken.
|
144 | pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); |
| 897 | 72 | } | |
| 898 | |||
| 899 | 1326 | const auto rhs_h = info.k; | |
| 900 | 1326 | const auto rhs_w = info.n; | |
| 901 |
2/4✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
|
1326 | auto rhs = method.fn_generate_rhs(rhs_h, rhs_w, feed); |
| 902 | |||
| 903 | − | KAI_ASSUME_ALWAYS(method.rhs_format.is_raw()); | |
| 904 |
3/6✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
|
663 | auto rhs_t = transpose(rhs.data(), method.rhs_format.data_type(), rhs_h, rhs_w); |
| 905 | |||
| 906 | 663 | Buffer rhs_scales; | |
| 907 |
3/8✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 663 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
663 | if (data_type_is_quantized(method.rhs_format.data_type()) && |
| 908 | ✗ | method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { | |
| 909 | ✗ | rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), feed()); | |
| 910 | ✗ | } | |
| 911 | |||
| 912 | 663 | const auto bias_h = 1; | |
| 913 | 1326 | const auto bias_w = info.n; | |
| 914 | 663 | Buffer bias; | |
| 915 | |||
| 916 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 663 times.
|
663 | if (has_bias) { |
| 917 |
2/4✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
|
1326 | bias = method.fn_generate_bias(bias_h, bias_w, feed, null_bias_mode); |
| 918 | 663 | } | |
| 919 | |||
| 920 | 663 | Buffer packed_rhs; | |
| 921 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 663 times.
|
663 | if (has_rhs_pack) { |
| 922 |
1/2✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
|
1326 | packed_rhs = matmul_pack_rhs( |
| 923 |
3/6✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
|
663 | rhs.data(), rhs_scales.data(), bias.data(), method.rhs_format, method.packed_rhs_format, info.n, info.k, |
| 924 | true); | ||
| 925 | 663 | } | |
| 926 | |||
| 927 | − | KAI_ASSUME_ALWAYS(method.lhs_format.is_raw()); | |
| 928 | − | KAI_ASSUME_ALWAYS(method.rhs_format.is_raw()); | |
| 929 | − | KAI_ASSUME_ALWAYS(method.dst_format.is_raw()); | |
| 930 |
1/2✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
|
1326 | auto ref_dst = matmul( |
| 931 |
2/4✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
|
663 | lhs.data(), nullptr, nullptr, method.lhs_format.data_type(), // |
| 932 |
3/6✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
|
663 | rhs.data(), rhs_scales.data(), nullptr, method.rhs_format.data_type(), // |
| 933 |
2/4✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
|
663 | bias.data(), nullptr, nullptr, method.bias_format.data_type(), // |
| 934 |
1/2✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
|
663 | method.dst_format.data_type(), // |
| 935 | 1989 | info.m, info.n, info.k, false, false); | |
| 936 | |||
| 937 | 3978 | const auto [clamp_min, clamp_max] = | |
| 938 |
5/10✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 663 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 663 times.
✗ Branch 9 not taken.
|
663 | find_clamp_range(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_keep_ratio); |
| 939 |
7/14✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 663 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 663 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 663 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 663 times.
✗ Branch 13 not taken.
|
1326 | ref_dst = clamp(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_min, clamp_max); |
| 940 | |||
| 941 |
9/18✓ Branch 0 taken 663 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 663 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 663 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 663 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 663 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 663 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 663 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 663 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 663 times.
✗ Branch 17 not taken.
|
5967 | auto& data = _data[data_id] = {}; |
| 942 | 663 | data.lhs = std::move(lhs); | |
| 943 | 663 | data.ref_packed_lhs = std::move(ref_packed_lhs); | |
| 944 | 663 | data.rhs = std::move(rhs); | |
| 945 | 663 | data.rhs_scales = std::move(rhs_scales); | |
| 946 | 663 | data.bias = std::move(bias); | |
| 947 | 663 | data.rhs_t = std::move(rhs_t); | |
| 948 | 663 | data.ref_packed_rhs = std::move(packed_rhs); | |
| 949 | 663 | data.ref_dst = std::move(ref_dst); | |
| 950 | 663 | data.clamp_min = clamp_min; | |
| 951 | 663 | data.clamp_max = clamp_max; | |
| 952 | |||
| 953 | 663 | return data; | |
| 954 | 4800 | } | |
| 955 | |||
| 956 | private: | ||
| 957 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 958 | static std::map<TestDataId, TestData> _data; | ||
| 959 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 960 | }; | ||
| 961 | |||
| 962 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 963 | 3 | std::map<MatMulTest::TestDataId, MatMulTest::TestData> MatMulTest::_data; | |
| 964 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 965 | |||
| 966 | /// Tests the LHS packing micro-kernel. | ||
| 967 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
8064 | TEST_P(MatMulTest, PackedLhs) { |
| 968 | 6474 | const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam(); | |
| 969 | |||
| 970 |
3/4✓ Branch 0 taken 3234 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✓ Branch 3 taken 1050 times.
|
3234 | if (method.fn_is_supported && !method.fn_is_supported()) { |
| 971 |
3/6✓ Branch 0 taken 1050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1050 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1050 times.
✗ Branch 5 not taken.
|
1050 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 972 | } | ||
| 973 | |||
| 974 |
2/2✓ Branch 0 taken 216 times.
✓ Branch 1 taken 1968 times.
|
2184 | if (!method.is_pack_lhs_needed()) { |
| 975 |
3/6✓ Branch 0 taken 1968 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1968 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1968 times.
✗ Branch 5 not taken.
|
1968 | GTEST_SKIP() << "Test not valid w/o LHS pack"; |
| 976 | } | ||
| 977 | |||
| 978 | 216 | const auto& data = test_data(); | |
| 979 | 432 | const auto lhs_h = info.m; | |
| 980 | 432 | const auto lhs_w = info.k; | |
| 981 | |||
| 982 | 432 | const auto rect = portion.compute_portion( | |
| 983 | 432 | lhs_h, lhs_w, method.packed_lhs_format.scheduler_block_height(lhs_h), | |
| 984 | 216 | lhs_w); // LHS packing micro-kernel API doesn't support scheduling over K dimension. | |
| 985 | |||
| 986 |
2/4✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 216 times.
|
216 | if (rect.height() == 0 || rect.width() == 0) { |
| 987 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 988 | } | ||
| 989 | |||
| 990 | 432 | const auto mr = method.fn_get_mr(); | |
| 991 | 432 | const auto kr = method.fn_get_kr(); | |
| 992 | 432 | const auto sr = method.fn_get_sr(); | |
| 993 | 432 | const auto ref_lhs_row_stride = method.lhs_format.default_row_stride(lhs_w); | |
| 994 | |||
| 995 | 864 | const auto packed_lhs_size = method.fn_get_packed_lhs_size(info.m, info.k, mr, kr, sr); | |
| 996 | 432 | const auto ref_packed_lhs_size = method.packed_lhs_format.default_size_in_bytes(lhs_h, lhs_w); | |
| 997 |
3/14✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
|
216 | ASSERT_EQ(packed_lhs_size, ref_packed_lhs_size); |
| 998 | |||
| 999 | 432 | const auto lhs_offset = method.fn_get_lhs_offset(rect.start_row(), ref_lhs_row_stride); | |
| 1000 | 432 | const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), lhs_w); | |
| 1001 |
3/14✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
|
216 | ASSERT_EQ(lhs_offset, ref_lhs_offset); |
| 1002 | |||
| 1003 | 648 | const auto packed_lhs_offset = method.fn_get_packed_lhs_offset(rect.start_row(), info.k); | |
| 1004 | 432 | const auto ref_packed_lhs_offset = method.packed_lhs_format.default_offset_in_bytes(rect.start_row(), 0, lhs_w); | |
| 1005 |
3/14✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
|
216 | ASSERT_EQ(packed_lhs_offset, ref_packed_lhs_offset); |
| 1006 | |||
| 1007 | 216 | Buffer packed_lhs(packed_lhs_size, 0); | |
| 1008 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | abi_check( |
| 1009 |
3/6✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
|
216 | method.fn_pack_lhs, rect.height(), rect.width(), mr, kr, sr, 0, data.lhs.data() + lhs_offset, |
| 1010 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | ref_lhs_row_stride, packed_lhs.data() + packed_lhs_offset); |
| 1011 | |||
| 1012 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); |
| 1013 | 432 | const auto success = | |
| 1014 |
3/6✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
|
216 | compare(packed_lhs.data(), data.ref_packed_lhs.data(), method.packed_lhs_format, lhs_h, lhs_w, rect, handler); |
| 1015 |
4/16✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 216 times.
|
216 | ASSERT_TRUE(success); |
| 1016 | 3234 | } | |
| 1017 | |||
| 1018 | /// Tests the RHS packing micro-kernel. | ||
| 1019 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
8064 | TEST_P(MatMulTest, PackedRhs) { |
| 1020 | 31626 | const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam(); | |
| 1021 | |||
| 1022 |
3/4✓ Branch 0 taken 3234 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✓ Branch 3 taken 1050 times.
|
3234 | if (method.fn_is_supported && !method.fn_is_supported()) { |
| 1023 |
3/6✓ Branch 0 taken 1050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1050 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1050 times.
✗ Branch 5 not taken.
|
1050 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 1024 | } | ||
| 1025 | |||
| 1026 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
2184 | if (!method.is_pack_rhs_needed()) { |
| 1027 | ✗ | GTEST_SKIP() << "Test not valid w/o RHS pack"; | |
| 1028 | } | ||
| 1029 | |||
| 1030 | 2184 | const auto& data = test_data(); | |
| 1031 | 4368 | const auto rhs_full_width = info.n; | |
| 1032 | 4368 | const auto rhs_full_height = info.k; | |
| 1033 | |||
| 1034 | 4368 | const auto block_height = method.packed_rhs_format.scheduler_block_height(rhs_full_width); | |
| 1035 | 4368 | const auto block_width = method.packed_rhs_format.scheduler_block_width(rhs_full_height); | |
| 1036 | |||
| 1037 | 4368 | const auto null_bias_mode = bias_mode == BiasMode::INTERNAL; | |
| 1038 | |||
| 1039 | 4368 | const Rect rect = portion.compute_portion(rhs_full_width, rhs_full_height, block_height, block_width); | |
| 1040 | |||
| 1041 |
2/4✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2184 times.
|
2184 | if (rect.height() == 0 || rect.width() == 0) { |
| 1042 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 1043 | } | ||
| 1044 | |||
| 1045 | 2184 | const auto rhs_start_row = rect.start_row(); | |
| 1046 | 2184 | const auto rhs_start_col = rect.start_col(); | |
| 1047 | 2184 | const auto width = rect.width(); | |
| 1048 | 2184 | const auto height = rect.height(); | |
| 1049 | 4368 | const auto rhs_row_stride = method.rhs_format.default_row_stride(rhs_full_width); | |
| 1050 | |||
| 1051 | /** Ensure that all relevant parameters are sane **/ | ||
| 1052 | 4368 | const auto n_step = method.fn_get_pack_rhs_n_step(); | |
| 1053 | 2184 | const auto ref_n_step = block_height; | |
| 1054 |
3/14✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
|
2184 | ASSERT_EQ(n_step, ref_n_step); |
| 1055 | |||
| 1056 | 4368 | const auto rhs_offset = method.fn_get_rhs_offset(rhs_start_row); | |
| 1057 | 4368 | const auto ref_rhs_offset = | |
| 1058 | 2184 | method.rhs_format.default_offset_in_bytes(rhs_start_col, rhs_start_row, rhs_full_height); | |
| 1059 |
3/14✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
|
2184 | ASSERT_EQ(rhs_offset, ref_rhs_offset); |
| 1060 | |||
| 1061 | 4368 | const auto packed_rhs_size = method.fn_get_packed_rhs_size(rhs_full_width, rhs_full_height); | |
| 1062 | 4368 | const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(rhs_full_width, rhs_full_height); | |
| 1063 |
3/14✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
|
2184 | ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size); |
| 1064 | |||
| 1065 | 4368 | const auto packed_rhs_offset = method.fn_get_pack_rhs_packed_rhs_offset(rhs_start_row, rhs_full_height); | |
| 1066 | 4368 | const auto ref_packed_rhs_offset = | |
| 1067 | 2184 | method.packed_rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_full_height); | |
| 1068 |
3/14✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
|
2184 | ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset); |
| 1069 | |||
| 1070 | 4368 | const auto scale_type = method.packed_rhs_format.scale_data_type(); | |
| 1071 | 2184 | const auto ref_rhs_scales_offset = rhs_start_row * data_type_size_in_bits(scale_type) / 8; | |
| 1072 | |||
| 1073 | 4368 | const auto bias_offset = method.fn_get_bias_offset(rhs_start_row); | |
| 1074 | 4368 | const auto ref_bias_offset = | |
| 1075 |
2/2✓ Branch 0 taken 432 times.
✓ Branch 1 taken 1752 times.
|
2184 | !null_bias_mode ? method.bias_format.default_offset_in_bytes(0, rhs_start_row, rhs_full_height) : bias_offset; |
| 1076 |
3/14✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
|
2184 | ASSERT_EQ(bias_offset, ref_bias_offset); |
| 1077 | |||
| 1078 | /** Perform RHS packing, and compare with reference result **/ | ||
| 1079 | 2184 | Buffer packed_rhs(packed_rhs_size, 0); | |
| 1080 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
2184 | abi_check( |
| 1081 |
2/4✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
|
4368 | &MatMulMethod::pack_rhs, method, height, width, data.rhs.data() + rhs_offset, rhs_row_stride, |
| 1082 |
3/4✓ Branch 0 taken 432 times.
✓ Branch 1 taken 1752 times.
✓ Branch 2 taken 1752 times.
✗ Branch 3 not taken.
|
2184 | !null_bias_mode ? data.bias.data() + bias_offset : nullptr, |
| 1083 |
2/6✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2184 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
2184 | data.rhs_scales.data() != nullptr ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr, |
| 1084 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
2184 | packed_rhs.data() + packed_rhs_offset); |
| 1085 | |||
| 1086 |
2/4✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
|
4368 | const bool exact = method.packed_rhs_format.pack_format() != DataFormat::PackFormat::QUANTIZE_PER_ROW; |
| 1087 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
2184 | DefaultMismatchHandler handler(0, exact ? 0 : 0.0001, 0, exact ? 0 : 0.001); |
| 1088 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
4368 | const auto success = compare( |
| 1089 |
2/4✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
|
2184 | packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, rhs_full_width, rhs_full_height, rect, |
| 1090 | handler); | ||
| 1091 |
4/16✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2184 times.
|
2184 | ASSERT_TRUE(success); |
| 1092 | 3234 | } | |
| 1093 | |||
| 1094 | /// Tests the transposed RHS packing micro-kernel. | ||
| 1095 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
8064 | TEST_P(MatMulTest, PackedTransposedRhs) { |
| 1096 | 9282 | const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam(); | |
| 1097 | |||
| 1098 |
3/4✓ Branch 0 taken 3234 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✓ Branch 3 taken 1050 times.
|
3234 | if (method.fn_is_supported && !method.fn_is_supported()) { |
| 1099 |
3/6✓ Branch 0 taken 1050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1050 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1050 times.
✗ Branch 5 not taken.
|
1050 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 1100 | } | ||
| 1101 | |||
| 1102 |
2/2✓ Branch 0 taken 216 times.
✓ Branch 1 taken 1968 times.
|
2184 | if (!method.is_pack_rhs_nxk_needed()) { |
| 1103 |
3/6✓ Branch 0 taken 1968 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1968 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1968 times.
✗ Branch 5 not taken.
|
1968 | GTEST_SKIP() << "Test not valid w/o pre-processing of transposed RHS matrix"; |
| 1104 | } | ||
| 1105 | |||
| 1106 | 216 | const auto& data = test_data(); | |
| 1107 | 432 | const bool null_bias_mode = bias_mode == BiasMode::INTERNAL; | |
| 1108 | |||
| 1109 | 432 | const auto n_step = method.fn_pack_rhs_nxk_get_n_step(); | |
| 1110 | 648 | const auto ref_n_step = method.packed_rhs_format.scheduler_block_height(info.n); | |
| 1111 |
3/14✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
|
216 | ASSERT_EQ(n_step, ref_n_step); |
| 1112 | |||
| 1113 | 432 | const auto rect = portion.compute_portion( | |
| 1114 | 864 | info.n, info.k, method.packed_rhs_format.scheduler_block_height(info.n), | |
| 1115 | 432 | method.packed_rhs_format.scheduler_block_width(info.k)); | |
| 1116 | |||
| 1117 |
2/4✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 216 times.
|
216 | if (rect.height() == 0 || rect.width() == 0) { |
| 1118 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 1119 | } | ||
| 1120 | |||
| 1121 | 648 | const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(info.k); | |
| 1122 | |||
| 1123 | 432 | const auto rhs_offset = method.fn_pack_rhs_nxk_get_rhs_offset(rect.start_row(), ref_rhs_row_stride); | |
| 1124 | 648 | const auto ref_rhs_offset = method.rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k); | |
| 1125 |
3/14✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
|
216 | ASSERT_EQ(rhs_offset, ref_rhs_offset); |
| 1126 | |||
| 1127 | 864 | const auto packed_rhs_size = method.fn_pack_rhs_nxk_get_packed_rhs_size(info.n, info.k); | |
| 1128 | 864 | const auto ref_packed_rhs_size = method.packed_rhs_format.default_size_in_bytes(info.n, info.k); | |
| 1129 |
3/14✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
|
216 | ASSERT_EQ(packed_rhs_size, ref_packed_rhs_size); |
| 1130 | |||
| 1131 | 648 | const auto packed_rhs_offset = method.fn_pack_rhs_nxk_get_packed_rhs_offset(rect.start_row(), info.k); | |
| 1132 | 432 | const auto ref_packed_rhs_offset = | |
| 1133 | 432 | method.packed_rhs_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), info.k); | |
| 1134 |
3/14✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
|
216 | ASSERT_EQ(packed_rhs_offset, ref_packed_rhs_offset); |
| 1135 | |||
| 1136 | 432 | const auto ref_rhs_scales_offset = | |
| 1137 | 432 | rect.start_row() * data_type_size_in_bits(method.packed_rhs_format.scale_data_type()) / 8; | |
| 1138 | |||
| 1139 | 432 | const auto bias_offset = method.fn_get_bias_offset(rect.start_row()); | |
| 1140 | 432 | const auto ref_bias_offset = | |
| 1141 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 216 times.
|
216 | !null_bias_mode ? method.bias_format.default_offset_in_bytes(0, rect.start_row(), info.n) : bias_offset; |
| 1142 |
3/14✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
|
216 | ASSERT_EQ(bias_offset, ref_bias_offset); |
| 1143 | |||
| 1144 | 216 | Buffer packed_rhs(packed_rhs_size, 0); | |
| 1145 | |||
| 1146 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | abi_check( |
| 1147 |
4/8✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
|
432 | &MatMulMethod::pack_rhs_nxk, method, rect.height(), rect.width(), data.rhs_t.data() + rhs_offset, |
| 1148 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 216 times.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
|
216 | ref_rhs_row_stride, null_bias_mode ? nullptr : data.bias.data() + bias_offset, |
| 1149 |
2/6✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 216 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
216 | data.rhs_scales.data() != nullptr ? data.rhs_scales.data() + ref_rhs_scales_offset : nullptr, |
| 1150 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | packed_rhs.data() + packed_rhs_offset); |
| 1151 | |||
| 1152 |
2/4✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
|
432 | const auto exact = method.packed_rhs_format.pack_format() != DataFormat::PackFormat::QUANTIZE_PER_ROW; |
| 1153 |
1/2✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
|
216 | DefaultMismatchHandler handler(0, exact ? 0 : 0.0001, 0, exact ? 0 : 0.001); |
| 1154 | 432 | const auto success = | |
| 1155 |
5/10✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 216 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 216 times.
✗ Branch 9 not taken.
|
216 | compare(packed_rhs.data(), data.ref_packed_rhs.data(), method.packed_rhs_format, info.n, info.k, rect, handler); |
| 1156 |
4/16✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 216 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 216 times.
|
216 | ASSERT_TRUE(success); |
| 1157 | 3234 | } | |
| 1158 | |||
| 1159 | /// Tests the output. | ||
| 1160 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
8064 | TEST_P(MatMulTest, Output) { |
| 1161 | 63066 | const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam(); | |
| 1162 | |||
| 1163 |
3/4✓ Branch 0 taken 3234 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✓ Branch 3 taken 1050 times.
|
3234 | if (method.fn_is_supported && !method.fn_is_supported()) { |
| 1164 |
3/6✓ Branch 0 taken 1050 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1050 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1050 times.
✗ Branch 5 not taken.
|
1050 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 1165 | } | ||
| 1166 | |||
| 1167 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
2184 | if (!method.has_main_kernel()) { |
| 1168 | ✗ | GTEST_SKIP() << "No main kernel available"; | |
| 1169 | } | ||
| 1170 | |||
| 1171 | 2184 | const auto& data = test_data(); | |
| 1172 | 4368 | const auto m_step = method.fn_get_main_m_step(); | |
| 1173 |
4/16✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2184 times.
|
4368 | ASSERT_EQ(m_step, method.m0); |
| 1174 | |||
| 1175 | 4368 | const auto n_step = method.fn_get_main_n_step(); | |
| 1176 |
4/16✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2184 times.
|
4368 | ASSERT_EQ(n_step, method.n0); |
| 1177 | |||
| 1178 | 10920 | const auto rect = portion.compute_portion(info.m, info.n, method.m0, method.n0); | |
| 1179 | |||
| 1180 |
2/4✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2184 times.
|
2184 | if (rect.height() == 0 || rect.width() == 0) { |
| 1181 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 1182 | } | ||
| 1183 | |||
| 1184 | 4368 | const auto lhs_w = info.k; | |
| 1185 | 4368 | const auto rhs_w = info.n; | |
| 1186 | 4368 | const auto bias_w = info.n; | |
| 1187 | 4368 | const auto dst_w = info.n; | |
| 1188 | |||
| 1189 | 4368 | const bool null_bias_mode = bias_mode == BiasMode::INTERNAL; | |
| 1190 | |||
| 1191 | 2184 | const auto lhs_start_row = rect.start_row(); | |
| 1192 | 2184 | const auto lhs_start_col = 0; | |
| 1193 | 4368 | const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); | |
| 1194 | |||
| 1195 | 2184 | const std::byte* lhs_data = nullptr; | |
| 1196 | 2184 | uintptr_t lhs_offset = 0; | |
| 1197 | |||
| 1198 |
2/2✓ Branch 0 taken 216 times.
✓ Branch 1 taken 1968 times.
|
2184 | if (method.is_pack_lhs_needed()) { |
| 1199 | 216 | lhs_data = data.ref_packed_lhs.data(); | |
| 1200 | |||
| 1201 | 432 | const auto ref_packed_lhs_offset = | |
| 1202 | 432 | method.packed_lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, info.k); | |
| 1203 | 432 | lhs_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); | |
| 1204 |
3/14✓ Branch 0 taken 216 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 216 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 216 times.
|
216 | ASSERT_EQ(lhs_offset, ref_packed_lhs_offset); |
| 1205 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 216 times.
|
216 | } else { |
| 1206 | 1968 | lhs_data = data.lhs.data(); | |
| 1207 | |||
| 1208 | 1968 | lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); | |
| 1209 | 3936 | const auto ref_lhs_offset = method.lhs_format.default_offset_in_bytes(lhs_start_row, lhs_start_col, lhs_w); | |
| 1210 |
3/14✓ Branch 0 taken 1968 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1968 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1968 times.
|
1968 | ASSERT_EQ(lhs_offset, ref_lhs_offset); |
| 1211 | 1968 | } | |
| 1212 | |||
| 1213 | 4368 | const auto rhs_stride = method.rhs_format.default_row_stride(rhs_w); | |
| 1214 | |||
| 1215 | 2184 | const std::byte* rhs_data = nullptr; | |
| 1216 | 2184 | uintptr_t rhs_offset = 0; | |
| 1217 | |||
| 1218 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
2184 | if (method.is_pack_rhs_needed()) { |
| 1219 | 2184 | const auto packed_rhs_start_row = rect.start_col(); | |
| 1220 | 2184 | const auto packed_rhs_start_col = 0; | |
| 1221 | |||
| 1222 | 2184 | rhs_data = data.ref_packed_rhs.data(); | |
| 1223 | |||
| 1224 | 4368 | rhs_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); | |
| 1225 | 4368 | const auto ref_rhs_offset = | |
| 1226 | 4368 | method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); | |
| 1227 |
3/14✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
|
2184 | ASSERT_EQ(rhs_offset, ref_rhs_offset); |
| 1228 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2184 times.
|
2184 | } else { |
| 1229 | ✗ | const auto rhs_start_row = 0; | |
| 1230 | ✗ | const auto rhs_start_col = rect.start_col(); | |
| 1231 | |||
| 1232 | ✗ | rhs_data = data.rhs.data(); | |
| 1233 | ✗ | rhs_offset = method.rhs_format.default_offset_in_bytes(rhs_start_row, rhs_start_col, rhs_w); | |
| 1234 | ✗ | } | |
| 1235 | |||
| 1236 | 2184 | const std::byte* bias_data = nullptr; | |
| 1237 |
2/2✓ Branch 0 taken 432 times.
✓ Branch 1 taken 1752 times.
|
2184 | if (!null_bias_mode) { |
| 1238 | 3504 | const auto bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_row(), bias_w); | |
| 1239 | 1752 | bias_data = data.bias.data() + bias_offset; | |
| 1240 | 1752 | } | |
| 1241 | |||
| 1242 | 4368 | const auto dst_stride = method.dst_format.default_row_stride(dst_w); | |
| 1243 | 4368 | const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); | |
| 1244 | 4368 | const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); | |
| 1245 |
3/14✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
|
2184 | ASSERT_EQ(dst_offset, ref_dst_offset); |
| 1246 | |||
| 1247 | 8736 | const auto dst_size = method.fn_get_dst_size(info.m, info.n); | |
| 1248 | 8736 | const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n); | |
| 1249 |
3/14✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2184 times.
|
2184 | ASSERT_EQ(dst_size, ref_dst_size); |
| 1250 | |||
| 1251 | 2184 | Buffer dst(dst_size, 0); | |
| 1252 | |||
| 1253 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
2184 | abi_check( |
| 1254 |
3/6✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
|
4368 | &MatMulMethod::main_kernel, method, rect.height(), rect.width(), info.k, lhs_data + lhs_offset, |
| 1255 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
2184 | rhs_data + rhs_offset, bias_data, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, data.clamp_min, |
| 1256 | 2184 | data.clamp_max); | |
| 1257 | |||
| 1258 |
1/2✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
|
2184 | DefaultMismatchHandler handler(0, 0.1, 0, 0.1); |
| 1259 |
5/10✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2184 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2184 times.
✗ Branch 9 not taken.
|
2184 | const auto success = compare(dst.data(), data.ref_dst.data(), method.dst_format, info.m, info.n, rect, handler); |
| 1260 |
4/16✓ Branch 0 taken 2184 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2184 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2184 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2184 times.
|
2184 | ASSERT_TRUE(success); |
| 1261 | 3234 | } | |
| 1262 | |||
| 1263 |
1/2✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
|
3 | const std::vector<MatrixPortion> MatrixPortions = { |
| 1264 | {0, 0, 1, 1}, | ||
| 1265 | {0, 0, 0.25, 0.25}, | ||
| 1266 | {0.75, 0.75, 1, 1}, | ||
| 1267 | }; | ||
| 1268 |
1/2✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
|
3 | const std::vector<MatMulShape> MatMulShapes = { |
| 1269 | {1, 16, 16}, // | ||
| 1270 | {20, 1, 20}, // | ||
| 1271 | {6, 16, 32}, // | ||
| 1272 | {12, 32, 17}, // | ||
| 1273 | {13, 33, 23}, // | ||
| 1274 | {87, 93, 56}, // | ||
| 1275 | }; | ||
| 1276 | |||
| 1277 |
20/64✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✓ Branch 18 taken 8 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✓ Branch 22 taken 1512 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 3024 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
|
4551 | INSTANTIATE_TEST_SUITE_P( |
| 1278 | MatMul, MatMulTest, | ||
| 1279 | testing::Combine( | ||
| 1280 | testing::ValuesIn(get_matmul_methods()), // | ||
| 1281 | testing::ValuesIn(MatMulShapes), // | ||
| 1282 | testing::ValuesIn(MatrixPortions), // | ||
| 1283 | testing::Values(BiasMode::PROVIDED), // | ||
| 1284 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 1285 | testing::PrintToStringParamName()); | ||
| 1286 | |||
| 1287 |
20/64✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✓ Branch 18 taken 8 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✓ Branch 22 taken 1728 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 3456 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
|
5199 | INSTANTIATE_TEST_SUITE_P( |
| 1288 | NullBiasMatMul, MatMulTest, | ||
| 1289 | testing::Combine( | ||
| 1290 | testing::ValuesIn(get_nullbias_matmul_methods()), // | ||
| 1291 | testing::ValuesIn(MatMulShapes), // | ||
| 1292 | testing::ValuesIn(MatrixPortions), // | ||
| 1293 | testing::Values(BiasMode::INTERNAL, BiasMode::PROVIDED), // | ||
| 1294 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 1295 | testing::PrintToStringParamName()); | ||
| 1296 | |||
| 1297 |
28/96✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✓ Branch 18 taken 8 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✓ Branch 22 taken 8 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 4 times.
✓ Branch 24 taken 8 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✓ Branch 26 taken 8 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✓ Branch 28 taken 8 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 8 times.
✓ Branch 30 taken 3120 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✓ Branch 32 taken 6240 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
|
9375 | INSTANTIATE_TEST_SUITE_P( |
| 1298 | VecMul, MatMulTest, | ||
| 1299 | testing::Combine( | ||
| 1300 | testing::ValuesIn(get_vecmul_methods()), | ||
| 1301 | testing::Values( | ||
| 1302 | MatMulShape{1, 16, 16}, // | ||
| 1303 | MatMulShape{1, 1, 20}, // | ||
| 1304 | MatMulShape{1, 16, 32}, // | ||
| 1305 | MatMulShape{1, 32, 17}, // | ||
| 1306 | MatMulShape{1, 33, 23}, // | ||
| 1307 | MatMulShape{1, 1500, 20}, // | ||
| 1308 | MatMulShape{1, 93, 56}, // | ||
| 1309 | MatMulShape{1, 1, 1}, // | ||
| 1310 | MatMulShape{1, 16, 1}, // | ||
| 1311 | MatMulShape{1, 32, 64}, // | ||
| 1312 | MatMulShape{1, 7, 74}, // | ||
| 1313 | MatMulShape{1, 800, 64}, // | ||
| 1314 | MatMulShape{1, 512, 130} // | ||
| 1315 | ), | ||
| 1316 | testing::Values( | ||
| 1317 | MatrixPortion(0, 0, 1, 1), // Full row. | ||
| 1318 | MatrixPortion(0, 0, 1, 0.5), // First half | ||
| 1319 | MatrixPortion(0, .4, 1, 0.3), // mid row-section. | ||
| 1320 | MatrixPortion(0, 0.75, 1, .25) // right row section | ||
| 1321 | ), | ||
| 1322 | testing::Values(BiasMode::PROVIDED), // | ||
| 1323 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 1324 | testing::PrintToStringParamName()); | ||
| 1325 | |||
| 1326 | } // namespace kai::test | ||
| 1327 |