test/tests/matmul_clamp_qai8dxp_qsi4c32p_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <ostream> | ||
| 12 | #include <sstream> | ||
| 13 | #include <string> | ||
| 14 | #include <tuple> | ||
| 15 | #include <utility> | ||
| 16 | #include <vector> | ||
| 17 | |||
| 18 | #include "kai/kai_common.h" | ||
| 19 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod.h" | ||
| 20 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm.h" | ||
| 21 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4c32p/kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_interface.h" | ||
| 22 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" | ||
| 23 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot.h" | ||
| 24 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h" | ||
| 25 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod.h" | ||
| 26 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" | ||
| 27 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod.h" | ||
| 28 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" | ||
| 29 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h" | ||
| 30 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod.h" | ||
| 31 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" | ||
| 32 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" | ||
| 33 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm.h" | ||
| 34 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" | ||
| 35 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" | ||
| 36 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h" | ||
| 37 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" | ||
| 38 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" | ||
| 39 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon.h" | ||
| 40 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" | ||
| 41 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon.h" | ||
| 42 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.h" | ||
| 43 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon.h" | ||
| 44 | #include "test/common/abi_checker.hpp" | ||
| 45 | #include "test/common/bfloat16.hpp" | ||
| 46 | #include "test/common/buffer.hpp" | ||
| 47 | #include "test/common/cache.hpp" | ||
| 48 | #include "test/common/compare.hpp" | ||
| 49 | #include "test/common/cpu_info.hpp" | ||
| 50 | #include "test/common/data_type.hpp" | ||
| 51 | #include "test/common/int4.hpp" | ||
| 52 | #include "test/common/matmul_test_common.hpp" | ||
| 53 | #include "test/common/matrix_portion.hpp" | ||
| 54 | #include "test/common/memory.hpp" | ||
| 55 | #include "test/common/round.hpp" | ||
| 56 | #include "test/common/seed.hpp" | ||
| 57 | #include "test/common/test_suite.hpp" | ||
| 58 | #include "test/reference/cast.hpp" | ||
| 59 | #include "test/reference/clamp.hpp" | ||
| 60 | #include "test/reference/fill.hpp" | ||
| 61 | #include "test/reference/matmul.hpp" | ||
| 62 | #include "test/reference/pad.hpp" | ||
| 63 | #include "test/reference/quantize.hpp" | ||
| 64 | #include "test/reference/transpose.hpp" | ||
| 65 | |||
| 66 | namespace kai::test { | ||
| 67 | |||
| 68 | namespace { | ||
| 69 | |||
| 70 | // LHS QAI8DXP | ||
| 71 | using kai_get_lhs_packed_size_func_t = decltype(&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32); | ||
| 72 | using kai_get_lhs_packed_offset_func_t = decltype(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32); | ||
| 73 | using kai_get_lhs_offset_func_t = decltype(&kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32); | ||
| 74 | using kai_run_lhs_pack_func_t = decltype(&kai_run_lhs_quant_pack_qai8dxp_f32); | ||
| 75 | |||
| 76 | // LHS QAI8DXP pack interface | ||
| 77 | struct kai_qai8dxp_pack_functions { | ||
| 78 | kai_get_lhs_packed_size_func_t packed_size; | ||
| 79 | kai_get_lhs_packed_offset_func_t get_packed_offset; | ||
| 80 | kai_get_lhs_offset_func_t get_offset; | ||
| 81 | kai_run_lhs_pack_func_t run_pack; | ||
| 82 | }; | ||
| 83 | |||
| 84 | // RHS QSI4C32P (nxk, BF16 block scales; sums float, bias float) | ||
| 85 | using kai_get_rhs_packed_size_func_t = decltype(&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0); | ||
| 86 | using kai_get_rhs_packed_offset_func_t = decltype(&kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0); | ||
| 87 | using kai_get_rhs_offset_func_t = decltype(&kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0); | ||
| 88 | using kai_run_rhs_pack_func_t = decltype(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0); | ||
| 89 | |||
| 90 | // RHS QSI4C32P pack interface | ||
| 91 | struct kai_qsi4c32p_pack_functions { | ||
| 92 | kai_get_rhs_packed_size_func_t packed_size; | ||
| 93 | kai_get_rhs_packed_offset_func_t get_packed_offset; | ||
| 94 | kai_get_rhs_offset_func_t get_offset; | ||
| 95 | kai_run_rhs_pack_func_t run_pack; | ||
| 96 | }; | ||
| 97 | |||
| 98 | 80972 | const auto& get_f32_gemm_variants() noexcept { | |
| 99 | using Variant = UkernelMatmulPackVariant< | ||
| 100 | kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>; | ||
| 101 | |||
| 102 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 80969 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
80973 | static const std::array<Variant, 12> variants = {{ |
| 103 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 104 | clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32, | ||
| 105 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, | ||
| 106 | /*rhs_s0s1_input=*/false), | ||
| 107 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 108 | clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32, | ||
| 109 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 110 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 111 | clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32, | ||
| 112 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 113 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 114 | clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32, | ||
| 115 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 116 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 117 | clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32, | ||
| 118 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 119 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 120 | clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32, | ||
| 121 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 122 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 123 | clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod, cpu_has_dotprod, lhs_quant_pack_qai8dxp_f32, | ||
| 124 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 125 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 126 | clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qai8dxp_f32, | ||
| 127 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 128 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 129 | clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qai8dxp_f32, | ||
| 130 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 131 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 132 | clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qai8dxp_f32, | ||
| 133 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 134 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 135 | clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm, cpu_has_i8mm, lhs_quant_pack_qai8dxp_f32, | ||
| 136 | rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, false), | ||
| 137 | // SME2 MOPA | ||
| 138 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 139 | clamp_f32_qai8dxp1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, cpu_has_sme2, lhs_quant_pack_qai8dxp_f32, | ||
| 140 | rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon, false), | ||
| 141 | }}; | ||
| 142 | |||
| 143 | 80972 | return variants; | |
| 144 | } | ||
| 145 | |||
| 146 | 1322 | const auto& get_f32_gemv_variants() noexcept { | |
| 147 | using Variant = UkernelMatmulPackVariant< | ||
| 148 | kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>; | ||
| 149 | |||
| 150 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1319 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
1323 | static const std::array<Variant, 1> variants = {{ |
| 151 |
3/6✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
|
3 | UKERNEL_MATMUL_PACK_VARIANT( |
| 152 | clamp_f32_qai8dxp1x4_qsi4c32p4vlx4_1x4vl_sme2_dot, cpu_has_sme2, lhs_quant_pack_qai8dxp_f32, | ||
| 153 | rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon, false), | ||
| 154 | }}; | ||
| 155 | |||
| 156 | 1322 | return variants; | |
| 157 | } | ||
| 158 | |||
| 159 | 771 | const auto& get_bf16_gemm_variants() noexcept { | |
| 160 | using Variant = UkernelVariant<kai_matmul_clamp_bf16_qai8dxp_qsi4c32p_ukernel>; | ||
| 161 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 768 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
774 | static const std::array<Variant, 2> variants = { |
| 162 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
7 | Variant{ |
| 163 | 3 | UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod), | |
| 164 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4c32p4x8_1x4_neon_dotprod", cpu_has_dotprod_and_bf16}, |
| 165 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
5 | Variant{ |
| 166 | 3 | UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm), | |
| 167 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4c32p4x8_16x4_neon_i8mm", cpu_has_i8mm_and_bf16}, |
| 168 | }; | ||
| 169 | 771 | return variants; | |
| 170 | } | ||
| 171 | |||
| 172 | // NEON/i8mm only (exclude SME2) | ||
| 173 | 6 | const auto& get_f32_neon_gemm_variants_only() { | |
| 174 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
6 | static std::vector<UkernelMatmulPackVariant< |
| 175 | kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>> | ||
| 176 | 3 | filtered; | |
| 177 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (filtered.empty()) { |
| 178 | 3 | const auto& all = get_f32_gemm_variants(); | |
| 179 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 3 times.
|
39 | for (const auto& v : all) { |
| 180 | 36 | const char* n = v.ukernel.name.data(); | |
| 181 |
3/4✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 33 times.
|
36 | if (n == nullptr || std::strstr(n, "sme2") == nullptr) { |
| 182 | 33 | filtered.push_back(v); | |
| 183 | 33 | } | |
| 184 | 36 | } | |
| 185 | 3 | } | |
| 186 | 6 | return filtered; | |
| 187 | } | ||
| 188 | |||
| 189 | enum class RhsPackType : std::uint8_t { NxK = 0, KxN = 1 }; | ||
| 190 | |||
| 191 | 3430 | std::tuple<Buffer, size_t> pack_lhs_qai8dxp( | |
| 192 | const kai_qai8dxp_pack_functions& pack_interface, const size_t M, const size_t K, const size_t mr, const size_t kr, | ||
| 193 | const size_t sr, const Buffer& lhs_values_f32, const size_t lhs_stride_bytes, const size_t rect_start_row, | ||
| 194 | const size_t rect_height) { | ||
| 195 | 3430 | const auto lhs_packed_size = pack_interface.packed_size(M, K, mr, kr, sr); | |
| 196 | 3430 | Buffer lhs_packed(lhs_packed_size, 0); | |
| 197 | |||
| 198 |
1/2✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
|
3430 | const auto lhs_offset = pack_interface.get_offset(rect_start_row, lhs_stride_bytes); |
| 199 |
1/2✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
|
3430 | const auto lhs_packed_offset = pack_interface.get_packed_offset(rect_start_row, K, mr, kr, sr); |
| 200 | |||
| 201 |
1/2✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
|
3430 | abi_check( |
| 202 | 3430 | pack_interface.run_pack, rect_height, K, mr, kr, sr, 0, | |
| 203 |
1/2✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
|
3430 | reinterpret_cast<const float*>(lhs_values_f32.data() + lhs_offset), lhs_stride_bytes, |
| 204 |
1/2✓ Branch 0 taken 3430 times.
✗ Branch 1 not taken.
|
3430 | lhs_packed.data() + lhs_packed_offset); |
| 205 | |||
| 206 | 3430 | return {std::move(lhs_packed), lhs_packed_offset}; | |
| 207 | 3430 | } | |
| 208 | |||
| 209 | // Executes the scalar RHS packing micro-kernel. | ||
| 210 | 1250 | std::tuple<Buffer, size_t> pack_rhs_qsi4c32pscalebf16( | |
| 211 | // clang-format off | ||
| 212 | const size_t N, | ||
| 213 | const size_t K, | ||
| 214 | const size_t nr, | ||
| 215 | const size_t kr, | ||
| 216 | const size_t sr, | ||
| 217 | const size_t bl, | ||
| 218 | const Buffer& rhs_values_qsi4, | ||
| 219 | const Buffer& biases, | ||
| 220 | const size_t bias_offset, | ||
| 221 | const Buffer& rhs_scales, | ||
| 222 | const RhsPackType pack_type, | ||
| 223 | const size_t rect_start_row, | ||
| 224 | const size_t rect_width, | ||
| 225 | const bool use_ps1s0) { | ||
| 226 | // clang-format on | ||
| 227 |
2/2✓ Branch 0 taken 1154 times.
✓ Branch 1 taken 96 times.
|
1250 | const size_t width = pack_type == RhsPackType::KxN ? N : K; |
| 228 |
2/2✓ Branch 0 taken 1154 times.
✓ Branch 1 taken 96 times.
|
1250 | const size_t height = pack_type == RhsPackType::KxN ? K : N; |
| 229 | 1250 | constexpr kai_datatype scale_dt = kai_dt_bf16; | |
| 230 | |||
| 231 | 1250 | const size_t rhs_stride = round_up_multiple(width, 2); | |
| 232 | 1250 | const size_t rhs_stride_bytes = round_up_division(width, 2); | |
| 233 | 1250 | const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt); | |
| 234 | |||
| 235 | − | KAI_ASSUME_ALWAYS(rhs_values_qsi4.size() == round_up_division(height * rhs_stride, 2)); | |
| 236 | |||
| 237 | 1250 | const auto rhs_values_qsu4 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2); | |
| 238 |
1/2✓ Branch 0 taken 1250 times.
✗ Branch 1 not taken.
|
1250 | const size_t dst_bytes_total = round_up_division(height * rhs_stride, 2); |
| 239 | 1250 | const size_t dst_bytes_total_safe = dst_bytes_total + rhs_stride_bytes + 8; | |
| 240 | 1250 | const auto rhs_qsu4 = | |
| 241 |
1/2✓ Branch 0 taken 1250 times.
✗ Branch 1 not taken.
|
1250 | pad_row<UInt4>(rhs_values_qsu4.data(), height, width, width, rhs_stride_bytes * 2, dst_bytes_total_safe); |
| 242 | |||
| 243 | 1250 | const size_t scale_offset = rect_start_row * scales_stride_bytes; | |
| 244 | 1250 | size_t rhs_offset = 0; | |
| 245 | 1250 | size_t rhs_packed_offset = 0; | |
| 246 | 1250 | size_t imp_packed_rhs_size = 0; | |
| 247 | |||
| 248 |
2/2✓ Branch 0 taken 1154 times.
✓ Branch 1 taken 96 times.
|
1250 | if (pack_type == RhsPackType::KxN) { |
| 249 |
2/2✓ Branch 0 taken 46 times.
✓ Branch 1 taken 1108 times.
|
1154 | if (use_ps1s0) { |
| 250 | 46 | rhs_offset = | |
| 251 |
1/2✓ Branch 0 taken 46 times.
✗ Branch 1 not taken.
|
46 | kai_get_rhs_offset_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes); |
| 252 |
1/2✓ Branch 0 taken 46 times.
✗ Branch 1 not taken.
|
46 | rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon( |
| 253 | 46 | rect_start_row, K, nr, kr, sr, bl, scale_dt); | |
| 254 | 46 | imp_packed_rhs_size = | |
| 255 |
1/2✓ Branch 0 taken 46 times.
✗ Branch 1 not taken.
|
46 | kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); |
| 256 | 46 | } else { | |
| 257 |
1/2✓ Branch 0 taken 1108 times.
✗ Branch 1 not taken.
|
1108 | rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(rect_start_row, rhs_stride_bytes); |
| 258 |
1/2✓ Branch 0 taken 1108 times.
✗ Branch 1 not taken.
|
1108 | rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( |
| 259 | 1108 | rect_start_row, K, nr, kr, sr, bl, scale_dt); | |
| 260 | 1108 | imp_packed_rhs_size = | |
| 261 |
1/2✓ Branch 0 taken 1108 times.
✗ Branch 1 not taken.
|
1108 | kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); |
| 262 | } | ||
| 263 | 1154 | } else { | |
| 264 |
1/2✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
|
96 | rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rect_start_row, rhs_stride_bytes); |
| 265 | 96 | rhs_packed_offset = | |
| 266 |
1/2✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
|
96 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(rect_start_row, K, nr, kr, sr, bl, scale_dt); |
| 267 |
1/2✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
|
96 | imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); |
| 268 | } | ||
| 269 | |||
| 270 |
1/2✓ Branch 0 taken 1250 times.
✗ Branch 1 not taken.
|
1250 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
| 271 |
2/2✓ Branch 0 taken 1154 times.
✓ Branch 1 taken 96 times.
|
1250 | if (pack_type == RhsPackType::KxN) { |
| 272 | 1154 | kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{}; | |
| 273 | 1154 | params.lhs_zero_point = 1; | |
| 274 | 1154 | params.rhs_zero_point = 8; | |
| 275 | 1154 | params.scale_dt = scale_dt; | |
| 276 | |||
| 277 |
2/2✓ Branch 0 taken 46 times.
✓ Branch 1 taken 1108 times.
|
1154 | if (use_ps1s0) { |
| 278 | // clang-format off | ||
| 279 |
1/2✓ Branch 0 taken 46 times.
✗ Branch 1 not taken.
|
46 | abi_check( |
| 280 | kai_run_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon, | ||
| 281 | 46 | 1, // num_groups | |
| 282 | rect_width, // n | ||
| 283 | K, // k | ||
| 284 | nr, kr, sr, bl, // packing args | ||
| 285 | 46 | reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset), | |
| 286 | rhs_stride_bytes, | ||
| 287 | 46 | reinterpret_cast<const float*>(biases.data() + bias_offset), | |
| 288 | 46 | reinterpret_cast<const void*>(rhs_scales.data() + scale_offset), | |
| 289 | scales_stride_bytes, | ||
| 290 | 46 | static_cast<void*>(imp_packed_rhs.data() + rhs_packed_offset), | |
| 291 | 46 | 0, | |
| 292 | 46 | ¶ms); | |
| 293 | // clang-format on | ||
| 294 | 46 | } else { | |
| 295 | 1108 | kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params_kxn{}; | |
| 296 | 1108 | params_kxn.lhs_zero_point = 1; | |
| 297 | 1108 | params_kxn.rhs_zero_point = 8; | |
| 298 | 1108 | params_kxn.scale_dt = scale_dt; | |
| 299 | |||
| 300 | // clang-format off | ||
| 301 |
1/2✓ Branch 0 taken 1108 times.
✗ Branch 1 not taken.
|
1108 | abi_check( |
| 302 | kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0, | ||
| 303 | 1108 | 1, | |
| 304 | rect_width, | ||
| 305 | K, | ||
| 306 | nr, kr, sr, bl, | ||
| 307 | 1108 | reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset), | |
| 308 | rhs_stride_bytes, | ||
| 309 | 1108 | reinterpret_cast<const float*>(biases.data() + bias_offset), | |
| 310 | 1108 | reinterpret_cast<const void*>(rhs_scales.data() + scale_offset), | |
| 311 | scales_stride_bytes, | ||
| 312 | 1108 | static_cast<void*>(imp_packed_rhs.data() + rhs_packed_offset), | |
| 313 | 1108 | 0, | |
| 314 | 1108 | ¶ms_kxn); | |
| 315 | // clang-format on | ||
| 316 | 1108 | } | |
| 317 | 1154 | } else { | |
| 318 | 96 | kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{}; | |
| 319 | 96 | params.lhs_zero_point = 1; | |
| 320 | 96 | params.rhs_zero_point = 8; | |
| 321 | 96 | params.scale_dt = scale_dt; | |
| 322 | |||
| 323 |
1/2✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
|
96 | abi_check( |
| 324 | // clang-format off | ||
| 325 | kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0, | ||
| 326 | 96 | 1, | |
| 327 | rect_width, | ||
| 328 | K, | ||
| 329 | nr, kr, sr, bl, | ||
| 330 | 96 | reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset), | |
| 331 | rhs_stride_bytes, | ||
| 332 | 96 | reinterpret_cast<const float*>(biases.data() + bias_offset), | |
| 333 | 96 | reinterpret_cast<const void*>(rhs_scales.data() + scale_offset), | |
| 334 | scales_stride_bytes, | ||
| 335 | 96 | static_cast<void*>(imp_packed_rhs.data() + rhs_packed_offset), | |
| 336 | 96 | 0, | |
| 337 | 96 | ¶ms); | |
| 338 | // clang-format on | ||
| 339 | 96 | } | |
| 340 | |||
| 341 | 1250 | return {std::move(imp_packed_rhs), rhs_packed_offset}; | |
| 342 | 1250 | } | |
| 343 | |||
| 344 | /// Executes RHS NxK packing helper | ||
| 345 | 1078 | std::tuple<Buffer, size_t> pack_rhs_qsi4c32p_nxk( | |
| 346 | const kai_qsi4c32p_pack_functions& pack_iface, const size_t N, const size_t K, const size_t nr, const size_t kr, | ||
| 347 | const size_t sr, const size_t bl, const Buffer& rhs_values_qsi4, const float* bias, const Buffer& rhs_scales, | ||
| 348 | const size_t rect_start_row, const size_t rect_width, const bool rhs_s0s1_input) { | ||
| 349 | // Convert signed int4 -> unsigned int4, preserving any row padding in the source buffer. | ||
| 350 | 1078 | const auto rhs_qsu4s1s0 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2); | |
| 351 | |||
| 352 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | const auto rhs_packed_size = pack_iface.packed_size(N, K, nr, kr, sr, bl, kai_dt_bf16); |
| 353 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | Buffer rhs_packed(rhs_packed_size); |
| 354 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | const auto rhs_packed_offset = pack_iface.get_packed_offset(rect_start_row, K, nr, kr, sr, bl, kai_dt_bf16); |
| 355 | |||
| 356 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | const size_t rhs_stride_bytes = round_up_division(K, 2); // bytes per row |
| 357 |
2/4✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1078 times.
✗ Branch 3 not taken.
|
1078 | const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(kai_dt_bf16); |
| 358 | 1078 | const size_t scale_offset = rect_start_row * scales_stride_bytes; | |
| 359 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | const size_t rhs_offset = pack_iface.get_offset(rect_start_row, rhs_stride_bytes); |
| 360 | |||
| 361 | 1078 | kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{}; | |
| 362 | 1078 | params.lhs_zero_point = 1; | |
| 363 | 1078 | params.rhs_zero_point = 8; | |
| 364 | 1078 | params.scale_dt = kai_dt_bf16; | |
| 365 | |||
| 366 | // Apply optional s0s1 -> s1s0 nibble swap. | ||
| 367 | 1078 | const Buffer* rhs_qsu4_ptr = &rhs_qsu4s1s0; | |
| 368 | 1078 | Buffer rhs_qsu4_converted; | |
| 369 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | if (rhs_s0s1_input) { |
| 370 | ✗ | rhs_qsu4_converted = convert_s0s1_s1s0(rhs_qsu4s1s0); | |
| 371 | ✗ | rhs_qsu4_ptr = &rhs_qsu4_converted; | |
| 372 | ✗ | } | |
| 373 | |||
| 374 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | abi_check( |
| 375 | 1078 | pack_iface.run_pack, 1, rect_width, K, nr, kr, sr, bl, | |
| 376 | 1078 | reinterpret_cast<const uint8_t*>(rhs_qsu4_ptr->data() + rhs_offset), rhs_stride_bytes, bias, | |
| 377 | 1078 | rhs_scales.data() + scale_offset, scales_stride_bytes, rhs_packed.data() + rhs_packed_offset, 0, ¶ms); | |
| 378 | |||
| 379 | 1078 | return {std::move(rhs_packed), rhs_packed_offset}; | |
| 380 | 1078 | } | |
| 381 | |||
| 382 | // Executes F32-only RHS KxN packing helper (wrapper around BF16-scaled helper for clarity) | ||
| 383 | 1058 | std::tuple<Buffer, size_t> pack_rhs_qsi4c32p_kxn( | |
| 384 | const size_t N, const size_t K, const size_t nr, const size_t kr, const size_t sr, const size_t bl, | ||
| 385 | const Buffer& rhs_values_qsi4, const Buffer& biases, const size_t bias_offset, const Buffer& rhs_scales, | ||
| 386 | const size_t rect_start_row, const size_t rect_width, const bool use_ps1s0) { | ||
| 387 | 1058 | return pack_rhs_qsi4c32pscalebf16( | |
| 388 | 1058 | N, K, nr, kr, sr, bl, rhs_values_qsi4, biases, bias_offset, rhs_scales, RhsPackType::KxN, rect_start_row, | |
| 389 | 1058 | rect_width, use_ps1s0); | |
| 390 | } | ||
| 391 | |||
| 392 | /// Executes the vectorized RHS packing micro-kernels for block length of 4 bytes or 8 bytes | ||
| 393 | 448 | std::tuple<Buffer, size_t> pack_rhs_qsi4c32pscalebf16_neon( | |
| 394 | const size_t N, const size_t K, const size_t nr, const size_t kr, const size_t sr, const size_t bl, | ||
| 395 | const Buffer& rhs_values_qsi4, const Buffer& biases, const size_t bias_offset, const Buffer& rhs_scales, | ||
| 396 | const RhsPackType pack_type, const size_t rect_start_row, const size_t rect_width) { | ||
| 397 | − | KAI_ASSUME_ALWAYS(kr / sr == 8 || kr / sr == 4); | |
| 398 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 448 times.
|
448 | const size_t width = pack_type == RhsPackType::KxN ? N : K; |
| 399 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 448 times.
|
448 | const size_t height = pack_type == RhsPackType::KxN ? K : N; |
| 400 | 448 | constexpr kai_datatype scale_dt = kai_dt_bf16; | |
| 401 | |||
| 402 | 448 | const size_t rhs_stride = round_up_multiple(width, 2); | |
| 403 | 448 | const size_t rhs_stride_bytes = round_up_division(width, 2); | |
| 404 | 448 | const size_t scales_stride_bytes = round_up_division(K, bl) * kai_get_datatype_size_in_bytes(scale_dt); | |
| 405 | |||
| 406 | − | KAI_ASSUME_ALWAYS(rhs_values_qsi4.size() == round_up_division(height * rhs_stride, 2)); | |
| 407 | |||
| 408 | 448 | const auto rhs_values_qsu4 = cast_qsu4_qsi4(rhs_values_qsi4.data(), rhs_values_qsi4.size() * 2); | |
| 409 |
1/2✓ Branch 0 taken 448 times.
✗ Branch 1 not taken.
|
448 | const size_t dst_bytes_total = round_up_division(height * rhs_stride, 2); |
| 410 | 448 | const size_t dst_bytes_total_safe = dst_bytes_total + rhs_stride_bytes + 8; | |
| 411 | 448 | const auto rhs_qsu4 = | |
| 412 |
1/2✓ Branch 0 taken 448 times.
✗ Branch 1 not taken.
|
448 | pad_row<UInt4>(rhs_values_qsu4.data(), height, width, width, rhs_stride_bytes * 2, dst_bytes_total_safe); |
| 413 | |||
| 414 | 448 | const size_t scale_offset = rect_start_row * scales_stride_bytes; | |
| 415 | |||
| 416 | 448 | size_t imp_packed_rhs_size_neon = 0; | |
| 417 | 448 | size_t rhs_packed_offset_neon = 0; | |
| 418 | 448 | size_t rhs_offset_neon = 0; | |
| 419 | |||
| 420 |
2/2✓ Branch 0 taken 128 times.
✓ Branch 1 taken 320 times.
|
448 | if (kr / sr == 8) { |
| 421 | 320 | imp_packed_rhs_size_neon = | |
| 422 |
1/2✓ Branch 0 taken 320 times.
✗ Branch 1 not taken.
|
320 | kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); |
| 423 |
1/2✓ Branch 0 taken 320 times.
✗ Branch 1 not taken.
|
320 | rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( |
| 424 | 320 | rect_start_row, K, nr, kr, sr, bl, scale_dt); | |
| 425 | 320 | rhs_offset_neon = | |
| 426 |
1/2✓ Branch 0 taken 320 times.
✗ Branch 1 not taken.
|
320 | kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes); |
| 427 | 320 | } else { | |
| 428 | 128 | imp_packed_rhs_size_neon = | |
| 429 |
1/2✓ Branch 0 taken 128 times.
✗ Branch 1 not taken.
|
128 | kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(N, K, nr, kr, sr, bl, scale_dt); |
| 430 |
1/2✓ Branch 0 taken 128 times.
✗ Branch 1 not taken.
|
128 | rhs_packed_offset_neon = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( |
| 431 | 128 | rect_start_row, K, nr, kr, sr, bl, scale_dt); | |
| 432 | 128 | rhs_offset_neon = | |
| 433 |
1/2✓ Branch 0 taken 128 times.
✗ Branch 1 not taken.
|
128 | kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon(rect_start_row, rhs_stride_bytes); |
| 434 | } | ||
| 435 | |||
| 436 | 448 | kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{}; | |
| 437 | 448 | params.lhs_zero_point = 1; | |
| 438 | 448 | params.rhs_zero_point = 8; | |
| 439 | 448 | params.scale_dt = scale_dt; | |
| 440 | |||
| 441 |
1/2✓ Branch 0 taken 448 times.
✗ Branch 1 not taken.
|
448 | Buffer imp_packed_rhs_neon(imp_packed_rhs_size_neon); |
| 442 |
2/2✓ Branch 0 taken 128 times.
✓ Branch 1 taken 320 times.
|
448 | if (kr / sr == 8) { |
| 443 |
1/2✓ Branch 0 taken 320 times.
✗ Branch 1 not taken.
|
320 | kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon( |
| 444 | 320 | 1, rect_width /* n */, K, nr, kr, sr, bl, | |
| 445 | 320 | reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset_neon), rhs_stride_bytes, | |
| 446 | 320 | reinterpret_cast<const float*>(biases.data() + bias_offset), rhs_scales.data() + scale_offset, | |
| 447 | 320 | scales_stride_bytes, imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); | |
| 448 | 320 | } else { | |
| 449 |
1/2✓ Branch 0 taken 128 times.
✗ Branch 1 not taken.
|
128 | kai_run_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon( |
| 450 | 128 | 1, rect_width /* n */, K, nr, kr, sr, bl, | |
| 451 | 128 | reinterpret_cast<const uint8_t*>(rhs_qsu4.data() + rhs_offset_neon), rhs_stride_bytes, | |
| 452 | 128 | reinterpret_cast<const float*>(biases.data() + bias_offset), rhs_scales.data() + scale_offset, | |
| 453 | 128 | scales_stride_bytes, imp_packed_rhs_neon.data() + rhs_packed_offset_neon, 0, ¶ms); | |
| 454 | } | ||
| 455 | 448 | return {std::move(imp_packed_rhs_neon), rhs_packed_offset_neon}; | |
| 456 | 448 | } | |
| 457 | |||
| 458 | 48840 | std::string test_description( | |
| 459 | const std::string& name, const RhsPackType rhs_pack_type, const MatMulShape& shape, const size_t bl, | ||
| 460 | const MatrixPortion& portion, const float clamp_keep_ratio) { | ||
| 461 | // Remove redundant prefix to make output easier to read | ||
| 462 | 48840 | std::string clean_name = name; | |
| 463 |
1/2✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
|
48840 | const std::string prefix = "kai_matmul_clamp_"; |
| 464 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 48840 times.
|
48840 | if (clean_name.rfind(prefix, 0) == 0) { // starts with prefix |
| 465 |
1/2✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
|
48840 | clean_name.erase(0, prefix.length()); |
| 466 | 48840 | } | |
| 467 | |||
| 468 |
1/2✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
|
48840 | std::ostringstream sstream; |
| 469 |
5/10✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 48840 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 48840 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 48840 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 48840 times.
✗ Branch 9 not taken.
|
97680 | sstream << test_description(clean_name, shape, portion, /*bias=*/false, clamp_keep_ratio) << "__BL_" << bl << "__" |
| 470 |
3/4✓ Branch 0 taken 23472 times.
✓ Branch 1 taken 25368 times.
✓ Branch 2 taken 48840 times.
✗ Branch 3 not taken.
|
48840 | << ((rhs_pack_type == RhsPackType::NxK) ? "NxK" : "KxN"); |
| 471 | |||
| 472 |
1/2✓ Branch 0 taken 48840 times.
✗ Branch 1 not taken.
|
48840 | return sstream.str(); |
| 473 | 48840 | } | |
| 474 | |||
| 475 | 1884 | struct TestData { | |
| 476 | 942 | size_t M{}, N{}, K{}, bl{}; | |
| 477 | |||
| 478 | 1884 | Rect rect{0, 0, 0, 0}; | |
| 479 | |||
| 480 | Buffer lhs; | ||
| 481 | Buffer rhs; | ||
| 482 | Buffer bias; | ||
| 483 | |||
| 484 | Buffer rhs_quant; | ||
| 485 | Buffer rhs_scales; | ||
| 486 | |||
| 487 | Buffer lhs_packed; | ||
| 488 | 942 | size_t lhs_packed_offset{}; | |
| 489 | |||
| 490 | Buffer ref_dst_clamped; | ||
| 491 | Range<float> clamp; | ||
| 492 | }; | ||
| 493 | |||
| 494 | using BF16QMatMulRefKey = std::tuple< | ||
| 495 | MatMulShape, // shape | ||
| 496 | size_t, // bl | ||
| 497 | size_t, // mr | ||
| 498 | size_t, // nr | ||
| 499 | size_t, // kr | ||
| 500 | size_t, // sr | ||
| 501 | size_t, size_t, size_t, size_t, // rect.start_row, rect.start_col, rect.height, rect.width | ||
| 502 | RhsPackType, // rhs_pack_type | ||
| 503 | float // clamp_keep_ratio | ||
| 504 | >; | ||
| 505 | |||
| 506 | 384 | struct BF16TestData { | |
| 507 | 192 | size_t M{}, N{}, K{}, bl{}; | |
| 508 | 384 | Rect rect{0, 0, 0, 0}; | |
| 509 | |||
| 510 | Buffer lhs_bf16; // Original BF16 LHS (kept for completeness) | ||
| 511 | Buffer bias; // Biases (FP32) | ||
| 512 | Buffer rhs_quant; // QSI4 quantized RHS (possibly transposed to match pack type) | ||
| 513 | Buffer rhs_scales; // BF16 per-block scales | ||
| 514 | |||
| 515 | Buffer lhs_packed; // Packed LHS buffer (BF16 dynamic quant + pack) | ||
| 516 | 192 | size_t lhs_packed_offset{}; // Offset for rect.start_row | |
| 517 | |||
| 518 | 384 | Range<float> clamp{}; // Clamp range used for matmul | |
| 519 | Buffer ref_dst_bf16; // Reference DST in BF16 (clamped) | ||
| 520 | }; | ||
| 521 | |||
| 522 | } // anonymous namespace | ||
| 523 | |||
| 524 | using QMatmulClampF32ParamT = std::tuple<size_t, bool, MatMulShape, size_t, MatrixPortion, RhsPackType, float>; | ||
| 525 | |||
| 526 | 16088 | class QMatMulClampF32Test : public ::testing::TestWithParam<QMatmulClampF32ParamT> { | |
| 527 | struct TestParams { | ||
| 528 | const UkernelMatmulPackVariant< | ||
| 529 | kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>* | ||
| 530 | variant; | ||
| 531 | size_t variant_index; | ||
| 532 | MatMulShape matmul_shape; | ||
| 533 | size_t bl; | ||
| 534 | MatrixPortion portion; | ||
| 535 | RhsPackType rhs_pack_type; | ||
| 536 | Rect rect; | ||
| 537 | float clamp_keep_ratio; | ||
| 538 | bool is_sme2; | ||
| 539 | |||
| 540 | 24132 | TestParams() : | |
| 541 | 16088 | variant(nullptr), | |
| 542 | 16088 | variant_index(0), | |
| 543 | 16088 | matmul_shape{0, 0, 0}, | |
| 544 | 16088 | bl(32), | |
| 545 | 16088 | portion(0, 0, 1, 1), | |
| 546 | 16088 | rhs_pack_type(RhsPackType::NxK), | |
| 547 | 16088 | rect(0, 0, 0, 0), | |
| 548 | 16088 | clamp_keep_ratio(0.8F), | |
| 549 | 24132 | is_sme2(false) { | |
| 550 | 24132 | } | |
| 551 | |||
| 552 | TestParams( | ||
| 553 | const UkernelMatmulPackVariant< | ||
| 554 | kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>* | ||
| 555 | variant, | ||
| 556 | const size_t v_idx, const MatMulShape& shape, const size_t bl, const MatrixPortion p, const RhsPackType r, | ||
| 557 | const Rect& rect, const float clamp_keep_ratio) : | ||
| 558 | variant(variant), | ||
| 559 | variant_index(v_idx), | ||
| 560 | matmul_shape(shape), | ||
| 561 | bl(bl), | ||
| 562 | portion(p), | ||
| 563 | rhs_pack_type(r), | ||
| 564 | rect(rect), | ||
| 565 | clamp_keep_ratio(clamp_keep_ratio), | ||
| 566 | is_sme2(false) { | ||
| 567 | } | ||
| 568 | }; | ||
| 569 | |||
| 570 | TestParams params; | ||
| 571 | |||
| 572 | protected: | ||
| 573 | static const TestData& test_data(); | ||
| 574 | 15304 | void SetupCommonForParam() { | |
| 575 | 15304 | TestWithParam::SetUp(); | |
| 576 |
2/2✓ Branch 0 taken 15164 times.
✓ Branch 1 taken 140 times.
|
15304 | if (std::get<1>(GetParam())) { // is_gemm |
| 577 | 15164 | SetupCommon(get_f32_gemm_variants()); | |
| 578 | 15164 | } else { | |
| 579 | 140 | SetupCommon(get_f32_gemv_variants()); | |
| 580 | } | ||
| 581 | 15304 | } | |
| 582 | |||
| 583 | [[nodiscard]] const TestParams& GetParams() const { | ||
| 584 | return params; | ||
| 585 | } | ||
| 586 | 30608 | TestParams& GetParams() { | |
| 587 | 30608 | return params; | |
| 588 | } | ||
| 589 | |||
| 590 | 16088 | void SetUp() override { | |
| 591 | // Gate CPU features before computing kernel interface params (which may touch unsupported instructions). | ||
| 592 | 16088 | const auto& param = GetParam(); | |
| 593 | 16088 | const size_t variant_index = std::get<0>(param); | |
| 594 | 16088 | const bool is_gemm = std::get<1>(param); | |
| 595 | 32176 | const auto& variant = | |
| 596 |
2/2✓ Branch 0 taken 15808 times.
✓ Branch 1 taken 280 times.
|
16088 | is_gemm ? get_f32_gemm_variants().at(variant_index) : get_f32_gemv_variants().at(variant_index); |
| 597 | |||
| 598 |
3/4✓ Branch 0 taken 16088 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 15304 times.
✓ Branch 3 taken 784 times.
|
16088 | if (variant.ukernel.fn_is_supported && !variant.ukernel.fn_is_supported()) { |
| 599 |
3/6✓ Branch 0 taken 784 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 784 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 784 times.
✗ Branch 5 not taken.
|
784 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 600 | return; | ||
| 601 | } | ||
| 602 | |||
| 603 | // Safe to compute aligned params/rect now. | ||
| 604 | 15304 | SetupCommonForParam(); | |
| 605 | 15304 | const auto& p = GetParams(); | |
| 606 | |||
| 607 | // GEMV vs GEMM constraints (after params are set) | ||
| 608 |
2/2✓ Branch 0 taken 15164 times.
✓ Branch 1 taken 140 times.
|
15304 | if (!is_gemm) { |
| 609 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 140 times.
|
140 | if (p.matmul_shape.m != 1) { |
| 610 | ✗ | GTEST_SKIP() << "GEMV requires M=1"; | |
| 611 | return; | ||
| 612 | } | ||
| 613 |
2/4✓ Branch 0 taken 140 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 140 times.
|
140 | if (p.rect.height() != 1 || p.rect.start_row() != 0) { |
| 614 | ✗ | GTEST_SKIP() << "GEMV portion invalid, rect height != 1 or start_row != 0"; | |
| 615 | return; | ||
| 616 | } | ||
| 617 | 140 | } | |
| 618 | 16088 | } | |
| 619 | |||
| 620 | template <size_t ArrN> | ||
| 621 | 15304 | void SetupCommon( | |
| 622 | const std::array< | ||
| 623 | UkernelMatmulPackVariant< | ||
| 624 | kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel, kai_qai8dxp_pack_functions, kai_qsi4c32p_pack_functions>, | ||
| 625 | ArrN>& variants) { | ||
| 626 | 153040 | const auto& [variant_index, is_gemm, shape, bl, portion, rhs_dir, clamp_keep_ratio] = GetParam(); | |
| 627 | 30608 | const auto& variant = variants.at(variant_index); | |
| 628 | |||
| 629 | 15304 | params.variant = &variant; | |
| 630 | 15304 | params.variant_index = variant_index; | |
| 631 | |||
| 632 | // Compute aligned portion rect once | ||
| 633 | 15304 | const size_t m_step = variant.ukernel.interface.get_m_step(); | |
| 634 | 15304 | const size_t n_step = variant.ukernel.interface.get_n_step(); | |
| 635 | 45912 | const Rect rect = portion.compute_portion(shape.m, shape.n, m_step, n_step); | |
| 636 | |||
| 637 | 15304 | params.matmul_shape = shape; | |
| 638 | 15304 | params.bl = bl; | |
| 639 | 15304 | params.portion = portion; | |
| 640 | 15304 | params.rhs_pack_type = rhs_dir; | |
| 641 | 15304 | params.rect = rect; | |
| 642 | 15304 | params.clamp_keep_ratio = clamp_keep_ratio; | |
| 643 | 15304 | params.is_sme2 = | |
| 644 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 15164 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 140 times.
|
15304 | (variant.ukernel.name.data() != nullptr && std::strstr(variant.ukernel.name.data(), "sme2") != nullptr); |
| 645 | 15304 | } | |
| 646 | }; | ||
| 647 | |||
| 648 | using F32QMatMulRefKey = std::tuple< | ||
| 649 | MatMulShape, // shape | ||
| 650 | size_t, // bl | ||
| 651 | size_t, // mr | ||
| 652 | size_t, // kr | ||
| 653 | size_t, // sr | ||
| 654 | size_t, // rect_start_row | ||
| 655 | size_t, // rect_start_col | ||
| 656 | size_t, // rect_height | ||
| 657 | size_t, // rect_width | ||
| 658 | RhsPackType, // rhs_pack_type | ||
| 659 | int, // clamp_pct | ||
| 660 | const void* // lhs_pack_key | ||
| 661 | >; | ||
| 662 | |||
| 663 | template <> | ||
| 664 | 942 | TestData ReferenceGenerator<F32QMatMulRefKey, TestData>::generate_reference(const F32QMatMulRefKey& test_id) { | |
| 665 | 1884 | TestData ref{}; | |
| 666 | |||
| 667 | 8478 | const auto& [shape, bl, mr, kr, sr, rect_start_row, rect_start_col, rect_height, rect_width, rhs_pack_type, clamp_pct, lhs_pack_key] = | |
| 668 | 942 | test_id; | |
| 669 | KAI_UNUSED(lhs_pack_key); | ||
| 670 | 1884 | const float clamp_keep_ratio = static_cast<float>(clamp_pct) / 100.0F; | |
| 671 |
5/10✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 942 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 942 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 942 times.
✗ Branch 9 not taken.
|
4710 | const Rect rect(rect_start_row, rect_start_col, rect_height, rect_width); |
| 672 | |||
| 673 | 942 | ref.M = shape.m; | |
| 674 | 942 | ref.N = shape.n; | |
| 675 | 942 | ref.K = shape.k; | |
| 676 | 942 | ref.bl = bl; | |
| 677 | 942 | ref.rect = rect; | |
| 678 | |||
| 679 | // Creates a unique seed for the test data. | ||
| 680 |
8/16✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 942 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 942 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 942 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 942 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 942 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 942 times.
✗ Branch 15 not taken.
|
2826 | const auto key = std::string("F32QMatMulRefKey:") + std::to_string(ref.M) + "x" + std::to_string(ref.N) + "x" + |
| 681 |
10/18✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 942 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 942 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 942 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 942 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 462 times.
✓ Branch 13 taken 480 times.
✓ Branch 14 taken 942 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 942 times.
✗ Branch 17 not taken.
|
1884 | std::to_string(ref.K) + "_" + std::to_string(bl) + "_" + ((rhs_pack_type == RhsPackType::NxK) ? "NxK" : "KxN") + |
| 682 |
2/4✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
|
942 | "_" + std::to_string(clamp_keep_ratio); |
| 683 |
1/2✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
|
942 | auto& feed = seed_stream(key); |
| 684 | |||
| 685 |
2/4✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
|
942 | ref.lhs = fill_random<float>(ref.M * ref.K, feed()); |
| 686 |
2/4✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
|
942 | ref.rhs = fill_random<float>(ref.N * ref.K, feed()); |
| 687 |
2/4✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 942 times.
✗ Branch 3 not taken.
|
942 | ref.bias = fill_random<float>(ref.N, feed()); |
| 688 | |||
| 689 | // Dynamic LHS quantization (reference only). | ||
| 690 |
1/2✓ Branch 0 taken 942 times.
✗ Branch 1 not taken.
|
942 | QuantizationInfo lhs_qinfo{}; |
| 691 | lhs_qinfo.quant_width = ref.K; | ||
| 692 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 693 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 694 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 695 | auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref.lhs.data(), DataType::FP32, ref.M, ref.K, lhs_qinfo); | ||
| 696 | |||
| 697 | // Dynamic RHS quantization to QSI4 with BF16 block scales. | ||
| 698 | QuantizationInfo rhs_qinfo{}; | ||
| 699 | rhs_qinfo.quant_width = bl; | ||
| 700 | rhs_qinfo.dst_type = DataType::QSI4; | ||
| 701 | rhs_qinfo.scale_type = DataType::BF16; | ||
| 702 | auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref.rhs.data(), DataType::FP32, ref.N, ref.K, rhs_qinfo); | ||
| 703 | |||
| 704 | ref.rhs_quant = std::move(ref_rhs_quant); | ||
| 705 | ref.rhs_scales = std::move(rhs_qoutputs.scales); | ||
| 706 | |||
| 707 | const bool transposed = (rhs_pack_type == RhsPackType::NxK); | ||
| 708 | const size_t width = transposed ? ref.K : ref.N; | ||
| 709 | const size_t height = transposed ? ref.N : ref.K; | ||
| 710 | |||
| 711 | const size_t qsi4_stride = round_up_multiple(width, 2); | ||
| 712 | const size_t qsi4_size_bytes = round_up_division(height * qsi4_stride, 2); | ||
| 713 | |||
| 714 | if (!transposed) { | ||
| 715 | ref.rhs_quant = | ||
| 716 | transpose_with_padding<Int4>(ref.rhs_quant.data(), ref.N, ref.K, ref.K, qsi4_stride, qsi4_size_bytes); | ||
| 717 | } | ||
| 718 | |||
| 719 | Buffer ref_dst_noclamp; | ||
| 720 | if (transposed) { | ||
| 721 | ref_dst_noclamp = | ||
| 722 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
| 723 | ref.M, ref.N, ref.K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), | ||
| 724 | 1, ref.K, ref.rhs_quant.data(), ref.rhs_scales.data(), nullptr, 1, bl, ref.bias.data(), nullptr, | ||
| 725 | nullptr, 1); | ||
| 726 | } else { | ||
| 727 | ref_dst_noclamp = matmul_nt_nt_quantized< | ||
| 728 | int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
| 729 | ref.M, ref.N, ref.K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), 1, | ||
| 730 | ref.K, ref.rhs_quant.data(), ref.rhs_scales.data(), nullptr, 1, bl, ref.bias.data(), nullptr, nullptr, 1); | ||
| 731 | } | ||
| 732 | |||
| 733 | const auto [cmin, cmax] = find_clamp_range<float>(ref_dst_noclamp.data(), ref.M * ref.N, clamp_keep_ratio); | ||
| 734 | ref.clamp = {cmin, cmax}; | ||
| 735 | ref.ref_dst_clamped = clamp<float>(ref_dst_noclamp.data(), ref.M * ref.N, cmin, cmax); | ||
| 736 | |||
| 737 | // Pack LHS once for this key. | ||
| 738 | const size_t lhs_stride_bytes = ref.K * sizeof(float); | ||
| 739 | constexpr kai_qai8dxp_pack_functions lhs_iface{ | ||
| 740 | kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32, | ||
| 741 | kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32, | ||
| 742 | kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32, | ||
| 743 | kai_run_lhs_quant_pack_qai8dxp_f32, | ||
| 744 | }; | ||
| 745 | |||
| 746 | auto [lhs_packed, lhs_packed_offset] = pack_lhs_qai8dxp( | ||
| 747 | lhs_iface, ref.M, ref.K, mr, kr, sr, ref.lhs, lhs_stride_bytes, rect.start_row(), rect.height()); | ||
| 748 | |||
| 749 | ref.lhs_packed = std::move(lhs_packed); | ||
| 750 | ref.lhs_packed_offset = lhs_packed_offset; | ||
| 751 | |||
| 752 | return ref; | ||
| 753 | ✗ | } | |
| 754 | |||
| 755 | 48264 | [[maybe_unused]] static void PrintTo(const QMatmulClampF32ParamT& param, std::ostream* os) { | |
| 756 | 193056 | const auto& [variant_idx, is_gemm, shape, bl, portion, rhs_pack_type, clamp_keep_ratio] = param; | |
| 757 |
1/2✓ Branch 0 taken 32176 times.
✗ Branch 1 not taken.
|
96528 | const auto name = std::string( |
| 758 |
2/2✓ Branch 0 taken 47424 times.
✓ Branch 1 taken 840 times.
|
48264 | (is_gemm ? get_f32_gemm_variants().at(variant_idx).ukernel.name |
| 759 | 1680 | : get_f32_gemv_variants().at(variant_idx).ukernel.name)); | |
| 760 |
5/10✓ Branch 0 taken 48264 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 48264 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 48264 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 48264 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 48264 times.
✗ Branch 9 not taken.
|
193056 | *os << test_description(name, rhs_pack_type, shape, bl, portion, clamp_keep_ratio); |
| 761 | 48264 | } | |
| 762 | |||
| 763 | 2488 | const TestData& QMatMulClampF32Test::test_data() { | |
| 764 | 2488 | const auto& param = GetParam(); | |
| 765 | 2488 | const size_t variant_index = std::get<0>(param); | |
| 766 | 2488 | const bool is_gemm = std::get<1>(param); | |
| 767 | 2488 | const MatMulShape& shape = std::get<2>(param); | |
| 768 | 2488 | const size_t bl = std::get<3>(param); | |
| 769 | 2488 | const MatrixPortion& portion = std::get<4>(param); | |
| 770 | 2488 | const RhsPackType rhs_pack_type = std::get<5>(param); | |
| 771 | 2488 | const float clamp_keep_ratio = std::get<6>(param); | |
| 772 | |||
| 773 | 4976 | const auto& variant = | |
| 774 |
2/2✓ Branch 0 taken 2468 times.
✓ Branch 1 taken 20 times.
|
2488 | is_gemm ? get_f32_gemm_variants().at(variant_index) : get_f32_gemv_variants().at(variant_index); |
| 775 | 2488 | const auto& iface = variant.ukernel.interface; | |
| 776 | |||
| 777 | 2488 | const size_t mr = iface.get_mr(); | |
| 778 | 2488 | const size_t kr = iface.get_kr(); | |
| 779 | 2488 | const size_t sr = iface.get_sr(); | |
| 780 | 2488 | const size_t m_step = iface.get_m_step(); | |
| 781 | 2488 | const size_t n_step = iface.get_n_step(); | |
| 782 | 2488 | const Rect rect = portion.compute_portion(shape.m, shape.n, m_step, n_step); | |
| 783 | |||
| 784 | 2488 | const int clamp_pct = static_cast<int>(clamp_keep_ratio * 100 + 0.5F); | |
| 785 | |||
| 786 | 4976 | const F32QMatMulRefKey key{ | |
| 787 | 2488 | shape, | |
| 788 | bl, | ||
| 789 | mr, | ||
| 790 | kr, | ||
| 791 | sr, | ||
| 792 | 2488 | rect.start_row(), | |
| 793 | 2488 | rect.start_col(), | |
| 794 | 2488 | rect.height(), | |
| 795 | 2488 | rect.width(), | |
| 796 | rhs_pack_type, | ||
| 797 | clamp_pct, | ||
| 798 | 2488 | reinterpret_cast<const void*>(variant.lhs_pack_interface.run_pack)}; | |
| 799 | |||
| 800 | 4976 | return getV<F32QMatMulRefKey, TestData>(key); | |
| 801 | 2488 | } | |
| 802 | |||
| 803 | using MatMulTestParams_withBL_withRHSPackType = | ||
| 804 | std::tuple<size_t, MatMulShape, size_t, MatrixPortion, RhsPackType, float>; | ||
| 805 | |||
| 806 | 576 | [[maybe_unused]] static void PrintTo(const MatMulTestParams_withBL_withRHSPackType& param, std::ostream* os) { | |
| 807 | 576 | const size_t variant_idx = std::get<0>(param); | |
| 808 | 576 | const MatMulShape shape = std::get<1>(param); | |
| 809 | 576 | const size_t bl = std::get<2>(param); | |
| 810 | 576 | const MatrixPortion portion = std::get<3>(param); | |
| 811 | 576 | const RhsPackType rhs_pack_type = std::get<4>(param); | |
| 812 |
1/2✓ Branch 0 taken 384 times.
✗ Branch 1 not taken.
|
576 | const std::string name{get_bf16_gemm_variants().at(variant_idx).name}; |
| 813 | 576 | const float clamp_keep_ratio = std::get<5>(param); | |
| 814 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | *os << test_description(name, rhs_pack_type, shape, bl, portion, clamp_keep_ratio); |
| 815 | 576 | } | |
| 816 | |||
| 817 | template <> | ||
| 818 | 192 | BF16TestData ReferenceGenerator<BF16QMatMulRefKey, BF16TestData>::generate_reference(const BF16QMatMulRefKey& test_id) { | |
| 819 | 576 | BF16TestData ref{}; | |
| 820 | |||
| 821 | 192 | const MatMulShape shape = std::get<0>(test_id); | |
| 822 | 192 | const size_t bl = std::get<1>(test_id); | |
| 823 | 192 | const size_t mr = std::get<2>(test_id); | |
| 824 | 192 | const size_t nr = std::get<3>(test_id); | |
| 825 | KAI_UNUSED(nr); | ||
| 826 | 192 | const size_t kr = std::get<4>(test_id); | |
| 827 | 192 | const size_t sr = std::get<5>(test_id); | |
| 828 | 192 | const size_t rect_start_row = std::get<6>(test_id); | |
| 829 | 192 | const size_t rect_start_col = std::get<7>(test_id); | |
| 830 | 192 | const size_t rect_height = std::get<8>(test_id); | |
| 831 | 192 | const size_t rect_width = std::get<9>(test_id); | |
| 832 | 192 | const RhsPackType rhs_pack_type = std::get<10>(test_id); | |
| 833 | 192 | const float clamp_keep_ratio = std::get<11>(test_id); | |
| 834 | |||
| 835 | 192 | ref.M = shape.m; | |
| 836 | 192 | ref.N = shape.n; | |
| 837 | 192 | ref.K = shape.k; | |
| 838 | 192 | ref.bl = bl; | |
| 839 |
1/2✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
|
192 | ref.rect = Rect(rect_start_row, rect_start_col, rect_height, rect_width); |
| 840 | |||
| 841 | // Creates a unique seed for the test data. | ||
| 842 |
8/16✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 192 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 192 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 192 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 192 times.
✗ Branch 15 not taken.
|
576 | const auto key = std::string("BF16QMatMulRefKey:") + std::to_string(ref.M) + "x" + std::to_string(ref.N) + "x" + |
| 843 |
9/16✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 192 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 96 times.
✓ Branch 11 taken 96 times.
✓ Branch 12 taken 192 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 192 times.
✗ Branch 15 not taken.
|
384 | std::to_string(ref.K) + "_" + std::to_string(bl) + "_" + ((rhs_pack_type == RhsPackType::NxK) ? "NxK" : "KxN") + |
| 844 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | "_" + std::to_string(clamp_keep_ratio); |
| 845 |
1/2✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
|
192 | auto& feed = seed_stream(key); |
| 846 | |||
| 847 | // Inputs | ||
| 848 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | ref.lhs_bf16 = fill_random<BFloat16<false>>(ref.M * ref.K, feed()); |
| 849 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | Buffer const ref_rhs = fill_random<float>(ref.N * ref.K, feed()); |
| 850 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | ref.bias = fill_random<float>(ref.N, feed()); |
| 851 | |||
| 852 | // Cast BF16 LHS to FP32 for reference quantization | ||
| 853 | 192 | const Buffer ref_lhs = | |
| 854 |
1/2✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
|
192 | cast<float, BFloat16<false>>(ref.lhs_bf16.data(), ref.lhs_bf16.size() * 8 / size_in_bits<BFloat16<false>>); |
| 855 | |||
| 856 | // Reference quantizations for LHS and RHS | ||
| 857 |
1/2✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
|
192 | QuantizationInfo lhs_qinfo{}; |
| 858 | lhs_qinfo.quant_width = ref.K; | ||
| 859 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 860 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 861 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 862 | auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, ref.M, ref.K, lhs_qinfo); | ||
| 863 | |||
| 864 | QuantizationInfo rhs_qinfo{}; | ||
| 865 | rhs_qinfo.quant_width = bl; | ||
| 866 | rhs_qinfo.dst_type = DataType::QSI4; | ||
| 867 | rhs_qinfo.scale_type = DataType::BF16; | ||
| 868 | auto [ref_rhs_quant, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, ref.N, ref.K, rhs_qinfo); | ||
| 869 | |||
| 870 | // Prepare RHS layout per pack type | ||
| 871 | const bool transposed = (rhs_pack_type == RhsPackType::NxK); | ||
| 872 | const size_t width = transposed ? ref.K : ref.N; | ||
| 873 | const size_t height = transposed ? ref.N : ref.K; | ||
| 874 | |||
| 875 | const size_t qsi4_stride = round_up_multiple(width, 2); | ||
| 876 | const size_t qsi4_size_bytes = round_up_division(height * qsi4_stride, 2); | ||
| 877 | |||
| 878 | ref.rhs_quant = std::move(ref_rhs_quant); | ||
| 879 | if (!transposed) { | ||
| 880 | ref.rhs_quant = | ||
| 881 | transpose_with_padding<Int4>(ref.rhs_quant.data(), ref.N, ref.K, ref.K, qsi4_stride, qsi4_size_bytes); | ||
| 882 | } | ||
| 883 | ref.rhs_scales = std::move(rhs_qoutputs.scales); | ||
| 884 | |||
| 885 | // Compute reference destination (float), clamp, and cast to BF16 | ||
| 886 | Buffer ref_dst_noclamp; | ||
| 887 | if (transposed) { | ||
| 888 | ref_dst_noclamp = | ||
| 889 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
| 890 | ref.M, ref.N, ref.K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), | ||
| 891 | 1, ref.K, ref.rhs_quant.data(), ref.rhs_scales.data(), nullptr, 1, bl, ref.bias.data(), nullptr, | ||
| 892 | nullptr, 1); | ||
| 893 | } else { | ||
| 894 | ref_dst_noclamp = matmul_nt_nt_quantized< | ||
| 895 | int8_t, float, int32_t, Int4, BFloat16<false>, int32_t, float, float, int32_t, float>( | ||
| 896 | ref.M, ref.N, ref.K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), 1, | ||
| 897 | ref.K, ref.rhs_quant.data(), ref.rhs_scales.data(), nullptr, 1, bl, ref.bias.data(), nullptr, nullptr, 1); | ||
| 898 | } | ||
| 899 | |||
| 900 | const auto [clamp_min, clamp_max] = | ||
| 901 | find_clamp_range<float>(ref_dst_noclamp.data(), ref.M * ref.N, clamp_keep_ratio); | ||
| 902 | ref.clamp = {clamp_min, clamp_max}; | ||
| 903 | const Buffer ref_dst_float = clamp<float>(ref_dst_noclamp.data(), ref.M * ref.N, clamp_min, clamp_max); | ||
| 904 | ref.ref_dst_bf16 = | ||
| 905 | cast<BFloat16<false>, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>); | ||
| 906 | |||
| 907 | // Pack LHS once (BF16 packer) | ||
| 908 | const size_t imp_packed_lhs_size = | ||
| 909 | kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(ref.M, ref.K, mr, kr, sr); | ||
| 910 | ref.lhs_packed = Buffer(imp_packed_lhs_size); | ||
| 911 | |||
| 912 | const size_t lhs_stride = ref.K * sizeof(uint16_t); | ||
| 913 | const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(rect_start_row, lhs_stride); | ||
| 914 | ref.lhs_packed_offset = | ||
| 915 | kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(rect_start_row, ref.K, mr, kr, sr); | ||
| 916 | |||
| 917 | kai_run_lhs_quant_pack_qai8dxp_bf16_neon( | ||
| 918 | rect_height, ref.K, mr, kr, sr, 0, ref.lhs_bf16.data() + lhs_offset, lhs_stride, | ||
| 919 | reinterpret_cast<uint8_t*>(ref.lhs_packed.data()) + ref.lhs_packed_offset); | ||
| 920 | |||
| 921 | return ref; | ||
| 922 | ✗ | } | |
| 923 | |||
| 924 | /// Verifies RHS packed offsets (KxN vs NxK) match each other and the matmul interface at n_step. | ||
| 925 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5514 | TEST_P(QMatMulClampF32Test, OffsetRHS) { |
| 926 | 2136 | const auto& p = GetParams(); | |
| 927 | 2136 | const auto fn_supported = p.variant->ukernel.fn_is_supported; | |
| 928 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
|
2136 | if (fn_supported && !fn_supported()) { |
| 929 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
| 930 | return; | ||
| 931 | } | ||
| 932 | |||
| 933 | 2136 | const auto& ukernel = p.variant->ukernel; | |
| 934 | 2136 | const size_t K = p.matmul_shape.k; | |
| 935 | 2136 | const size_t bl = p.bl; | |
| 936 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto nr = ukernel.interface.get_nr(); |
| 937 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto kr = ukernel.interface.get_kr(); |
| 938 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto sr = ukernel.interface.get_sr(); |
| 939 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto n_step = ukernel.interface.get_n_step(); |
| 940 | |||
| 941 | 2136 | const auto rhs_packed_offset_kxn = | |
| 942 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(n_step, K, nr, kr, sr, bl, kai_dt_bf16); |
| 943 |
2/4✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
|
4272 | const auto rhs_packed_offset_kxn_ps1s0 = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon( |
| 944 | 2136 | n_step, K, nr, kr, sr, bl, kai_dt_bf16); | |
| 945 | 2136 | const auto rhs_packed_offset_nxk = | |
| 946 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n_step, K, nr, kr, sr, bl, kai_dt_bf16); |
| 947 | 2136 | const auto rhs_packed_offset_nxk_ps1s0_nrx4 = | |
| 948 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32ps1s0nrx4_qsu4c32s1s0_neon( |
| 949 | 2136 | n_step, K, nr, kr, sr, bl, kai_dt_bf16); | |
| 950 | |||
| 951 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(rhs_packed_offset_kxn, rhs_packed_offset_kxn_ps1s0); |
| 952 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(rhs_packed_offset_kxn_ps1s0, rhs_packed_offset_nxk); |
| 953 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(rhs_packed_offset_nxk, rhs_packed_offset_nxk_ps1s0_nrx4); |
| 954 | |||
| 955 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto rhs_matmul_offset = ukernel.interface.get_rhs_packed_offset(n_step, K, bl); |
| 956 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(rhs_packed_offset_kxn, rhs_matmul_offset); |
| 957 | 2136 | } | |
| 958 | |||
| 959 | /// Verifies LHS packed offset matches the matmul interface at m_step. | ||
| 960 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5514 | TEST_P(QMatMulClampF32Test, OffsetLHS) { |
| 961 | 2136 | const auto& p = GetParams(); | |
| 962 | 2136 | const auto fn_supported = p.variant->ukernel.fn_is_supported; | |
| 963 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
|
2136 | if (fn_supported && !fn_supported()) { |
| 964 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
| 965 | return; | ||
| 966 | } | ||
| 967 | |||
| 968 | 2136 | const auto& ukernel = p.variant->ukernel; | |
| 969 | 2136 | const size_t K = p.matmul_shape.k; | |
| 970 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto mr = ukernel.interface.get_mr(); |
| 971 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto kr = ukernel.interface.get_kr(); |
| 972 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto sr = ukernel.interface.get_sr(); |
| 973 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto m_step = ukernel.interface.get_m_step(); |
| 974 | |||
| 975 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(m_step, K, mr, kr, sr); |
| 976 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto lhs_matmul_offset = ukernel.interface.get_lhs_packed_offset(m_step, K); |
| 977 | |||
| 978 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
| 979 | 2136 | } | |
| 980 | |||
| 981 | /// Verifies the kernel’s get_dst_offset computes row/col addressing correctly at tile-aligned starts: | ||
| 982 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5514 | TEST_P(QMatMulClampF32Test, OffsetDst) { |
| 983 | 2136 | const auto& p = GetParams(); | |
| 984 | 2136 | const auto fn_supported = p.variant->ukernel.fn_is_supported; | |
| 985 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
|
2136 | if (fn_supported && !fn_supported()) { |
| 986 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
| 987 | return; | ||
| 988 | } | ||
| 989 | |||
| 990 | 2136 | const auto& ukernel = p.variant->ukernel; | |
| 991 | 2136 | const size_t M = p.matmul_shape.m; | |
| 992 | 2136 | const size_t N = p.matmul_shape.n; | |
| 993 | |||
| 994 | 2136 | const auto dst_stride_row = N * sizeof(float); | |
| 995 | 2136 | constexpr auto dst_stride_col = sizeof(float); | |
| 996 | |||
| 997 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto m_step = ukernel.interface.get_m_step(); |
| 998 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto n_step = ukernel.interface.get_n_step(); |
| 999 | |||
| 1000 |
5/18✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
|
2136 | ASSERT_TRUE(m_step % ukernel.interface.get_mr() == 0); |
| 1001 |
5/18✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
|
2136 | ASSERT_TRUE(n_step % ukernel.interface.get_nr() == 0); |
| 1002 | |||
| 1003 |
2/2✓ Branch 0 taken 2116 times.
✓ Branch 1 taken 20 times.
|
2136 | const size_t m_idx = (M > m_step) ? m_step : 0; |
| 1004 |
2/2✓ Branch 0 taken 2094 times.
✓ Branch 1 taken 42 times.
|
2136 | const size_t n_idx = (N > n_step) ? n_step : 0; |
| 1005 | |||
| 1006 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto off00 = ukernel.interface.get_dst_offset(0, 0, dst_stride_row); |
| 1007 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(off00, 0U); |
| 1008 | |||
| 1009 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto off10 = ukernel.interface.get_dst_offset(m_idx, 0, dst_stride_row); |
| 1010 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(off10, m_idx * dst_stride_row); |
| 1011 | |||
| 1012 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto off01 = ukernel.interface.get_dst_offset(0, n_idx, dst_stride_row); |
| 1013 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(off01, n_idx * dst_stride_col); |
| 1014 | |||
| 1015 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto off11 = ukernel.interface.get_dst_offset(m_idx, n_idx, dst_stride_row); |
| 1016 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(off11, m_idx * dst_stride_row + n_idx * dst_stride_col); |
| 1017 | 2136 | } | |
| 1018 | |||
| 1019 | /// Sanity-checks kernel interface parameters (mr/nr/kr/sr and step alignment). | ||
| 1020 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5514 | TEST_P(QMatMulClampF32Test, KernelInvariants) { |
| 1021 | 2136 | const auto& p = GetParams(); | |
| 1022 | 2136 | const auto fn_supported = p.variant->ukernel.fn_is_supported; | |
| 1023 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
|
2136 | if (fn_supported && !fn_supported()) { |
| 1024 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
| 1025 | return; | ||
| 1026 | } | ||
| 1027 | |||
| 1028 | 2136 | const auto& ukernel = p.variant->ukernel; | |
| 1029 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto mr = ukernel.interface.get_mr(); |
| 1030 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto nr = ukernel.interface.get_nr(); |
| 1031 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto kr = ukernel.interface.get_kr(); |
| 1032 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto sr = ukernel.interface.get_sr(); |
| 1033 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto m_step = ukernel.interface.get_m_step(); |
| 1034 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto n_step = ukernel.interface.get_n_step(); |
| 1035 | |||
| 1036 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_GT(mr, 0U); |
| 1037 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_GT(nr, 0U); |
| 1038 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_GT(kr, 0U); |
| 1039 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_GT(sr, 0U); |
| 1040 | |||
| 1041 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(m_step % mr, 0U); |
| 1042 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(n_step % nr, 0U); |
| 1043 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(kr % sr, 0U); |
| 1044 | 2136 | } | |
| 1045 | |||
| 1046 | /// Verifies RHS row stride using difference of offsets equals the layout formula. | ||
| 1047 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5514 | TEST_P(QMatMulClampF32Test, RhsStrideByDifference) { |
| 1048 | 2136 | const auto& p = GetParams(); | |
| 1049 | 2136 | const auto fn_supported = p.variant->ukernel.fn_is_supported; | |
| 1050 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
|
2136 | if (fn_supported && !fn_supported()) { |
| 1051 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
| 1052 | return; | ||
| 1053 | } | ||
| 1054 | |||
| 1055 | 2136 | const auto& ukernel = p.variant->ukernel; | |
| 1056 | 2136 | const size_t K = p.matmul_shape.k; | |
| 1057 | 2136 | const size_t bl = p.bl; | |
| 1058 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto nr = ukernel.interface.get_nr(); |
| 1059 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto n_step = ukernel.interface.get_n_step(); |
| 1060 | |||
| 1061 | // Stride by difference using kernel offsets at 0 and n_step. | ||
| 1062 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const size_t off0 = ukernel.interface.get_rhs_packed_offset(0, K, bl); |
| 1063 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const size_t off1 = ukernel.interface.get_rhs_packed_offset(n_step, K, bl); |
| 1064 | 2136 | const size_t stride_by_diff = off1 - off0; | |
| 1065 | |||
| 1066 | // Expected stride formula for qsi4c32p with BF16 scales: | ||
| 1067 | // nr * ( num_blocks * (bl/2 + 2) + 4 /*rsum*/ + 4 /*bias*/ ) | ||
| 1068 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const size_t k_internal = round_up_multiple(K, 32); |
| 1069 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const size_t num_blocks = round_up_division(k_internal, bl); |
| 1070 | 2136 | const size_t bytes_per_block = (bl / 2) + 2; // int4 values + BF16 scale | |
| 1071 | 2136 | const size_t expected_stride = nr * (num_blocks * bytes_per_block) + nr * 4 + nr * 4; | |
| 1072 | |||
| 1073 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_EQ(stride_by_diff, expected_stride); |
| 1074 | 2136 | } | |
| 1075 | |||
| 1076 | /// Validation of the packed group slice against a reconstructed reference. | ||
| 1077 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5514 | TEST_P(QMatMulClampF32Test, LhsPackBufferMatchesReference) { |
| 1078 | 2136 | const auto& p = GetParams(); | |
| 1079 | 2136 | const auto fn_supported = p.variant->ukernel.fn_is_supported; | |
| 1080 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
|
2136 | if (fn_supported && !fn_supported()) { |
| 1081 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
| 1082 | return; | ||
| 1083 | } | ||
| 1084 | 2136 | const auto& uk = p.variant->ukernel; | |
| 1085 | |||
| 1086 | 2136 | const size_t M = p.matmul_shape.m; | |
| 1087 | 2136 | const size_t K = p.matmul_shape.k; | |
| 1088 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const size_t mr = uk.interface.get_mr(); |
| 1089 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const size_t kr = uk.interface.get_kr(); |
| 1090 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const size_t sr = uk.interface.get_sr(); |
| 1091 | |||
| 1092 | 2136 | const size_t k_block_len = kr / sr; | |
| 1093 | 2136 | const size_t k_internal = ((K + 31) / 32) * 32; | |
| 1094 | |||
| 1095 | 2136 | const size_t i8_region_bytes = mr * k_internal; | |
| 1096 | 2136 | const size_t neg_zero_point_region_bytes = mr * sizeof(int32_t); | |
| 1097 | 2136 | const size_t recip_scale_region_bytes = mr * sizeof(float); | |
| 1098 | 2136 | const size_t group_stride = i8_region_bytes + neg_zero_point_region_bytes + recip_scale_region_bytes; | |
| 1099 | |||
| 1100 | 2136 | constexpr size_t rect_start_row = 0; | |
| 1101 | 2136 | constexpr size_t rect_height = 1; | |
| 1102 | |||
| 1103 |
4/8✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
|
2136 | const auto ref_lhs = fill_random<float>(M * K, seed_stream(current_test_key())()); |
| 1104 | |||
| 1105 | 2136 | const size_t lhs_stride = K * sizeof(float); | |
| 1106 |
2/4✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
|
4272 | std::tuple<Buffer, size_t> pack_pair = pack_lhs_qai8dxp( |
| 1107 | 2136 | p.variant->lhs_pack_interface, M, K, mr, kr, sr, ref_lhs, lhs_stride, rect_start_row, rect_height); | |
| 1108 | |||
| 1109 | 2136 | Buffer const lhs_packed = std::move(std::get<0>(pack_pair)); | |
| 1110 | 2136 | const size_t lhs_packed_off = std::get<1>(pack_pair); | |
| 1111 | |||
| 1112 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | QuantizationInfo lhs_qinfo{}; |
| 1113 | lhs_qinfo.quant_width = K; | ||
| 1114 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 1115 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 1116 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 1117 | auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo); | ||
| 1118 | |||
| 1119 | Buffer const expected(group_stride, 0); | ||
| 1120 | std::byte* expected_bytes = expected.data(); | ||
| 1121 | |||
| 1122 | // Build reference layout into `expected` | ||
| 1123 | constexpr size_t lane_row_idx = rect_start_row; | ||
| 1124 | const size_t lane = lane_row_idx % mr; | ||
| 1125 | const size_t ref_row_base = lane_row_idx * K; | ||
| 1126 | const auto pad_val = read_array<int8_t>(ref_lhs_quant.data(), ref_row_base + (K - 1)); | ||
| 1127 | |||
| 1128 | size_t ref_idx = 0; | ||
| 1129 | const size_t num_blocks_internal = k_internal / k_block_len; | ||
| 1130 | |||
| 1131 | for (size_t b = 0; b < num_blocks_internal; ++b) { | ||
| 1132 | const size_t block_base = b * mr * k_block_len; | ||
| 1133 | const size_t lane_offset = block_base + lane * k_block_len; | ||
| 1134 | |||
| 1135 | for (size_t i = 0; i < k_block_len; ++i) { | ||
| 1136 | const size_t dst_index = lane_offset + i; | ||
| 1137 | const bool in_range = ref_idx < K; | ||
| 1138 | |||
| 1139 | const int8_t val = in_range ? read_array<int8_t>(ref_lhs_quant.data(), ref_row_base + ref_idx) : pad_val; | ||
| 1140 | |||
| 1141 | write_array<int8_t>(expected_bytes, dst_index, val); | ||
| 1142 | |||
| 1143 | if (in_range) { | ||
| 1144 | ++ref_idx; | ||
| 1145 | } | ||
| 1146 | } | ||
| 1147 | } | ||
| 1148 | |||
| 1149 | // Header (per-lane): neg_zero_point, recip_scale | ||
| 1150 | const size_t neg_zero_point_elem_base = i8_region_bytes / sizeof(int32_t); | ||
| 1151 | const size_t recip_scale_elem_base = (i8_region_bytes + neg_zero_point_region_bytes) / sizeof(float); | ||
| 1152 | |||
| 1153 | write_array<int32_t>( | ||
| 1154 | expected_bytes, neg_zero_point_elem_base + lane, | ||
| 1155 | -read_array<int32_t>(lhs_qoutputs.zero_points.data(), lane_row_idx)); | ||
| 1156 | |||
| 1157 | write_array<float>( | ||
| 1158 | expected_bytes, recip_scale_elem_base + lane, read_array<float>(lhs_qoutputs.scales.data(), lane_row_idx)); | ||
| 1159 | |||
| 1160 | // Validate packed buffer vs reference | ||
| 1161 | KAI_ASSUME_ALWAYS(lhs_packed_off + group_stride <= lhs_packed.size()); | ||
| 1162 | |||
| 1163 | // Int8 region: allow ±1 LSB | ||
| 1164 | for (size_t i = 0; i < i8_region_bytes; ++i) { | ||
| 1165 | const auto g = read_array<int8_t>(lhs_packed.data(), lhs_packed_off + i); | ||
| 1166 | const auto e = read_array<int8_t>(expected.data(), i); | ||
| 1167 | const int dq = static_cast<int>(g) - static_cast<int>(e); | ||
| 1168 | EXPECT_LE(std::abs(dq), 1) << "int8 mismatch at byte " << i << " (got=" << static_cast<int>(g) | ||
| 1169 | << ", exp=" << static_cast<int>(e) << ", dq=" << dq << ")"; | ||
| 1170 | } | ||
| 1171 | |||
| 1172 | // Region offsets (in bytes) | ||
| 1173 | const size_t neg_zero_point_offset = i8_region_bytes; | ||
| 1174 | const size_t recip_scale_offset = neg_zero_point_offset + neg_zero_point_region_bytes; | ||
| 1175 | |||
| 1176 | // neg_zero_point (exact) | ||
| 1177 | for (size_t hdr_lane = 0; hdr_lane < mr; ++hdr_lane) { | ||
| 1178 | const auto gzp = read_array<int32_t>( | ||
| 1179 | lhs_packed.data(), lhs_packed_off / sizeof(int32_t) + (neg_zero_point_offset / sizeof(int32_t)) + hdr_lane); | ||
| 1180 | const auto ezp = read_array<int32_t>(expected.data(), (neg_zero_point_offset / sizeof(int32_t)) + hdr_lane); | ||
| 1181 | EXPECT_EQ(gzp, ezp) << "neg_zp mismatch at lane " << hdr_lane; | ||
| 1182 | } | ||
| 1183 | |||
| 1184 | // recip_scale (near-equal) | ||
| 1185 | for (size_t hdr_lane = 0; hdr_lane < mr; ++hdr_lane) { | ||
| 1186 | const auto gsc = read_array<float>( | ||
| 1187 | lhs_packed.data(), lhs_packed_off / sizeof(float) + (recip_scale_offset / sizeof(float)) + hdr_lane); | ||
| 1188 | const auto esc = read_array<float>(expected.data(), (recip_scale_offset / sizeof(float)) + hdr_lane); | ||
| 1189 | EXPECT_NEAR(gsc, esc, 1e-5F) << "recip_scale mismatch at lane " << hdr_lane; | ||
| 1190 | } | ||
| 1191 | ✗ | } | |
| 1192 | |||
| 1193 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5514 | TEST_P(QMatMulClampF32Test, EndToEnd) { |
| 1194 | 2136 | const auto& p = GetParams(); | |
| 1195 | 2136 | const auto fn_supported = p.variant->ukernel.fn_is_supported; | |
| 1196 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 2136 times.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
|
2136 | if (fn_supported && !fn_supported()) { |
| 1197 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
| 1198 | return; | ||
| 1199 | } | ||
| 1200 | 2136 | const auto& ukernel = p.variant->ukernel; | |
| 1201 | |||
| 1202 | 2136 | const size_t bl = p.bl; | |
| 1203 | 2136 | const RhsPackType rhs_pack_type = p.rhs_pack_type; | |
| 1204 | |||
| 1205 | − | KAI_ASSUME_ALWAYS(bl % 32 == 0); | |
| 1206 | |||
| 1207 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto nr = ukernel.interface.get_nr(); |
| 1208 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto kr = ukernel.interface.get_kr(); |
| 1209 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto sr = ukernel.interface.get_sr(); |
| 1210 | |||
| 1211 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto n_step = ukernel.interface.get_n_step(); |
| 1212 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_TRUE(n_step % nr == 0); |
| 1213 | |||
| 1214 | 2136 | const auto rect = p.rect; | |
| 1215 |
5/18✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
|
2136 | ASSERT_GT(rect.height(), 0U); |
| 1216 |
5/18✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
|
2136 | ASSERT_GT(rect.width(), 0U); |
| 1217 | |||
| 1218 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto& data = test_data(); |
| 1219 | |||
| 1220 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto rhs_start_col = rect.start_col(); |
| 1221 | 2136 | const size_t bias_offset_bytes = rhs_start_col * sizeof(float); | |
| 1222 | |||
| 1223 | 2136 | Buffer imp_packed_rhs; | |
| 1224 | 2136 | size_t rhs_packed_offset = 0; | |
| 1225 |
2/2✓ Branch 0 taken 1058 times.
✓ Branch 1 taken 1078 times.
|
2136 | if (rhs_pack_type == RhsPackType::NxK) { |
| 1226 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | const float* bias_ptr = reinterpret_cast<const float*>(data.bias.data()) + rhs_start_col; |
| 1227 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | std::tie(imp_packed_rhs, rhs_packed_offset) = pack_rhs_qsi4c32p_nxk( |
| 1228 | 1078 | p.variant->rhs_pack_interface, data.N, data.K, nr, kr, sr, bl, data.rhs_quant, bias_ptr, data.rhs_scales, | |
| 1229 |
1/2✓ Branch 0 taken 1078 times.
✗ Branch 1 not taken.
|
1078 | rhs_start_col, rect.width(), p.variant->rhs_s0s1_input); |
| 1230 | 1078 | } else { | |
| 1231 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1058 times.
|
1058 | if ((rhs_start_col % 2) != 0) { |
| 1232 | ✗ | GTEST_SKIP() << "KxN RHS pack requires even N-start index"; | |
| 1233 | return; | ||
| 1234 | } | ||
| 1235 |
1/2✓ Branch 0 taken 1058 times.
✗ Branch 1 not taken.
|
1058 | std::tie(imp_packed_rhs, rhs_packed_offset) = pack_rhs_qsi4c32p_kxn( |
| 1236 | 1058 | data.N, data.K, nr, kr, sr, bl, data.rhs_quant, data.bias, bias_offset_bytes, data.rhs_scales, | |
| 1237 |
1/2✓ Branch 0 taken 1058 times.
✗ Branch 1 not taken.
|
1058 | rhs_start_col, rect.width(), p.is_sme2); |
| 1238 | } | ||
| 1239 | |||
| 1240 |
5/18✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
|
2136 | ASSERT_EQ(rhs_packed_offset, ukernel.interface.get_rhs_packed_offset(rhs_start_col, data.K, bl)); |
| 1241 | |||
| 1242 | // Destination buffer and offsets | ||
| 1243 | 2136 | const auto dst_stride_row = data.N * sizeof(float); | |
| 1244 | 2136 | constexpr auto dst_stride_col = sizeof(float); | |
| 1245 |
2/4✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
|
2136 | const auto dst_offset = ukernel.interface.get_dst_offset(rect.start_row(), rhs_start_col, dst_stride_row); |
| 1246 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | const auto imp_dst_size = ukernel.interface.get_dst_size(data.M, data.N); |
| 1247 |
5/18✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2136 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2136 times.
|
2136 | ASSERT_EQ(imp_dst_size, data.ref_dst_clamped.size()); |
| 1248 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | Buffer const imp_dst(imp_dst_size); |
| 1249 | |||
| 1250 | // Run matmul | ||
| 1251 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | abi_check( |
| 1252 |
2/4✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
|
2136 | ukernel.interface.run_matmul, rect.height(), rect.width(), data.K, bl, |
| 1253 |
2/4✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
|
2136 | data.lhs_packed.data() + data.lhs_packed_offset, imp_packed_rhs.data() + rhs_packed_offset, |
| 1254 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, dst_stride_col, data.clamp.min, |
| 1255 | 2136 | data.clamp.max); | |
| 1256 | |||
| 1257 |
1/2✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
|
2136 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); |
| 1258 | 2136 | const auto dst_format = DataFormat(DataType::FP32); | |
| 1259 | 4272 | const auto success = | |
| 1260 |
3/6✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
|
2136 | compare(imp_dst.data(), data.ref_dst_clamped.data(), dst_format, data.M, data.N, rect, handler); |
| 1261 |
4/16✓ Branch 0 taken 2136 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2136 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2136 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2136 times.
|
2136 | ASSERT_TRUE(success); |
| 1262 | 2136 | } | |
| 1263 | |||
| 1264 | /// RHS vectorised packer format is s16s0 this is not relevant for sme2 kernels | ||
| 1265 | class NeonRhsPackF32Test : public QMatMulClampF32Test {}; | ||
| 1266 | |||
| 1267 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
886 | TEST_P(NeonRhsPackF32Test, EndToEndNeonRhsPack) { |
| 1268 | 352 | const auto& p = GetParams(); | |
| 1269 | 352 | const auto fn_supported = p.variant->ukernel.fn_is_supported; | |
| 1270 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 352 times.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
|
352 | if (fn_supported && !fn_supported()) { |
| 1271 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
| 1272 | return; | ||
| 1273 | } | ||
| 1274 | 352 | const auto& ukernel = p.variant->ukernel; | |
| 1275 | |||
| 1276 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | const size_t mr = ukernel.interface.get_mr(); |
| 1277 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | const size_t nr = ukernel.interface.get_nr(); |
| 1278 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | const size_t kr = ukernel.interface.get_kr(); |
| 1279 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | const size_t sr = ukernel.interface.get_sr(); |
| 1280 |
5/18✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 352 times.
|
352 | ASSERT_EQ(ukernel.interface.get_m_step() % mr, 0U); |
| 1281 |
5/18✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 352 times.
|
352 | ASSERT_EQ(ukernel.interface.get_n_step() % nr, 0U); |
| 1282 | |||
| 1283 |
4/6✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 128 times.
✓ Branch 3 taken 224 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 128 times.
|
352 | if (p.rhs_pack_type != RhsPackType::NxK || (kr / sr != 8 && kr / sr != 4)) { |
| 1284 | ✗ | GTEST_SKIP() << "RHS packers not applicable"; | |
| 1285 | } | ||
| 1286 |
5/18✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 352 times.
|
352 | ASSERT_GT(p.rect.height(), 0U); |
| 1287 |
5/18✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 352 times.
|
352 | ASSERT_GT(p.rect.width(), 0U); |
| 1288 | |||
| 1289 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | const auto& data = test_data(); |
| 1290 | |||
| 1291 | // LHS pack | ||
| 1292 | 352 | const size_t lhs_stride_bytes = data.K * sizeof(float); | |
| 1293 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
1056 | auto [imp_packed_lhs, lhs_packed_offset] = pack_lhs_qai8dxp( |
| 1294 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | p.variant->lhs_pack_interface, data.M, data.K, mr, kr, sr, data.lhs, lhs_stride_bytes, p.rect.start_row(), |
| 1295 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | p.rect.height()); |
| 1296 |
7/22✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 352 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 352 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 352 times.
|
704 | ASSERT_EQ(lhs_packed_offset, ukernel.interface.get_lhs_packed_offset(p.rect.start_row(), data.K)); |
| 1297 | |||
| 1298 | // RHS pack | ||
| 1299 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | const size_t rhs_start_row = p.rect.start_col(); |
| 1300 | 352 | const size_t bias_offset = rhs_start_row * sizeof(float); | |
| 1301 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
704 | const auto [imp_packed_rhs_neon, rhs_packed_offset_neon] = pack_rhs_qsi4c32pscalebf16_neon( |
| 1302 | 352 | data.N, data.K, nr, kr, sr, p.bl, data.rhs_quant, data.bias, bias_offset, data.rhs_scales, p.rhs_pack_type, | |
| 1303 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | rhs_start_row, p.rect.width()); |
| 1304 | |||
| 1305 |
6/20✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 352 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✓ Branch 19 taken 352 times.
|
704 | ASSERT_EQ(rhs_packed_offset_neon, ukernel.interface.get_rhs_packed_offset(rhs_start_row, data.K, p.bl)); |
| 1306 | |||
| 1307 | 352 | const auto dst_stride_row = data.N * sizeof(float); | |
| 1308 |
2/4✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
|
352 | Buffer const imp_dst(ukernel.interface.get_dst_size(data.M, data.N)); |
| 1309 |
2/4✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
|
352 | const auto dst_offset = ukernel.interface.get_dst_offset(p.rect.start_row(), rhs_start_row, dst_stride_row); |
| 1310 | |||
| 1311 | // Run matmul | ||
| 1312 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | abi_check( |
| 1313 |
2/4✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
|
352 | ukernel.interface.run_matmul, p.rect.height(), p.rect.width(), data.K, p.bl, |
| 1314 |
4/8✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
|
704 | imp_packed_lhs.data() + lhs_packed_offset, imp_packed_rhs_neon.data() + rhs_packed_offset_neon, |
| 1315 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | reinterpret_cast<float*>(imp_dst.data() + dst_offset), dst_stride_row, sizeof(float), data.clamp.min, |
| 1316 | 352 | data.clamp.max); | |
| 1317 | |||
| 1318 |
1/2✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
|
352 | DefaultMismatchHandler handler(0, 0.1, 0, 0.05); |
| 1319 | 352 | const DataFormat dst_format(DataType::FP32); | |
| 1320 |
7/22✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 352 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 352 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 352 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 352 times.
|
352 | ASSERT_TRUE(compare(imp_dst.data(), data.ref_dst_clamped.data(), dst_format, data.M, data.N, p.rect, handler)); |
| 1321 | 352 | } | |
| 1322 | |||
| 1323 | class QMatMulClampBF16Test : public ::testing::TestWithParam<MatMulTestParams_withBL_withRHSPackType> {}; | ||
| 1324 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
486 | TEST_P(QMatMulClampBF16Test, EndToEnd) { |
| 1325 | 2592 | const auto& [variant_index, matmul_shape, bl, portion, rhs_pack_type, clamp_keep_ratio] = GetParam(); | |
| 1326 | 384 | const auto& ukernel_variant = get_bf16_gemm_variants().at(variant_index); | |
| 1327 | |||
| 1328 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 1329 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
| 1330 | } | ||
| 1331 | |||
| 1332 | 384 | const size_t M = matmul_shape.m; | |
| 1333 | 384 | const size_t N = matmul_shape.n; | |
| 1334 | 384 | const size_t K = matmul_shape.k; | |
| 1335 | |||
| 1336 | 192 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 1337 | 192 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 1338 | 192 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 1339 | 192 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 1340 | |||
| 1341 | 192 | const auto m_step = ukernel_variant.interface.get_m_step(); | |
| 1342 |
3/14✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
|
192 | ASSERT_TRUE(m_step % mr == 0); |
| 1343 | |||
| 1344 | 192 | const auto n_step = ukernel_variant.interface.get_n_step(); | |
| 1345 |
3/14✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
|
192 | ASSERT_TRUE(n_step % nr == 0); |
| 1346 | |||
| 1347 | 384 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 1348 |
3/14✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
|
192 | ASSERT_GT(rect.height(), 0U); |
| 1349 |
3/14✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
|
192 | ASSERT_GT(rect.width(), 0U); |
| 1350 | |||
| 1351 | // Cached reference and inputs | ||
| 1352 | 384 | const BF16QMatMulRefKey key{matmul_shape, | |
| 1353 | bl, | ||
| 1354 | mr, | ||
| 1355 | nr, | ||
| 1356 | kr, | ||
| 1357 | sr, | ||
| 1358 | 192 | rect.start_row(), | |
| 1359 | 192 | rect.start_col(), | |
| 1360 | 192 | rect.height(), | |
| 1361 | 192 | rect.width(), | |
| 1362 | rhs_pack_type, | ||
| 1363 | clamp_keep_ratio}; | ||
| 1364 | 192 | const BF16TestData& data = getV<BF16QMatMulRefKey, BF16TestData>(key); | |
| 1365 | |||
| 1366 | // Verify LHS offsets match interface | ||
| 1367 | 192 | const auto lhs_start_row = rect.start_row(); | |
| 1368 | 384 | const auto lhs_packed_offset = | |
| 1369 | 192 | kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); | |
| 1370 | 192 | const auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); | |
| 1371 |
3/14✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 192 times.
|
192 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
| 1372 | |||
| 1373 | // RHS: pack using cached quant/scales/bias | ||
| 1374 | 192 | const size_t rhs_start_row = rect.start_col(); | |
| 1375 | 192 | const size_t bias_offset = rhs_start_row * sizeof(float); | |
| 1376 |
3/4✓ Branch 0 taken 96 times.
✓ Branch 1 taken 96 times.
✓ Branch 2 taken 96 times.
✗ Branch 3 not taken.
|
192 | if (rhs_pack_type == RhsPackType::KxN && (rhs_start_row % 2) != 0) { |
| 1377 | ✗ | GTEST_SKIP() << "KxN RHS pack requires even N-start index"; | |
| 1378 | return; | ||
| 1379 | } | ||
| 1380 | |||
| 1381 | 672 | auto [imp_packed_rhs, rhs_packed_offset] = pack_rhs_qsi4c32pscalebf16( | |
| 1382 | 576 | N, K, nr, kr, sr, bl, data.rhs_quant, data.bias, bias_offset, data.rhs_scales, rhs_pack_type, rhs_start_row, | |
| 1383 | 192 | rect.width(), false); | |
| 1384 | |||
| 1385 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
384 | const auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K, bl); |
| 1386 |
5/18✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 192 times.
|
384 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
| 1387 | |||
| 1388 | // Destination | ||
| 1389 | 192 | const auto dst_stride_row = N * sizeof(uint16_t); | |
| 1390 | 192 | constexpr auto dst_stride_col = sizeof(uint16_t); | |
| 1391 | 384 | const auto dst_offset = | |
| 1392 |
3/6✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
|
192 | ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); |
| 1393 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; |
| 1394 |
4/16✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 192 times.
|
192 | ASSERT_EQ(dst_offset, ref_dst_offset); |
| 1395 | |||
| 1396 |
1/2✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
|
192 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
| 1397 |
5/18✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 192 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 192 times.
|
192 | ASSERT_EQ(imp_dst_size, data.ref_dst_bf16.size()); |
| 1398 |
1/2✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
|
192 | Buffer imp_dst(imp_dst_size); |
| 1399 | |||
| 1400 | // Run matmul | ||
| 1401 |
1/2✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
|
192 | abi_check( |
| 1402 |
2/4✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
|
192 | ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, bl, |
| 1403 |
1/2✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
|
192 | reinterpret_cast<const uint8_t*>(data.lhs_packed.data()) + lhs_matmul_offset, |
| 1404 |
3/6✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
|
384 | reinterpret_cast<const uint8_t*>(imp_packed_rhs.data()) + rhs_matmul_offset, imp_dst.data() + dst_offset, |
| 1405 | 192 | dst_stride_row, dst_stride_col, data.clamp.min, data.clamp.max); | |
| 1406 | |||
| 1407 |
1/2✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
|
192 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
| 1408 | 192 | auto dst_format = DataFormat(DataType::BF16); | |
| 1409 |
3/6✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
|
192 | const auto success = compare(imp_dst.data(), data.ref_dst_bf16.data(), dst_format, M, N, rect, handler); |
| 1410 |
4/16✓ Branch 0 taken 192 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 192 times.
|
192 | ASSERT_TRUE(success); |
| 1411 | |||
| 1412 | // Test vectorized packing micro-kernels, if packing parameters allow | ||
| 1413 |
3/6✓ Branch 0 taken 96 times.
✓ Branch 1 taken 96 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 96 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
192 | if (rhs_pack_type == RhsPackType::NxK && (kr / sr == 8 || kr / sr == 4)) { |
| 1414 |
1/2✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
|
96 | const auto [imp_packed_rhs_neon, rhs_packed_offset_neon] = pack_rhs_qsi4c32pscalebf16_neon( |
| 1415 | 288 | N, K, nr, kr, sr, bl, data.rhs_quant, data.bias, bias_offset, data.rhs_scales, rhs_pack_type, rhs_start_row, | |
| 1416 |
1/2✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
|
96 | rect.width()); |
| 1417 |
5/18✓ Branch 0 taken 96 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 96 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 96 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 96 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 96 times.
|
192 | ASSERT_EQ(rhs_packed_offset_neon, rhs_packed_offset); |
| 1418 | 96 | } | |
| 1419 | 192 | } | |
| 1420 | |||
| 1421 | // clang-format off | ||
| 1422 | |||
| 1423 | /// Portion categories (GEMM/GEMV) | ||
| 1424 | static constexpr std::array gemm_portions{ | ||
| 1425 | MatrixPortion(0, 0, 1, 1), // Full matrix | ||
| 1426 | MatrixPortion(0.4, 0.5, 0.6, 0.8), // Middle block | ||
| 1427 | }; | ||
| 1428 | static constexpr std::array gemv_portions{ | ||
| 1429 | MatrixPortion(0, 0, 1, 1), // Full width | ||
| 1430 | MatrixPortion(0, 0.5, 1, 0.5), // Right half | ||
| 1431 | }; | ||
| 1432 | |||
| 1433 | /// Shape categories (GEMM/GEMV) | ||
| 1434 | |||
| 1435 | /// Small/Odd edge coverage (odd m/n, varied K) | ||
| 1436 | static constexpr std::array gemm_shapes_small_odd{ | ||
| 1437 | MatMulShape{ 17, 25, 64}, | ||
| 1438 | MatMulShape{ 31, 31, 64}, | ||
| 1439 | MatMulShape{ 21, 53, 256}, | ||
| 1440 | MatMulShape{ 35, 27, 320}, | ||
| 1441 | }; | ||
| 1442 | |||
| 1443 | /// Aligned squares (cache-friendly, power-of-two-ish) | ||
| 1444 | static constexpr std::array gemm_shapes_aligned{ | ||
| 1445 | MatMulShape{ 32, 32, 128}, | ||
| 1446 | MatMulShape{ 64, 64, 128}, | ||
| 1447 | MatMulShape{128, 128, 256}, | ||
| 1448 | MatMulShape{192, 192, 384}, | ||
| 1449 | }; | ||
| 1450 | |||
| 1451 | /// Rectangular (skinny/wide), varied K | ||
| 1452 | static constexpr std::array gemm_shapes_rect{ | ||
| 1453 | MatMulShape{ 64, 128, 256}, // wide N | ||
| 1454 | MatMulShape{128, 64, 256}, // tall M | ||
| 1455 | MatMulShape{ 96, 192, 384}, | ||
| 1456 | MatMulShape{160, 96, 320}, | ||
| 1457 | }; | ||
| 1458 | |||
| 1459 | /// Larger/stress (within reason for CI) | ||
| 1460 | static constexpr std::array gemm_shapes_large{ | ||
| 1461 | MatMulShape{128, 160, 320}, | ||
| 1462 | MatMulShape{160, 128, 320}, | ||
| 1463 | MatMulShape{224, 160, 320}, | ||
| 1464 | MatMulShape{160, 224, 320}, | ||
| 1465 | }; | ||
| 1466 | |||
| 1467 | /// GEMV shape categories (F32) | ||
| 1468 | /// M = 1, RHS NxK only in instantiation | ||
| 1469 | |||
| 1470 | /// Small/medium N, diverse K (aligned/odd N) | ||
| 1471 | static constexpr std::array gemv_shapes_small{ | ||
| 1472 | MatMulShape{ 1, 16, 64}, | ||
| 1473 | MatMulShape{ 1, 31, 64}, | ||
| 1474 | MatMulShape{ 1, 128, 256}, | ||
| 1475 | MatMulShape{ 1, 256, 256}, | ||
| 1476 | MatMulShape{ 1, 320, 320}, | ||
| 1477 | }; | ||
| 1478 | |||
| 1479 | /// Larger N bands (bandwidth/cache stress) | ||
| 1480 | static constexpr std::array gemv_shapes_large{ | ||
| 1481 | MatMulShape{ 1, 512, 256}, | ||
| 1482 | MatMulShape{ 1, 640, 320}, | ||
| 1483 | MatMulShape{ 1, 768, 384}, | ||
| 1484 | MatMulShape{ 1, 1024, 256}, | ||
| 1485 | MatMulShape{ 1, 896, 384}, | ||
| 1486 | }; | ||
| 1487 | |||
| 1488 | static constexpr std::array bf16_shapes { | ||
| 1489 | MatMulShape{ 32, 32, 64}, // small aligned | ||
| 1490 | MatMulShape{ 48, 64, 64}, // rectangular (tall K-block reuse) | ||
| 1491 | MatMulShape{ 64, 64, 128}, // aligned square | ||
| 1492 | MatMulShape{ 96, 96, 192}, // larger aligned | ||
| 1493 | MatMulShape{128, 64, 256}, // rectangular (tall M) | ||
| 1494 | MatMulShape{ 17, 25, 64}, // odd sizes (edge behavior) | ||
| 1495 | MatMulShape{ 33, 29, 192}, // odd sizes with larger K | ||
| 1496 | MatMulShape{128, 160, 320}, // larger rectangular | ||
| 1497 | }; | ||
| 1498 | |||
| 1499 | /// Dedicated clamp sweep ratios | ||
| 1500 | static constexpr std::array<float, 3> clamp_keep_ratios_sweep{ | ||
| 1501 | 1.0F, // no clamp | ||
| 1502 | 0.5F, // clamp away 50% | ||
| 1503 | 0.1F, // clamp away 90% | ||
| 1504 | }; | ||
| 1505 | |||
| 1506 |
24/80✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 1344 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 2688 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
4056 | INSTANTIATE_TEST_SUITE_P( |
| 1507 | MatMulGemm_SmallOdd, QMatMulClampF32Test, | ||
| 1508 | testing::Combine( | ||
| 1509 | testing::Range<size_t>(0, get_f32_gemm_variants().size()), | ||
| 1510 | testing::Values(true), | ||
| 1511 | testing::ValuesIn(gemm_shapes_small_odd), | ||
| 1512 | testing::Values(32), | ||
| 1513 | testing::ValuesIn(gemm_portions), | ||
| 1514 | testing::Values(RhsPackType::NxK, RhsPackType::KxN), | ||
| 1515 | testing::Values(0.5F)), | ||
| 1516 | testing::PrintToStringParamName()); | ||
| 1517 | |||
| 1518 |
24/80✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 1344 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 2688 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
4056 | INSTANTIATE_TEST_SUITE_P( |
| 1519 | MatMulGemm_Aligned, QMatMulClampF32Test, | ||
| 1520 | testing::Combine( | ||
| 1521 | testing::Range<size_t>(0, get_f32_gemm_variants().size()), | ||
| 1522 | testing::Values(true), | ||
| 1523 | testing::ValuesIn(gemm_shapes_aligned), | ||
| 1524 | testing::Values(32), | ||
| 1525 | testing::ValuesIn(gemm_portions), | ||
| 1526 | testing::Values(RhsPackType::NxK, RhsPackType::KxN), | ||
| 1527 | testing::Values(0.5F)), | ||
| 1528 | testing::PrintToStringParamName()); | ||
| 1529 | |||
| 1530 |
24/80✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 2688 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 5376 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
8088 | INSTANTIATE_TEST_SUITE_P( |
| 1531 | MatMulGemm_Rect, QMatMulClampF32Test, | ||
| 1532 | testing::Combine( | ||
| 1533 | testing::Range<size_t>(0, get_f32_gemm_variants().size()), | ||
| 1534 | testing::Values(true), | ||
| 1535 | testing::ValuesIn(gemm_shapes_rect), | ||
| 1536 | testing::Values(32, 64), | ||
| 1537 | testing::ValuesIn(gemm_portions), | ||
| 1538 | testing::Values(RhsPackType::NxK, RhsPackType::KxN), | ||
| 1539 | testing::Values(0.5F)), | ||
| 1540 | testing::PrintToStringParamName()); | ||
| 1541 | |||
| 1542 |
24/80✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 1344 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 2688 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
4056 | INSTANTIATE_TEST_SUITE_P( |
| 1543 | MatMulGemm_Large, QMatMulClampF32Test, | ||
| 1544 | testing::Combine( | ||
| 1545 | testing::Range<size_t>(0, get_f32_gemm_variants().size()), | ||
| 1546 | testing::Values(true), | ||
| 1547 | testing::ValuesIn(gemm_shapes_large), | ||
| 1548 | testing::Values(32), | ||
| 1549 | testing::ValuesIn(gemm_portions), | ||
| 1550 | testing::Values(RhsPackType::NxK, RhsPackType::KxN), | ||
| 1551 | testing::Values(0.5F)), | ||
| 1552 | testing::PrintToStringParamName()); | ||
| 1553 | |||
| 1554 |
24/80✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 70 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 140 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
234 | INSTANTIATE_TEST_SUITE_P( |
| 1555 | MatMulGemv_Small, QMatMulClampF32Test, | ||
| 1556 | testing::Combine( | ||
| 1557 | testing::Range<size_t>(0, get_f32_gemv_variants().size()), | ||
| 1558 | testing::Values(false), | ||
| 1559 | testing::ValuesIn(gemv_shapes_small), | ||
| 1560 | testing::Values(32), | ||
| 1561 | testing::ValuesIn(gemv_portions), | ||
| 1562 | testing::Values(RhsPackType::NxK), | ||
| 1563 | testing::Values(0.5F)), | ||
| 1564 | testing::PrintToStringParamName()); | ||
| 1565 | |||
| 1566 |
24/80✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✓ Branch 26 taken 70 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 140 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
234 | INSTANTIATE_TEST_SUITE_P( |
| 1567 | MatMulGemv_Large, QMatMulClampF32Test, | ||
| 1568 | testing::Combine( | ||
| 1569 | testing::Range<size_t>(0, get_f32_gemv_variants().size()), | ||
| 1570 | testing::Values(false), | ||
| 1571 | testing::ValuesIn(gemv_shapes_large), | ||
| 1572 | testing::Values(32), | ||
| 1573 | testing::ValuesIn(gemv_portions), | ||
| 1574 | testing::Values(RhsPackType::NxK), | ||
| 1575 | testing::Values(0.5F)), | ||
| 1576 | testing::PrintToStringParamName()); | ||
| 1577 | |||
| 1578 |
24/80✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 88 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 176 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
270 | INSTANTIATE_TEST_SUITE_P( |
| 1579 | MatMulNeonRhsPackGemm_SmallOdd, NeonRhsPackF32Test, | ||
| 1580 | testing::Combine( | ||
| 1581 | testing::Range<size_t>(0, get_f32_neon_gemm_variants_only().size()), | ||
| 1582 | testing::Values(true), | ||
| 1583 | testing::ValuesIn(gemm_shapes_small_odd), | ||
| 1584 | testing::Values(32), | ||
| 1585 | testing::ValuesIn(gemm_portions), | ||
| 1586 | testing::Values(RhsPackType::NxK), | ||
| 1587 | testing::Values(0.5F)), | ||
| 1588 | testing::PrintToStringParamName()); | ||
| 1589 | |||
| 1590 |
24/80✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 88 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 176 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
270 | INSTANTIATE_TEST_SUITE_P( |
| 1591 | MatMulNeonRhsPackGemm_Aligned, NeonRhsPackF32Test, | ||
| 1592 | testing::Combine( | ||
| 1593 | testing::Range<size_t>(0, get_f32_neon_gemm_variants_only().size()), | ||
| 1594 | testing::Values(true), | ||
| 1595 | testing::ValuesIn(gemm_shapes_aligned), | ||
| 1596 | testing::Values(32), | ||
| 1597 | testing::ValuesIn(gemm_portions), | ||
| 1598 | testing::Values(RhsPackType::NxK), | ||
| 1599 | testing::Values(0.5F)), | ||
| 1600 | testing::PrintToStringParamName()); | ||
| 1601 | |||
| 1602 | static constexpr std::array clamp_sweep_shapes{ | ||
| 1603 | MatMulShape{ 64, 64, 128 }, | ||
| 1604 | MatMulShape{ 64, 128, 256 }, | ||
| 1605 | }; | ||
| 1606 | |||
| 1607 |
26/88✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 12 taken 14 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✓ Branch 20 taken 14 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 22 taken 14 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 7 times.
✓ Branch 26 taken 14 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 14 times.
✓ Branch 28 taken 1008 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✓ Branch 30 taken 2016 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
|
3048 | INSTANTIATE_TEST_SUITE_P( |
| 1608 | MatMulGemm_ClampSweep, QMatMulClampF32Test, | ||
| 1609 | testing::Combine( | ||
| 1610 | testing::Range<size_t>(0, get_f32_gemm_variants().size()), | ||
| 1611 | testing::Values(true), | ||
| 1612 | testing::ValuesIn(clamp_sweep_shapes), | ||
| 1613 | testing::Values(32), | ||
| 1614 | testing::Values(MatrixPortion(0, 0, 1, 1)), | ||
| 1615 | testing::Values(RhsPackType::NxK, RhsPackType::KxN), | ||
| 1616 | testing::ValuesIn(clamp_keep_ratios_sweep)), | ||
| 1617 | testing::PrintToStringParamName()); | ||
| 1618 | |||
| 1619 |
24/80✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 96 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 28 taken 192 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
|
294 | INSTANTIATE_TEST_SUITE_P( |
| 1620 | MatMulBF16_SingleSet, QMatMulClampBF16Test, | ||
| 1621 | testing::Combine( | ||
| 1622 | testing::Range<size_t>(0, get_bf16_gemm_variants().size()), | ||
| 1623 | testing::ValuesIn(bf16_shapes), | ||
| 1624 | testing::Values(32), | ||
| 1625 | testing::Values(MatrixPortion(0, 0, 1, 1)), | ||
| 1626 | testing::Values(RhsPackType::NxK, RhsPackType::KxN), | ||
| 1627 | testing::ValuesIn(clamp_keep_ratios_sweep) | ||
| 1628 | ), | ||
| 1629 | testing::PrintToStringParamName()); | ||
| 1630 | |||
| 1631 | } // namespace kai::test | ||
| 1632 |