test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <cstdlib> | ||
| 13 | #include <sstream> | ||
| 14 | #include <string> | ||
| 15 | #include <tuple> | ||
| 16 | |||
| 17 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa.h" | ||
| 18 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot.h" | ||
| 19 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod.h" | ||
| 20 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.h" | ||
| 21 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod.h" | ||
| 22 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h" | ||
| 23 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p_qai4c32p_interface.h" | ||
| 24 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h" | ||
| 25 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.h" | ||
| 26 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon.h" | ||
| 27 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon.h" | ||
| 28 | #include "test/common/abi_checker.hpp" | ||
| 29 | #include "test/common/buffer.hpp" | ||
| 30 | #include "test/common/compare.hpp" | ||
| 31 | #include "test/common/cpu_info.hpp" | ||
| 32 | #include "test/common/int4.hpp" | ||
| 33 | #include "test/common/matmul_test_common.hpp" | ||
| 34 | #include "test/common/matrix_portion.hpp" | ||
| 35 | #include "test/common/memory.hpp" | ||
| 36 | #include "test/common/round.hpp" | ||
| 37 | #include "test/common/seed.hpp" | ||
| 38 | #include "test/common/test_suite.hpp" | ||
| 39 | #include "test/reference/cast.hpp" | ||
| 40 | #include "test/reference/clamp.hpp" | ||
| 41 | #include "test/reference/fill.hpp" | ||
| 42 | #include "test/reference/matmul.hpp" | ||
| 43 | #include "test/reference/pack.hpp" | ||
| 44 | #include "test/reference/quantize.hpp" | ||
| 45 | |||
| 46 | namespace kai::test { | ||
| 47 | // Interface for the LHS and RHS packed size and packing micro-kernels | ||
| 48 | using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon); | ||
| 49 | using kai_get_rhs_packed_size_func_t = | ||
| 50 | decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); | ||
| 51 | using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon); | ||
| 52 | using kai_get_rhs_packed_offset_func_t = | ||
| 53 | decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); | ||
| 54 | using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon); | ||
| 55 | using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); | ||
| 56 | using kai_run_lhs_pack_func_t = decltype(&kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon); | ||
| 57 | using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon); | ||
| 58 | |||
| 59 | // Micro-kernel interface | ||
| 60 | struct kai_qai4c32p_pack_functions { | ||
| 61 | kai_get_rhs_packed_size_func_t packed_size; | ||
| 62 | kai_get_rhs_packed_offset_func_t get_packed_offset; | ||
| 63 | kai_get_rhs_offset_func_t get_offset; | ||
| 64 | kai_run_rhs_pack_func_t run_pack; | ||
| 65 | }; | ||
| 66 | |||
| 67 | struct kai_qsi8d32p_pack_functions { | ||
| 68 | kai_get_lhs_packed_size_func_t packed_size; | ||
| 69 | kai_get_lhs_packed_offset_func_t get_packed_offset; | ||
| 70 | kai_get_lhs_offset_func_t get_offset; | ||
| 71 | kai_run_lhs_pack_func_t run_pack; | ||
| 72 | }; | ||
| 73 | |||
| 74 | ✗ | static const std::array< | |
| 75 | UkernelMatmulPackVariant< | ||
| 76 | kai_matmul_clamp_f32_qsi8d32p_qai4c32p_ukernel, kai_qsi8d32p_pack_functions, kai_qai4c32p_pack_functions>, | ||
| 77 | 8> | ||
| 78 |
0/4✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
|
3 | variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p = { |
| 79 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
10 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 80 | clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod, cpu_has_dotprod, | ||
| 81 | lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), | ||
| 82 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 83 | clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32pscalef32_f32_neon, | ||
| 84 | rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), | ||
| 85 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 86 | clamp_f32_qsi8d32p4x4_qai4c32p4x4_8x4_neon_dotprod, cpu_has_dotprod, | ||
| 87 | lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), | ||
| 88 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 89 | clamp_f32_qsi8d32p1x4_qai4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod, | ||
| 90 | lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon, true), | ||
| 91 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 92 | clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f32_neon, | ||
| 93 | rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon, false), | ||
| 94 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 95 | clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, | ||
| 96 | lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s1s0_f32_f32_f32_neon, | ||
| 97 | false), | ||
| 98 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 99 | clamp_f32_qsi8d32p1x4_qai4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qsi8d32pscalef32_f32_neon, | ||
| 100 | rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon, true), | ||
| 101 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 102 | clamp_f32_qsi8d32p1vlx4_qai4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, | ||
| 103 | lhs_quant_pack_qsi8d32pscalef32_f32_neon, rhs_pack_nxk_qai4c32ps1s0nrx4_qau4c32s0s1_f32_f32_f32_neon, | ||
| 104 | true)}}; | ||
| 105 | |||
| 106 | // Executes the LHS packing micro-kernel. | ||
| 107 | 50148 | static inline std::tuple<Buffer, size_t> pack_lhs_qsi8d32p( | |
| 108 | const kai_qsi8d32p_pack_functions& pack_interface, size_t M, size_t K, size_t bl, size_t mr, size_t kr, size_t sr, | ||
| 109 | const Buffer& lhs_values_qsi8, size_t stride, size_t rect_start_row, size_t rect_height) { | ||
| 110 | 50148 | const auto imp_packed_lhs_size = pack_interface.packed_size(M, K, bl, mr, kr, sr); | |
| 111 | 50148 | Buffer imp_packed_lhs(imp_packed_lhs_size, 0); | |
| 112 | |||
| 113 |
1/2✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
|
50148 | auto lhs_offset = pack_interface.get_offset(rect_start_row, stride); |
| 114 |
1/2✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
|
50148 | auto lhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, bl, mr, kr, sr); |
| 115 | |||
| 116 |
1/2✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
|
50148 | abi_check( |
| 117 | 50148 | pack_interface.run_pack, rect_height, K, bl, mr, kr, sr, 0, | |
| 118 |
1/2✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
|
50148 | reinterpret_cast<const float*>(lhs_values_qsi8.data() + lhs_offset), stride, |
| 119 |
1/2✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
|
50148 | imp_packed_lhs.data() + lhs_packed_offset); |
| 120 | |||
| 121 | 50148 | return {std::move(imp_packed_lhs), lhs_packed_offset}; | |
| 122 | 50148 | } | |
| 123 | |||
| 124 | // Executes the RHS packing micro-kernel. | ||
| 125 | 12852 | static inline std::tuple<Buffer, size_t> pack_rhs_qai4c32p( | |
| 126 | const kai_qai4c32p_pack_functions& pack_interface, size_t N, size_t K, size_t bl, size_t nr, size_t kr, size_t sr, | ||
| 127 | const Buffer& rhs_values_qai4, const bool has_bias, const Buffer& biases, const Buffer& rhs_scales, | ||
| 128 | const Buffer& rhs_zp, bool s0s1_input, size_t rect_start_row) { | ||
| 129 | // Cast to unsigned int | ||
| 130 | 12852 | auto rhs_qau4s1s0 = cast_qsu4_qsi4(rhs_values_qai4.data(), N * K); | |
| 131 | |||
| 132 |
1/2✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
|
12852 | const auto imp_packed_rhs_size = pack_interface.packed_size(N, K, nr, kr, bl); |
| 133 |
1/2✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
|
12852 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
| 134 |
1/2✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
|
12852 | auto rhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, nr, kr, bl); |
| 135 | |||
| 136 | // Runs the RHS packing micro-kernel. | ||
| 137 | 12852 | kai_rhs_pack_nxk_qai4c32p_params params{}; | |
| 138 | 12852 | params.lhs_zero_point = 1; | |
| 139 | 12852 | params.rhs_zero_point = 8; | |
| 140 | |||
| 141 |
5/10✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2142 times.
✓ Branch 3 taken 10710 times.
✓ Branch 4 taken 2142 times.
✓ Branch 5 taken 10710 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
|
12852 | abi_check( |
| 142 | 12852 | pack_interface.run_pack, 1, N, K, nr, kr, sr, bl, | |
| 143 |
3/4✓ Branch 0 taken 10710 times.
✓ Branch 1 taken 2142 times.
✓ Branch 2 taken 10710 times.
✗ Branch 3 not taken.
|
12852 | reinterpret_cast<const uint8_t*>(s0s1_input ? convert_s0s1_s1s0(rhs_qau4s1s0).data() : rhs_qau4s1s0.data()), |
| 144 |
2/2✓ Branch 0 taken 6426 times.
✓ Branch 1 taken 6426 times.
|
12852 | rhs_zp.data(), has_bias ? biases.data() : nullptr, rhs_scales.data(), imp_packed_rhs.data(), 0, ¶ms); |
| 145 | |||
| 146 | 12852 | return {std::move(imp_packed_rhs), rhs_packed_offset}; | |
| 147 | 12852 | } | |
| 148 | |||
| 149 | using MatMulTestClampPortionedParamsWithBias_WithBL = | ||
| 150 | std::tuple<size_t, MatMulShape, size_t, MatrixPortion, float, bool>; | ||
| 151 | class MatMulTest_f32_qsi8d32p_qai4c32p | ||
| 152 | : public ::testing::TestWithParam<MatMulTestClampPortionedParamsWithBias_WithBL> {}; | ||
| 153 | |||
| 154 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
84006 | TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, LhsPackedWithSameBlockdepth) { |
| 155 | // Verify LHS quant and pack int8 kernel behaves same for int4 and int8 matmul kernels, | ||
| 156 | // when the block-depth is same for different values of kr, sr. | ||
| 157 | |||
| 158 | 20880384 | const auto& [variant_index, matmul_shape, bl, portion, clamp_keep_ratio, has_bias] = GetParam(); | |
| 159 | 67200 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index); | |
| 160 | |||
| 161 |
3/4✓ Branch 0 taken 33600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25200 times.
✓ Branch 3 taken 8400 times.
|
33600 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
| 162 |
3/6✓ Branch 0 taken 8400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8400 times.
✗ Branch 5 not taken.
|
8400 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 163 | } | ||
| 164 | |||
| 165 | 50400 | const size_t M = matmul_shape.m; | |
| 166 | 50400 | const size_t N = matmul_shape.n; | |
| 167 | 50400 | const size_t K = matmul_shape.k; | |
| 168 | |||
| 169 |
4/4✓ Branch 0 taken 6552 times.
✓ Branch 1 taken 18648 times.
✓ Branch 2 taken 6552 times.
✓ Branch 3 taken 18648 times.
|
50400 | if (K % bl != 0) { |
| 170 |
3/6✓ Branch 0 taken 6552 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6552 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6552 times.
✗ Branch 5 not taken.
|
6552 | GTEST_SKIP() << "K must be a multiple of bl"; |
| 171 | } | ||
| 172 | |||
| 173 | 18648 | const auto mr = ukernel_variant.ukernel.interface.get_mr(); | |
| 174 | 18648 | const auto nr = ukernel_variant.ukernel.interface.get_nr(); | |
| 175 | 18648 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
| 176 | 18648 | const auto sr = ukernel_variant.ukernel.interface.get_sr(); | |
| 177 | |||
| 178 | 18648 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
| 179 |
3/14✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 18648 times.
|
18648 | ASSERT_TRUE(m_step % mr == 0); |
| 180 | |||
| 181 | 18648 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
| 182 |
3/14✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 18648 times.
|
18648 | ASSERT_TRUE(n_step % nr == 0); |
| 183 | |||
| 184 | 37296 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 185 | |||
| 186 | // Generates input data. | ||
| 187 |
3/6✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18648 times.
✗ Branch 5 not taken.
|
18648 | const auto ref_lhs = fill_random<float>(M * K, seed_stream(current_test_key())()); |
| 188 | |||
| 189 | // Runs the LHS packing micro-kernel. | ||
| 190 |
1/2✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
|
18648 | const auto lhs_start_row = rect.start_row(); |
| 191 | 18648 | auto lhs_stride = K * sizeof(float); | |
| 192 | |||
| 193 |
1/2✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
|
55944 | auto [imp_packed_lhs, lhs_packed_offset] = pack_lhs_qsi8d32p( |
| 194 |
2/4✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
|
37296 | ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs, lhs_stride, lhs_start_row, rect.height()); |
| 195 | |||
| 196 | 18648 | const size_t kr_qsi8 = kr / sr; | |
| 197 | 18648 | const size_t sr_qsi8 = 1; | |
| 198 | |||
| 199 |
1/2✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
|
37296 | auto [imp_packed_lhs_qsi8, lhs_qsi8_packed_offset] = pack_lhs_qsi8d32p( |
| 200 | 37296 | ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr_qsi8, sr_qsi8, ref_lhs, lhs_stride, lhs_start_row, | |
| 201 |
1/2✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
|
18648 | rect.height()); |
| 202 | |||
| 203 |
5/18✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18648 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 18648 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 18648 times.
|
37296 | ASSERT_EQ(lhs_qsi8_packed_offset, lhs_packed_offset); |
| 204 | |||
| 205 |
2/4✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
|
37296 | auto* imp_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs.data()); |
| 206 |
2/4✓ Branch 0 taken 18648 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 18648 times.
✗ Branch 3 not taken.
|
37296 | auto* imp_packed_lhs_qsi8_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs_qsi8.data()); |
| 207 |
5/8✗ Branch 0 not taken.
✓ Branch 1 taken 20656440 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 20656440 times.
✓ Branch 4 taken 18648 times.
✓ Branch 5 taken 20637792 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18648 times.
|
20656440 | for (size_t i = 0; i < ukernel_variant.lhs_pack_interface.packed_size(M, K, bl, mr, kr, sr); i++) { |
| 208 |
4/16✓ Branch 0 taken 20637792 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20637792 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20637792 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 20637792 times.
|
20637792 | ASSERT_EQ(imp_packed_lhs_ptr[i], imp_packed_lhs_qsi8_ptr[i]); |
| 209 | 20637792 | } | |
| 210 | 33600 | } | |
| 211 | |||
| 212 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
84006 | TEST_P(MatMulTest_f32_qsi8d32p_qai4c32p, EndToEnd) { |
| 213 | 193704 | const auto& [variant_index, matmul_shape, bl, portion, clamp_keep_ratio, has_bias] = GetParam(); | |
| 214 | 67200 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_index); | |
| 215 | |||
| 216 |
3/4✓ Branch 0 taken 33600 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 25200 times.
✓ Branch 3 taken 8400 times.
|
33600 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
| 217 |
3/6✓ Branch 0 taken 8400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8400 times.
✗ Branch 5 not taken.
|
8400 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 218 | } | ||
| 219 | |||
| 220 | 50400 | const size_t M = matmul_shape.m; | |
| 221 | 50400 | const size_t N = matmul_shape.n; | |
| 222 | 50400 | const size_t K = matmul_shape.k; | |
| 223 | |||
| 224 |
4/4✓ Branch 0 taken 6552 times.
✓ Branch 1 taken 18648 times.
✓ Branch 2 taken 6552 times.
✓ Branch 3 taken 18648 times.
|
50400 | if (K % bl != 0) { |
| 225 |
3/6✓ Branch 0 taken 6552 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6552 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6552 times.
✗ Branch 5 not taken.
|
6552 | GTEST_SKIP() << "K must be a multiple of bl"; |
| 226 | } | ||
| 227 | |||
| 228 | 18648 | const auto mr = ukernel_variant.ukernel.interface.get_mr(); | |
| 229 | 18648 | const auto nr = ukernel_variant.ukernel.interface.get_nr(); | |
| 230 | 18648 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
| 231 | 18648 | const auto sr = ukernel_variant.ukernel.interface.get_sr(); | |
| 232 | |||
| 233 |
4/4✓ Branch 0 taken 9324 times.
✓ Branch 1 taken 9324 times.
✓ Branch 2 taken 3528 times.
✓ Branch 3 taken 5796 times.
|
18648 | if (mr == 1 && M > 1) { |
| 234 |
3/6✓ Branch 0 taken 5796 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5796 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5796 times.
✗ Branch 5 not taken.
|
5796 | GTEST_SKIP() << "Kernel does not support M != 1"; |
| 235 | } | ||
| 236 | |||
| 237 | 12852 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
| 238 |
3/14✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12852 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 12852 times.
|
12852 | ASSERT_TRUE(m_step % mr == 0); |
| 239 | |||
| 240 | 12852 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
| 241 |
3/14✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12852 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 12852 times.
|
12852 | ASSERT_TRUE(n_step % nr == 0); |
| 242 | |||
| 243 | 25704 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 244 |
2/4✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 12852 times.
|
12852 | if (rect.height() == 0 || rect.width() == 0) { |
| 245 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 246 | } | ||
| 247 | |||
| 248 | // Seed the random generator. | ||
| 249 |
1/2✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
|
12852 | auto& feed = seed_stream(current_test_key()); |
| 250 | |||
| 251 | // Generates input data. | ||
| 252 | 12852 | const auto ref_lhs = fill_random<float>(M * K, feed()); | |
| 253 |
2/4✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12852 times.
✗ Branch 3 not taken.
|
12852 | const auto ref_rhs = fill_random<float>(N * K, feed()); |
| 254 | 12852 | Buffer ref_biases; | |
| 255 | |||
| 256 |
2/2✓ Branch 0 taken 6426 times.
✓ Branch 1 taken 6426 times.
|
12852 | if (has_bias) { |
| 257 |
2/4✓ Branch 0 taken 6426 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6426 times.
✗ Branch 3 not taken.
|
6426 | ref_biases = fill_random<float>(N, feed()); |
| 258 | 6426 | } | |
| 259 | // Runs the reference implementation. | ||
| 260 | // * Quantizes the LHS matrix using 8-bit symmetric quantization. | ||
| 261 | // * Quantizes the RHS matrix using 4-bit asymmetric quantization. | ||
| 262 | // * Performs GEMM. | ||
| 263 |
1/2✓ Branch 0 taken 12852 times.
✗ Branch 1 not taken.
|
12852 | QuantizationInfo lhs_qinfo{}; |
| 264 | lhs_qinfo.quant_width = bl; | ||
| 265 | lhs_qinfo.dst_type = DataType::QSI8; | ||
| 266 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 267 | const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo); | ||
| 268 | |||
| 269 | QuantizationInfo rhs_qinfo{}; | ||
| 270 | rhs_qinfo.quant_width = bl; | ||
| 271 | rhs_qinfo.dst_type = DataType::QAI4; | ||
| 272 | rhs_qinfo.scale_type = DataType::FP32; | ||
| 273 | rhs_qinfo.zero_point_type = DataType::I32; | ||
| 274 | const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo); | ||
| 275 | |||
| 276 | const auto ref_dst_no_clamp = | ||
| 277 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>( | ||
| 278 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), nullptr, 1, bl, ref_rhs_quant.data(), | ||
| 279 | rhs_qoutputs.scales.data(), rhs_qoutputs.zero_points.data(), 1, bl, has_bias ? ref_biases.data() : nullptr, | ||
| 280 | nullptr, nullptr, 1); | ||
| 281 | |||
| 282 | // Clamps the reference output. | ||
| 283 | const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_no_clamp.data(), M * N, clamp_keep_ratio); | ||
| 284 | const auto ref_dst = clamp<float>(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); | ||
| 285 | |||
| 286 | // Runs the LHS packing micro-kernel. | ||
| 287 | const auto lhs_start_row = rect.start_row(); | ||
| 288 | auto [imp_packed_lhs, lhs_packed_offset] = pack_lhs_qsi8d32p( | ||
| 289 | ukernel_variant.lhs_pack_interface, M, K, bl, mr, kr, sr, ref_lhs, K * sizeof(float), lhs_start_row, | ||
| 290 | rect.height()); | ||
| 291 | auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl); | ||
| 292 | |||
| 293 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); | ||
| 294 | |||
| 295 | // Prepare the offsets as the RHS packing micro-kernel expects the scaled zero-points in float. | ||
| 296 | const size_t num_blocks_per_row = round_up_division(K, bl); | ||
| 297 | const size_t ref_zp_size = N * num_blocks_per_row; | ||
| 298 | const size_t ref_zp_size_in_bytes = ref_zp_size * sizeof(float); | ||
| 299 | Buffer ref_rhs_zp_f32(ref_zp_size_in_bytes); | ||
| 300 | for (size_t i = 0; i < ref_zp_size; ++i) { | ||
| 301 | reinterpret_cast<float*>(ref_rhs_zp_f32.data())[i] = | ||
| 302 | -reinterpret_cast<const int32_t*>(rhs_qoutputs.zero_points.data())[i] * | ||
| 303 | reinterpret_cast<const float*>(rhs_qoutputs.scales.data())[i]; | ||
| 304 | } | ||
| 305 | |||
| 306 | const auto rhs_start_row = rect.start_col(); | ||
| 307 | auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qai4c32p( | ||
| 308 | ukernel_variant.rhs_pack_interface, N, K, bl, nr, kr, sr, ref_rhs_quant, has_bias, ref_biases, | ||
| 309 | rhs_qoutputs.scales, ref_rhs_zp_f32, ukernel_variant.rhs_s0s1_input, rhs_start_row); | ||
| 310 | |||
| 311 | auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl); | ||
| 312 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); | ||
| 313 | |||
| 314 | const auto dst_stride_row = N * sizeof(float); | ||
| 315 | const auto dst_stride_col = sizeof(float); | ||
| 316 | const auto dst_offset = | ||
| 317 | ukernel_variant.ukernel.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); | ||
| 318 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; | ||
| 319 | ASSERT_EQ(dst_offset, ref_dst_offset); | ||
| 320 | |||
| 321 | // Runs the GEMM micro-kernel. | ||
| 322 | const auto imp_dst_size = ukernel_variant.ukernel.interface.get_dst_size(M, N); | ||
| 323 | ASSERT_EQ(imp_dst_size, ref_dst.size()); | ||
| 324 | Buffer imp_dst(imp_dst_size); | ||
| 325 | abi_check( | ||
| 326 | ukernel_variant.ukernel.interface.run_matmul, rect.height(), rect.width(), K, bl, | ||
| 327 | imp_packed_lhs.data() + lhs_matmul_offset, imp_packed_rhs.data() + rhs_matmul_offset, | ||
| 328 | reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, dst_stride_col, clamp_min, clamp_max); | ||
| 329 | |||
| 330 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
| 331 | // tested. | ||
| 332 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); | ||
| 333 | DataFormat dst_format = DataFormat(DataType::FP32); | ||
| 334 | const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); | ||
| 335 | ASSERT_TRUE(success); | ||
| 336 | ✗ | } | |
| 337 |
77/202✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✓ Branch 24 taken 4 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 4 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✓ Branch 28 taken 4 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✓ Branch 30 taken 4 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✓ Branch 32 taken 4 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✓ Branch 34 taken 4 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 2 times.
✓ Branch 36 taken 4 times.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✓ Branch 38 taken 4 times.
✓ Branch 38 taken 33600 times.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✓ Branch 40 taken 67200 times.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 59 not taken.
✗ Branch 60 not taken.
✗ Branch 60 not taken.
✗ Branch 61 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✗ Branch 63 not taken.
✗ Branch 64 not taken.
✗ Branch 64 not taken.
✗ Branch 65 not taken.
✗ Branch 65 not taken.
✗ Branch 66 not taken.
✗ Branch 66 not taken.
✗ Branch 67 not taken.
✗ Branch 67 not taken.
✗ Branch 68 not taken.
✓ Branch 68 taken 33600 times.
✗ Branch 69 not taken.
✗ Branch 69 not taken.
✓ Branch 70 taken 33600 times.
✓ Branch 70 taken 67200 times.
✗ Branch 71 not taken.
✗ Branch 71 not taken.
✓ Branch 72 taken 33600 times.
✓ Branch 72 taken 67200 times.
✗ Branch 73 not taken.
✗ Branch 73 not taken.
✓ Branch 74 taken 33600 times.
✓ Branch 74 taken 67200 times.
✗ Branch 75 not taken.
✗ Branch 75 not taken.
✓ Branch 76 taken 33600 times.
✓ Branch 76 taken 67200 times.
✗ Branch 77 not taken.
✗ Branch 77 not taken.
✓ Branch 78 taken 33600 times.
✓ Branch 78 taken 67200 times.
✗ Branch 79 not taken.
✗ Branch 79 not taken.
✓ Branch 80 taken 33600 times.
✓ Branch 80 taken 67200 times.
✗ Branch 81 not taken.
✗ Branch 81 not taken.
✓ Branch 82 taken 33600 times.
✓ Branch 82 taken 67200 times.
✗ Branch 83 not taken.
✗ Branch 83 not taken.
✓ Branch 84 taken 33600 times.
✓ Branch 84 taken 67200 times.
✗ Branch 85 not taken.
✗ Branch 85 not taken.
✓ Branch 86 taken 16800 times.
✓ Branch 86 taken 67200 times.
✓ Branch 87 taken 16800 times.
✗ Branch 87 not taken.
✓ Branch 88 taken 16800 times.
✓ Branch 88 taken 67200 times.
✗ Branch 89 not taken.
✗ Branch 89 not taken.
✓ Branch 90 taken 16800 times.
✓ Branch 90 taken 33600 times.
✗ Branch 91 not taken.
✓ Branch 91 taken 33600 times.
✓ Branch 92 taken 33600 times.
✓ Branch 92 taken 33600 times.
✗ Branch 93 not taken.
✗ Branch 93 not taken.
✓ Branch 94 taken 25200 times.
✓ Branch 94 taken 33600 times.
✓ Branch 95 taken 8400 times.
✗ Branch 95 not taken.
✓ Branch 96 taken 25200 times.
✓ Branch 96 taken 67200 times.
✗ Branch 97 not taken.
✗ Branch 97 not taken.
✓ Branch 98 taken 8400 times.
✓ Branch 98 taken 50400 times.
✗ Branch 99 not taken.
✓ Branch 99 taken 16800 times.
✓ Branch 100 taken 33600 times.
✓ Branch 100 taken 50400 times.
✗ Branch 101 not taken.
✗ Branch 101 not taken.
✓ Branch 102 taken 33600 times.
✓ Branch 102 taken 16800 times.
✗ Branch 103 not taken.
✗ Branch 103 not taken.
✓ Branch 104 taken 67200 times.
✗ Branch 105 not taken.
✓ Branch 106 taken 67200 times.
✗ Branch 107 not taken.
|
277209 | INSTANTIATE_TEST_SUITE_P( |
| 338 | MatMul, MatMulTest_f32_qsi8d32p_qai4c32p, | ||
| 339 | testing::Combine( | ||
| 340 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.size()), | ||
| 341 | testing::Values( | ||
| 342 | MatMulShape{1, 64, 32}, // | ||
| 343 | MatMulShape{1, 63, 32}, // | ||
| 344 | MatMulShape{1, 65, 32}, // | ||
| 345 | MatMulShape{1, 64, 64}, // | ||
| 346 | MatMulShape{1, 64, 128}, // | ||
| 347 | MatMulShape{1, 128, 32}, // | ||
| 348 | MatMulShape{1, 128, 128}, // | ||
| 349 | MatMulShape{1, 2, 32}, // | ||
| 350 | MatMulShape{1, 3, 32}, // | ||
| 351 | MatMulShape{1, 4, 32}, // | ||
| 352 | MatMulShape{1, 5, 32}, // | ||
| 353 | MatMulShape{3, 3, 32}, // | ||
| 354 | MatMulShape{4, 4, 32}, // | ||
| 355 | MatMulShape{5, 5, 32}, // | ||
| 356 | MatMulShape{32, 128, 32}, // | ||
| 357 | MatMulShape{15, 64, 64}, // | ||
| 358 | MatMulShape{17, 64, 64}, // | ||
| 359 | MatMulShape{16, 63, 64}, // | ||
| 360 | MatMulShape{16, 64, 64}, // | ||
| 361 | MatMulShape{16, 65, 64}, // | ||
| 362 | MatMulShape{32, 64, 64}, // | ||
| 363 | MatMulShape{16, 32, 64}, // | ||
| 364 | MatMulShape{8, 32, 64}, // | ||
| 365 | MatMulShape{15, 32, 32}, // | ||
| 366 | MatMulShape{77, 99, 64}), | ||
| 367 | testing::Values(32, 64), | ||
| 368 | testing::Values( | ||
| 369 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 370 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
| 371 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
| 372 | MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle | ||
| 373 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
| 374 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
| 375 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
| 376 | ), | ||
| 377 | testing::ValuesIn(std::initializer_list<float>{1.0f, 0.9f, 0.5f}), // | ||
| 378 | testing::Bool()), | ||
| 379 | [](const auto& info) { | ||
| 380 | const auto variant_idx = std::get<0>(info.param); | ||
| 381 | const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_idx).ukernel.name}; | ||
| 382 | const auto shape = std::get<MatMulShape>(info.param); | ||
| 383 | const auto bl = std::get<2>(info.param); | ||
| 384 | const auto portion = std::get<3>(info.param); | ||
| 385 | const auto clamp_keep_ratio = std::get<4>(info.param); | ||
| 386 | const auto has_bias = std::get<5>(info.param); | ||
| 387 | |||
| 388 | std::ostringstream sstream; | ||
| 389 | sstream << name << "__"; | ||
| 390 | PrintTo(shape, &sstream); | ||
| 391 | sstream << "__BL_" << bl << "_"; | ||
| 392 | sstream << "__clamp_keep_ratio_" << static_cast<int>(clamp_keep_ratio * 100); | ||
| 393 | |||
| 394 | if (has_bias) { | ||
| 395 | sstream << "_withBias_"; | ||
| 396 | } else { | ||
| 397 | sstream << "_noBias_"; | ||
| 398 | } | ||
| 399 | if (variants_kai_matmul_clamp_f32_qsi8d32p_qai4c32p.at(variant_idx).rhs_s0s1_input) { | ||
| 400 | sstream << "_RHS_s0s1__"; | ||
| 401 | } else { | ||
| 402 | sstream << "_RHS_s1s0__"; | ||
| 403 | } | ||
| 404 | PrintTo(portion, &sstream); | ||
| 405 | |||
| 406 | return sstream.str(); | ||
| 407 | }); | ||
| 408 | |||
| 409 | } // namespace kai::test | ||
| 410 |