test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <functional> | ||
| 13 | #include <limits> | ||
| 14 | #include <map> | ||
| 15 | #include <string_view> | ||
| 16 | #include <tuple> | ||
| 17 | #include <utility> | ||
| 18 | |||
| 19 | #include "kai/kai_common.h" | ||
| 20 | #include "test/common/abi_checker.hpp" | ||
| 21 | #include "test/common/buffer.hpp" | ||
| 22 | #include "test/common/compare.hpp" | ||
| 23 | #include "test/common/cpu_info.hpp" | ||
| 24 | #include "test/common/data_format.hpp" | ||
| 25 | #include "test/common/data_type.hpp" | ||
| 26 | #include "test/common/matmul_test_common.hpp" | ||
| 27 | #include "test/common/matrix_portion.hpp" | ||
| 28 | #include "test/common/printer.hpp" | ||
| 29 | #include "test/common/seed.hpp" | ||
| 30 | #include "test/common/sme.hpp" | ||
| 31 | #include "test/reference/cast.hpp" | ||
| 32 | #include "test/reference/clamp.hpp" | ||
| 33 | #include "test/reference/fill.hpp" | ||
| 34 | #include "test/reference/matmul.hpp" | ||
| 35 | #include "test/reference/pack.hpp" | ||
| 36 | |||
| 37 | // matmul_clamp_f32_bf16p_bf16p | ||
| 38 | #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h" | ||
| 39 | #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" | ||
| 40 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h" | ||
| 41 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.h" | ||
| 42 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.h" | ||
| 43 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.h" | ||
| 44 | #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h" | ||
| 45 | |||
| 46 | // SME files here. | ||
| 47 | #include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" | ||
| 48 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" | ||
| 49 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" | ||
| 50 | |||
| 51 | namespace kai::test { | ||
| 52 | |||
| 53 | /// List of supported matrix multiplication methods. | ||
| 54 | namespace { | ||
| 55 | |||
| 56 | 3 | static const std::array<MatMulMethod, 5>& get_gemm_methods() { | |
| 57 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
3 | static std::array<MatMulMethod, 5> gemm_methods{}; |
| 58 | gemm_methods[0].name = "matmul_nt_nt_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa"; | ||
| 59 | gemm_methods[0].m0 = 2 * get_sme_vector_length<float>(); | ||
| 60 | gemm_methods[0].n0 = 2 * get_sme_vector_length<float>(); | ||
| 61 | gemm_methods[0].k0 = 2; | ||
| 62 | gemm_methods[0].dst_format = DataFormat(DataType::FP32); | ||
| 63 | gemm_methods[0].lhs_format = DataFormat(DataType::FP32); | ||
| 64 | gemm_methods[0].packed_lhs_format = DataFormat( | ||
| 65 | DataType::BF16, 2 * get_sme_vector_length<float>(), 2, DataFormat::PackFormat::NONE, DataType::FP32, | ||
| 66 | DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 2); | ||
| 67 | gemm_methods[0].rhs_format = DataFormat(DataType::FP32); | ||
| 68 | gemm_methods[0].packed_rhs_format = DataFormat( | ||
| 69 | DataType::BF16, 2 * get_sme_vector_length<float>(), 2, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, | ||
| 70 | DataType::UNKNOWN, 2 * get_sme_vector_length<float>(), 2); | ||
| 71 | gemm_methods[0].bias_format = DataFormat(DataType::FP32); | ||
| 72 | gemm_methods[0].fn_is_supported = cpu_has_sme2; | ||
| 73 | gemm_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 74 | gemm_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 75 | gemm_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 76 | gemm_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 77 | gemm_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 78 | gemm_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
| 79 | gemm_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 80 | gemm_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme; | ||
| 81 | gemm_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme; | ||
| 82 | gemm_methods[0].fn_get_packed_lhs_offset = | ||
| 83 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 84 | gemm_methods[0].fn_pack_lhs = kai_run_lhs_pack_bf16p2vlx2_f32_sme; | ||
| 85 | gemm_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
| 86 | gemm_methods[0].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
| 87 | gemm_methods[0].fn_get_main_packed_rhs_offset = | ||
| 88 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 89 | gemm_methods[0].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
| 90 | gemm_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme; | ||
| 91 | gemm_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 92 | gemm_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 93 | gemm_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa; | ||
| 94 | |||
| 95 | gemm_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla"; | ||
| 96 | gemm_methods[1].m0 = 8; | ||
| 97 | gemm_methods[1].n0 = 12; | ||
| 98 | gemm_methods[1].k0 = 4; | ||
| 99 | gemm_methods[1].dst_format = DataFormat(DataType::FP32); | ||
| 100 | gemm_methods[1].lhs_format = DataFormat(DataType::FP32); | ||
| 101 | gemm_methods[1].packed_lhs_format = | ||
| 102 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); | ||
| 103 | gemm_methods[1].rhs_format = DataFormat(DataType::FP32); | ||
| 104 | gemm_methods[1].packed_rhs_format = DataFormat( | ||
| 105 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
| 106 | gemm_methods[1].bias_format = DataFormat(DataType::FP32); | ||
| 107 | gemm_methods[1].fn_is_supported = cpu_has_bf16; | ||
| 108 | gemm_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 109 | gemm_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 110 | gemm_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 111 | gemm_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 112 | gemm_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 113 | gemm_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 114 | gemm_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 115 | gemm_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; | ||
| 116 | gemm_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; | ||
| 117 | gemm_methods[1].fn_get_packed_lhs_offset = | ||
| 118 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 119 | gemm_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; | ||
| 120 | gemm_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 121 | gemm_methods[1].fn_get_packed_rhs_size_generic_block_size = | ||
| 122 | kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 123 | gemm_methods[1].fn_get_main_packed_rhs_offset = | ||
| 124 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 125 | gemm_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 126 | gemm_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 127 | gemm_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 128 | gemm_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 129 | gemm_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 130 | |||
| 131 | gemm_methods[2].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output"; | ||
| 132 | gemm_methods[2].m0 = 8; | ||
| 133 | gemm_methods[2].n0 = 12; | ||
| 134 | gemm_methods[2].k0 = 4; | ||
| 135 | gemm_methods[2].dst_format = DataFormat(DataType::FP32); | ||
| 136 | gemm_methods[2].lhs_format = DataFormat(DataType::FP16); | ||
| 137 | gemm_methods[2].packed_lhs_format = | ||
| 138 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); | ||
| 139 | gemm_methods[2].rhs_format = DataFormat(DataType::FP16); | ||
| 140 | gemm_methods[2].packed_rhs_format = DataFormat( | ||
| 141 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
| 142 | gemm_methods[2].bias_format = DataFormat(DataType::FP32); | ||
| 143 | gemm_methods[2].fn_is_supported = cpu_has_bf16; | ||
| 144 | gemm_methods[2].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 145 | gemm_methods[2].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 146 | gemm_methods[2].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 147 | gemm_methods[2].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 148 | gemm_methods[2].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 149 | gemm_methods[2].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 150 | gemm_methods[2].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 151 | gemm_methods[2].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; | ||
| 152 | gemm_methods[2].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; | ||
| 153 | gemm_methods[2].fn_get_packed_lhs_offset = | ||
| 154 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 155 | gemm_methods[2].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; | ||
| 156 | gemm_methods[2].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 157 | gemm_methods[2].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 158 | gemm_methods[2].fn_get_main_packed_rhs_offset = | ||
| 159 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 160 | gemm_methods[2].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 161 | gemm_methods[2].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 162 | gemm_methods[2].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 163 | gemm_methods[2].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 164 | gemm_methods[2].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 165 | |||
| 166 | gemm_methods[3].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_f16_inputs_f32_bias_and_output_opt_bias"; | ||
| 167 | gemm_methods[3].m0 = 8; | ||
| 168 | gemm_methods[3].n0 = 12; | ||
| 169 | gemm_methods[3].k0 = 4; | ||
| 170 | gemm_methods[3].dst_format = DataFormat(DataType::FP32); | ||
| 171 | gemm_methods[3].lhs_format = DataFormat(DataType::FP16); | ||
| 172 | gemm_methods[3].packed_lhs_format = | ||
| 173 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP16, DataType::UNKNOWN, 8, 4); | ||
| 174 | gemm_methods[3].rhs_format = DataFormat(DataType::FP16); | ||
| 175 | gemm_methods[3].packed_rhs_format = DataFormat( | ||
| 176 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
| 177 | gemm_methods[3].bias_format = DataFormat(DataType::UNKNOWN); | ||
| 178 | gemm_methods[3].fn_is_supported = cpu_has_bf16; | ||
| 179 | gemm_methods[3].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 180 | gemm_methods[3].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 181 | gemm_methods[3].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 182 | gemm_methods[3].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 183 | gemm_methods[3].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 184 | gemm_methods[3].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 185 | gemm_methods[3].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 186 | gemm_methods[3].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_bf16p8x4_f16_neon; | ||
| 187 | gemm_methods[3].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_bf16p8x4_f16_neon; | ||
| 188 | gemm_methods[3].fn_get_packed_lhs_offset = | ||
| 189 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 190 | gemm_methods[3].fn_pack_lhs = kai_run_lhs_pack_bf16p8x4_f16_neon; | ||
| 191 | gemm_methods[3].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 192 | gemm_methods[3].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 193 | gemm_methods[3].fn_get_main_packed_rhs_offset = | ||
| 194 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 195 | gemm_methods[3].fn_pack_rhs = kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 196 | gemm_methods[3].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon; | ||
| 197 | gemm_methods[3].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 198 | gemm_methods[3].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 199 | gemm_methods[3].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 200 | |||
| 201 | gemm_methods[4].name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias"; | ||
| 202 | gemm_methods[4].m0 = 8; | ||
| 203 | gemm_methods[4].n0 = 12; | ||
| 204 | gemm_methods[4].k0 = 4; | ||
| 205 | gemm_methods[4].dst_format = DataFormat(DataType::FP32); | ||
| 206 | gemm_methods[4].lhs_format = DataFormat(DataType::FP32); | ||
| 207 | gemm_methods[4].packed_lhs_format = | ||
| 208 | DataFormat(DataType::BF16, 8, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 8, 4); | ||
| 209 | gemm_methods[4].rhs_format = DataFormat(DataType::FP32); | ||
| 210 | gemm_methods[4].packed_rhs_format = DataFormat( | ||
| 211 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
| 212 | gemm_methods[4].bias_format = DataFormat(DataType::UNKNOWN); | ||
| 213 | gemm_methods[4].fn_is_supported = cpu_has_bf16; | ||
| 214 | gemm_methods[4].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 215 | gemm_methods[4].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 216 | gemm_methods[4].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 217 | gemm_methods[4].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 218 | gemm_methods[4].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 219 | gemm_methods[4].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 220 | gemm_methods[4].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 221 | gemm_methods[4].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon; | ||
| 222 | gemm_methods[4].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon; | ||
| 223 | gemm_methods[4].fn_get_packed_lhs_offset = | ||
| 224 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 225 | gemm_methods[4].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon; | ||
| 226 | gemm_methods[4].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 227 | gemm_methods[4].fn_get_packed_rhs_size_generic_block_size = | ||
| 228 | kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 229 | gemm_methods[4].fn_get_main_packed_rhs_offset = | ||
| 230 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 231 | gemm_methods[4].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 232 | gemm_methods[4].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 233 | gemm_methods[4].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 234 | gemm_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 235 | gemm_methods[4].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla; | ||
| 236 | |||
| 237 | return gemm_methods; | ||
| 238 | } | ||
| 239 | |||
| 240 | 3 | static const std::array<MatMulMethod, 2>& get_gemv_methods() { | |
| 241 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
3 | static std::array<MatMulMethod, 2> gemv_methods{}; |
| 242 | gemv_methods[0].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot"; | ||
| 243 | gemv_methods[0].m0 = 1; | ||
| 244 | gemv_methods[0].n0 = 12; | ||
| 245 | gemv_methods[0].k0 = 4; | ||
| 246 | gemv_methods[0].dst_format = DataFormat(DataType::FP32); | ||
| 247 | gemv_methods[0].lhs_format = DataFormat(DataType::FP32); | ||
| 248 | gemv_methods[0].packed_lhs_format = | ||
| 249 | DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); | ||
| 250 | gemv_methods[0].rhs_format = DataFormat(DataType::FP32); | ||
| 251 | gemv_methods[0].packed_rhs_format = DataFormat( | ||
| 252 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
| 253 | gemv_methods[0].bias_format = DataFormat(DataType::FP32); | ||
| 254 | gemv_methods[0].fn_is_supported = cpu_has_bf16; | ||
| 255 | gemv_methods[0].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 256 | gemv_methods[0].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 257 | gemv_methods[0].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 258 | gemv_methods[0].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 259 | gemv_methods[0].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 260 | gemv_methods[0].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 261 | gemv_methods[0].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 262 | gemv_methods[0].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; | ||
| 263 | gemv_methods[0].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; | ||
| 264 | gemv_methods[0].fn_get_packed_lhs_offset = | ||
| 265 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 266 | gemv_methods[0].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; | ||
| 267 | gemv_methods[0].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 268 | gemv_methods[0].fn_get_packed_rhs_size_generic_block_size = | ||
| 269 | kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 270 | gemv_methods[0].fn_get_main_packed_rhs_offset = | ||
| 271 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 272 | gemv_methods[0].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 273 | gemv_methods[0].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 274 | gemv_methods[0].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 275 | gemv_methods[0].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 276 | gemv_methods[0].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 277 | |||
| 278 | gemv_methods[1].name = "matmul_nt_nt_f32_bf16p_bf16p_1x36_neon_dot_opt_bias"; | ||
| 279 | gemv_methods[1].m0 = 1; | ||
| 280 | gemv_methods[1].n0 = 12; | ||
| 281 | gemv_methods[1].k0 = 4; | ||
| 282 | gemv_methods[1].dst_format = DataFormat(DataType::FP32); | ||
| 283 | gemv_methods[1].lhs_format = DataFormat(DataType::FP32); | ||
| 284 | gemv_methods[1].packed_lhs_format = | ||
| 285 | DataFormat(DataType::BF16, 1, 4, DataFormat::PackFormat::NONE, DataType::FP32, DataType::UNKNOWN, 1, 4); | ||
| 286 | gemv_methods[1].rhs_format = DataFormat(DataType::FP32); | ||
| 287 | gemv_methods[1].packed_rhs_format = DataFormat( | ||
| 288 | DataType::BF16, 12, 4, DataFormat::PackFormat::BIAS_PER_ROW, DataType::FP32, DataType::UNKNOWN, 12, 4); | ||
| 289 | gemv_methods[1].bias_format = DataFormat(DataType::UNKNOWN); | ||
| 290 | gemv_methods[1].fn_is_supported = cpu_has_bf16; | ||
| 291 | gemv_methods[1].fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 292 | gemv_methods[1].fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 293 | gemv_methods[1].fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 294 | gemv_methods[1].fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 295 | gemv_methods[1].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 296 | gemv_methods[1].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 297 | gemv_methods[1].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 298 | gemv_methods[1].fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon; | ||
| 299 | gemv_methods[1].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon; | ||
| 300 | gemv_methods[1].fn_get_packed_lhs_offset = | ||
| 301 | kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 302 | gemv_methods[1].fn_pack_lhs = kai_run_lhs_quant_pack_bf16p1x4_f32_neon; | ||
| 303 | gemv_methods[1].fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 304 | gemv_methods[1].fn_get_packed_rhs_size_generic_block_size = | ||
| 305 | kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 306 | gemv_methods[1].fn_get_main_packed_rhs_offset = | ||
| 307 | kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 308 | gemv_methods[1].fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 309 | gemv_methods[1].fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon; | ||
| 310 | gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 311 | gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 312 | gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; | ||
| 313 | |||
| 314 | return gemv_methods; | ||
| 315 | } | ||
| 316 | |||
| 317 | } // namespace | ||
| 318 | |||
| 319 | /// Matrix multiplication test fixture. | ||
| 320 | class MatMulTestBf16 : public testing::TestWithParam<MatMulClampTestParams> { | ||
| 321 | private: | ||
| 322 | /// Unique ID: m, n, k | ||
| 323 | using TestDataId = std::tuple<size_t, size_t, size_t, float, std::string_view>; | ||
| 324 | |||
| 325 | protected: | ||
| 326 | /// Cached test data that is shared between multiple test case. | ||
| 327 | 414 | struct TestData { | |
| 328 | 828 | Buffer lhs{}; ///< LHS operand. | |
| 329 | 828 | Buffer ref_packed_lhs{}; ///< Reference packed LHS. | |
| 330 | 828 | Buffer rhs{}; ///< RHS operand. | |
| 331 | 828 | Buffer rhs_scales{}; ///< RHS per-row quantization scales. | |
| 332 | 828 | Buffer bias{}; ///< Bias. | |
| 333 | 828 | Buffer ref_packed_rhs{}; ///< Reference packed RHS. | |
| 334 | 828 | Buffer ref_dst{}; ///< Reference output. | |
| 335 | 828 | Buffer ref_clamped{}; ///< Reference clamped. | |
| 336 | Range<float> clamp_range; ///< Clamping Range. | ||
| 337 | }; | ||
| 338 | |||
| 339 | /// Gets the test data for the current test case. | ||
| 340 | 1926 | static const TestData& test_data() { | |
| 341 | 30960 | const auto& [method, info, portion, bias_mode, clamp_keep_ratio] = GetParam(); | |
| 342 | 9630 | const TestDataId data_id{info.m, info.n, info.k, clamp_keep_ratio, method.name}; | |
| 343 | |||
| 344 | // Creates a unique seed for the test data. | ||
| 345 |
12/24✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1926 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1926 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1926 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1926 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1926 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1926 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1926 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 888 times.
✗ Branch 23 not taken.
|
7704 | const auto key = std::string(method.name) + "_" + std::to_string(info.m) + "x" + std::to_string(info.n) + "x" + |
| 346 |
8/14✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1038 times.
✓ Branch 7 taken 888 times.
✓ Branch 8 taken 1926 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1926 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1926 times.
✗ Branch 13 not taken.
|
7704 | std::to_string(info.k) + "_" + (bias_mode == BiasMode::INTERNAL ? "internal" : "provided") + ":" + |
| 347 |
2/4✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
|
3852 | std::to_string(clamp_keep_ratio); |
| 348 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | auto& feed = seed_stream(key); |
| 349 | |||
| 350 | // If the test data is already available, returns it. | ||
| 351 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | const auto data_it = _data.find(data_id); |
| 352 | |||
| 353 |
4/4✓ Branch 0 taken 1734 times.
✓ Branch 1 taken 192 times.
✓ Branch 2 taken 816 times.
✓ Branch 3 taken 222 times.
|
1926 | if (data_it != _data.end()) { |
| 354 |
1/2✓ Branch 0 taken 816 times.
✗ Branch 1 not taken.
|
1512 | return data_it->second; |
| 355 | } | ||
| 356 | |||
| 357 | // Generates the test data. | ||
| 358 |
2/4✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
|
828 | const auto has_lhs_pack = method.packed_lhs_format.data_type() != DataType::UNKNOWN; |
| 359 |
2/4✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
|
828 | const auto has_rhs_pack = method.packed_rhs_format.data_type() != DataType::UNKNOWN; |
| 360 |
2/4✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
|
828 | const auto has_bias = method.bias_format.data_type() != DataType::UNKNOWN; |
| 361 | |||
| 362 | 828 | const auto lhs_h = info.m; | |
| 363 | 828 | const auto lhs_w = info.k; | |
| 364 |
3/6✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
|
828 | auto lhs = fill_matrix_random(lhs_h, lhs_w, method.lhs_format, feed()); |
| 365 | 414 | Buffer ref_packed_lhs; | |
| 366 | |||
| 367 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 414 times.
|
414 | if (has_lhs_pack) { |
| 368 | 414 | ref_packed_lhs = | |
| 369 |
3/6✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
|
828 | pack(method.packed_lhs_format, lhs.data(), nullptr, nullptr, method.lhs_format, lhs_h, lhs_w); |
| 370 | 414 | } | |
| 371 | |||
| 372 | 828 | const auto rhs_h = info.k; | |
| 373 | 828 | const auto rhs_w = info.n; | |
| 374 |
3/6✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
|
828 | auto rhs = fill_matrix_random(rhs_h, rhs_w, method.rhs_format, feed()); |
| 375 | |||
| 376 | 414 | Buffer rhs_scales; | |
| 377 |
3/8✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 414 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
414 | if (data_type_is_quantized(method.rhs_format.data_type()) && |
| 378 | ✗ | method.rhs_format.pack_format() == DataFormat::PackFormat::NONE) { | |
| 379 | ✗ | rhs_scales = fill_matrix_random(rhs_h, 1, DataFormat(DataType::FP32), feed()); | |
| 380 | ✗ | } | |
| 381 | |||
| 382 | 414 | const auto bias_h = 1; | |
| 383 | 828 | const auto bias_w = info.n; | |
| 384 | 414 | Buffer bias; | |
| 385 | |||
| 386 |
2/2✓ Branch 0 taken 192 times.
✓ Branch 1 taken 222 times.
|
414 | if (has_bias) { |
| 387 |
3/6✓ Branch 0 taken 222 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 222 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 222 times.
✗ Branch 5 not taken.
|
444 | bias = fill_matrix_random(bias_h, bias_w, method.bias_format, feed()); |
| 388 | 222 | } | |
| 389 | |||
| 390 | 414 | constexpr size_t nr = 12; | |
| 391 | 414 | constexpr size_t kr = 4; | |
| 392 | |||
| 393 | 414 | size_t packed_rhs_size = 0; | |
| 394 | |||
| 395 |
2/2✓ Branch 0 taken 150 times.
✓ Branch 1 taken 264 times.
|
414 | if (method.fn_get_packed_rhs_size) { |
| 396 |
1/2✓ Branch 0 taken 150 times.
✗ Branch 1 not taken.
|
150 | packed_rhs_size = method.fn_get_packed_rhs_size(rhs_w, rhs_h); |
| 397 |
1/2✓ Branch 0 taken 264 times.
✗ Branch 1 not taken.
|
414 | } else if (method.fn_get_packed_rhs_size_generic_block_size) { |
| 398 |
1/2✓ Branch 0 taken 264 times.
✗ Branch 1 not taken.
|
264 | packed_rhs_size = method.fn_get_packed_rhs_size_generic_block_size(rhs_w, rhs_h, nr, kr); |
| 399 | 264 | } else { | |
| 400 | − | KAI_ERROR("No function to calculate Packed Rhs Matrix Size"); | |
| 401 | } | ||
| 402 | |||
| 403 |
1/2✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
|
414 | Buffer packed_rhs(packed_rhs_size); |
| 404 | |||
| 405 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 414 times.
|
414 | if (has_rhs_pack) { |
| 406 |
2/4✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
|
828 | const auto ref_rhs_row_stride = method.rhs_format.default_row_stride(rhs_w); |
| 407 |
1/2✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
|
414 | method.pack_rhs( |
| 408 |
5/8✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 222 times.
✓ Branch 5 taken 192 times.
✓ Branch 6 taken 222 times.
✗ Branch 7 not taken.
|
828 | info.n, info.k, rhs.data(), ref_rhs_row_stride, has_bias ? bias.data() : nullptr, nullptr, |
| 409 |
1/2✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
|
414 | packed_rhs.data()); |
| 410 | 414 | } | |
| 411 | |||
| 412 | − | KAI_ASSUME_ALWAYS(method.lhs_format.is_raw()); | |
| 413 | − | KAI_ASSUME_ALWAYS(method.rhs_format.is_raw()); | |
| 414 | − | KAI_ASSUME_ALWAYS(method.dst_format.is_raw()); | |
| 415 | |||
| 416 | 414 | Buffer tmp_lhs; | |
| 417 | 414 | Buffer tmp_rhs; | |
| 418 |
1/2✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
|
414 | const void* p_lhs_buff = lhs.data(); |
| 419 |
1/2✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
|
414 | const void* p_rhs_buff = rhs.data(); |
| 420 | |||
| 421 |
5/8✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 120 times.
✓ Branch 3 taken 294 times.
✓ Branch 4 taken 120 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 120 times.
✗ Branch 7 not taken.
|
414 | if (method.lhs_format.data_type() == DataType::FP32 || method.lhs_format.data_type() == DataType::FP16) { |
| 422 |
3/6✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
|
828 | tmp_lhs = cast(p_lhs_buff, method.lhs_format.data_type(), DataType::BF16, lhs_h, lhs_w); |
| 423 |
1/2✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
|
414 | p_lhs_buff = tmp_lhs.data(); |
| 424 | 414 | } | |
| 425 |
5/8✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 120 times.
✓ Branch 3 taken 294 times.
✓ Branch 4 taken 120 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 120 times.
✗ Branch 7 not taken.
|
414 | if (method.rhs_format.data_type() == DataType::FP32 || method.rhs_format.data_type() == DataType::FP16) { |
| 426 |
3/6✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
|
828 | tmp_rhs = cast(p_rhs_buff, method.rhs_format.data_type(), DataType::BF16, rhs_h, rhs_w); |
| 427 |
1/2✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
|
414 | p_rhs_buff = tmp_rhs.data(); |
| 428 | 414 | } | |
| 429 | |||
| 430 | 414 | auto ref_dst = | |
| 431 |
1/2✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
|
414 | matmul_nt_nt_quantized<BFloat16<>, float, float, BFloat16<>, float, float, float, float, float, float>( |
| 432 | 1656 | info.m, info.n, info.k, p_lhs_buff, nullptr, nullptr, 1, info.k, p_rhs_buff, nullptr, nullptr, 1, | |
| 433 |
1/2✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
|
414 | info.k, bias.data(), nullptr, nullptr, info.k); |
| 434 | |||
| 435 | 1656 | const auto [min, max] = | |
| 436 |
5/10✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 414 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 414 times.
✗ Branch 9 not taken.
|
414 | find_clamp_range(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_keep_ratio); |
| 437 | |||
| 438 |
5/10✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 414 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 414 times.
✗ Branch 9 not taken.
|
414 | auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), info.m * info.n, min, max); |
| 439 | |||
| 440 |
9/18✓ Branch 0 taken 414 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 414 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 414 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 414 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 414 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 414 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 414 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 414 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 414 times.
✗ Branch 17 not taken.
|
3726 | auto& data = _data[data_id] = {}; |
| 441 | 414 | data.lhs = std::move(lhs); | |
| 442 | 414 | data.ref_packed_lhs = std::move(ref_packed_lhs); | |
| 443 | 414 | data.rhs = std::move(rhs); | |
| 444 | 414 | data.rhs_scales = std::move(rhs_scales); | |
| 445 | 414 | data.bias = std::move(bias); | |
| 446 | 414 | data.ref_packed_rhs = std::move(packed_rhs); | |
| 447 | 414 | data.ref_dst = std::move(ref_dst); | |
| 448 | 414 | data.ref_clamped = std::move(ref_clamped); | |
| 449 | 1242 | data.clamp_range = {min, max}; | |
| 450 | |||
| 451 | 414 | return data; | |
| 452 | 1926 | } | |
| 453 | |||
| 454 | private: | ||
| 455 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 456 | static std::map<TestDataId, TestData> _data; | ||
| 457 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 458 | }; | ||
| 459 | |||
| 460 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 461 | 3 | std::map<MatMulTestBf16::TestDataId, MatMulTestBf16::TestData> MatMulTestBf16::_data; | |
| 462 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 463 | |||
| 464 | /// Tests the output. | ||
| 465 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5196 | TEST_P(MatMulTestBf16, Output) { |
| 466 | 66672 | const auto& [method, info, portion, clamp_keep_ratio, bias_mode] = GetParam(); | |
| 467 | |||
| 468 |
3/4✓ Branch 0 taken 2076 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✓ Branch 3 taken 150 times.
|
2076 | if (method.fn_is_supported && !method.fn_is_supported()) { |
| 469 |
3/6✓ Branch 0 taken 150 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 150 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 150 times.
✗ Branch 5 not taken.
|
150 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 470 | } | ||
| 471 | |||
| 472 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | if (!method.has_main_kernel()) { |
| 473 | ✗ | GTEST_SKIP() << "No main kernel available"; | |
| 474 | } | ||
| 475 | |||
| 476 | 1926 | const auto& data = test_data(); | |
| 477 | 3852 | const auto m_step = method.fn_get_main_m_step(); | |
| 478 |
4/16✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
|
3852 | ASSERT_TRUE(m_step % method.m0 == 0); |
| 479 | |||
| 480 | 3852 | const auto n_step = method.fn_get_main_n_step(); | |
| 481 |
4/16✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
|
3852 | ASSERT_TRUE(n_step % method.n0 == 0); |
| 482 | |||
| 483 | 5778 | const auto rect = portion.compute_portion(info.m, info.n, m_step, n_step); | |
| 484 | |||
| 485 |
2/4✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1926 times.
|
1926 | if (rect.height() == 0 || rect.width() == 0) { |
| 486 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 487 | } | ||
| 488 | |||
| 489 | 3852 | const size_t lhs_w = info.k; | |
| 490 | 1926 | const size_t rhs_w = rect.width(); | |
| 491 | 3852 | const size_t bias_w = info.n; | |
| 492 | 3852 | const size_t dst_w = info.n; | |
| 493 | 1926 | const bool has_bias = (data.bias.size() > 0); | |
| 494 | |||
| 495 | 1926 | const auto lhs_start_row = rect.start_row(); | |
| 496 | 3852 | const auto lhs_stride = method.lhs_format.default_row_stride(lhs_w); | |
| 497 | |||
| 498 | 11556 | const size_t lhs_packed_size = method.fn_get_packed_lhs_size(info.m, info.k, method.m0, method.k0, 1 /* sr */); | |
| 499 | 1926 | Buffer lhs_data(lhs_packed_size); | |
| 500 | |||
| 501 |
2/4✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
|
3852 | uintptr_t lhs_offset = method.fn_get_lhs_offset(lhs_start_row, lhs_stride); |
| 502 |
3/6✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
|
5778 | uintptr_t lhs_packed_offset = method.fn_get_packed_lhs_offset(lhs_start_row, info.k); |
| 503 | |||
| 504 | KAI_UNUSED(lhs_offset); | ||
| 505 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | abi_check( |
| 506 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | method.fn_pack_lhs, rect.height(), info.k, method.m0, method.k0, 1 /* sr */, 0 /* m_idx_start */, |
| 507 |
2/4✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
|
1926 | data.lhs.data() + lhs_offset, lhs_stride, lhs_data.data() + lhs_packed_offset); |
| 508 | |||
| 509 |
3/6✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
|
5778 | const auto rhs_stride = method.rhs_format.default_row_stride(info.n); |
| 510 | |||
| 511 | 1926 | size_t rhs_packed_size = 0; | |
| 512 | |||
| 513 |
2/2✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 750 times.
|
1926 | if (method.fn_get_packed_rhs_size_generic_block_size) { |
| 514 |
5/10✓ Branch 0 taken 1176 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1176 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1176 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1176 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1176 times.
✗ Branch 9 not taken.
|
5880 | rhs_packed_size = method.fn_get_packed_rhs_size_generic_block_size(info.n, info.k, method.n0, method.k0); |
| 515 |
1/2✓ Branch 0 taken 750 times.
✗ Branch 1 not taken.
|
1926 | } else if (method.fn_get_packed_rhs_size) { |
| 516 |
3/6✓ Branch 0 taken 750 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 750 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 750 times.
✗ Branch 5 not taken.
|
2250 | rhs_packed_size = method.fn_get_packed_rhs_size(info.n, info.k); |
| 517 | 750 | } | |
| 518 | |||
| 519 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | Buffer rhs_data(rhs_packed_size); |
| 520 | |||
| 521 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | const auto packed_rhs_start_row = rect.start_col(); |
| 522 | 1926 | const auto packed_rhs_start_col = 0; | |
| 523 | |||
| 524 |
3/6✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
|
3852 | uintptr_t rhs_offset = method.fn_get_rhs_offset(rect.start_col()); |
| 525 |
3/6✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
|
5778 | uintptr_t rhs_packed_offset = method.fn_get_main_packed_rhs_offset(packed_rhs_start_row, info.k); |
| 526 | 1926 | const auto ref_rhs_packed_offset = | |
| 527 |
2/4✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
|
3852 | method.packed_rhs_format.default_offset_in_bytes(packed_rhs_start_row, packed_rhs_start_col, info.k); |
| 528 | |||
| 529 |
4/16✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
|
1926 | ASSERT_EQ(rhs_packed_offset, ref_rhs_packed_offset); |
| 530 | |||
| 531 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | uintptr_t bias_offset = sizeof(float) * rect.start_col(); |
| 532 | |||
| 533 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | abi_check( |
| 534 | 1926 | method.fn_pack_rhs, | |
| 535 | 1926 | 1, // num_groups | |
| 536 | 5778 | rhs_w, info.k, method.n0, method.k0, | |
| 537 | 1926 | 1, // sr | |
| 538 |
4/6✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1038 times.
✓ Branch 3 taken 888 times.
✓ Branch 4 taken 1038 times.
✗ Branch 5 not taken.
|
1926 | rhs_stride, data.rhs.data() + rhs_offset, has_bias ? data.bias.data() + bias_offset : nullptr, |
| 539 | 1926 | nullptr, // Scale | |
| 540 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | rhs_data.data() + rhs_packed_offset, 0, nullptr); |
| 541 | |||
| 542 |
2/2✓ Branch 0 taken 888 times.
✓ Branch 1 taken 1038 times.
|
1926 | if (has_bias) { |
| 543 |
3/6✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1038 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1038 times.
✗ Branch 5 not taken.
|
2076 | const auto ref_bias_offset = method.bias_format.default_offset_in_bytes(0, rect.start_col(), bias_w); |
| 544 |
4/16✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1038 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1038 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1038 times.
|
1038 | ASSERT_EQ(ref_bias_offset, bias_offset); |
| 545 | 1038 | } | |
| 546 | |||
| 547 |
2/4✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
|
3852 | const auto dst_stride = method.dst_format.default_row_stride(dst_w); |
| 548 |
4/8✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
|
3852 | const auto dst_offset = method.fn_get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
| 549 |
4/8✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
|
3852 | const auto ref_dst_offset = method.dst_format.default_offset_in_bytes(rect.start_row(), rect.start_col(), dst_w); |
| 550 |
4/16✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
|
1926 | ASSERT_EQ(dst_offset, ref_dst_offset); |
| 551 | |||
| 552 |
4/8✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
|
7704 | const auto dst_size = method.fn_get_dst_size(info.m, info.n); |
| 553 |
4/8✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
|
7704 | const auto ref_dst_size = method.dst_format.default_size_in_bytes(info.m, info.n); |
| 554 |
4/16✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
|
1926 | ASSERT_EQ(dst_size, ref_dst_size); |
| 555 | |||
| 556 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | Buffer dst(dst_size); |
| 557 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | abi_check( |
| 558 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1926 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1926 times.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1926 times.
✗ Branch 9 not taken.
|
3852 | &MatMulMethod::main_kernel, method, rect.height(), rect.width(), info.k, lhs_data.data() + lhs_packed_offset, |
| 559 |
2/4✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
|
1926 | rhs_data.data() + rhs_packed_offset, nullptr, dst.data() + dst_offset, lhs_stride, rhs_stride, dst_stride, |
| 560 | 1926 | data.clamp_range.min, data.clamp_range.max); | |
| 561 | |||
| 562 |
1/2✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
|
1926 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
| 563 |
5/10✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1926 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1926 times.
✗ Branch 9 not taken.
|
1926 | const auto success = compare(dst.data(), data.ref_clamped.data(), method.dst_format, info.m, info.n, rect, handler); |
| 564 | |||
| 565 |
4/16✓ Branch 0 taken 1926 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1926 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1926 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1926 times.
|
1926 | ASSERT_TRUE(success); |
| 566 | 2076 | } | |
| 567 | |||
| 568 |
30/104✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✓ Branch 30 taken 2 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✓ Branch 32 taken 750 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✓ Branch 34 taken 1500 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
|
2256 | INSTANTIATE_TEST_SUITE_P( |
| 569 | MatMulGemm, MatMulTestBf16, | ||
| 570 | testing::Combine( | ||
| 571 | testing::ValuesIn(get_gemm_methods()), | ||
| 572 | testing::Values( | ||
| 573 | MatMulShape{1, 1, 1}, // Smallest Possible Shape | ||
| 574 | MatMulShape{3, 7, 3}, // Smaller than block size | ||
| 575 | MatMulShape{12, 8, 4}, // Same block size | ||
| 576 | MatMulShape{1, 1, 1023}, // Long K | ||
| 577 | MatMulShape{1013, 1, 5}, // Long M | ||
| 578 | MatMulShape{2, 1013, 6}, // Long N | ||
| 579 | MatMulShape{13, 33, 23}, // | ||
| 580 | MatMulShape{93, 57, 89}, // | ||
| 581 | MatMulShape{256, 256, 256}, // Nice shapes | ||
| 582 | MatMulShape{257, 113, 373} // Prime numbers | ||
| 583 | ), | ||
| 584 | testing::Values( | ||
| 585 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 586 | MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. | ||
| 587 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
| 588 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
| 589 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
| 590 | ), | ||
| 591 | testing::Values(BiasMode::PROVIDED), // | ||
| 592 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 593 | testing::PrintToStringParamName()); | ||
| 594 | |||
| 595 |
28/96✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✓ Branch 30 taken 288 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✓ Branch 32 taken 576 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
|
870 | INSTANTIATE_TEST_SUITE_P( |
| 596 | MatMulGemv, MatMulTestBf16, | ||
| 597 | testing::Combine( | ||
| 598 | testing::ValuesIn(get_gemv_methods()), | ||
| 599 | testing::Values( | ||
| 600 | MatMulShape{1, 1, 1}, // Smallest Possible Shape | ||
| 601 | MatMulShape{1, 1, 1023}, // Long K | ||
| 602 | MatMulShape{1, 1023, 1}, // Long N | ||
| 603 | MatMulShape{1, 1013, 1023}, // Large Rhs | ||
| 604 | MatMulShape{1, 37, 23}, // | ||
| 605 | MatMulShape{1, 57, 89}, // | ||
| 606 | MatMulShape{1, 36, 89}, // | ||
| 607 | MatMulShape{1, 98, 23}, // | ||
| 608 | MatMulShape{1, 64, 1024}, // Nice shapes - Long Rhs Rect | ||
| 609 | MatMulShape{1, 1024, 64}, // Nice shapes - Wide Rhs Rect | ||
| 610 | MatMulShape{1, 256, 256}, // Nice shapes - Square | ||
| 611 | MatMulShape{1, 113, 373} // Prime numbers | ||
| 612 | ), | ||
| 613 | testing::Values( | ||
| 614 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 615 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
| 616 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
| 617 | MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle | ||
| 618 | ), | ||
| 619 | testing::Values(BiasMode::PROVIDED), // | ||
| 620 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 621 | testing::PrintToStringParamName()); | ||
| 622 | } // namespace kai::test | ||
| 623 |