test/nextgen/operators/matmul/matmul/matmul_wrapper.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/nextgen/operators/matmul/matmul/matmul_wrapper.hpp" | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <memory> | ||
| 11 | |||
| 12 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h" | ||
| 13 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h" | ||
| 14 | #include "test/common/data_type.hpp" | ||
| 15 | #include "test/common/sme.hpp" | ||
| 16 | #include "test/nextgen/format/block2d_row_format.hpp" | ||
| 17 | #include "test/nextgen/format/plain_format.hpp" | ||
| 18 | #include "test/nextgen/functions/round.hpp" | ||
| 19 | #include "test/nextgen/harness/kernel_wrapper.hpp" | ||
| 20 | #include "test/nextgen/operators/matmul/matmul/matmul_dq_wrapper.hpp" | ||
| 21 | #include "test/nextgen/operators/matmul/matmul/matmul_interface.hpp" | ||
| 22 | #include "test/nextgen/quantization/asymm_linear_quantizer.hpp" | ||
| 23 | #include "test/nextgen/quantization/symm_linear_quantizer.hpp" | ||
| 24 | |||
| 25 | namespace kai::test { | ||
| 26 | |||
| 27 | 3 | std::unique_ptr<KernelWrapper> create_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa() { | |
| 28 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
6 | return std::make_unique<MatMulDqWrapper>( |
| 29 | "matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa", | ||
| 30 | 3 | MatMulDqInterface{ | |
| 31 | kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 32 | kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 33 | kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 34 | kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 35 | kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 36 | kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 37 | kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 38 | kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 39 | kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 40 | kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 41 | kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa, | ||
| 42 | }, | ||
| 43 | 6 | std::make_unique<AsymmLinearQuantizer>( | |
| 44 | 3 | DataType::I8, DataType::FP32, DataType::I32, RoundMode::TIE_AWAY, RoundMode::CURRENT, 1, 0), | |
| 45 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | std::make_unique<SymmLinearQuantizer>(DataType::U4, DataType::FP32, RoundMode::CURRENT, 1, 0), |
| 46 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
4 | make_poly<Block2dRowFormat>( |
| 47 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | 1 * get_sme_vector_length<float>(), 4, 32, true, DataType::I8, std::array<DataType, 0>{}, |
| 48 | 3 | std::array{DataType::I32, DataType::FP32}), | |
| 49 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
4 | make_poly<Block2dRowFormat>( |
| 50 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | 4 * get_sme_vector_length<float>(), 4, 32, false, DataType::I4, std::array<DataType, 0>{}, |
| 51 | 3 | std::array{DataType::I32, DataType::FP32, DataType::FP32}), | |
| 52 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | make_poly<PlainFormat>(DataType::FP32)); |
| 53 | ✗ | } | |
| 54 | |||
| 55 | 3 | std::unique_ptr<KernelWrapper> create_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() { | |
| 56 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
6 | return std::make_unique<MatMulDqWrapper>( |
| 57 | "matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", | ||
| 58 | 3 | MatMulDqInterface{ | |
| 59 | kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 60 | kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 61 | kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 62 | kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 63 | kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 64 | kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 65 | kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 66 | kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 67 | kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 68 | kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 69 | kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot, | ||
| 70 | }, | ||
| 71 | 6 | std::make_unique<AsymmLinearQuantizer>( | |
| 72 | 3 | DataType::I8, DataType::FP32, DataType::I32, RoundMode::TIE_AWAY, RoundMode::CURRENT, 1, 0), | |
| 73 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | std::make_unique<SymmLinearQuantizer>(DataType::U4, DataType::FP32, RoundMode::CURRENT, 1, 0), |
| 74 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | make_poly<Block2dRowFormat>( |
| 75 | 3 | 1, 4, 32, true, DataType::I8, std::array<DataType, 0>{}, std::array{DataType::I32, DataType::FP32}), | |
| 76 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
4 | make_poly<Block2dRowFormat>( |
| 77 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | 4 * get_sme_vector_length<float>(), 4, 32, false, DataType::I4, std::array<DataType, 0>{}, |
| 78 | 3 | std::array{DataType::I32, DataType::FP32, DataType::FP32}), | |
| 79 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | make_poly<PlainFormat>(DataType::FP32)); |
| 80 | ✗ | } | |
| 81 | |||
| 82 | } // namespace kai::test | ||
| 83 |