test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <cstdlib> | ||
| 13 | #include <limits> | ||
| 14 | #include <sstream> | ||
| 15 | #include <string> | ||
| 16 | #include <tuple> | ||
| 17 | |||
| 18 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" | ||
| 19 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" | ||
| 20 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" | ||
| 21 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p8x4_1x8_sve_dotprod.h" | ||
| 22 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" | ||
| 23 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h" | ||
| 24 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" | ||
| 25 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" | ||
| 26 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" | ||
| 27 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h" | ||
| 28 | #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" | ||
| 29 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" | ||
| 30 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.h" | ||
| 31 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.h" | ||
| 32 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" | ||
| 33 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" | ||
| 34 | #include "test/common/abi_checker.hpp" | ||
| 35 | #include "test/common/buffer.hpp" | ||
| 36 | #include "test/common/compare.hpp" | ||
| 37 | #include "test/common/cpu_info.hpp" | ||
| 38 | #include "test/common/float16.hpp" | ||
| 39 | #include "test/common/int4.hpp" | ||
| 40 | #include "test/common/matmul_test_common.hpp" | ||
| 41 | #include "test/common/matrix_portion.hpp" | ||
| 42 | #include "test/common/memory.hpp" | ||
| 43 | #include "test/common/round.hpp" | ||
| 44 | #include "test/common/test_suite.hpp" | ||
| 45 | #include "test/reference/cast.hpp" | ||
| 46 | #include "test/reference/clamp.hpp" | ||
| 47 | #include "test/reference/fill.hpp" | ||
| 48 | #include "test/reference/matmul.hpp" | ||
| 49 | #include "test/reference/pack.hpp" | ||
| 50 | #include "test/reference/quantize.hpp" | ||
| 51 | |||
| 52 | namespace kai::test { | ||
| 53 | |||
| 54 | // Interface for the LHS and RHS packed size and packing micro-kernels | ||
| 55 | using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32); | ||
| 56 | using kai_get_rhs_packed_size_func_t = decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0); | ||
| 57 | using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32); | ||
| 58 | using kai_get_rhs_packed_offset_func_t = | ||
| 59 | decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0); | ||
| 60 | using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32); | ||
| 61 | using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0); | ||
| 62 | using kai_run_lhs_pack_func_t = decltype(&kai_run_lhs_quant_pack_qsi8d32p_f32); | ||
| 63 | using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0); | ||
| 64 | |||
| 65 | // Micro-kernel interface | ||
| 66 | struct kai_qsi8d32p_pack_functions { | ||
| 67 | kai_get_lhs_packed_size_func_t packed_size; | ||
| 68 | kai_get_lhs_packed_offset_func_t get_packed_offset; | ||
| 69 | kai_get_lhs_offset_func_t get_offset; | ||
| 70 | kai_run_lhs_pack_func_t run_pack; | ||
| 71 | }; | ||
| 72 | struct kai_qsi4c32p_pack_functions { | ||
| 73 | kai_get_rhs_packed_size_func_t packed_size; | ||
| 74 | kai_get_rhs_packed_offset_func_t get_packed_offset; | ||
| 75 | kai_get_rhs_offset_func_t get_offset; | ||
| 76 | kai_run_rhs_pack_func_t run_pack; | ||
| 77 | }; | ||
| 78 | |||
| 79 | struct UKernelVariants { | ||
| 80 | UkernelMatmulPackVariant< | ||
| 81 | kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_ukernel, kai_qsi8d32p_pack_functions, kai_qsi4c32p_pack_functions> | ||
| 82 | variant; | ||
| 83 | bool clamp_support; | ||
| 84 | }; | ||
| 85 | |||
| 86 | // clang-format off | ||
| 87 | static const int num_non_clamping_kernels = 4; | ||
| 88 | ✗ | static const std::array<UKernelVariants, 11> | |
| 89 |
0/4✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
|
3 | variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p = { |
| 90 | 14 | { | |
| 91 | // NOTE: The following kernels do not support clamping despite their names. | ||
| 92 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 93 | clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32, | ||
| 94 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), false}, | ||
| 95 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 96 | clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32, | ||
| 97 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), false}, | ||
| 98 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 99 | clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, lhs_quant_pack_qsi8d32p_f32_neon, | ||
| 100 | rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, false), false}, | ||
| 101 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 102 | clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, cpu_has_sme2, lhs_quant_pack_qsi8d32p_f32_neon, | ||
| 103 | rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, false), false}, | ||
| 104 | |||
| 105 | // The kernels below this point will run clamping tests | ||
| 106 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 107 | clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32, | ||
| 108 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true}, | ||
| 109 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 110 | clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qsi8d32p_f32, | ||
| 111 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true}, | ||
| 112 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 113 | 4x8sb_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||
| 114 | cpu_has_i8mm, lhs_quant_pack_qsi8d32p4x8sb_f32_neon, rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true}, | ||
| 115 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 116 | clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qsi8d32p_f32, | ||
| 117 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true}, | ||
| 118 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 119 | clamp_f32_qsi8d32p1x4_qsi4c32p8x4_1x8_sve_dotprod, (cpu_check<cpu_has_sve_vl256, cpu_has_dotprod>), lhs_quant_pack_qsi8d32p_f32, | ||
| 120 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true}, | ||
| 121 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 122 | clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, (cpu_check<cpu_has_sve_vl256, cpu_has_dotprod>), lhs_quant_pack_qsi8d32p_f32, | ||
| 123 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true}, | ||
| 124 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | {UKERNEL_MATMUL_PACK_VARIANT( |
| 125 | clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, (cpu_check<cpu_has_sve_vl256, cpu_has_i8mm>), lhs_quant_pack_qsi8d32p_f32, | ||
| 126 | 3 | rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, false), true}}}; | |
| 127 | // clang-format on | ||
| 128 | |||
| 129 | class MatMulTest_f32_qsi8d32p_qsi4c32p | ||
| 130 | : public ::testing::TestWithParam<std::tuple<size_t, MatMulShape, MatrixPortion, float>> {}; | ||
| 131 | |||
| 132 | // Ensure non-clamping tests are marked correctly. | ||
| 133 |
9/18✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
|
12 | TEST(KernelClampingCheck, SanityCheck) { |
| 134 |
3/4✓ Branch 0 taken 22 times.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
|
24 | for (size_t i = 0; i < variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.size(); i++) { |
| 135 |
3/14✓ Branch 0 taken 22 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 22 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 22 times.
|
22 | ASSERT_EQ(variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(i).clamp_support, !(i < num_non_clamping_kernels)); |
| 136 | 22 | } | |
| 137 | 2 | } | |
| 138 | |||
| 139 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
4418 | TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, Offset_RHS) { |
| 140 | 8456 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 141 | 3640 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index).variant; | |
| 142 | |||
| 143 |
3/4✓ Branch 0 taken 1820 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✓ Branch 3 taken 616 times.
|
1820 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
| 144 |
3/6✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
|
616 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 145 | } | ||
| 146 | |||
| 147 | 1204 | const size_t bl = 32; | |
| 148 | 2408 | const size_t M = matmul_shape.m; | |
| 149 | 2408 | const size_t N = matmul_shape.n; | |
| 150 | 2408 | const size_t K = matmul_shape.k; | |
| 151 | |||
| 152 | 1204 | const auto nr = ukernel_variant.ukernel.interface.get_nr(); | |
| 153 | 1204 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
| 154 | |||
| 155 | 1204 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
| 156 | 1204 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
| 157 | |||
| 158 | 2408 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 159 |
2/4✓ Branch 0 taken 1204 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1204 times.
|
1204 | if (rect.height() == 0 || rect.width() == 0) { |
| 160 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 161 | } | ||
| 162 | |||
| 163 | 1204 | const auto rhs_start_row = rect.start_col(); | |
| 164 | 1204 | auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl); | |
| 165 | 1204 | auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl); | |
| 166 |
3/14✓ Branch 0 taken 1204 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1204 times.
|
1204 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
| 167 | 1820 | } | |
| 168 | |||
| 169 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
4418 | TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, Offset_LHS) { |
| 170 | 8456 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 171 | 3640 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index).variant; | |
| 172 | |||
| 173 |
3/4✓ Branch 0 taken 1820 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✓ Branch 3 taken 616 times.
|
1820 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
| 174 |
3/6✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
|
616 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 175 | } | ||
| 176 | |||
| 177 | 1204 | const size_t bl = 32; | |
| 178 | 2408 | const size_t M = matmul_shape.m; | |
| 179 | 2408 | const size_t N = matmul_shape.n; | |
| 180 | 2408 | const size_t K = matmul_shape.k; | |
| 181 | |||
| 182 | 1204 | const auto mr = ukernel_variant.ukernel.interface.get_mr(); | |
| 183 | 1204 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
| 184 | 1204 | const auto sr = ukernel_variant.ukernel.interface.get_sr(); | |
| 185 | |||
| 186 | 1204 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
| 187 | 1204 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
| 188 | |||
| 189 | 2408 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 190 |
2/4✓ Branch 0 taken 1204 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1204 times.
|
1204 | if (rect.height() == 0 || rect.width() == 0) { |
| 191 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 192 | } | ||
| 193 | |||
| 194 | 1204 | const auto lhs_start_row = rect.start_row(); | |
| 195 | 1204 | auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr); | |
| 196 | 1204 | auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl); | |
| 197 | |||
| 198 |
3/14✓ Branch 0 taken 1204 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1204 times.
|
1204 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
| 199 | 1820 | } | |
| 200 | |||
| 201 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
4418 | TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, EndToEnd) { |
| 202 | 9376 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 203 | 3640 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index).variant; | |
| 204 | |||
| 205 |
3/4✓ Branch 0 taken 1820 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1204 times.
✓ Branch 3 taken 616 times.
|
1820 | if (ukernel_variant.ukernel.fn_is_supported && !ukernel_variant.ukernel.fn_is_supported()) { |
| 206 |
3/6✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
|
616 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 207 | } | ||
| 208 | |||
| 209 | // NOTE: Workaround - some kernels despite being called matmul_clamp do not support clamping. | ||
| 210 | 2408 | const bool clamp_support = variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_index).clamp_support; | |
| 211 | 1204 | const std::uint32_t seed = 0; | |
| 212 | |||
| 213 | 2408 | const size_t M = matmul_shape.m; | |
| 214 | 2408 | const size_t N = matmul_shape.n; | |
| 215 | 2408 | const size_t K = matmul_shape.k; | |
| 216 | 1204 | const size_t bl = 32; | |
| 217 | |||
| 218 | 1204 | const auto mr = ukernel_variant.ukernel.interface.get_mr(); | |
| 219 | 1204 | const auto nr = ukernel_variant.ukernel.interface.get_nr(); | |
| 220 | 1204 | const auto kr = ukernel_variant.ukernel.interface.get_kr(); | |
| 221 | 1204 | const auto sr = ukernel_variant.ukernel.interface.get_sr(); | |
| 222 | |||
| 223 | // Skip tests on clamping when kernel does not support it. | ||
| 224 | − | KAI_ASSERT_ALWAYS_IF_MSG(clamp_keep_ratio != 1.0F, clamp_support, "Clamping not supported by this kernel"); | |
| 225 | |||
| 226 |
4/4✓ Branch 0 taken 464 times.
✓ Branch 1 taken 740 times.
✓ Branch 2 taken 180 times.
✓ Branch 3 taken 284 times.
|
1204 | if (mr == 1 && M > 1) { |
| 227 |
3/6✓ Branch 0 taken 284 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 284 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 284 times.
✗ Branch 5 not taken.
|
284 | GTEST_SKIP() << "Kernel does not support M != 1"; |
| 228 | } | ||
| 229 | |||
| 230 | 920 | auto m_step = ukernel_variant.ukernel.interface.get_m_step(); | |
| 231 |
3/14✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 920 times.
|
920 | ASSERT_TRUE(m_step % mr == 0); |
| 232 | |||
| 233 | 920 | auto n_step = ukernel_variant.ukernel.interface.get_n_step(); | |
| 234 |
3/14✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 920 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 920 times.
|
920 | ASSERT_TRUE(n_step % nr == 0); |
| 235 | |||
| 236 | 1840 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 237 |
2/4✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 920 times.
|
920 | if (rect.height() == 0 || rect.width() == 0) { |
| 238 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 239 | } | ||
| 240 | // Generates input data. | ||
| 241 | 920 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
| 242 |
1/2✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
|
920 | const auto ref_rhs = fill_random<float>(N * K, seed + 1); |
| 243 | |||
| 244 | // Runs the reference implementation. | ||
| 245 |
1/2✓ Branch 0 taken 920 times.
✗ Branch 1 not taken.
|
920 | QuantizationInfo lhs_qinfo{}; |
| 246 | lhs_qinfo.quant_width = bl; | ||
| 247 | lhs_qinfo.dst_type = DataType::QSI8; | ||
| 248 | lhs_qinfo.scale_type = DataType::FP16; | ||
| 249 | const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo); | ||
| 250 | |||
| 251 | QuantizationInfo rhs_qinfo{}; | ||
| 252 | rhs_qinfo.quant_width = bl; | ||
| 253 | rhs_qinfo.dst_type = DataType::QSI4; | ||
| 254 | rhs_qinfo.scale_type = DataType::FP16; | ||
| 255 | const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo); | ||
| 256 | |||
| 257 | const auto ref_dst = matmul_clamp_nt_t<int8_t, Float16, int32_t, Int4, Float16, int32_t, float, int32_t, float>( | ||
| 258 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), nullptr, bl, ref_rhs_quant.data(), | ||
| 259 | rhs_qoutputs.scales.data(), nullptr, bl, nullptr, std::numeric_limits<float>::lowest(), | ||
| 260 | std::numeric_limits<float>::max()); | ||
| 261 | |||
| 262 | // Clamp reference output | ||
| 263 | const auto [min, max] = find_clamp_range<float>(ref_dst.data(), M * N, clamp_keep_ratio); | ||
| 264 | const auto out_clamped = clamp<float>(ref_dst.data(), M * N, min, max); | ||
| 265 | |||
| 266 | // Runs the LHS packing micro-kernel. | ||
| 267 | const auto lhs_start_row = rect.start_row(); | ||
| 268 | const auto imp_packed_lhs_size = ukernel_variant.lhs_pack_interface.packed_size(M, K, bl, mr, kr, sr); | ||
| 269 | Buffer imp_packed_lhs(imp_packed_lhs_size); | ||
| 270 | |||
| 271 | auto lhs_stride = K * sizeof(float); | ||
| 272 | auto lhs_offset = ukernel_variant.lhs_pack_interface.get_offset(lhs_start_row, lhs_stride); | ||
| 273 | auto lhs_packed_offset = ukernel_variant.lhs_pack_interface.get_packed_offset(lhs_start_row, K, bl, mr, kr, sr); | ||
| 274 | auto lhs_matmul_offset = ukernel_variant.ukernel.interface.get_lhs_packed_offset(lhs_start_row, K, bl); | ||
| 275 | |||
| 276 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); | ||
| 277 | |||
| 278 | abi_check( | ||
| 279 | ukernel_variant.lhs_pack_interface.run_pack, rect.height() /* m */, K, bl, mr, kr, sr, 0, | ||
| 280 | reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, | ||
| 281 | imp_packed_lhs.data() + lhs_packed_offset); | ||
| 282 | |||
| 283 | // Runs the RHS packing micro-kernel. | ||
| 284 | const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_quant.data(), N * K); | ||
| 285 | const auto ref_rhs_qsu4_scale_f16 = | ||
| 286 | pack_data_scales_interleave_block<UInt4, Float16>(ref_rhs_qsu4.data(), rhs_qoutputs.scales.data(), N, K, bl); | ||
| 287 | |||
| 288 | const auto imp_packed_rhs_size = ukernel_variant.rhs_pack_interface.packed_size(N, K, nr, kr, bl); | ||
| 289 | Buffer imp_packed_rhs(imp_packed_rhs_size); | ||
| 290 | const auto rhs_start_row = rect.start_col(); | ||
| 291 | auto rhs_packed_offset = ukernel_variant.rhs_pack_interface.get_packed_offset(rhs_start_row, K, nr, kr, bl); | ||
| 292 | auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl); | ||
| 293 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); | ||
| 294 | |||
| 295 | const kai_rhs_pack_qs4cxs1s0_param params{.lhs_zero_point = 1, .rhs_zero_point = 8}; | ||
| 296 | abi_check( | ||
| 297 | ukernel_variant.rhs_pack_interface.run_pack, 1, N, K, nr, kr, sr, bl, | ||
| 298 | reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_scale_f16.data()), nullptr, imp_packed_rhs.data(), 0, ¶ms); | ||
| 299 | |||
| 300 | const auto dst_stride_row = N * sizeof(float); | ||
| 301 | const auto dst_stride_col = sizeof(float); | ||
| 302 | const auto dst_offset = | ||
| 303 | ukernel_variant.ukernel.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); | ||
| 304 | |||
| 305 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; | ||
| 306 | ASSERT_EQ(dst_offset, ref_dst_offset); | ||
| 307 | |||
| 308 | // Runs the GEMM micro-kernel. | ||
| 309 | const auto imp_dst_size = ukernel_variant.ukernel.interface.get_dst_size(M, N); | ||
| 310 | ASSERT_EQ(imp_dst_size, ref_dst.size()); | ||
| 311 | Buffer imp_dst(imp_dst_size); | ||
| 312 | abi_check( | ||
| 313 | ukernel_variant.ukernel.interface.run_matmul, rect.height(), rect.width(), K, bl, | ||
| 314 | imp_packed_lhs.data() + lhs_matmul_offset, imp_packed_rhs.data() + rhs_matmul_offset, | ||
| 315 | reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, dst_stride_col, min, max); | ||
| 316 | |||
| 317 | DefaultMismatchHandler handler(0, 0.0001, 0, 0.0001); | ||
| 318 | const auto success = compare(imp_dst.data(), out_clamped.data(), DataType::FP32, M, N, rect, handler); | ||
| 319 | |||
| 320 | ASSERT_TRUE(success); | ||
| 321 | ✗ | } | |
| 322 | |||
| 323 | // Test all kernels without clamping | ||
| 324 |
29/94✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✓ Branch 14 taken 6 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✓ Branch 18 taken 6 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✓ Branch 20 taken 6 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 22 taken 6 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✓ Branch 24 taken 6 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✓ Branch 26 taken 6 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✓ Branch 28 taken 1056 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✓ Branch 30 taken 2112 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 48 taken 1056 times.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✓ Branch 50 taken 2112 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 2112 times.
✗ Branch 53 not taken.
|
6348 | INSTANTIATE_TEST_SUITE_P( |
| 325 | MatMul, MatMulTest_f32_qsi8d32p_qsi4c32p, | ||
| 326 | testing::Combine( | ||
| 327 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.size()), | ||
| 328 | testing::Values( | ||
| 329 | MatMulShape{1, 2, 32}, // | ||
| 330 | MatMulShape{1, 40, 32}, // | ||
| 331 | MatMulShape{1, 33, 32}, // | ||
| 332 | MatMulShape{32, 64, 64}, // | ||
| 333 | MatMulShape{16, 32, 64}, // | ||
| 334 | MatMulShape{8, 32, 64}, // | ||
| 335 | MatMulShape{15, 32, 32}, // | ||
| 336 | MatMulShape{77, 99, 64}), | ||
| 337 | testing::Values( | ||
| 338 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 339 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
| 340 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
| 341 | MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle | ||
| 342 | ), | ||
| 343 | testing::ValuesIn(std::initializer_list<float>{1.0F})), // We keep 100% of values - no clamping | ||
| 344 | [](const auto& info) { | ||
| 345 | const auto variant_idx = std::get<0>(info.param); | ||
| 346 | const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_idx).variant.ukernel.name}; | ||
| 347 | const auto shape = std::get<MatMulShape>(info.param); | ||
| 348 | const auto portion = std::get<2>(info.param); | ||
| 349 | const auto clamp_keep_ratio = std::get<3>(info.param); | ||
| 350 | |||
| 351 | return test_description(name, shape, portion, true, clamp_keep_ratio); | ||
| 352 | }); | ||
| 353 | |||
| 354 | // Test supported matmul kernels with clamping support. | ||
| 355 |
29/94✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✓ Branch 12 taken 6 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✓ Branch 14 taken 6 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✓ Branch 16 taken 6 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✓ Branch 18 taken 6 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✓ Branch 20 taken 6 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 22 taken 6 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✓ Branch 24 taken 6 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✓ Branch 26 taken 6 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✓ Branch 28 taken 1260 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✓ Branch 30 taken 2520 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 48 taken 1260 times.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✓ Branch 50 taken 2520 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 2520 times.
✗ Branch 53 not taken.
|
7572 | INSTANTIATE_TEST_SUITE_P( |
| 356 | MatMulClamped, MatMulTest_f32_qsi8d32p_qsi4c32p, | ||
| 357 | testing::Combine( | ||
| 358 | testing::Range<size_t>(num_non_clamping_kernels, variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.size()), | ||
| 359 | testing::Values( | ||
| 360 | MatMulShape{1, 2, 32}, // | ||
| 361 | MatMulShape{1, 33, 32}, // | ||
| 362 | MatMulShape{17, 32, 64}, // | ||
| 363 | MatMulShape{32, 64, 64}, // | ||
| 364 | MatMulShape{77, 99, 64}), | ||
| 365 | testing::Values( | ||
| 366 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 367 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
| 368 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
| 369 | MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle | ||
| 370 | ), | ||
| 371 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 372 | [](const auto& info) { | ||
| 373 | const auto variant_idx = std::get<0>(info.param); | ||
| 374 | const std::string name{variants_kai_matmul_clamp_f32_qsi8d32p_qsi4c32p.at(variant_idx).variant.ukernel.name}; | ||
| 375 | const auto shape = std::get<MatMulShape>(info.param); | ||
| 376 | const auto portion = std::get<2>(info.param); | ||
| 377 | const auto clamp_keep_ratio = std::get<3>(info.param); | ||
| 378 | |||
| 379 | return test_description(name, shape, portion, true, clamp_keep_ratio); | ||
| 380 | }); | ||
| 381 | |||
| 382 | } // namespace kai::test | ||
| 383 |