test/nextgen/operators/matmul/pack_rhs/matmul_pack_rhs_wrapper.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/nextgen/operators/matmul/pack_rhs/matmul_pack_rhs_wrapper.hpp" | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <memory> | ||
| 11 | |||
| 12 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h" | ||
| 13 | #include "test/common/data_type.hpp" | ||
| 14 | #include "test/common/sme.hpp" | ||
| 15 | #include "test/nextgen/common/poly.hpp" | ||
| 16 | #include "test/nextgen/format/block2d_row_format.hpp" | ||
| 17 | #include "test/nextgen/format/plain_format.hpp" | ||
| 18 | #include "test/nextgen/harness/kernel_wrapper.hpp" | ||
| 19 | #include "test/nextgen/operators/matmul/pack_rhs/matmul_pack_rhs_interface.hpp" | ||
| 20 | #include "test/nextgen/operators/matmul/pack_rhs/matmul_pack_rhs_quant_wrapper.hpp" | ||
| 21 | |||
| 22 | namespace kai::test { | ||
| 23 | |||
| 24 | 6 | std::unique_ptr<KernelWrapper> create_matmul_rhs_pack_nxk_qsi4cxp4vlx4s1s0_qsu4cxs1s0_neon() { | |
| 25 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
12 | return std::make_unique<MatMulPackRhsQuantWrapper>( |
| 26 | "matmul_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon", | ||
| 27 | 6 | MatMulPackRhsQuantInterface{ | |
| 28 | kai_get_n_step_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, | ||
| 29 | kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, | ||
| 30 | kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, | ||
| 31 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, | ||
| 32 | kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, | ||
| 33 | kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, | ||
| 34 | }, | ||
| 35 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | make_poly<PlainFormat>(DataType::U4), make_poly<PlainFormat>(DataType::FP32), |
| 36 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
|
6 | make_poly<PlainFormat>(DataType::FP32), make_poly<PlainFormat>(DataType::I32), |
| 37 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
8 | make_poly<Block2dRowFormat>( |
| 38 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | 4 * get_sme_vector_length<float>(), 4, 32, false, DataType::I4, std::array<DataType, 0>{}, |
| 39 | 6 | std::array{DataType::I32, DataType::FP32, DataType::FP32})); | |
| 40 | ✗ | } | |
| 41 | |||
| 42 | } // namespace kai::test | ||
| 43 |