test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <cstdlib> | ||
| 13 | #include <functional> | ||
| 14 | #include <limits> | ||
| 15 | #include <random> | ||
| 16 | #include <sstream> | ||
| 17 | #include <string> | ||
| 18 | #include <string_view> | ||
| 19 | |||
| 20 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h" | ||
| 21 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.h" | ||
| 22 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h" | ||
| 23 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" | ||
| 24 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" | ||
| 25 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod.h" | ||
| 26 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod.h" | ||
| 27 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" | ||
| 28 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" | ||
| 29 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" | ||
| 30 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" | ||
| 31 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" | ||
| 32 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" | ||
| 33 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" | ||
| 34 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" | ||
| 35 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h" | ||
| 36 | #include "test/common/abi_checker.hpp" | ||
| 37 | #include "test/common/buffer.hpp" | ||
| 38 | #include "test/common/compare.hpp" | ||
| 39 | #include "test/common/cpu_info.hpp" | ||
| 40 | #include "test/common/int4.hpp" | ||
| 41 | #include "test/common/matmul_test_common.hpp" | ||
| 42 | #include "test/common/matrix_portion.hpp" | ||
| 43 | #include "test/common/memory.hpp" | ||
| 44 | #include "test/common/round.hpp" | ||
| 45 | #include "test/common/test_suite.hpp" | ||
| 46 | #include "test/reference/cast.hpp" | ||
| 47 | #include "test/reference/clamp.hpp" | ||
| 48 | #include "test/reference/fill.hpp" | ||
| 49 | #include "test/reference/matmul.hpp" | ||
| 50 | #include "test/reference/pad.hpp" | ||
| 51 | #include "test/reference/quantize.hpp" | ||
| 52 | #include "test/reference/transpose.hpp" | ||
| 53 | |||
| 54 | namespace kai::test { | ||
| 55 | /// Matrix multiplication test information. | ||
| 56 | |||
| 57 | enum class RhsPackType { NxK, KxN }; | ||
| 58 | |||
| 59 | using ukernel_rhs_pack_function = std::function<decltype(kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>; | ||
| 60 | using ukernel_get_rhs_packed_size = std::function<decltype(kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>; | ||
| 61 | using ukernel_get_rhs_packed_offset = std::function<decltype(kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0)>; | ||
| 62 | using ukernel_get_rhs_offset = std::function<decltype(kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon)>; | ||
| 63 | |||
| 64 | template <typename T> | ||
| 65 | struct UkernelVariantCustom : public UkernelVariant<T> { | ||
| 66 | ukernel_rhs_pack_function run_rhs_pack; | ||
| 67 | ukernel_get_rhs_packed_size get_rhs_packed_size; | ||
| 68 | ukernel_get_rhs_packed_offset get_rhs_packed_offset; | ||
| 69 | ukernel_get_rhs_offset get_rhs_offset; | ||
| 70 | RhsPackType rhs_pack_type; | ||
| 71 | |||
| 72 | UkernelVariantCustom() = delete; | ||
| 73 | |||
| 74 | 80 | UkernelVariantCustom( | |
| 75 | T interface, std::string_view name, const std::function<bool(void)>& fn_is_supported, | ||
| 76 | ukernel_rhs_pack_function run_rhs_pack, ukernel_get_rhs_packed_size get_rhs_packed_size, | ||
| 77 | ukernel_get_rhs_packed_offset get_rhs_packed_offset, ukernel_get_rhs_offset get_rhs_offset, | ||
| 78 | const RhsPackType pack_type) : | ||
| 79 | 60 | UkernelVariant<T>(interface, name, fn_is_supported), | |
| 80 | 60 | run_rhs_pack(std::move(run_rhs_pack)), | |
| 81 | 60 | get_rhs_packed_size(std::move(get_rhs_packed_size)), | |
| 82 | 60 | get_rhs_packed_offset(std::move(get_rhs_packed_offset)), | |
| 83 | 60 | get_rhs_offset(std::move(get_rhs_offset)), | |
| 84 | 80 | rhs_pack_type(pack_type) { | |
| 85 | 80 | } | |
| 86 | }; | ||
| 87 | |||
| 88 | 3 | static const std::array<UkernelVariantCustom<kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel>, 20> | |
| 89 |
0/4✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
|
3 | variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp = { |
| 90 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
25 | {{UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa), |
| 91 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa__RHS_NxK__", cpu_has_sme2, |
| 92 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
| 93 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
3 | kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
| 94 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
| 95 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
3 | kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, RhsPackType::NxK}, |
| 96 | |||
| 97 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot), |
| 98 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot__RHS_NxK__", cpu_has_sme2, |
| 99 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
| 100 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
3 | kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
| 101 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, |
| 102 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
3 | kai_get_rhs_offset_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon, RhsPackType::NxK}, |
| 103 | |||
| 104 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod), |
| 105 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
| 106 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 107 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 108 | RhsPackType::NxK}, | ||
| 109 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod), |
| 110 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
| 111 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 112 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 113 | RhsPackType::KxN}, | ||
| 114 | |||
| 115 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod), |
| 116 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
| 117 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 118 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 119 | RhsPackType::NxK}, | ||
| 120 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod), |
| 121 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
| 122 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 123 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 124 | RhsPackType::KxN}, | ||
| 125 | |||
| 126 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod), |
| 127 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
| 128 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 129 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 130 | RhsPackType::NxK}, | ||
| 131 | |||
| 132 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod), |
| 133 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
| 134 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 135 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 136 | RhsPackType::KxN}, | ||
| 137 | |||
| 138 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod), |
| 139 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
| 140 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 141 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 142 | RhsPackType::NxK}, | ||
| 143 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod), |
| 144 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
| 145 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 146 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 147 | RhsPackType::KxN}, | ||
| 148 | |||
| 149 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod), |
| 150 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod__RHS_NxK__", cpu_has_dotprod, |
| 151 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 152 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 153 | RhsPackType::NxK}, | ||
| 154 | |||
| 155 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod), |
| 156 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod__RHS_KxN__", cpu_has_dotprod, |
| 157 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 158 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 159 | RhsPackType::KxN}, | ||
| 160 | |||
| 161 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm), |
| 162 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm, |
| 163 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 164 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 165 | RhsPackType::NxK}, | ||
| 166 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm), |
| 167 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm, |
| 168 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 169 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 170 | RhsPackType::KxN}, | ||
| 171 | |||
| 172 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm), |
| 173 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm, |
| 174 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 175 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 176 | RhsPackType::NxK}, | ||
| 177 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm), |
| 178 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm, |
| 179 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 180 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 181 | RhsPackType::KxN}, | ||
| 182 | |||
| 183 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm), |
| 184 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm, |
| 185 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 186 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 187 | RhsPackType::NxK}, | ||
| 188 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm), |
| 189 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm, |
| 190 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 191 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 192 | RhsPackType::KxN}, | ||
| 193 | |||
| 194 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm), |
| 195 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm__RHS_NxK__", cpu_has_i8mm, |
| 196 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 197 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0, |
| 198 | RhsPackType::NxK}, | ||
| 199 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm), |
| 200 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm__RHS_KxN__", cpu_has_i8mm, |
| 201 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 202 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0, |
| 203 | RhsPackType::KxN}} | ||
| 204 | |||
| 205 | }; | ||
| 206 | |||
| 207 | using MatMulClampTestPortionedParams = std::tuple<size_t, MatMulShape, MatrixPortion, float>; | ||
| 208 | class MatMulTest_f32_qai8dxp_qsi4cxp : public ::testing::TestWithParam<MatMulClampTestPortionedParams> {}; | ||
| 209 | |||
| 210 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
12006 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, Offset_RHS) { |
| 211 | 27840 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 212 | 9600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
| 213 | |||
| 214 |
3/4✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
|
4800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 215 |
3/6✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
|
240 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 216 | } | ||
| 217 | |||
| 218 | 9120 | const size_t M = matmul_shape.m; | |
| 219 | 9120 | const size_t N = matmul_shape.n; | |
| 220 | 9120 | const size_t K = matmul_shape.k; | |
| 221 | |||
| 222 | 4560 | auto m_step = ukernel_variant.interface.get_m_step(); | |
| 223 | 4560 | auto n_step = ukernel_variant.interface.get_n_step(); | |
| 224 | |||
| 225 | 9120 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 226 |
2/4✓ Branch 0 taken 4560 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4560 times.
|
4560 | if (rect.height() == 0 || rect.width() == 0) { |
| 227 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 228 | } | ||
| 229 | |||
| 230 | 4560 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 231 | 4560 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 232 | 4560 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 233 | |||
| 234 | 4560 | const auto rhs_start_row = rect.start_col(); | |
| 235 | 4560 | auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr); | |
| 236 | 4560 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); | |
| 237 |
3/14✓ Branch 0 taken 4560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 4560 times.
|
4560 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
| 238 | 4800 | } | |
| 239 | |||
| 240 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
12006 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, Offset_LHS) { |
| 241 | 27840 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 242 | 9600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
| 243 | |||
| 244 |
3/4✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
|
4800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 245 |
3/6✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
|
240 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 246 | } | ||
| 247 | |||
| 248 | 9120 | const size_t M = matmul_shape.m; | |
| 249 | 9120 | const size_t N = matmul_shape.n; | |
| 250 | 9120 | const size_t K = matmul_shape.k; | |
| 251 | |||
| 252 | 4560 | auto m_step = ukernel_variant.interface.get_m_step(); | |
| 253 | 4560 | auto n_step = ukernel_variant.interface.get_n_step(); | |
| 254 | |||
| 255 | 9120 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 256 |
2/4✓ Branch 0 taken 4560 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4560 times.
|
4560 | if (rect.height() == 0 || rect.width() == 0) { |
| 257 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 258 | } | ||
| 259 | |||
| 260 | 4560 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 261 | 4560 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 262 | 4560 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 263 | |||
| 264 | 4560 | const auto lhs_start_row = rect.start_row(); | |
| 265 | 4560 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); | |
| 266 | 4560 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); | |
| 267 | |||
| 268 |
3/14✓ Branch 0 taken 4560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 4560 times.
|
4560 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
| 269 | 4800 | } | |
| 270 | |||
| 271 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
12006 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) { |
| 272 | 16800 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 273 | 9600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
| 274 | |||
| 275 |
3/4✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
|
4800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 276 |
3/6✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
|
240 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 277 | } | ||
| 278 |
2/2✓ Branch 0 taken 2160 times.
✓ Branch 1 taken 2400 times.
|
4560 | if (ukernel_variant.rhs_pack_type == RhsPackType::KxN) { |
| 279 |
3/6✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2160 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2160 times.
✗ Branch 5 not taken.
|
2160 | GTEST_SKIP() << "Wrong type. This test for NxK"; |
| 280 | } | ||
| 281 | |||
| 282 | 2400 | const uint32_t seed = 0; | |
| 283 | |||
| 284 | 4800 | const size_t M = matmul_shape.m; | |
| 285 | 4800 | const size_t N = matmul_shape.n; | |
| 286 | 4800 | const size_t K = matmul_shape.k; | |
| 287 | |||
| 288 | 2400 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 289 | 2400 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 290 | 2400 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 291 | 2400 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 292 | |||
| 293 | // Generates input data. | ||
| 294 | 2400 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
| 295 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
2400 | const auto ref_biases = fill_random<float>(N, seed + 2); |
| 296 | |||
| 297 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
2400 | std::uniform_real_distribution<float> dist(-10.0, 1.0); |
| 298 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
2400 | std::mt19937 rnd(seed + 1); |
| 299 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
10312800 | const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); }); |
| 300 | |||
| 301 | // Runs the reference implementation. | ||
| 302 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
| 303 | // * Quantizes the RHS matrix using 4-bit symmetric quantization. | ||
| 304 | // * Performs GEMM. | ||
| 305 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
2400 | QuantizationInfo lhs_qinfo{}; |
| 306 | lhs_qinfo.quant_width = K; | ||
| 307 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 308 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 309 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 310 | const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo); | ||
| 311 | |||
| 312 | QuantizationInfo rhs_qinfo{}; | ||
| 313 | rhs_qinfo.quant_width = K; | ||
| 314 | rhs_qinfo.dst_type = DataType::QSI4; | ||
| 315 | rhs_qinfo.scale_type = DataType::FP32; | ||
| 316 | const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo); | ||
| 317 | |||
| 318 | const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( | ||
| 319 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K, | ||
| 320 | ref_rhs_quant.data(), rhs_qoutputs.scales.data(), nullptr, K, ref_biases.data(), | ||
| 321 | std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | ||
| 322 | |||
| 323 | const auto [clamp_min, clamp_max] = | ||
| 324 | find_clamp_range(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_keep_ratio); | ||
| 325 | |||
| 326 | auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_min, clamp_max); | ||
| 327 | |||
| 328 | auto m_step = ukernel_variant.interface.get_m_step(); | ||
| 329 | ASSERT_TRUE(m_step % mr == 0); | ||
| 330 | |||
| 331 | auto n_step = ukernel_variant.interface.get_n_step(); | ||
| 332 | ASSERT_TRUE(n_step % nr == 0); | ||
| 333 | |||
| 334 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | ||
| 335 | if (rect.height() == 0 || rect.width() == 0) { | ||
| 336 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | ||
| 337 | } | ||
| 338 | |||
| 339 | const auto lhs_start_row = rect.start_row(); | ||
| 340 | size_t lhs_stride = K * sizeof(float); | ||
| 341 | |||
| 342 | // Runs the LHS packing micro-kernel. | ||
| 343 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); | ||
| 344 | Buffer imp_packed_lhs(imp_packed_lhs_size); | ||
| 345 | |||
| 346 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); | ||
| 347 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); | ||
| 348 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); | ||
| 349 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); | ||
| 350 | |||
| 351 | kai_run_lhs_quant_pack_qai8dxp_f32( | ||
| 352 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, | ||
| 353 | reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, | ||
| 354 | imp_packed_lhs.data() + lhs_packed_offset); | ||
| 355 | |||
| 356 | // Runs the RHS packing micro-kernel. | ||
| 357 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
| 358 | // * Packs the RHS matrix. | ||
| 359 | const auto ref_rhs_qsi4_padded = pad_row<Int4>( | ||
| 360 | ref_rhs_quant.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); | ||
| 361 | |||
| 362 | const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); | ||
| 363 | const auto rhs_start_row = rect.start_col(); | ||
| 364 | auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr); | ||
| 365 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); | ||
| 366 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); | ||
| 367 | |||
| 368 | auto rhs_offset = ukernel_variant.get_rhs_offset(rhs_start_row, round_up_division(K, 2)); | ||
| 369 | size_t bias_offset = rhs_start_row * sizeof(float); | ||
| 370 | size_t scale_offset = rhs_start_row * sizeof(float); | ||
| 371 | |||
| 372 | Buffer imp_packed_rhs(imp_packed_rhs_size); | ||
| 373 | kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{}; | ||
| 374 | params.lhs_zero_point = 1; | ||
| 375 | params.rhs_zero_point = 0; | ||
| 376 | |||
| 377 | abi_check( | ||
| 378 | ukernel_variant.run_rhs_pack, 1, rect.width() /* n */, K, nr, kr, sr, | ||
| 379 | reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data() + rhs_offset), | ||
| 380 | reinterpret_cast<const float*>(ref_biases.data() + bias_offset), | ||
| 381 | reinterpret_cast<const float*>(rhs_qoutputs.scales.data() + scale_offset), | ||
| 382 | imp_packed_rhs.data() + rhs_packed_offset, 0, ¶ms); | ||
| 383 | |||
| 384 | const auto dst_stride = N * sizeof(float); | ||
| 385 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); | ||
| 386 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); | ||
| 387 | ASSERT_EQ(dst_offset, ref_dst_offset); | ||
| 388 | |||
| 389 | // Runs the GEMM micro-kernel. | ||
| 390 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); | ||
| 391 | ASSERT_EQ(imp_dst_size, ref_dst.size()); | ||
| 392 | Buffer imp_dst(imp_dst_size); | ||
| 393 | abi_check( | ||
| 394 | ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, | ||
| 395 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), | ||
| 396 | N * sizeof(float), sizeof(float), clamp_min, clamp_max); | ||
| 397 | |||
| 398 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
| 399 | // tested. | ||
| 400 | for (size_t y = 0; y < rect.height(); ++y) { | ||
| 401 | for (size_t x = 0; x < rect.width(); ++x) { | ||
| 402 | const auto imp_value = | ||
| 403 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); | ||
| 404 | const auto ref_value = | ||
| 405 | read_array<float>(ref_clamped.data(), (rect.start_row() + y) * N + (x + rect.start_col())); | ||
| 406 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; | ||
| 407 | |||
| 408 | if (rel_error > 0.0001F) { | ||
| 409 | ASSERT_EQ(imp_value, ref_value); | ||
| 410 | } | ||
| 411 | } | ||
| 412 | } | ||
| 413 | ✗ | } | |
| 414 | |||
| 415 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
12006 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { |
| 416 | 16800 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 417 | 9600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
| 418 | |||
| 419 |
3/4✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
|
4800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 420 |
3/6✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
|
240 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 421 | } | ||
| 422 |
2/2✓ Branch 0 taken 2160 times.
✓ Branch 1 taken 2400 times.
|
4560 | if (ukernel_variant.rhs_pack_type == RhsPackType::KxN) { |
| 423 |
3/6✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2160 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2160 times.
✗ Branch 5 not taken.
|
2160 | GTEST_SKIP() << "Wrong type. This test for NxK"; |
| 424 | } | ||
| 425 | |||
| 426 | 2400 | const uint32_t seed = 0; | |
| 427 | |||
| 428 | 4800 | const size_t M = matmul_shape.m; | |
| 429 | 4800 | const size_t N = matmul_shape.n; | |
| 430 | 4800 | const size_t K = matmul_shape.k; | |
| 431 | |||
| 432 | 2400 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 433 | 2400 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 434 | 2400 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 435 | 2400 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 436 | |||
| 437 | // Generates input data. | ||
| 438 | 2400 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
| 439 | |||
| 440 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
2400 | std::uniform_real_distribution<float> dist(-10.0, 1.0); |
| 441 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
2400 | std::mt19937 rnd(seed + 1); |
| 442 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
10312800 | const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); }); |
| 443 | |||
| 444 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
2400 | const auto ref_biases = fill_random<float>(N, seed + 2); |
| 445 | |||
| 446 | // Runs the reference implementation. | ||
| 447 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
| 448 | // * Quantizes the RHS matrix using 4-bit symmetric quantization. | ||
| 449 | // * Performs GEMM. | ||
| 450 |
1/2✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
|
2400 | QuantizationInfo lhs_qinfo{}; |
| 451 | lhs_qinfo.quant_width = K; | ||
| 452 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 453 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 454 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 455 | const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo); | ||
| 456 | |||
| 457 | QuantizationInfo rhs_qinfo{}; | ||
| 458 | rhs_qinfo.quant_width = K; | ||
| 459 | rhs_qinfo.dst_type = DataType::QSI4; | ||
| 460 | rhs_qinfo.scale_type = DataType::FP32; | ||
| 461 | const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo); | ||
| 462 | |||
| 463 | const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( | ||
| 464 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K, | ||
| 465 | ref_rhs_quant.data(), rhs_qoutputs.scales.data(), nullptr, K, ref_biases.data(), | ||
| 466 | std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | ||
| 467 | |||
| 468 | const auto [clamp_min, clamp_max] = | ||
| 469 | find_clamp_range(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_keep_ratio); | ||
| 470 | |||
| 471 | auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_min, clamp_max); | ||
| 472 | |||
| 473 | auto m_step = ukernel_variant.interface.get_m_step(); | ||
| 474 | ASSERT_TRUE(m_step % mr == 0); | ||
| 475 | |||
| 476 | auto n_step = ukernel_variant.interface.get_n_step(); | ||
| 477 | ASSERT_TRUE(n_step % nr == 0); | ||
| 478 | |||
| 479 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | ||
| 480 | if (rect.height() == 0 || rect.width() == 0) { | ||
| 481 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | ||
| 482 | } | ||
| 483 | |||
| 484 | const auto lhs_start_row = rect.start_row(); | ||
| 485 | size_t lhs_stride = K * sizeof(float); | ||
| 486 | |||
| 487 | // Runs the LHS packing micro-kernel. | ||
| 488 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); | ||
| 489 | Buffer imp_packed_lhs(imp_packed_lhs_size); | ||
| 490 | |||
| 491 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); | ||
| 492 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); | ||
| 493 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); | ||
| 494 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); | ||
| 495 | |||
| 496 | kai_run_lhs_quant_pack_qai8dxp_f32( | ||
| 497 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, | ||
| 498 | reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, | ||
| 499 | imp_packed_lhs.data() + lhs_packed_offset); | ||
| 500 | |||
| 501 | const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_quant.data(), N * K); | ||
| 502 | // Runs the RHS packing micro-kernel. | ||
| 503 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
| 504 | // * Packs the RHS matrix. | ||
| 505 | const auto ref_rhs_qsu4_padded = pad_row<UInt4>( | ||
| 506 | ref_rhs_qsu4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); | ||
| 507 | |||
| 508 | const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); | ||
| 509 | const auto rhs_start_row = rect.start_col(); | ||
| 510 | auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr); | ||
| 511 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); | ||
| 512 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); | ||
| 513 | |||
| 514 | auto rhs_offset = ukernel_variant.get_rhs_offset(rhs_start_row, round_up_division(K, 2)); | ||
| 515 | size_t bias_offset = rhs_start_row * sizeof(float); | ||
| 516 | size_t scale_offset = rhs_start_row * sizeof(float); | ||
| 517 | |||
| 518 | Buffer imp_packed_rhs(imp_packed_rhs_size); | ||
| 519 | kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{}; | ||
| 520 | params.lhs_zero_point = 1; | ||
| 521 | params.rhs_zero_point = 8; | ||
| 522 | abi_check( | ||
| 523 | ukernel_variant.run_rhs_pack, 1, rect.width() /* n */, K, nr, kr, sr, | ||
| 524 | reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_padded.data() + rhs_offset), | ||
| 525 | reinterpret_cast<const float*>(ref_biases.data() + bias_offset), | ||
| 526 | reinterpret_cast<const float*>(rhs_qoutputs.scales.data() + scale_offset), | ||
| 527 | imp_packed_rhs.data() + rhs_packed_offset, 0, ¶ms); | ||
| 528 | |||
| 529 | const auto dst_stride = N * sizeof(float); | ||
| 530 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); | ||
| 531 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); | ||
| 532 | ASSERT_EQ(dst_offset, ref_dst_offset); | ||
| 533 | // Runs the GEMM micro-kernel. | ||
| 534 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); | ||
| 535 | ASSERT_EQ(imp_dst_size, ref_dst.size()); | ||
| 536 | Buffer imp_dst(imp_dst_size); | ||
| 537 | abi_check( | ||
| 538 | ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, | ||
| 539 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), | ||
| 540 | N * sizeof(float), sizeof(float), clamp_min, clamp_max); | ||
| 541 | |||
| 542 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
| 543 | // tested. | ||
| 544 | for (size_t y = 0; y < rect.height(); ++y) { | ||
| 545 | for (size_t x = 0; x < rect.width(); ++x) { | ||
| 546 | const auto imp_value = | ||
| 547 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); | ||
| 548 | const auto ref_value = | ||
| 549 | read_array<float>(ref_clamped.data(), (rect.start_row() + y) * N + (x + rect.start_col())); | ||
| 550 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; | ||
| 551 | |||
| 552 | if (rel_error > 0.0001F) { | ||
| 553 | ASSERT_EQ(imp_value, ref_value); | ||
| 554 | } | ||
| 555 | } | ||
| 556 | } | ||
| 557 | ✗ | } | |
| 558 | |||
| 559 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
12006 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { |
| 560 | 16080 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 561 | 9600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
| 562 | |||
| 563 |
3/4✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
|
4800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 564 |
3/6✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
|
240 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 565 | } | ||
| 566 |
2/2✓ Branch 0 taken 2160 times.
✓ Branch 1 taken 2400 times.
|
4560 | if (ukernel_variant.rhs_pack_type == RhsPackType::NxK) { |
| 567 |
3/6✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
|
2400 | GTEST_SKIP() << "Wrong type. This test for KxN"; |
| 568 | } | ||
| 569 | |||
| 570 | 2160 | const uint32_t seed = 0; | |
| 571 | |||
| 572 | 4320 | const size_t M = matmul_shape.m; | |
| 573 | 4320 | const size_t N = matmul_shape.n; | |
| 574 | 4320 | const size_t K = matmul_shape.k; | |
| 575 | |||
| 576 | 2160 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 577 | 2160 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 578 | 2160 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 579 | 2160 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 580 | |||
| 581 | // Generates input data. | ||
| 582 | 2160 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
| 583 | |||
| 584 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | std::uniform_real_distribution<float> dist(-10.0, 1.0); |
| 585 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | std::mt19937 rnd(seed + 1); |
| 586 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
9281520 | const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); }); |
| 587 | |||
| 588 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | const auto ref_biases = fill_random<float>(N, seed + 2); |
| 589 | |||
| 590 | // Transposed(nxk) RHS dimensions | ||
| 591 | 2160 | const size_t ref_rhs_qsi4_nxk_stride = K; | |
| 592 | |||
| 593 | // Non-Transposed(kxn) RHS dimensions | ||
| 594 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2); |
| 595 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(K * ref_rhs_qsi4_kxn_stride, 2); |
| 596 | |||
| 597 | // Runs the reference implementation. | ||
| 598 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
| 599 | // * Quantizes the RHS matrix using 4-bit symmetric quantization. | ||
| 600 | // * Performs GEMM. | ||
| 601 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | QuantizationInfo lhs_qinfo{}; |
| 602 | lhs_qinfo.quant_width = K; | ||
| 603 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 604 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 605 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 606 | const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo); | ||
| 607 | |||
| 608 | QuantizationInfo rhs_qinfo{}; | ||
| 609 | rhs_qinfo.quant_width = K; | ||
| 610 | rhs_qinfo.dst_type = DataType::QSI4; | ||
| 611 | rhs_qinfo.scale_type = DataType::FP32; | ||
| 612 | const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo); | ||
| 613 | |||
| 614 | const auto ref_rhs_qsi4 = transpose_with_padding<Int4>( | ||
| 615 | ref_rhs_quant.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); | ||
| 616 | |||
| 617 | const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( | ||
| 618 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K, | ||
| 619 | ref_rhs_qsi4.data(), rhs_qoutputs.scales.data(), nullptr, K, ref_biases.data(), | ||
| 620 | std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | ||
| 621 | |||
| 622 | const auto [clamp_min, clamp_max] = | ||
| 623 | find_clamp_range(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_keep_ratio); | ||
| 624 | |||
| 625 | auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_min, clamp_max); | ||
| 626 | |||
| 627 | auto m_step = ukernel_variant.interface.get_m_step(); | ||
| 628 | ASSERT_TRUE(m_step % mr == 0); | ||
| 629 | |||
| 630 | auto n_step = ukernel_variant.interface.get_n_step(); | ||
| 631 | ASSERT_TRUE(n_step % nr == 0); | ||
| 632 | |||
| 633 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | ||
| 634 | if (rect.height() == 0 || rect.width() == 0) { | ||
| 635 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | ||
| 636 | } | ||
| 637 | |||
| 638 | const auto lhs_start_row = rect.start_row(); | ||
| 639 | size_t lhs_stride = K * sizeof(float); | ||
| 640 | |||
| 641 | // Runs the LHS packing micro-kernel. | ||
| 642 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); | ||
| 643 | Buffer imp_packed_lhs(imp_packed_lhs_size); | ||
| 644 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); | ||
| 645 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); | ||
| 646 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); | ||
| 647 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); | ||
| 648 | |||
| 649 | kai_run_lhs_quant_pack_qai8dxp_f32( | ||
| 650 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, | ||
| 651 | reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, | ||
| 652 | imp_packed_lhs.data() + lhs_packed_offset); | ||
| 653 | |||
| 654 | // Runs the RHS packing micro-kernel. | ||
| 655 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
| 656 | // * Packs the RHS matrix. | ||
| 657 | const auto ref_rhs_qsi4_padded = pad_row<Int4>( | ||
| 658 | ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); | ||
| 659 | const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); | ||
| 660 | |||
| 661 | const auto rhs_start_row = rect.start_col(); | ||
| 662 | auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(rhs_start_row, K, nr, kr, sr); | ||
| 663 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); | ||
| 664 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); | ||
| 665 | |||
| 666 | Buffer imp_packed_rhs(imp_packed_rhs_size); | ||
| 667 | kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{}; | ||
| 668 | params.lhs_zero_point = 1; | ||
| 669 | params.rhs_zero_point = 0; | ||
| 670 | abi_check( | ||
| 671 | ukernel_variant.run_rhs_pack, 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()), | ||
| 672 | reinterpret_cast<const float*>(ref_biases.data()), reinterpret_cast<const float*>(rhs_qoutputs.scales.data()), | ||
| 673 | imp_packed_rhs.data(), 0, ¶ms); | ||
| 674 | |||
| 675 | const auto dst_stride = N * sizeof(float); | ||
| 676 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); | ||
| 677 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); | ||
| 678 | ASSERT_EQ(dst_offset, ref_dst_offset); | ||
| 679 | |||
| 680 | // Runs the GEMM micro-kernel. | ||
| 681 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); | ||
| 682 | ASSERT_EQ(imp_dst_size, ref_dst.size()); | ||
| 683 | Buffer imp_dst(imp_dst_size); | ||
| 684 | abi_check( | ||
| 685 | ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, | ||
| 686 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), | ||
| 687 | N * sizeof(float), sizeof(float), clamp_min, clamp_max); | ||
| 688 | |||
| 689 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
| 690 | // tested. | ||
| 691 | for (size_t y = 0; y < rect.height(); ++y) { | ||
| 692 | for (size_t x = 0; x < rect.width(); ++x) { | ||
| 693 | const auto imp_value = | ||
| 694 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); | ||
| 695 | const auto ref_value = | ||
| 696 | read_array<float>(ref_clamped.data(), (rect.start_row() + y) * N + (x + rect.start_col())); | ||
| 697 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; | ||
| 698 | |||
| 699 | if (rel_error > 0.0001F) { | ||
| 700 | ASSERT_EQ(imp_value, ref_value); | ||
| 701 | } | ||
| 702 | } | ||
| 703 | } | ||
| 704 | ✗ | } | |
| 705 | |||
| 706 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
12006 | TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { |
| 707 | 16080 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 708 | 9600 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); | |
| 709 | |||
| 710 |
3/4✓ Branch 0 taken 4800 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4560 times.
✓ Branch 3 taken 240 times.
|
4800 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 711 |
3/6✓ Branch 0 taken 240 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 240 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 240 times.
✗ Branch 5 not taken.
|
240 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 712 | } | ||
| 713 |
2/2✓ Branch 0 taken 2160 times.
✓ Branch 1 taken 2400 times.
|
4560 | if (ukernel_variant.rhs_pack_type == RhsPackType::NxK) { |
| 714 |
3/6✓ Branch 0 taken 2400 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2400 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2400 times.
✗ Branch 5 not taken.
|
2400 | GTEST_SKIP() << "Wrong type. This test for KxN"; |
| 715 | } | ||
| 716 | |||
| 717 | 2160 | const uint32_t seed = 0; | |
| 718 | |||
| 719 | 4320 | const size_t M = matmul_shape.m; | |
| 720 | 4320 | const size_t N = matmul_shape.n; | |
| 721 | 4320 | const size_t K = matmul_shape.k; | |
| 722 | |||
| 723 | 2160 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 724 | 2160 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 725 | 2160 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 726 | 2160 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 727 | |||
| 728 | // Generates input data. | ||
| 729 | 2160 | const auto ref_lhs = fill_random<float>(M * K, seed + 0); | |
| 730 | |||
| 731 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | std::uniform_real_distribution<float> dist(-10.0, 1.0); |
| 732 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | std::mt19937 rnd(seed + 1); |
| 733 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
9281520 | const auto ref_rhs = fill_matrix_raw<float>(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); }); |
| 734 | |||
| 735 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | const auto ref_biases = fill_random<float>(N, seed + 2); |
| 736 | |||
| 737 | // Transposed(nxk) RHS dimensions | ||
| 738 | 2160 | const size_t ref_rhs_qsi4_nxk_stride = K; | |
| 739 | |||
| 740 | // Non-Transposed(kxn) RHS dimensions | ||
| 741 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2); |
| 742 | 2160 | const size_t ref_rhs_qsi4_kxn_size = K * ref_rhs_qsi4_kxn_stride; | |
| 743 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(ref_rhs_qsi4_kxn_size, 2); |
| 744 | |||
| 745 | // Runs the reference implementation. | ||
| 746 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
| 747 | // * Quantizes the RHS matrix using 4-bit symmetric quantization. | ||
| 748 | // * Performs GEMM. | ||
| 749 |
1/2✓ Branch 0 taken 2160 times.
✗ Branch 1 not taken.
|
2160 | QuantizationInfo lhs_qinfo{}; |
| 750 | lhs_qinfo.quant_width = K; | ||
| 751 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 752 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 753 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 754 | const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo); | ||
| 755 | |||
| 756 | QuantizationInfo rhs_qinfo{}; | ||
| 757 | rhs_qinfo.quant_width = K; | ||
| 758 | rhs_qinfo.dst_type = DataType::QSI4; | ||
| 759 | rhs_qinfo.scale_type = DataType::FP32; | ||
| 760 | const auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo); | ||
| 761 | |||
| 762 | const auto ref_rhs_qsi4 = transpose_with_padding<Int4>( | ||
| 763 | ref_rhs_quant.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); | ||
| 764 | |||
| 765 | const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( | ||
| 766 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K, | ||
| 767 | ref_rhs_qsi4.data(), rhs_qoutputs.scales.data(), nullptr, K, ref_biases.data(), | ||
| 768 | std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | ||
| 769 | |||
| 770 | const auto [clamp_min, clamp_max] = | ||
| 771 | find_clamp_range(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_keep_ratio); | ||
| 772 | |||
| 773 | auto ref_clamped = clamp(DataType::FP32, ref_dst.data(), matmul_shape.m * matmul_shape.n, clamp_min, clamp_max); | ||
| 774 | |||
| 775 | auto m_step = ukernel_variant.interface.get_m_step(); | ||
| 776 | ASSERT_TRUE(m_step % mr == 0); | ||
| 777 | |||
| 778 | auto n_step = ukernel_variant.interface.get_n_step(); | ||
| 779 | ASSERT_TRUE(n_step % nr == 0); | ||
| 780 | |||
| 781 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | ||
| 782 | if (rect.height() == 0 || rect.width() == 0) { | ||
| 783 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | ||
| 784 | } | ||
| 785 | |||
| 786 | const auto lhs_start_row = rect.start_row(); | ||
| 787 | size_t lhs_stride = K * sizeof(float); | ||
| 788 | |||
| 789 | // Runs the LHS packing micro-kernel. | ||
| 790 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); | ||
| 791 | Buffer imp_packed_lhs(imp_packed_lhs_size); | ||
| 792 | |||
| 793 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); | ||
| 794 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); | ||
| 795 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); | ||
| 796 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); | ||
| 797 | |||
| 798 | kai_run_lhs_quant_pack_qai8dxp_f32( | ||
| 799 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, | ||
| 800 | reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, | ||
| 801 | imp_packed_lhs.data() + lhs_packed_offset); | ||
| 802 | |||
| 803 | // Runs the RHS packing micro-kernel. | ||
| 804 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
| 805 | // * Packs the RHS matrix. | ||
| 806 | const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), ref_rhs_qsi4_kxn_size); | ||
| 807 | const auto ref_rhs_qsu4_padded = pad_row<UInt4>( | ||
| 808 | ref_rhs_qsu4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); | ||
| 809 | const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); | ||
| 810 | |||
| 811 | const auto rhs_start_row = rect.start_col(); | ||
| 812 | auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr); | ||
| 813 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); | ||
| 814 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); | ||
| 815 | |||
| 816 | Buffer imp_packed_rhs(imp_packed_rhs_size); | ||
| 817 | kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{}; | ||
| 818 | params.lhs_zero_point = 1; | ||
| 819 | params.rhs_zero_point = 8; | ||
| 820 | abi_check( | ||
| 821 | ukernel_variant.run_rhs_pack, 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsu4_padded.data()), | ||
| 822 | reinterpret_cast<const float*>(ref_biases.data()), reinterpret_cast<const float*>(rhs_qoutputs.scales.data()), | ||
| 823 | imp_packed_rhs.data(), 0, ¶ms); | ||
| 824 | |||
| 825 | const auto dst_stride = N * sizeof(float); | ||
| 826 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); | ||
| 827 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); | ||
| 828 | ASSERT_EQ(dst_offset, ref_dst_offset); | ||
| 829 | |||
| 830 | // Runs the GEMM micro-kernel. | ||
| 831 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); | ||
| 832 | ASSERT_EQ(imp_dst_size, ref_dst.size()); | ||
| 833 | Buffer imp_dst(imp_dst_size); | ||
| 834 | abi_check( | ||
| 835 | ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, | ||
| 836 | imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), | ||
| 837 | N * sizeof(float), sizeof(float), clamp_min, clamp_max); | ||
| 838 | |||
| 839 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
| 840 | // tested. | ||
| 841 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); | ||
| 842 | DataFormat dst_format = DataFormat(DataType::FP32); | ||
| 843 | const auto success = compare(imp_dst.data(), ref_clamped.data(), dst_format, M, N, rect, handler); | ||
| 844 | ASSERT_TRUE(success); | ||
| 845 | ✗ | } | |
| 846 | |||
| 847 |
29/94✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 6 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 6 times.
✓ Branch 12 taken 12 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 6 times.
✓ Branch 14 taken 12 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✓ Branch 16 taken 12 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 6 times.
✓ Branch 18 taken 12 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 6 times.
✓ Branch 20 taken 12 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✓ Branch 22 taken 12 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 6 times.
✓ Branch 24 taken 12 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 6 times.
✓ Branch 26 taken 12 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 12 times.
✓ Branch 28 taken 14400 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✓ Branch 30 taken 28800 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 48 taken 14400 times.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✓ Branch 50 taken 28800 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 28800 times.
✗ Branch 53 not taken.
|
86421 | INSTANTIATE_TEST_SUITE_P( |
| 848 | MatMul, MatMulTest_f32_qai8dxp_qsi4cxp, | ||
| 849 | testing::Combine( | ||
| 850 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.size()), | ||
| 851 | testing::Values( | ||
| 852 | MatMulShape{16, 32, 64}, // | ||
| 853 | MatMulShape{16, 32, 36}, // | ||
| 854 | MatMulShape{15, 35, 65}, // | ||
| 855 | MatMulShape{8, 32, 64}, // | ||
| 856 | MatMulShape{15, 31, 45}, // | ||
| 857 | MatMulShape{1, 35, 65}, // | ||
| 858 | MatMulShape{1, 128, 32}, // | ||
| 859 | MatMulShape{64, 128, 32}, // | ||
| 860 | MatMulShape{1, 225, 55}, // | ||
| 861 | MatMulShape{125, 200, 56}), | ||
| 862 | testing::Values( | ||
| 863 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 864 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
| 865 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
| 866 | MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle | ||
| 867 | ), | ||
| 868 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 869 | [](const auto& info) { | ||
| 870 | const auto variant_idx = std::get<0>(info.param); | ||
| 871 | const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_idx).name}; | ||
| 872 | const auto shape = std::get<MatMulShape>(info.param); | ||
| 873 | const auto portion = std::get<2>(info.param); | ||
| 874 | const auto clamp_keep_ratio = std::get<3>(info.param); | ||
| 875 | |||
| 876 | return test_description(name, shape, portion, true, clamp_keep_ratio); | ||
| 877 | }); | ||
| 878 | |||
| 879 | } // namespace kai::test | ||
| 880 |