test/tests/imatmul_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <initializer_list> | ||
| 12 | #include <string_view> | ||
| 13 | #include <tuple> | ||
| 14 | #include <unordered_map> | ||
| 15 | |||
| 16 | #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" | ||
| 17 | #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h" | ||
| 18 | #include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h" | ||
| 19 | #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" | ||
| 20 | #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" | ||
| 21 | #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h" | ||
| 22 | #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" | ||
| 23 | #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" | ||
| 24 | #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" | ||
| 25 | #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" | ||
| 26 | #include "test/common/abi_checker.hpp" | ||
| 27 | #include "test/common/buffer.hpp" | ||
| 28 | #include "test/common/compare.hpp" | ||
| 29 | #include "test/common/cpu_info.hpp" | ||
| 30 | #include "test/common/matmul_test_common.hpp" | ||
| 31 | #include "test/common/matrix_portion.hpp" | ||
| 32 | #include "test/common/memory.hpp" | ||
| 33 | #include "test/common/round.hpp" | ||
| 34 | #include "test/common/seed.hpp" | ||
| 35 | #include "test/common/sme.hpp" | ||
| 36 | #include "test/reference/clamp.hpp" | ||
| 37 | #include "test/reference/fill.hpp" | ||
| 38 | #include "test/reference/matmul.hpp" | ||
| 39 | #include "test/reference/reorder.hpp" | ||
| 40 | |||
| 41 | namespace kai::test { | ||
| 42 | |||
| 43 | // Ensure static linkage for all functionality local to this test file | ||
| 44 | namespace { | ||
| 45 | |||
| 46 | /// Convenience wrapper for K-chunk handling | ||
| 47 | struct KChunk { | ||
| 48 | size_t count; | ||
| 49 | size_t length; | ||
| 50 | }; | ||
| 51 | |||
| 52 | /// Interface for indirect matmul LHS packing micro-kernel | ||
| 53 |
2/4✓ Branch 0 taken 140904 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 140904 times.
✗ Branch 3 not taken.
|
140904 | struct LhsPackIndirectKernel { |
| 54 | std::function<size_t()> get_m_step; | ||
| 55 | std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset; | ||
| 56 | std::function<size_t(size_t m, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_size; | ||
| 57 | std::function<void( | ||
| 58 | size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, | ||
| 59 | const void* zero, void* lhs_packed)> | ||
| 60 | pack; | ||
| 61 | }; | ||
| 62 | |||
| 63 | /// Interface for indirect matmul RHS packing micro-kernel | ||
| 64 |
4/8✓ Branch 0 taken 140904 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 140904 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 140904 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 140904 times.
✗ Branch 7 not taken.
|
140904 | struct RhsPackIndirectKernel { |
| 65 | std::function<size_t()> get_n_step; | ||
| 66 | std::function<size_t(size_t n_idx)> get_rhs_offset; | ||
| 67 | std::function<size_t(size_t n_idx)> get_bias_offset; | ||
| 68 | std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset; | ||
| 69 | std::function<size_t(size_t n, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_size; | ||
| 70 | std::function<void( | ||
| 71 | size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_row_stride, const void* rhs, const void* bias, | ||
| 72 | void* rhs_packed)> | ||
| 73 | pack; | ||
| 74 | }; | ||
| 75 | |||
| 76 | /// Interface for indirect matmul kernel | ||
| 77 |
8/16✓ Branch 0 taken 140904 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 140904 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 140904 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 140904 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 140904 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 140904 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 140904 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 140904 times.
✗ Branch 15 not taken.
|
140904 | struct MatMulIndirectKernel { |
| 78 | std::function<size_t(void)> get_m_step; | ||
| 79 | std::function<size_t(void)> get_n_step; | ||
| 80 | std::function<size_t(void)> get_mr; | ||
| 81 | std::function<size_t(void)> get_nr; | ||
| 82 | std::function<size_t(void)> get_kr; | ||
| 83 | std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset; | ||
| 84 | std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset; | ||
| 85 | std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride_row)> get_dst_offset; | ||
| 86 | std::function<size_t(size_t m, size_t n)> get_dst_size; | ||
| 87 | std::function<void( | ||
| 88 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, | ||
| 89 | void* dst, size_t dst_stride_row, float clamp_min, float clamp_max)> | ||
| 90 | imatmul; | ||
| 91 | }; | ||
| 92 | |||
| 93 | /// Description of a Indirect Matmul kernel set | ||
| 94 |
2/4✓ Branch 0 taken 140904 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 140904 times.
✗ Branch 3 not taken.
|
140904 | struct IndirectMatMul { |
| 95 | std::string_view name; | ||
| 96 | std::function<bool(void)> is_supported; | ||
| 97 | |||
| 98 | MatMulShape pack_shape; | ||
| 99 | struct Format { | ||
| 100 | DataFormat lhs; | ||
| 101 | DataFormat rhs; | ||
| 102 | DataFormat bias; | ||
| 103 | DataFormat out; | ||
| 104 | |||
| 105 | struct Hash { | ||
| 106 | 11152 | size_t operator()(const Format& format) const { | |
| 107 | 11152 | return // | |
| 108 | 22304 | (DataFormat::Hash{}(format.lhs) << 0) ^ // | |
| 109 | 22304 | (DataFormat::Hash{}(format.rhs) << 1) ^ // | |
| 110 | 22304 | (DataFormat::Hash{}(format.bias) << 2) ^ // | |
| 111 | 11152 | (DataFormat::Hash{}(format.out) << 3); | |
| 112 | } | ||
| 113 | }; | ||
| 114 | |||
| 115 | private: | ||
| 116 | 8500 | friend bool operator==(const Format& lhs, const Format& rhs) { | |
| 117 | 8500 | return // | |
| 118 |
1/2✓ Branch 0 taken 8500 times.
✗ Branch 1 not taken.
|
8500 | lhs.lhs == rhs.lhs && // |
| 119 |
1/2✓ Branch 0 taken 8500 times.
✗ Branch 1 not taken.
|
8500 | lhs.rhs == rhs.rhs && // |
| 120 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8500 times.
|
8500 | lhs.bias == rhs.bias && // |
| 121 | 8500 | lhs.out == rhs.out; | |
| 122 | } | ||
| 123 | } format; | ||
| 124 | |||
| 125 | LhsPackIndirectKernel lhs; | ||
| 126 | RhsPackIndirectKernel rhs; | ||
| 127 | MatMulIndirectKernel imatmul; | ||
| 128 | }; | ||
| 129 | |||
| 130 | /// Test parameter bundle type | ||
| 131 | using IndirectMatMulTestParams = std::tuple<IndirectMatMul, MatMulShape, size_t, MatrixPortion, float>; | ||
| 132 | |||
| 133 | /// Test type | ||
| 134 | using IndirectMatMulTest = testing::TestWithParam<IndirectMatMulTestParams>; | ||
| 135 | |||
| 136 | /// Use interface for matmul kernel | ||
| 137 | 18 | const kai_imatmul_clamp_f16_f16p_f16p_ukernel& get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa() { | |
| 138 | static kai_imatmul_clamp_f16_f16p_f16p_ukernel ukernel; | ||
| 139 | 18 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
| 140 | 18 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
| 141 | 18 | ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
| 142 | 18 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
| 143 | 18 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
| 144 | 18 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
| 145 | 18 | ukernel.run_imatmul = kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa; | |
| 146 | 18 | return ukernel; | |
| 147 | } | ||
| 148 | |||
| 149 | 18 | const kai_imatmul_clamp_f16_f16p_f16p_ukernel& get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa() { | |
| 150 | static kai_imatmul_clamp_f16_f16p_f16p_ukernel ukernel; | ||
| 151 | 18 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
| 152 | 18 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
| 153 | 18 | ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
| 154 | 18 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
| 155 | 18 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
| 156 | 18 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
| 157 | 18 | ukernel.run_imatmul = kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; | |
| 158 | 18 | return ukernel; | |
| 159 | } | ||
| 160 | |||
| 161 | /// Use interface for matmul kernel | ||
| 162 | 18 | const kai_imatmul_clamp_f32_f32p_f32p_ukernel& get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() { | |
| 163 | static kai_imatmul_clamp_f32_f32p_f32p_ukernel ukernel; | ||
| 164 | 18 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
| 165 | 18 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
| 166 | 18 | ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
| 167 | 18 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
| 168 | 18 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
| 169 | 18 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
| 170 | 18 | ukernel.run_imatmul = kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa; | |
| 171 | 18 | return ukernel; | |
| 172 | } | ||
| 173 | |||
| 174 | 18 | const kai_imatmul_clamp_f32_f32p_f32p_ukernel& get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa() { | |
| 175 | static kai_imatmul_clamp_f32_f32p_f32p_ukernel ukernel; | ||
| 176 | 18 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
| 177 | 18 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
| 178 | 18 | ukernel.get_lhs_packed_offset = kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
| 179 | 18 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
| 180 | 18 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
| 181 | 18 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
| 182 | 18 | ukernel.run_imatmul = kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; | |
| 183 | 18 | return ukernel; | |
| 184 | } | ||
| 185 | |||
| 186 | /// Retreive the test list | ||
| 187 | 18 | const auto& get_indirect_matmul_methods() { | |
| 188 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 15 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
18 | static std::array<IndirectMatMul, 4> indirect_matmul_methods{}; |
| 189 | |||
| 190 | // F16 IMATMUL SME2 /////////////////////////////////////////////////////// | ||
| 191 | 18 | indirect_matmul_methods[0].name = "imatmul_f16_f16p_f16p_2vlx2vl_sme2_mopa"; | |
| 192 | 18 | indirect_matmul_methods[0].is_supported = cpu_has_sme2; | |
| 193 | 18 | indirect_matmul_methods[0].pack_shape.m = 2 * get_sme_vector_length<int32_t>(); | |
| 194 | 18 | indirect_matmul_methods[0].pack_shape.n = 2 * get_sme_vector_length<int32_t>(); | |
| 195 | 18 | indirect_matmul_methods[0].pack_shape.k = sizeof(int32_t); | |
| 196 | 18 | indirect_matmul_methods[0].format.lhs = DataFormat(DataType::FP16); | |
| 197 | 18 | indirect_matmul_methods[0].format.rhs = DataFormat(DataType::FP16); | |
| 198 | 18 | indirect_matmul_methods[0].format.bias = DataFormat(DataType::FP16); | |
| 199 | 18 | indirect_matmul_methods[0].format.out = DataFormat(DataType::FP16); | |
| 200 | |||
| 201 | // LHS | ||
| 202 | 18 | indirect_matmul_methods[0].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
| 203 | 18 | indirect_matmul_methods[0].lhs.get_lhs_packed_offset = | |
| 204 | kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | ||
| 205 | 18 | indirect_matmul_methods[0].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
| 206 | 18 | indirect_matmul_methods[0].lhs.pack = kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
| 207 | |||
| 208 | // RHS | ||
| 209 | 18 | indirect_matmul_methods[0].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
| 210 | 18 | indirect_matmul_methods[0].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
| 211 | 18 | indirect_matmul_methods[0].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
| 212 | 18 | indirect_matmul_methods[0].rhs.get_rhs_packed_offset = | |
| 213 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 214 | 18 | indirect_matmul_methods[0].rhs.get_rhs_packed_size = | |
| 215 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 216 | 18 | indirect_matmul_methods[0].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
| 217 | |||
| 218 | // IMATMUL | ||
| 219 | 36 | const kai_imatmul_clamp_f16_f16p_f16p_ukernel& ukernel_f16_sme2 = | |
| 220 | 18 | get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(); | |
| 221 | 18 | indirect_matmul_methods[0].imatmul.get_m_step = ukernel_f16_sme2.get_m_step; | |
| 222 | 18 | indirect_matmul_methods[0].imatmul.get_n_step = ukernel_f16_sme2.get_n_step; | |
| 223 | 18 | indirect_matmul_methods[0].imatmul.get_lhs_packed_offset = ukernel_f16_sme2.get_lhs_packed_offset; | |
| 224 | 18 | indirect_matmul_methods[0].imatmul.get_rhs_packed_offset = ukernel_f16_sme2.get_rhs_packed_offset; | |
| 225 | 18 | indirect_matmul_methods[0].imatmul.get_dst_offset = ukernel_f16_sme2.get_dst_offset; | |
| 226 | 18 | indirect_matmul_methods[0].imatmul.get_dst_size = ukernel_f16_sme2.get_dst_size; | |
| 227 | 18 | indirect_matmul_methods[0].imatmul.imatmul = ukernel_f16_sme2.run_imatmul; | |
| 228 | |||
| 229 | // F32 IMATMUL SME2 /////////////////////////////////////////////////////// | ||
| 230 | 18 | indirect_matmul_methods[1].name = "imatmul_f32_f32p_f32p_2vlx2vl_sme2_mopa"; | |
| 231 | 18 | indirect_matmul_methods[1].is_supported = cpu_has_sme2; | |
| 232 | 18 | indirect_matmul_methods[1].pack_shape.m = 2 * get_sme_vector_length<int32_t>(); | |
| 233 | 18 | indirect_matmul_methods[1].pack_shape.n = 2 * get_sme_vector_length<int32_t>(); | |
| 234 | 18 | indirect_matmul_methods[1].pack_shape.k = sizeof(int32_t); | |
| 235 | 18 | indirect_matmul_methods[1].format.lhs = DataFormat(DataType::FP32); | |
| 236 | 18 | indirect_matmul_methods[1].format.rhs = DataFormat(DataType::FP32); | |
| 237 | 18 | indirect_matmul_methods[1].format.bias = DataFormat(DataType::FP32); | |
| 238 | 18 | indirect_matmul_methods[1].format.out = DataFormat(DataType::FP32); | |
| 239 | |||
| 240 | // LHS | ||
| 241 | 18 | indirect_matmul_methods[1].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
| 242 | 18 | indirect_matmul_methods[1].lhs.get_lhs_packed_offset = | |
| 243 | kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | ||
| 244 | 18 | indirect_matmul_methods[1].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
| 245 | 18 | indirect_matmul_methods[1].lhs.pack = kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
| 246 | |||
| 247 | // RHS | ||
| 248 | 18 | indirect_matmul_methods[1].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
| 249 | 18 | indirect_matmul_methods[1].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
| 250 | 18 | indirect_matmul_methods[1].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
| 251 | 18 | indirect_matmul_methods[1].rhs.get_rhs_packed_offset = | |
| 252 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | ||
| 253 | 18 | indirect_matmul_methods[1].rhs.get_rhs_packed_size = | |
| 254 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | ||
| 255 | 18 | indirect_matmul_methods[1].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
| 256 | |||
| 257 | // IMATMUL | ||
| 258 | 36 | const kai_imatmul_clamp_f32_f32p_f32p_ukernel& ukernel_f32_sme2 = | |
| 259 | 18 | get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); | |
| 260 | 18 | indirect_matmul_methods[1].imatmul.get_m_step = ukernel_f32_sme2.get_m_step; | |
| 261 | 18 | indirect_matmul_methods[1].imatmul.get_n_step = ukernel_f32_sme2.get_n_step; | |
| 262 | 18 | indirect_matmul_methods[1].imatmul.get_lhs_packed_offset = ukernel_f32_sme2.get_lhs_packed_offset; | |
| 263 | 18 | indirect_matmul_methods[1].imatmul.get_rhs_packed_offset = ukernel_f32_sme2.get_rhs_packed_offset; | |
| 264 | 18 | indirect_matmul_methods[1].imatmul.get_dst_offset = ukernel_f32_sme2.get_dst_offset; | |
| 265 | 18 | indirect_matmul_methods[1].imatmul.get_dst_size = ukernel_f32_sme2.get_dst_size; | |
| 266 | 18 | indirect_matmul_methods[1].imatmul.imatmul = ukernel_f32_sme2.run_imatmul; | |
| 267 | |||
| 268 | // F16 IMATMUL SME //////////////////////////////////////////////////////// | ||
| 269 | 18 | indirect_matmul_methods[2].name = "imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa"; | |
| 270 | 18 | indirect_matmul_methods[2].is_supported = cpu_has_sme; | |
| 271 | 18 | indirect_matmul_methods[2].pack_shape.m = 2 * get_sme_vector_length<int32_t>(); | |
| 272 | 18 | indirect_matmul_methods[2].pack_shape.n = 2 * get_sme_vector_length<int32_t>(); | |
| 273 | 18 | indirect_matmul_methods[2].pack_shape.k = sizeof(int32_t); | |
| 274 | 18 | indirect_matmul_methods[2].format.lhs = DataFormat(DataType::FP16); | |
| 275 | 18 | indirect_matmul_methods[2].format.rhs = DataFormat(DataType::FP16); | |
| 276 | 18 | indirect_matmul_methods[2].format.bias = DataFormat(DataType::FP16); | |
| 277 | 18 | indirect_matmul_methods[2].format.out = DataFormat(DataType::FP16); | |
| 278 | |||
| 279 | // LHS | ||
| 280 | 18 | indirect_matmul_methods[2].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
| 281 | 18 | indirect_matmul_methods[2].lhs.get_lhs_packed_offset = | |
| 282 | kai_get_lhs_packed_offset_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | ||
| 283 | 18 | indirect_matmul_methods[2].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
| 284 | 18 | indirect_matmul_methods[2].lhs.pack = kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme; | |
| 285 | |||
| 286 | // RHS | ||
| 287 | 18 | indirect_matmul_methods[2].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
| 288 | 18 | indirect_matmul_methods[2].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
| 289 | 18 | indirect_matmul_methods[2].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
| 290 | 18 | indirect_matmul_methods[2].rhs.get_rhs_packed_offset = | |
| 291 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 292 | 18 | indirect_matmul_methods[2].rhs.get_rhs_packed_size = | |
| 293 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | ||
| 294 | 18 | indirect_matmul_methods[2].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme; | |
| 295 | |||
| 296 | // IMATMUL | ||
| 297 | 36 | const kai_imatmul_clamp_f16_f16p_f16p_ukernel& ukernel_f16_sme = | |
| 298 | 18 | get_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(); | |
| 299 | 18 | indirect_matmul_methods[2].imatmul.get_m_step = ukernel_f16_sme.get_m_step; | |
| 300 | 18 | indirect_matmul_methods[2].imatmul.get_n_step = ukernel_f16_sme.get_n_step; | |
| 301 | 18 | indirect_matmul_methods[2].imatmul.get_lhs_packed_offset = ukernel_f16_sme.get_lhs_packed_offset; | |
| 302 | 18 | indirect_matmul_methods[2].imatmul.get_rhs_packed_offset = ukernel_f16_sme.get_rhs_packed_offset; | |
| 303 | 18 | indirect_matmul_methods[2].imatmul.get_dst_offset = ukernel_f16_sme.get_dst_offset; | |
| 304 | 18 | indirect_matmul_methods[2].imatmul.get_dst_size = ukernel_f16_sme.get_dst_size; | |
| 305 | 18 | indirect_matmul_methods[2].imatmul.imatmul = ukernel_f16_sme.run_imatmul; | |
| 306 | |||
| 307 | // F32 IMATMUL SME //////////////////////////////////////////////////////// | ||
| 308 | 18 | indirect_matmul_methods[3].name = "imatmul_f32_f32p_f32p_2vlx2vl_sme_mopa"; | |
| 309 | 18 | indirect_matmul_methods[3].is_supported = cpu_has_sme; | |
| 310 | 18 | indirect_matmul_methods[3].pack_shape.m = 2 * get_sme_vector_length<int32_t>(); | |
| 311 | 18 | indirect_matmul_methods[3].pack_shape.n = 2 * get_sme_vector_length<int32_t>(); | |
| 312 | 18 | indirect_matmul_methods[3].pack_shape.k = sizeof(int32_t); | |
| 313 | 18 | indirect_matmul_methods[3].format.lhs = DataFormat(DataType::FP32); | |
| 314 | 18 | indirect_matmul_methods[3].format.rhs = DataFormat(DataType::FP32); | |
| 315 | 18 | indirect_matmul_methods[3].format.bias = DataFormat(DataType::FP32); | |
| 316 | 18 | indirect_matmul_methods[3].format.out = DataFormat(DataType::FP32); | |
| 317 | |||
| 318 | // LHS | ||
| 319 | 18 | indirect_matmul_methods[3].lhs.get_m_step = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
| 320 | 18 | indirect_matmul_methods[3].lhs.get_lhs_packed_offset = | |
| 321 | kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | ||
| 322 | 18 | indirect_matmul_methods[3].lhs.get_lhs_packed_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
| 323 | 18 | indirect_matmul_methods[3].lhs.pack = kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme; | |
| 324 | |||
| 325 | // RHS | ||
| 326 | 18 | indirect_matmul_methods[3].rhs.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
| 327 | 18 | indirect_matmul_methods[3].rhs.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
| 328 | 18 | indirect_matmul_methods[3].rhs.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
| 329 | 18 | indirect_matmul_methods[3].rhs.get_rhs_packed_offset = | |
| 330 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | ||
| 331 | 18 | indirect_matmul_methods[3].rhs.get_rhs_packed_size = | |
| 332 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | ||
| 333 | 18 | indirect_matmul_methods[3].rhs.pack = kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme; | |
| 334 | |||
| 335 | // IMATMUL | ||
| 336 | 36 | const kai_imatmul_clamp_f32_f32p_f32p_ukernel& ukernel_f32_sme = | |
| 337 | 18 | get_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); | |
| 338 | 18 | indirect_matmul_methods[3].imatmul.get_m_step = ukernel_f32_sme.get_m_step; | |
| 339 | 18 | indirect_matmul_methods[3].imatmul.get_n_step = ukernel_f32_sme.get_n_step; | |
| 340 | 18 | indirect_matmul_methods[3].imatmul.get_lhs_packed_offset = ukernel_f32_sme.get_lhs_packed_offset; | |
| 341 | 18 | indirect_matmul_methods[3].imatmul.get_rhs_packed_offset = ukernel_f32_sme.get_rhs_packed_offset; | |
| 342 | 18 | indirect_matmul_methods[3].imatmul.get_dst_offset = ukernel_f32_sme.get_dst_offset; | |
| 343 | 18 | indirect_matmul_methods[3].imatmul.get_dst_size = ukernel_f32_sme.get_dst_size; | |
| 344 | 18 | indirect_matmul_methods[3].imatmul.imatmul = ukernel_f32_sme.run_imatmul; | |
| 345 | |||
| 346 | 18 | return indirect_matmul_methods; | |
| 347 | 18 | } | |
| 348 | |||
| 349 | /// Test reference identification | ||
| 350 | struct TestDataId { | ||
| 351 | MatMulShape shape; | ||
| 352 | MatMulShape pack_shape; | ||
| 353 | IndirectMatMul::Format format; | ||
| 354 | size_t k_chunk_length; | ||
| 355 | float clamp_keep_ratio; | ||
| 356 | |||
| 357 | struct Hash { | ||
| 358 | 11152 | size_t operator()(const TestDataId& test_id) const { | |
| 359 | 11152 | return // | |
| 360 | 22304 | (MatMulShape::Hash{}(test_id.shape) << 0) ^ // | |
| 361 | 22304 | (MatMulShape::Hash{}(test_id.pack_shape) << 1) ^ // | |
| 362 | 22304 | (IndirectMatMul::Format::Hash{}(test_id.format) << 2) ^ // | |
| 363 | 22304 | (std::hash<size_t>{}(test_id.k_chunk_length) << 3) ^ // | |
| 364 | 11152 | (std::hash<float>{}(test_id.clamp_keep_ratio) << 4); // | |
| 365 | } | ||
| 366 | }; | ||
| 367 | |||
| 368 | private: | ||
| 369 | 11166 | friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { | |
| 370 | 11166 | return // | |
| 371 |
2/2✓ Branch 0 taken 8500 times.
✓ Branch 1 taken 2666 times.
|
11166 | lhs.shape == rhs.shape && // |
| 372 |
1/2✓ Branch 0 taken 8500 times.
✗ Branch 1 not taken.
|
8500 | lhs.pack_shape == rhs.pack_shape && // |
| 373 |
1/2✓ Branch 0 taken 8500 times.
✗ Branch 1 not taken.
|
8500 | lhs.format == rhs.format && // |
| 374 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8500 times.
|
8500 | lhs.k_chunk_length == rhs.k_chunk_length && // |
| 375 | 8500 | lhs.clamp_keep_ratio == rhs.clamp_keep_ratio; | |
| 376 | } | ||
| 377 | }; | ||
| 378 | |||
| 379 | /// Test reference data | ||
| 380 | struct TestData { | ||
| 381 | Buffer lhs; ///< LHS input matrix | ||
| 382 | Buffer rhs; ///< RHS input matrix | ||
| 383 | Buffer bias; ///< Bias vector | ||
| 384 | Buffer out; ///< Reference imatmul result | ||
| 385 | Buffer indirection; ///< LHS indirection buffer | ||
| 386 | uintptr_t indirection_offset; ///< LHS indirection buffer offset | ||
| 387 | Buffer padding; ///< Padding buffer | ||
| 388 | Range<float> clamp_range; ///< Clamp range | ||
| 389 | }; | ||
| 390 | |||
| 391 | /// Reference data generator | ||
| 392 | /// | ||
| 393 | /// Uses test id to generate reference data, and caches it. | ||
| 394 | struct ReferenceGenerator { | ||
| 395 | /// Retrieve reference data for the provided test identification | ||
| 396 | 9384 | static const TestData& get_test_reference(const TestDataId& test_id) { | |
| 397 |
3/4✓ Branch 0 taken 1 time.
✓ Branch 1 taken 9383 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
9384 | static std::unordered_map<TestDataId, TestData, TestDataId::Hash> m_data; |
| 398 |
4/5✓ Branch 0 taken 8500 times.
✓ Branch 1 taken 884 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 8500 times.
✓ Branch 4 taken 884 times.
|
17884 | if (const auto itr = m_data.find(test_id); itr != end(m_data)) { |
| 399 | 8500 | return itr->second; | |
| 400 | } | ||
| 401 | |||
| 402 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
884 | return m_data[test_id] = generate_reference(test_id); |
| 403 | 9384 | } | |
| 404 | |||
| 405 | private: | ||
| 406 | /// Generate reference data. Not intended to be called | ||
| 407 | /// directly, as this would bypass caching mechanism. | ||
| 408 | 884 | static TestData generate_reference(const TestDataId& test_id) { | |
| 409 | 424646 | const auto& [chunked_shape, pack_shape, format, k_chunk_length, clamp_keep_ratio] = test_id; | |
| 410 | |||
| 411 | // The LHS matrix will be split into several chunks in the K dimension | ||
| 412 | 1768 | const size_t k_chunk_count = chunked_shape.k; | |
| 413 | 3536 | MatMulShape shape = {chunked_shape.m, chunked_shape.n, k_chunk_count * k_chunk_length}; | |
| 414 | |||
| 415 | // Stable key derived from the cache identifier. | ||
| 416 | 884 | const auto key_hash = static_cast<std::uint32_t>(TestDataId::Hash{}(test_id)); | |
| 417 |
2/4✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
|
884 | const auto key = std::string("imatmul_cache:") + std::to_string(key_hash); |
| 418 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
884 | auto& feed = seed_stream(key); |
| 419 | |||
| 420 | // Generate random input data | ||
| 421 |
3/6✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 884 times.
✗ Branch 5 not taken.
|
1768 | Buffer lhs = fill_matrix_random(shape.m, shape.k, format.lhs, feed()); |
| 422 |
3/6✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 884 times.
✗ Branch 5 not taken.
|
1768 | Buffer rhs = fill_matrix_random(shape.k, shape.n, format.rhs, feed()); |
| 423 |
3/6✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 884 times.
✗ Branch 5 not taken.
|
1768 | Buffer bias = fill_matrix_random(1, shape.n, format.bias, feed()); |
| 424 | |||
| 425 | // Data types used | ||
| 426 |
2/4✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
|
1768 | const DataType lhs_dt = format.lhs.data_type(); |
| 427 |
2/4✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
|
1768 | const DataType rhs_dt = format.rhs.data_type(); |
| 428 |
2/4✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
|
1768 | const DataType out_dt = format.out.data_type(); |
| 429 |
2/4✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
|
1768 | const DataType bias_dt = format.bias.data_type(); |
| 430 | |||
| 431 | // Create a padding chunk | ||
| 432 |
3/6✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 884 times.
✗ Branch 5 not taken.
|
1768 | const size_t k_chunk_size = round_up_division(k_chunk_length * data_type_size_in_bits(lhs_dt), 8); |
| 433 | 884 | const size_t row_size = k_chunk_count * k_chunk_size; | |
| 434 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
884 | Buffer lhs_padding(k_chunk_size); |
| 435 |
4/4✓ Branch 0 taken 16524 times.
✓ Branch 1 taken 884 times.
✓ Branch 2 taken 16524 times.
✓ Branch 3 taken 884 times.
|
17408 | for (size_t i = 0; i < k_chunk_length; i += 1) { |
| 436 | static constexpr double padding_value = 0; | ||
| 437 |
2/4✓ Branch 0 taken 16524 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 16524 times.
✗ Branch 3 not taken.
|
16524 | write_array(lhs_dt, lhs_padding.data(), i, padding_value); |
| 438 | 16524 | } | |
| 439 | |||
| 440 | // Set up indirection buffer | ||
| 441 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
884 | const uintptr_t indirection_offset = reinterpret_cast<uintptr_t>(lhs.data()); |
| 442 |
3/6✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 884 times.
✗ Branch 5 not taken.
|
2652 | std::vector<const void*> indirection(chunked_shape.m * chunked_shape.k); |
| 443 |
4/4✓ Branch 0 taken 28158 times.
✓ Branch 1 taken 884 times.
✓ Branch 2 taken 28158 times.
✓ Branch 3 taken 884 times.
|
29042 | for (size_t i_m = 0; i_m < chunked_shape.m; i_m += 1) { |
| 444 |
4/4✓ Branch 0 taken 28158 times.
✓ Branch 1 taken 333242 times.
✓ Branch 2 taken 28158 times.
✓ Branch 3 taken 333242 times.
|
361400 | for (size_t i_k = 0; i_k < chunked_shape.k; i_k += 1) { |
| 445 | 666484 | const size_t idx = i_m * chunked_shape.k + i_k; | |
| 446 | // Test padding pointers using first LHS row for shapes where M > 1 | ||
| 447 |
4/4✓ Branch 0 taken 330616 times.
✓ Branch 1 taken 2626 times.
✓ Branch 2 taken 321100 times.
✓ Branch 3 taken 9516 times.
|
333242 | if (chunked_shape.m > 1 && i_m == 0) { |
| 448 |
2/4✓ Branch 0 taken 9516 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9516 times.
✗ Branch 3 not taken.
|
9516 | indirection.at(idx) = lhs_padding.data(); |
| 449 | 9516 | } else { | |
| 450 | 323726 | uintptr_t offset = i_m * row_size + i_k * k_chunk_size; | |
| 451 |
1/2✓ Branch 0 taken 323726 times.
✗ Branch 1 not taken.
|
323726 | indirection.at(idx) = reinterpret_cast<const void*>(offset); |
| 452 | 323726 | } | |
| 453 | 333242 | } | |
| 454 | 28158 | } | |
| 455 | |||
| 456 | // Pack indirection buffer | ||
| 457 |
2/4✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
|
1768 | Buffer indirection_packed = reorder_block<const void*>( |
| 458 | 3536 | reinterpret_cast<const void* const*>(indirection.data()), chunked_shape.m, chunked_shape.k, pack_shape.m, | |
| 459 | 1); | ||
| 460 | |||
| 461 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
1768 | Buffer out = indirect_matmul( // |
| 462 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
884 | indirection.data(), indirection_offset, lhs_padding.data(), nullptr, nullptr, lhs_dt, // LHS |
| 463 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
884 | rhs.data(), nullptr, nullptr, rhs_dt, // RHS |
| 464 |
1/2✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
|
884 | bias.data(), nullptr, nullptr, bias_dt, // Bias |
| 465 | 884 | out_dt, // Out | |
| 466 | 3536 | chunked_shape.m, chunked_shape.n, chunked_shape.k, k_chunk_length); | |
| 467 | |||
| 468 | // Calculate clamping range based on full range of values, and then clamp values | ||
| 469 |
3/6✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 884 times.
✗ Branch 5 not taken.
|
884 | const auto [min, max] = find_clamp_range(out_dt, out.data(), shape.m * shape.n, clamp_keep_ratio); |
| 470 |
4/8✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 884 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 884 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 884 times.
✗ Branch 7 not taken.
|
884 | Buffer out_clamped = clamp(out_dt, out.data(), shape.m * shape.n, min, max); |
| 471 | |||
| 472 | // Populate reference data | ||
| 473 | 884 | TestData test_reference; | |
| 474 | 884 | test_reference.lhs = std::move(lhs); | |
| 475 | 884 | test_reference.rhs = std::move(rhs); | |
| 476 | 884 | test_reference.bias = std::move(bias); | |
| 477 | 884 | test_reference.padding = std::move(lhs_padding); | |
| 478 | 884 | test_reference.out = std::move(out_clamped); | |
| 479 | 884 | test_reference.indirection_offset = indirection_offset; | |
| 480 | 884 | test_reference.indirection = std::move(indirection_packed); | |
| 481 | 2652 | test_reference.clamp_range = {min, max}; | |
| 482 | |||
| 483 | 884 | return test_reference; | |
| 484 |
1/4✗ Branch 0 not taken.
✓ Branch 0 taken 884 times.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
|
884 | }; |
| 485 | }; | ||
| 486 | |||
| 487 | /// Perform LHS packing for indirect matmul | ||
| 488 | 9384 | Buffer pack_lhs( | |
| 489 | const LhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t m, | ||
| 490 | const KChunk& k_chunk) { | ||
| 491 | 9384 | const void* const* indirection_pointer = reinterpret_cast<const void* const*>(reference.indirection.data()); | |
| 492 | |||
| 493 | // Calculate size, and allocate buffer | ||
| 494 | 9384 | const size_t dst_size = kernel.get_lhs_packed_size(m, k_chunk.count, k_chunk.length); | |
| 495 | 9384 | Buffer dst(dst_size); | |
| 496 | |||
| 497 | // Calculate portion offsets | ||
| 498 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | const size_t input_offset = portion.start_row() * k_chunk.count; |
| 499 |
2/4✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
|
9384 | const size_t dst_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); |
| 500 | |||
| 501 | // Perform packing | ||
| 502 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | abi_check( |
| 503 | 9384 | kernel.pack, // Kernel | |
| 504 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | portion.height(), k_chunk.count, k_chunk.length, // Dimensions |
| 505 | 9384 | indirection_pointer + input_offset, // Indirection input | |
| 506 | 9384 | reference.indirection_offset, // Chunk offset | |
| 507 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | reference.padding.data(), // Padding pointer |
| 508 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | dst.data() + dst_offset); |
| 509 | 9384 | return dst; | |
| 510 | 9384 | } | |
| 511 | |||
| 512 | /// Perform RHS packign for indirect matmul | ||
| 513 | 9384 | Buffer pack_rhs( | |
| 514 | const RhsPackIndirectKernel& kernel, const Rect& portion, const TestData& reference, size_t n, | ||
| 515 | const KChunk& k_chunk, DataType type) { | ||
| 516 | // Calculate size, and allocate buffer | ||
| 517 | 9384 | const size_t row_stride = round_up_division(n * data_type_size_in_bits(type), 8); | |
| 518 | 9384 | const size_t dst_size = kernel.get_rhs_packed_size(n, k_chunk.count, k_chunk.length); | |
| 519 | 9384 | Buffer dst(dst_size); | |
| 520 | |||
| 521 | // Calculate offsets | ||
| 522 |
2/4✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
|
9384 | const size_t rhs_offset = kernel.get_rhs_offset(portion.start_col()); |
| 523 |
2/4✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
|
9384 | const size_t bias_offset = kernel.get_bias_offset(portion.start_col()); |
| 524 |
2/4✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
|
9384 | const size_t dst_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); |
| 525 | |||
| 526 | // Perform actual packing | ||
| 527 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | abi_check( |
| 528 | 9384 | kernel.pack, // Kernel | |
| 529 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | portion.width(), k_chunk.count, k_chunk.length, row_stride, // Dimensions |
| 530 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | reference.rhs.data() + rhs_offset, // RHS input |
| 531 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | reference.bias.data() + bias_offset, // Bias |
| 532 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | dst.data() + dst_offset); // Output |
| 533 | 9384 | return dst; | |
| 534 | 9384 | } | |
| 535 | |||
| 536 | /// Perform imatmul | ||
| 537 | /// | ||
| 538 | /// Note, this should not be aware of reference result, as to make it clear that | ||
| 539 | /// any produced result is strictly from the code under test | ||
| 540 | 9384 | Buffer imatmul( | |
| 541 | const MatMulIndirectKernel& kernel, const Rect& portion, const MatMulShape& shape, const KChunk& k_chunk, | ||
| 542 | const Buffer& lhs_packed, const Buffer& rhs_packed, Range<float> clamp_range, DataType type) { | ||
| 543 | // Calculate size, and allocate buffer | ||
| 544 | 9384 | const size_t dst_size = kernel.get_dst_size(shape.m, shape.n); | |
| 545 | 9384 | const size_t row_stride = round_up_division(shape.n * data_type_size_in_bits(type), 8); | |
| 546 | 9384 | Buffer dst(dst_size); | |
| 547 | |||
| 548 | // Calculate portion offsets | ||
| 549 |
2/4✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
|
9384 | const size_t lhs_offset = kernel.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); |
| 550 |
2/4✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
|
9384 | const size_t rhs_offset = kernel.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); |
| 551 |
3/6✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9384 times.
✗ Branch 5 not taken.
|
9384 | const size_t dst_offset = kernel.get_dst_offset(portion.start_row(), portion.start_col(), row_stride); |
| 552 | |||
| 553 | // Call matmul kernel | ||
| 554 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | abi_check( |
| 555 | 9384 | kernel.imatmul, // Kernel | |
| 556 |
2/4✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
|
9384 | portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions |
| 557 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | lhs_packed.data() + lhs_offset, // LHS |
| 558 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | rhs_packed.data() + rhs_offset, // RHS |
| 559 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | dst.data() + dst_offset, // DST |
| 560 | 9384 | row_stride, clamp_range.min, clamp_range.max); | |
| 561 | |||
| 562 | 9384 | return dst; | |
| 563 | 9384 | } | |
| 564 | |||
| 565 | } // namespace | ||
| 566 | |||
| 567 | /// End-to-end test for indirection matmul kernels | ||
| 568 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
46926 | TEST_P(IndirectMatMulTest, Output) { |
| 569 | 103224 | const auto& [method, shape, k_chunk_length, output_portion, clamp_keep_ratio] = GetParam(); | |
| 570 |
2/2✓ Branch 0 taken 9384 times.
✓ Branch 1 taken 9384 times.
|
18768 | if (not method.is_supported()) { |
| 571 |
3/6✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9384 times.
✗ Branch 5 not taken.
|
9384 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 572 | } | ||
| 573 | |||
| 574 | 28152 | const KChunk k_chunk{shape.k, k_chunk_length}; | |
| 575 | |||
| 576 | // Retrieve reference data | ||
| 577 | 56304 | const TestDataId test_id{shape, method.pack_shape, method.format, k_chunk_length, clamp_keep_ratio}; | |
| 578 | 9384 | const TestData& test_data = ReferenceGenerator::get_test_reference(test_id); | |
| 579 | 46920 | const Rect portion = output_portion.compute_portion(shape.m, shape.n, method.pack_shape.m, method.pack_shape.n); | |
| 580 | |||
| 581 |
2/4✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 9384 times.
|
9384 | if (portion.height() == 0 || portion.width() == 0) { |
| 582 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << portion.width() << "," << portion.height() << ")"; | |
| 583 | } | ||
| 584 | |||
| 585 | // Call packing micro-kernels, and then imatmul kernel | ||
| 586 | 28152 | Buffer lhs_packed = pack_lhs(method.lhs, portion, test_data, shape.m, k_chunk); | |
| 587 |
5/10✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9384 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9384 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 9384 times.
✗ Branch 9 not taken.
|
37536 | Buffer rhs_packed = pack_rhs(method.rhs, portion, test_data, shape.n, k_chunk, method.format.rhs.data_type()); |
| 588 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
18768 | Buffer out = imatmul( |
| 589 | 18768 | method.imatmul, portion, shape, k_chunk, lhs_packed, rhs_packed, test_data.clamp_range, | |
| 590 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | method.format.out.data_type()); |
| 591 | |||
| 592 | // Compare the actual result with the reference result | ||
| 593 |
1/2✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
|
9384 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); |
| 594 | 18768 | const auto success = | |
| 595 |
7/14✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9384 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 9384 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 9384 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9384 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 9384 times.
✗ Branch 13 not taken.
|
9384 | compare(out.data(), test_data.out.data(), method.format.out.data_type(), shape.m, shape.n, portion, handler); |
| 596 |
4/16✓ Branch 0 taken 9384 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9384 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9384 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 9384 times.
|
9384 | ASSERT_TRUE(success); |
| 597 | 18768 | } | |
| 598 | |||
| 599 | /// Name generator for test case | ||
| 600 | 56304 | [[maybe_unused]] static void PrintTo(const IndirectMatMulTestParams& param, std::ostream* os) { | |
| 601 | 337824 | const auto& [method, shape, k_chunk_length, portion, clamp_keep_ratio] = param; | |
| 602 | 112608 | *os << method.name << "__"; | |
| 603 | 56304 | PrintTo(shape, os); | |
| 604 | 112608 | *os << "__K_chunk_length_" << k_chunk_length; | |
| 605 | 112608 | *os << "__clamp_keep_ratio_" << static_cast<int>(clamp_keep_ratio * 100) << "__"; | |
| 606 | 56304 | PrintTo(portion, os); | |
| 607 | 56304 | } | |
| 608 | |||
| 609 | 18 | static auto get_indirect_matmul_shapes() { | |
| 610 | static const std::array indirect_matmul_shapes{ | ||
| 611 | // clang-format off | ||
| 612 | MatMulShape{ 1, 1, 1}, | ||
| 613 | MatMulShape{ 1, 17, 4}, | ||
| 614 | MatMulShape{ 1, 19, 24}, | ||
| 615 | MatMulShape{ 1, 32, 4}, | ||
| 616 | MatMulShape{ 1, 32, 32}, | ||
| 617 | MatMulShape{ 1, 33, 7}, | ||
| 618 | MatMulShape{ 1, 49, 21}, | ||
| 619 | MatMulShape{ 1, 64, 4}, | ||
| 620 | MatMulShape{ 1, 65, 4}, | ||
| 621 | MatMulShape{ 3, 6, 6}, | ||
| 622 | MatMulShape{ 3, 28, 25}, | ||
| 623 | MatMulShape{ 4, 16, 4}, | ||
| 624 | MatMulShape{ 4, 16, 27}, | ||
| 625 | MatMulShape{ 6, 18, 31}, | ||
| 626 | MatMulShape{ 6, 28, 1}, | ||
| 627 | MatMulShape{ 6, 29, 24}, | ||
| 628 | MatMulShape{ 8, 16, 16}, | ||
| 629 | MatMulShape{ 16, 16, 4}, | ||
| 630 | MatMulShape{ 16, 16, 16}, | ||
| 631 | MatMulShape{ 20, 30, 40}, | ||
| 632 | MatMulShape{ 23, 1, 43}, | ||
| 633 | MatMulShape{ 32, 14, 1}, | ||
| 634 | MatMulShape{ 32, 16, 27}, | ||
| 635 | MatMulShape{ 32, 32, 3}, | ||
| 636 | MatMulShape{ 32, 32, 4}, | ||
| 637 | MatMulShape{ 33, 29, 24}, | ||
| 638 | MatMulShape{ 64, 64, 3}, | ||
| 639 | MatMulShape{ 64, 64, 4}, | ||
| 640 | MatMulShape{ 96, 96, 3}, | ||
| 641 | MatMulShape{ 96, 97, 3}, | ||
| 642 | MatMulShape{ 97, 96, 3}, | ||
| 643 | MatMulShape{123, 85, 45}, | ||
| 644 | MatMulShape{128, 128, 3}, | ||
| 645 | MatMulShape{130, 130, 6}, | ||
| 646 | // clang-format on | ||
| 647 | }; | ||
| 648 | |||
| 649 | 18 | return indirect_matmul_shapes; | |
| 650 | } | ||
| 651 | |||
| 652 | 15 | static auto get_indirect_matmul_portions() { | |
| 653 | static const std::array<MatrixPortion, 6> indirect_matmul_portions{ | ||
| 654 | // (Start row , start col , height , width) | ||
| 655 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 656 | MatrixPortion(0, 0, 1, 0.5), // Left half | ||
| 657 | MatrixPortion(0, 0, 0.5, 1), // Upper half | ||
| 658 | MatrixPortion(0, 0.5, 1, 0.5), // Right half | ||
| 659 | MatrixPortion(0.5, 0, 0.5, 1), // Bottom half | ||
| 660 | MatrixPortion(0.4, 0.4, 0.3, 0.3), // Center ninth | ||
| 661 | }; | ||
| 662 | |||
| 663 | 15 | return indirect_matmul_portions; | |
| 664 | } | ||
| 665 | |||
| 666 | // Test suite focused on small K chunk | ||
| 667 |
24/72✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 5712 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 11424 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
|
17142 | INSTANTIATE_TEST_SUITE_P( |
| 668 | ShapesSmallKC, IndirectMatMulTest, | ||
| 669 | testing::Combine( | ||
| 670 | testing::ValuesIn(get_indirect_matmul_methods()), // | ||
| 671 | testing::ValuesIn(get_indirect_matmul_shapes()), // | ||
| 672 | testing::ValuesIn(std::initializer_list<size_t>{1, 2, 3, 4, 8, 11, 16}), // | ||
| 673 | testing::ValuesIn(get_indirect_matmul_portions()), // | ||
| 674 | testing::Values(0.5F)), // | ||
| 675 | testing::PrintToStringParamName()); | ||
| 676 | |||
| 677 | // Test suite focused on K chunk 31 | ||
| 678 |
20/64✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 816 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 1632 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
|
2454 | INSTANTIATE_TEST_SUITE_P( |
| 679 | ShapesKC31, IndirectMatMulTest, | ||
| 680 | testing::Combine( | ||
| 681 | testing::ValuesIn(get_indirect_matmul_methods()), // | ||
| 682 | testing::ValuesIn(get_indirect_matmul_shapes()), // | ||
| 683 | testing::Values(static_cast<size_t>(31)), // | ||
| 684 | testing::ValuesIn(get_indirect_matmul_portions()), // | ||
| 685 | testing::Values(0.5F)), // | ||
| 686 | testing::PrintToStringParamName()); | ||
| 687 | |||
| 688 | // Test suite focused on K chunk 32 | ||
| 689 |
20/64✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 816 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 1632 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
|
2454 | INSTANTIATE_TEST_SUITE_P( |
| 690 | ShapesKC32, IndirectMatMulTest, | ||
| 691 | testing::Combine( | ||
| 692 | testing::ValuesIn(get_indirect_matmul_methods()), // | ||
| 693 | testing::ValuesIn(get_indirect_matmul_shapes()), // | ||
| 694 | testing::Values(static_cast<size_t>(32)), // | ||
| 695 | testing::ValuesIn(get_indirect_matmul_portions()), // | ||
| 696 | testing::Values(0.5F)), // | ||
| 697 | testing::PrintToStringParamName()); | ||
| 698 | |||
| 699 | // Test suite focused on K chunk 64 | ||
| 700 |
20/64✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 816 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 1632 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
|
2454 | INSTANTIATE_TEST_SUITE_P( |
| 701 | ShapesKC64, IndirectMatMulTest, | ||
| 702 | testing::Combine( | ||
| 703 | testing::ValuesIn(get_indirect_matmul_methods()), // | ||
| 704 | testing::ValuesIn(get_indirect_matmul_shapes()), // | ||
| 705 | testing::Values(static_cast<size_t>(64)), // | ||
| 706 | testing::ValuesIn(get_indirect_matmul_portions()), // | ||
| 707 | testing::Values(0.5F)), // | ||
| 708 | testing::PrintToStringParamName()); | ||
| 709 | |||
| 710 | // Test suite focused on K chunk 65, other parametes are limited | ||
| 711 |
20/64✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 816 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 1632 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
|
2454 | INSTANTIATE_TEST_SUITE_P( |
| 712 | ShapesKC65, IndirectMatMulTest, | ||
| 713 | testing::Combine( | ||
| 714 | testing::ValuesIn(get_indirect_matmul_methods()), // | ||
| 715 | testing::ValuesIn(get_indirect_matmul_shapes()), // | ||
| 716 | testing::Values(static_cast<size_t>(65)), // | ||
| 717 | testing::ValuesIn(get_indirect_matmul_portions()), // | ||
| 718 | testing::Values(0.5F)), // | ||
| 719 | testing::PrintToStringParamName()); | ||
| 720 | |||
| 721 | // Test suite focused on clamping values, other parametes are limited | ||
| 722 |
22/72✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✓ Branch 24 taken 408 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✓ Branch 26 taken 816 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
|
1230 | INSTANTIATE_TEST_SUITE_P( |
| 723 | Clamp, IndirectMatMulTest, | ||
| 724 | testing::Combine( | ||
| 725 | testing::ValuesIn(get_indirect_matmul_methods()), // | ||
| 726 | testing::ValuesIn(get_indirect_matmul_shapes()), // | ||
| 727 | testing::Values(static_cast<size_t>(3)), // | ||
| 728 | testing::Values(MatrixPortion(0, 0, 1, 1)), // | ||
| 729 | testing::Values(1.0F, 0.9F, 0.5F)), // | ||
| 730 | testing::PrintToStringParamName()); | ||
| 731 | |||
| 732 | } // namespace kai::test | ||
| 733 |