test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <cstdlib> | ||
| 13 | #include <limits> | ||
| 14 | #include <string> | ||
| 15 | #include <tuple> | ||
| 16 | |||
| 17 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" | ||
| 18 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h" | ||
| 19 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" | ||
| 20 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot.h" | ||
| 21 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" | ||
| 22 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" | ||
| 23 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" | ||
| 24 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" | ||
| 25 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h" | ||
| 26 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" | ||
| 27 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" | ||
| 28 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h" | ||
| 29 | #include "test/common/abi_checker.hpp" | ||
| 30 | #include "test/common/buffer.hpp" | ||
| 31 | #include "test/common/cache.hpp" | ||
| 32 | #include "test/common/cpu_info.hpp" | ||
| 33 | #include "test/common/matmul_test_common.hpp" | ||
| 34 | #include "test/common/matrix_portion.hpp" | ||
| 35 | #include "test/common/memory.hpp" | ||
| 36 | #include "test/common/printer.hpp" | ||
| 37 | #include "test/common/seed.hpp" | ||
| 38 | #include "test/common/test_suite.hpp" | ||
| 39 | #include "test/reference/clamp.hpp" | ||
| 40 | #include "test/reference/fill.hpp" | ||
| 41 | #include "test/reference/matmul.hpp" | ||
| 42 | #include "test/reference/quantize.hpp" | ||
| 43 | #include "test/reference/transpose.hpp" | ||
| 44 | |||
| 45 | namespace kai::test { | ||
| 46 | |||
| 47 | using F32Qai8Qsi8CacheDataId = std::tuple< | ||
| 48 | MatMulShape, // | ||
| 49 | DataFormat, // lhs format | ||
| 50 | DataFormat, // rhs format | ||
| 51 | DataFormat, // bias format | ||
| 52 | float>; | ||
| 53 | |||
| 54 | struct F32Qai8Qsi8CacheData { | ||
| 55 | Buffer ref_dst_nt_t; | ||
| 56 | Buffer ref_dst_nt_nt; | ||
| 57 | Buffer ref_rhs_qsi8_nt_t; | ||
| 58 | Buffer ref_rhs_qsi8_nt_nt; | ||
| 59 | Buffer ref_rhs_scales; | ||
| 60 | Buffer ref_lhs; | ||
| 61 | Buffer ref_bias; | ||
| 62 | Range<float> clamp_range; | ||
| 63 | }; | ||
| 64 | |||
| 65 | template <> | ||
| 66 | 66 | F32Qai8Qsi8CacheData ReferenceGenerator<F32Qai8Qsi8CacheDataId, F32Qai8Qsi8CacheData>::generate_reference( | |
| 67 | const F32Qai8Qsi8CacheDataId& data_id) { | ||
| 68 | 1056 | auto [shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio] = data_id; | |
| 69 | |||
| 70 | 132 | const size_t M = shape.m; | |
| 71 | 132 | const size_t N = shape.n; | |
| 72 | 132 | const size_t K = shape.k; | |
| 73 | |||
| 74 | // Seed the random generator. | ||
| 75 |
8/16✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 66 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 66 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 33 times.
✗ Branch 15 not taken.
|
198 | const auto key = std::string("F32Qai8Qsi8_cache:") + std::to_string(M) + "x" + std::to_string(N) + "x" + |
| 76 |
8/16✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 66 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 66 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 33 times.
✗ Branch 15 not taken.
|
198 | std::to_string(K) + ":" + std::to_string(static_cast<uint32_t>(lhs_format.data_type())) + ":" + |
| 77 |
5/10✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 33 times.
✗ Branch 9 not taken.
|
231 | std::to_string(static_cast<uint32_t>(rhs_format.data_type())) + ":" + |
| 78 |
7/14✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 66 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 33 times.
✗ Branch 13 not taken.
|
165 | std::to_string(static_cast<uint32_t>(bias_format.data_type())) + ":" + std::to_string(clamp_keep_ratio); |
| 79 |
1/2✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
|
66 | auto& feed = seed_stream(key); |
| 80 | |||
| 81 |
5/10✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
|
264 | Buffer lhs = fill_matrix_random(shape.m, shape.k, lhs_format, feed()); |
| 82 |
5/10✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 66 times.
✗ Branch 9 not taken.
|
264 | Buffer rhs = fill_matrix_random(shape.k, shape.n, rhs_format, feed()); |
| 83 |
4/8✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 66 times.
✗ Branch 7 not taken.
|
198 | Buffer bias = fill_matrix_random(1, shape.n, bias_format, feed()); |
| 84 | |||
| 85 |
1/2✓ Branch 0 taken 66 times.
✗ Branch 1 not taken.
|
66 | QuantizationInfo lhs_qinfo{}; |
| 86 | lhs_qinfo.quant_width = K; | ||
| 87 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 88 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 89 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 90 | const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(lhs.data(), DataType::FP32, M, K, lhs_qinfo); | ||
| 91 | |||
| 92 | QuantizationInfo rhs_qinfo{}; | ||
| 93 | rhs_qinfo.quant_width = K; | ||
| 94 | rhs_qinfo.dst_type = DataType::QSI8; | ||
| 95 | rhs_qinfo.scale_type = DataType::FP32; | ||
| 96 | auto [ref_rhs_quant_t, rhs_qoutputs] = quantize_dynamic(rhs.data(), DataType::FP32, N, K, rhs_qinfo); | ||
| 97 | |||
| 98 | F32Qai8Qsi8CacheData out; | ||
| 99 | |||
| 100 | // Transposed RHS path. | ||
| 101 | const size_t ref_rhs_qsi8_nxk_stride = K; | ||
| 102 | const size_t ref_rhs_qsi8_kxn_stride = N; | ||
| 103 | const size_t ref_rhs_qsi8_kxn_size_bytes = K * ref_rhs_qsi8_kxn_stride; | ||
| 104 | |||
| 105 | // Non-Transposed(kxn) RHS dimensions | ||
| 106 | auto ref_rhs_qsi8 = transpose_with_padding<int8_t>( | ||
| 107 | ref_rhs_quant_t.data(), N, K, ref_rhs_qsi8_nxk_stride, ref_rhs_qsi8_kxn_stride, ref_rhs_qsi8_kxn_size_bytes); | ||
| 108 | |||
| 109 | auto ref_dst_nt_nt = matmul_clamp_nt_nt<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>( | ||
| 110 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K, | ||
| 111 | ref_rhs_qsi8.data(), rhs_qoutputs.scales.data(), nullptr, K, bias.data(), std::numeric_limits<float>::lowest(), | ||
| 112 | std::numeric_limits<float>::max()); | ||
| 113 | |||
| 114 | // Non-transposed RHS path. | ||
| 115 | auto ref_dst_nt_t = matmul_clamp_nt_t<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>( | ||
| 116 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K, | ||
| 117 | ref_rhs_quant_t.data(), rhs_qoutputs.scales.data(), nullptr, K, bias.data(), | ||
| 118 | std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | ||
| 119 | |||
| 120 | // Only need to calculate range once for both, apply clamping | ||
| 121 | const auto [clamp_min, clamp_max] = find_clamp_range(DataType::FP32, ref_dst_nt_t.data(), M * N, clamp_keep_ratio); | ||
| 122 | auto ref_clamped_nt_t = clamp(DataType::FP32, ref_dst_nt_t.data(), M * N, clamp_min, clamp_max); | ||
| 123 | auto ref_clamped_nt_nt = clamp(DataType::FP32, ref_dst_nt_nt.data(), M * N, clamp_min, clamp_max); | ||
| 124 | |||
| 125 | out.ref_rhs_qsi8_nt_nt = std::move(ref_rhs_qsi8); | ||
| 126 | out.ref_rhs_qsi8_nt_t = std::move(ref_rhs_quant_t); | ||
| 127 | out.ref_dst_nt_nt = std::move(ref_clamped_nt_nt); | ||
| 128 | out.ref_dst_nt_t = std::move(ref_clamped_nt_t); | ||
| 129 | out.ref_lhs = std::move(lhs); | ||
| 130 | out.ref_bias = std::move(bias); | ||
| 131 | out.ref_rhs_scales = std::move(rhs_qoutputs.scales); | ||
| 132 | out.clamp_range = {clamp_min, clamp_max}; | ||
| 133 | |||
| 134 | return out; | ||
| 135 | ✗ | } | |
| 136 | |||
| 137 | 3 | static const std::array<UkernelVariant<kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel>, 8> | |
| 138 |
0/4✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
|
11 | variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp = {{ |
| 139 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod), |
| 140 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod}, |
| 141 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod), |
| 142 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", cpu_has_dotprod}, |
| 143 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod), |
| 144 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", cpu_has_dotprod}, |
| 145 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm), |
| 146 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", cpu_has_i8mm}, |
| 147 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot), |
| 148 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot", cpu_has_sme}, |
| 149 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa), |
| 150 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa", cpu_has_sme}, |
| 151 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot), |
| 152 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot", cpu_has_sme2}, |
| 153 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa), |
| 154 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", cpu_has_sme2}, |
| 155 | }}; | ||
| 156 | |||
| 157 | using MatMulClampTestPortionedParams = std::tuple<size_t, MatMulShape, MatrixPortion, float>; | ||
| 158 | |||
| 159 | class MatMulTest_f32_qai8dxp_qsi8cxp : public ::testing::TestWithParam<MatMulClampTestPortionedParams> {}; | ||
| 160 | |||
| 161 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
9246 | TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, Offset_RHS) { |
| 162 | 10164 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 163 | 7392 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index); | |
| 164 | |||
| 165 |
3/4✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✓ Branch 3 taken 924 times.
|
3696 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 166 |
3/6✓ Branch 0 taken 924 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 924 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 924 times.
✗ Branch 5 not taken.
|
924 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 167 | } | ||
| 168 | |||
| 169 | 5544 | const size_t K = matmul_shape.k; | |
| 170 | 2772 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 171 | 2772 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 172 | 2772 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 173 | |||
| 174 | 2772 | auto n_step = ukernel_variant.interface.get_n_step(); | |
| 175 | |||
| 176 | 2772 | auto rhs_packed_offset_kxn = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(n_step, K, nr, kr, sr); | |
| 177 | 2772 | auto rhs_packed_offset_nxk = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(n_step, K, nr, kr, sr); | |
| 178 | |||
| 179 |
3/14✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
|
2772 | ASSERT_EQ(rhs_packed_offset_kxn, rhs_packed_offset_nxk); |
| 180 | |||
| 181 | 2772 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(n_step, K); | |
| 182 |
3/14✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
|
2772 | ASSERT_EQ(rhs_packed_offset_kxn, rhs_matmul_offset); |
| 183 | 3696 | } | |
| 184 | |||
| 185 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
9246 | TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, Offset_LHS) { |
| 186 | 10164 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 187 | 7392 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index); | |
| 188 | |||
| 189 |
3/4✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✓ Branch 3 taken 924 times.
|
3696 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 190 |
3/6✓ Branch 0 taken 924 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 924 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 924 times.
✗ Branch 5 not taken.
|
924 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 191 | } | ||
| 192 | |||
| 193 | 5544 | const size_t K = matmul_shape.k; | |
| 194 | 2772 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 195 | 2772 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 196 | 2772 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 197 | |||
| 198 | 2772 | auto m_step = ukernel_variant.interface.get_m_step(); | |
| 199 | |||
| 200 | 2772 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(m_step, K, mr, kr, sr); | |
| 201 | 2772 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(m_step, K); | |
| 202 | |||
| 203 |
3/14✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
|
2772 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
| 204 | 3696 | } | |
| 205 | |||
| 206 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
9246 | TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_nxk_qsi8cx) { |
| 207 | 24024 | auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 208 | 7392 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index); | |
| 209 | |||
| 210 |
3/4✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✓ Branch 3 taken 924 times.
|
3696 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 211 |
3/6✓ Branch 0 taken 924 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 924 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 924 times.
✗ Branch 5 not taken.
|
924 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 212 | } | ||
| 213 | |||
| 214 | 5544 | const size_t M = matmul_shape.m; | |
| 215 | 5544 | const size_t N = matmul_shape.n; | |
| 216 | 5544 | const size_t K = matmul_shape.k; | |
| 217 | |||
| 218 | 2772 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 219 | 2772 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 220 | 2772 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 221 | 2772 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 222 | |||
| 223 | 5544 | const F32Qai8Qsi8CacheDataId testdata_id = { | |
| 224 | matmul_shape, // | ||
| 225 | 2772 | DataFormat(DataType::FP32), // | |
| 226 | 2772 | DataFormat(DataType::FP32), // | |
| 227 | 2772 | DataFormat(DataType::FP32), clamp_keep_ratio}; | |
| 228 | 2772 | const F32Qai8Qsi8CacheData& testdata = getV<F32Qai8Qsi8CacheDataId, F32Qai8Qsi8CacheData>(testdata_id); | |
| 229 | |||
| 230 | 2772 | const auto& ref_rhs_qsi8 = testdata.ref_rhs_qsi8_nt_t; | |
| 231 | 2772 | const auto& ref_rhs_scales = testdata.ref_rhs_scales; | |
| 232 | 2772 | const auto& ref_dst = testdata.ref_dst_nt_t; | |
| 233 | 2772 | const auto& ref_bias = testdata.ref_bias; | |
| 234 | 2772 | const auto& ref_lhs = testdata.ref_lhs; | |
| 235 | 2772 | auto m_step = ukernel_variant.interface.get_m_step(); | |
| 236 |
3/14✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
|
2772 | ASSERT_TRUE(m_step % mr == 0); |
| 237 | |||
| 238 | 2772 | auto n_step = ukernel_variant.interface.get_n_step(); | |
| 239 |
3/14✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
|
2772 | ASSERT_TRUE(n_step % nr == 0); |
| 240 | |||
| 241 | 5544 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 242 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2772 times.
|
2772 | if (rect.height() == 0 || rect.width() == 0) { |
| 243 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 244 | } | ||
| 245 | |||
| 246 | // Runs the LHS packing micro-kernel. | ||
| 247 | 2772 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); | |
| 248 | 2772 | Buffer imp_packed_lhs(imp_packed_lhs_size); | |
| 249 | |||
| 250 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | const auto lhs_start_row = rect.start_row(); |
| 251 | 2772 | size_t lhs_stride = K * sizeof(float); | |
| 252 | |||
| 253 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); |
| 254 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); |
| 255 | |||
| 256 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | kai_run_lhs_quant_pack_qai8dxp_f32( |
| 257 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | rect.height(), K, mr, kr, sr, 0, reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), lhs_stride, |
| 258 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | imp_packed_lhs.data() + lhs_packed_offset); |
| 259 | |||
| 260 | // Runs the RHS packing micro-kernel. | ||
| 261 | // * Generates the 8-bit signed symmetric quantized input for the micro-kernel. | ||
| 262 | // * Packs the RHS matrix. | ||
| 263 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); |
| 264 | |||
| 265 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
| 266 | 2772 | const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f}; | |
| 267 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon( |
| 268 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()), |
| 269 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | reinterpret_cast<const float*>(ref_bias.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()), |
| 270 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | imp_packed_rhs.data(), 0, ¶ms); |
| 271 | |||
| 272 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | const auto packed_rhs_start_row = rect.start_col(); |
| 273 | 2772 | auto rhs_packed_offset = | |
| 274 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(packed_rhs_start_row, K, nr, kr, sr); |
| 275 | |||
| 276 | 2772 | const auto dst_stride = N * sizeof(float); | |
| 277 |
3/6✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
|
2772 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
| 278 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); |
| 279 |
4/16✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
|
2772 | ASSERT_EQ(dst_offset, ref_dst_offset); |
| 280 | |||
| 281 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | const auto matmul_lhs_packed_offset = ukernel_variant.interface.get_lhs_packed_offset(rect.start_row(), K); |
| 282 |
4/16✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
|
2772 | ASSERT_EQ(lhs_packed_offset, matmul_lhs_packed_offset); |
| 283 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | const auto matmul_rhs_packed_offset = ukernel_variant.interface.get_rhs_packed_offset(rect.start_col(), K); |
| 284 |
4/16✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
|
2772 | ASSERT_EQ(rhs_packed_offset, matmul_rhs_packed_offset); |
| 285 | |||
| 286 | // Runs the GEMM micro-kernel. | ||
| 287 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
| 288 |
5/18✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2772 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2772 times.
|
2772 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
| 289 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | Buffer imp_dst(imp_dst_size); |
| 290 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | abi_check( |
| 291 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, |
| 292 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | imp_packed_lhs.data() + matmul_lhs_packed_offset, imp_packed_rhs.data() + matmul_rhs_packed_offset, |
| 293 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | reinterpret_cast<float*>(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), |
| 294 | 2772 | testdata.clamp_range.min, testdata.clamp_range.max); | |
| 295 | |||
| 296 | // Compares the output of the micro-kernels against the output of the reference implementation. | ||
| 297 |
5/6✓ Branch 0 taken 8880 times.
✓ Branch 1 taken 17760 times.
✓ Branch 2 taken 23868 times.
✓ Branch 3 taken 2772 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2772 times.
|
26640 | for (size_t y = 0; y < rect.height(); ++y) { |
| 298 |
5/6✓ Branch 0 taken 294936 times.
✓ Branch 1 taken 759120 times.
✓ Branch 2 taken 1030188 times.
✓ Branch 3 taken 23868 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 23868 times.
|
1054056 | for (size_t x = 0; x < rect.width(); ++x) { |
| 299 | 2060376 | const auto imp_value = | |
| 300 |
4/8✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1030188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1030188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1030188 times.
✗ Branch 7 not taken.
|
1030188 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
| 301 | 2060376 | const auto ref_value = | |
| 302 |
4/8✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1030188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1030188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1030188 times.
✗ Branch 7 not taken.
|
1030188 | read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
| 303 |
3/6✓ Branch 0 taken 743208 times.
✓ Branch 1 taken 286980 times.
✓ Branch 2 taken 286980 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
1030188 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); |
| 304 | |||
| 305 |
1/2✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
|
1030188 | if (rel_error > 0.0001F) { |
| 306 | ✗ | ASSERT_EQ(imp_value, ref_value); | |
| 307 | ✗ | } | |
| 308 | 1030188 | } | |
| 309 | 23868 | } | |
| 310 | 3696 | } | |
| 311 | |||
| 312 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
9246 | TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_kxn_qsi8cx) { |
| 313 | 24024 | auto& [variant_index, matmul_shape, portion, clamp_keep_ratio] = GetParam(); | |
| 314 | 7392 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index); | |
| 315 | |||
| 316 |
3/4✓ Branch 0 taken 3696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✓ Branch 3 taken 924 times.
|
3696 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 317 |
3/6✓ Branch 0 taken 924 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 924 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 924 times.
✗ Branch 5 not taken.
|
924 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 318 | } | ||
| 319 | |||
| 320 | 5544 | const size_t M = matmul_shape.m; | |
| 321 | 5544 | const size_t N = matmul_shape.n; | |
| 322 | 5544 | const size_t K = matmul_shape.k; | |
| 323 | |||
| 324 | 2772 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 325 | 2772 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 326 | 2772 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 327 | 2772 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 328 | |||
| 329 | 5544 | const F32Qai8Qsi8CacheDataId testdata_id = { | |
| 330 | matmul_shape, // | ||
| 331 | 2772 | DataFormat(DataType::FP32), // | |
| 332 | 2772 | DataFormat(DataType::FP32), // | |
| 333 | 2772 | DataFormat(DataType::FP32), clamp_keep_ratio}; | |
| 334 | 2772 | const F32Qai8Qsi8CacheData& testdata = getV<F32Qai8Qsi8CacheDataId, F32Qai8Qsi8CacheData>(testdata_id); | |
| 335 | 2772 | const auto& ref_rhs_qsi8 = testdata.ref_rhs_qsi8_nt_nt; | |
| 336 | 2772 | const auto& ref_rhs_scales = testdata.ref_rhs_scales; | |
| 337 | 2772 | const auto& ref_dst = testdata.ref_dst_nt_nt; | |
| 338 | 2772 | const auto& ref_bias = testdata.ref_bias; | |
| 339 | 2772 | const auto& ref_lhs = testdata.ref_lhs; | |
| 340 | |||
| 341 | 2772 | auto m_step = ukernel_variant.interface.get_m_step(); | |
| 342 |
3/14✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
|
2772 | ASSERT_TRUE(m_step % mr == 0); |
| 343 | |||
| 344 | 2772 | auto n_step = ukernel_variant.interface.get_n_step(); | |
| 345 |
3/14✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2772 times.
|
2772 | ASSERT_TRUE(n_step % nr == 0); |
| 346 | |||
| 347 | 5544 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 348 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2772 times.
|
2772 | if (rect.height() == 0 || rect.width() == 0) { |
| 349 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 350 | } | ||
| 351 | |||
| 352 | 2772 | const auto lhs_start_row = rect.start_row(); | |
| 353 | 2772 | size_t const lhs_stride = K * sizeof(float); | |
| 354 | |||
| 355 | 2772 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); | |
| 356 | 2772 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); | |
| 357 | |||
| 358 | // Runs the LHS packing micro-kernel. | ||
| 359 | 2772 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); | |
| 360 | 2772 | Buffer imp_packed_lhs(imp_packed_lhs_size); | |
| 361 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | kai_run_lhs_quant_pack_qai8dxp_f32( |
| 362 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | rect.height(), K, mr, kr, sr, 0, reinterpret_cast<const float*>(ref_lhs.data() + lhs_offset), K * sizeof(float), |
| 363 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | imp_packed_lhs.data() + lhs_packed_offset); |
| 364 | |||
| 365 | // Runs the RHS packing micro-kernel. | ||
| 366 | // * Generates the 8-bit signed symmetric quantized input for the micro-kernel. | ||
| 367 | // * Packs the RHS matrix. | ||
| 368 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); |
| 369 | |||
| 370 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
| 371 | 2772 | const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f}; | |
| 372 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon( |
| 373 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()), |
| 374 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | reinterpret_cast<const float*>(ref_bias.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()), |
| 375 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | imp_packed_rhs.data(), 0, ¶ms); |
| 376 | |||
| 377 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | const auto packed_rhs_start_row = rect.start_col(); |
| 378 | 2772 | auto rhs_packed_offset = | |
| 379 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(packed_rhs_start_row, K, nr, kr, sr); |
| 380 | |||
| 381 | 2772 | const auto dst_stride = N * sizeof(float); | |
| 382 |
3/6✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
|
2772 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
| 383 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); |
| 384 |
4/16✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
|
2772 | ASSERT_EQ(dst_offset, ref_dst_offset); |
| 385 | |||
| 386 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | const auto matmul_lhs_packed_offset = ukernel_variant.interface.get_lhs_packed_offset(rect.start_row(), K); |
| 387 |
4/16✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
|
2772 | ASSERT_EQ(lhs_packed_offset, matmul_lhs_packed_offset); |
| 388 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | const auto matmul_rhs_packed_offset = ukernel_variant.interface.get_rhs_packed_offset(rect.start_col(), K); |
| 389 |
4/16✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2772 times.
|
2772 | ASSERT_EQ(rhs_packed_offset, matmul_rhs_packed_offset); |
| 390 | |||
| 391 | // Runs the GEMM micro-kernel. | ||
| 392 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
| 393 |
5/18✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2772 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2772 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2772 times.
|
2772 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
| 394 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | Buffer imp_dst(imp_dst_size); |
| 395 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | abi_check( |
| 396 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | ukernel_variant.interface.run_matmul, rect.height(), rect.width(), K, |
| 397 |
2/4✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2772 times.
✗ Branch 3 not taken.
|
2772 | imp_packed_lhs.data() + matmul_lhs_packed_offset, imp_packed_rhs.data() + matmul_rhs_packed_offset, |
| 398 |
1/2✓ Branch 0 taken 2772 times.
✗ Branch 1 not taken.
|
2772 | reinterpret_cast<float*>(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), |
| 399 | 2772 | testdata.clamp_range.min, testdata.clamp_range.max); | |
| 400 | |||
| 401 | // Compares the output of the micro-kernels against the output of the reference implementation. | ||
| 402 |
5/6✓ Branch 0 taken 8880 times.
✓ Branch 1 taken 17760 times.
✓ Branch 2 taken 23868 times.
✓ Branch 3 taken 2772 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2772 times.
|
26640 | for (size_t y = 0; y < rect.height(); ++y) { |
| 403 |
5/6✓ Branch 0 taken 294936 times.
✓ Branch 1 taken 759120 times.
✓ Branch 2 taken 1030188 times.
✓ Branch 3 taken 23868 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 23868 times.
|
1054056 | for (size_t x = 0; x < rect.width(); ++x) { |
| 404 | 2060376 | const auto imp_value = | |
| 405 |
4/8✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1030188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1030188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1030188 times.
✗ Branch 7 not taken.
|
1030188 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
| 406 | 2060376 | const auto ref_value = | |
| 407 |
4/8✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1030188 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1030188 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1030188 times.
✗ Branch 7 not taken.
|
1030188 | read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
| 408 |
3/6✓ Branch 0 taken 743208 times.
✓ Branch 1 taken 286980 times.
✓ Branch 2 taken 286980 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
1030188 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); |
| 409 | |||
| 410 |
1/2✓ Branch 0 taken 1030188 times.
✗ Branch 1 not taken.
|
1030188 | if (rel_error > 0.0001F) { |
| 411 | ✗ | ASSERT_EQ(imp_value, ref_value); | |
| 412 | ✗ | } | |
| 413 | 1030188 | } | |
| 414 | 23868 | } | |
| 415 | 3696 | } | |
| 416 | |||
| 417 |
35/118✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✓ Branch 14 taken 8 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✓ Branch 18 taken 8 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✓ Branch 20 taken 8 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✓ Branch 22 taken 8 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 4 times.
✓ Branch 24 taken 8 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✓ Branch 26 taken 8 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✓ Branch 28 taken 8 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 4 times.
✓ Branch 30 taken 8 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 4 times.
✓ Branch 32 taken 8 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✓ Branch 34 taken 8 times.
✓ Branch 34 taken 7392 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✓ Branch 36 taken 14784 times.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 59 not taken.
✗ Branch 60 not taken.
✓ Branch 60 taken 7392 times.
✗ Branch 61 not taken.
✗ Branch 61 not taken.
✓ Branch 62 taken 14784 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 14784 times.
✗ Branch 65 not taken.
|
44367 | INSTANTIATE_TEST_SUITE_P( |
| 418 | MatMul, MatMulTest_f32_qai8dxp_qsi8cxp, | ||
| 419 | testing::Combine( | ||
| 420 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.size()), | ||
| 421 | testing::Values( | ||
| 422 | MatMulShape{17, 33, 67}, // | ||
| 423 | MatMulShape{19, 35, 63}, // | ||
| 424 | MatMulShape{1, 27, 31}, // | ||
| 425 | MatMulShape{1, 65, 35}, // | ||
| 426 | MatMulShape{1, 64, 65}, // | ||
| 427 | MatMulShape{1, 63, 15}, // | ||
| 428 | MatMulShape{1, 130, 15}, // | ||
| 429 | MatMulShape{15, 65, 35}, // | ||
| 430 | MatMulShape{16, 64, 65}, // | ||
| 431 | MatMulShape{17, 63, 15}, // | ||
| 432 | MatMulShape{20, 130, 15}), | ||
| 433 | testing::Values( | ||
| 434 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 435 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
| 436 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
| 437 | MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle | ||
| 438 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
| 439 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
| 440 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
| 441 | ), | ||
| 442 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 443 | [](const auto& info) { | ||
| 444 | const auto variant_idx = std::get<0>(info.param); | ||
| 445 | const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_idx).name}; | ||
| 446 | const auto shape = std::get<MatMulShape>(info.param); | ||
| 447 | const auto portion = std::get<MatrixPortion>(info.param); | ||
| 448 | const auto clamp_keep_ratio = std::get<float>(info.param); | ||
| 449 | |||
| 450 | return test_description(name, shape, portion, true, clamp_keep_ratio); | ||
| 451 | }); | ||
| 452 | |||
| 453 | } // namespace kai::test | ||
| 454 |