test/tests/matmul_clamp_bf16_qai8dxp_qsi4cxp_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <cstdlib> | ||
| 13 | #include <limits> | ||
| 14 | #include <sstream> | ||
| 15 | #include <string> | ||
| 16 | #include <tuple> | ||
| 17 | #include <vector> | ||
| 18 | |||
| 19 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h" | ||
| 20 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h" | ||
| 21 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h" | ||
| 22 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h" | ||
| 23 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" | ||
| 24 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" | ||
| 25 | #include "test/common/bfloat16.hpp" | ||
| 26 | #include "test/common/buffer.hpp" | ||
| 27 | #include "test/common/cache.hpp" | ||
| 28 | #include "test/common/compare.hpp" | ||
| 29 | #include "test/common/cpu_info.hpp" | ||
| 30 | #include "test/common/data_format.hpp" | ||
| 31 | #include "test/common/int4.hpp" | ||
| 32 | #include "test/common/matmul_test_common.hpp" | ||
| 33 | #include "test/common/matrix_portion.hpp" | ||
| 34 | #include "test/common/memory.hpp" | ||
| 35 | #include "test/common/round.hpp" | ||
| 36 | #include "test/common/seed.hpp" | ||
| 37 | #include "test/common/test_suite.hpp" | ||
| 38 | #include "test/reference/cast.hpp" | ||
| 39 | #include "test/reference/clamp.hpp" | ||
| 40 | #include "test/reference/fill.hpp" | ||
| 41 | #include "test/reference/matmul.hpp" | ||
| 42 | #include "test/reference/pack.hpp" | ||
| 43 | #include "test/reference/pad.hpp" | ||
| 44 | #include "test/reference/quantize.hpp" | ||
| 45 | #include "test/reference/transpose.hpp" | ||
| 46 | |||
| 47 | // Using BFloat truncate implementation (BFloat16<false>) to match existing packing/inference | ||
| 48 | |||
| 49 | namespace kai::test { | ||
| 50 | |||
| 51 | using Bf16Qai8Qsi4CacheDataId = std::tuple< // | ||
| 52 | MatMulShape, // | ||
| 53 | DataFormat, // lhs format | ||
| 54 | DataFormat, // rhs format | ||
| 55 | DataFormat, // bias format | ||
| 56 | float // clamp_keep_ratio | ||
| 57 | >; | ||
| 58 | |||
| 59 | struct Bf16Qai8Qsi4CacheData { | ||
| 60 | Buffer ref_dst_nt_t; | ||
| 61 | Buffer ref_dst_nt_nt; | ||
| 62 | Buffer ref_rhs_qsi4_nt_t; | ||
| 63 | Buffer ref_rhs_qsi4_nt_nt; | ||
| 64 | Buffer ref_rhs_scales; | ||
| 65 | Buffer ref_lhs_bf16; | ||
| 66 | Buffer ref_biases_buf; | ||
| 67 | Range<float> clamp_nt_nt; | ||
| 68 | Range<float> clamp_nt_t; | ||
| 69 | }; | ||
| 70 | |||
| 71 | template <> | ||
| 72 | 168 | Bf16Qai8Qsi4CacheData ReferenceGenerator<Bf16Qai8Qsi4CacheDataId, Bf16Qai8Qsi4CacheData>::generate_reference( | |
| 73 | const Bf16Qai8Qsi4CacheDataId& data_id) { | ||
| 74 | 2184 | auto [shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio] = data_id; | |
| 75 | |||
| 76 | 336 | size_t M = shape.m; | |
| 77 | 336 | size_t N = shape.n; | |
| 78 | 336 | size_t K = shape.k; | |
| 79 | |||
| 80 | // Seed the random generator. | ||
| 81 |
8/16✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 168 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 168 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 84 times.
✗ Branch 15 not taken.
|
504 | const auto key = std::string("Bf16Qai8Qsi4_cache:") + std::to_string(M) + "x" + std::to_string(N) + "x" + |
| 82 |
8/16✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 168 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 168 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 84 times.
✗ Branch 15 not taken.
|
504 | std::to_string(K) + ":" + std::to_string(static_cast<uint32_t>(lhs_format.data_type())) + ":" + |
| 83 |
5/10✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 84 times.
✗ Branch 9 not taken.
|
588 | std::to_string(static_cast<uint32_t>(rhs_format.data_type())) + ":" + |
| 84 |
7/14✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 168 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 84 times.
✗ Branch 13 not taken.
|
420 | std::to_string(static_cast<uint32_t>(bias_format.data_type())) + ":" + std::to_string(clamp_keep_ratio); |
| 85 |
1/2✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
|
168 | auto& feed = seed_stream(key); |
| 86 | |||
| 87 |
2/4✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
|
336 | bool has_bias = bias_format.data_type() != DataType::UNKNOWN; |
| 88 |
5/10✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
|
672 | Buffer lhs = fill_matrix_random(shape.m, shape.k, lhs_format, feed()); |
| 89 |
5/10✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 168 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 168 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 168 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 168 times.
✗ Branch 9 not taken.
|
672 | Buffer ref_rhs = fill_matrix_random(shape.n, shape.k, rhs_format, feed()); |
| 90 |
5/8✓ Branch 0 taken 84 times.
✓ Branch 1 taken 84 times.
✓ Branch 2 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 84 times.
✗ Branch 7 not taken.
|
168 | Buffer bias = has_bias ? fill_matrix_random(1, shape.n, bias_format, feed()) : Buffer(); |
| 91 | |||
| 92 | 168 | Bf16Qai8Qsi4CacheData out; | |
| 93 | // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BFP16 because the matmul | ||
| 94 | // implementation works with FP32 accumulation and casts the result to BFP16 | ||
| 95 |
1/2✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
|
336 | const auto ref_lhs = cast<float, BFloat16<false>>( |
| 96 |
1/2✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
|
168 | lhs.data(), // |
| 97 |
1/2✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
|
168 | lhs.size() * 8 / size_in_bits<BFloat16<false>>); |
| 98 | |||
| 99 | // Transposed(nxk) RHS dimensions | ||
| 100 | 168 | const size_t ref_rhs_qsi4_nxk_stride = K; | |
| 101 | |||
| 102 | // Non-Transposed(kxn) RHS dimensions | ||
| 103 |
1/2✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
|
168 | const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2); |
| 104 |
1/2✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
|
168 | const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(K * ref_rhs_qsi4_kxn_stride, 2); |
| 105 | |||
| 106 |
1/2✓ Branch 0 taken 168 times.
✗ Branch 1 not taken.
|
168 | QuantizationInfo lhs_qinfo{}; |
| 107 | lhs_qinfo.quant_width = K; | ||
| 108 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 109 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 110 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 111 | const auto [ref_lhs_quant, lhs_qoutputs] = quantize_dynamic(ref_lhs.data(), DataType::FP32, M, K, lhs_qinfo); | ||
| 112 | |||
| 113 | QuantizationInfo rhs_qinfo{}; | ||
| 114 | rhs_qinfo.quant_width = K; | ||
| 115 | rhs_qinfo.dst_type = DataType::QSI4; | ||
| 116 | rhs_qinfo.scale_type = DataType::FP32; | ||
| 117 | auto [ref_rhs_quant_t, rhs_qoutputs] = quantize_dynamic(ref_rhs.data(), DataType::FP32, N, K, rhs_qinfo); | ||
| 118 | |||
| 119 | auto ref_rhs_qsi4 = transpose_with_padding<Int4>( | ||
| 120 | ref_rhs_quant_t.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, ref_rhs_qsi4_kxn_size_bytes); | ||
| 121 | |||
| 122 | const auto ref_dst_nt_nt = matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( | ||
| 123 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), K, | ||
| 124 | ref_rhs_qsi4.data(), rhs_qoutputs.scales.data(), nullptr, K, has_bias ? bias.data() : nullptr, | ||
| 125 | std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | ||
| 126 | |||
| 127 | const auto [clamp_min_nt_nt, clamp_max_nt_nt] = | ||
| 128 | find_clamp_range<float>(ref_dst_nt_nt.data(), M * N, clamp_keep_ratio); | ||
| 129 | out.ref_rhs_qsi4_nt_nt = std::move(ref_rhs_qsi4); | ||
| 130 | |||
| 131 | const auto ref_dst_float_nt_nt = clamp<float>(ref_dst_nt_nt.data(), M * N, clamp_min_nt_nt, clamp_max_nt_nt); | ||
| 132 | |||
| 133 | auto ref_dst_nt_nt_bf16 = | ||
| 134 | cast<BFloat16<false>, float>(ref_dst_float_nt_nt.data(), ref_dst_float_nt_nt.size() * 8 / size_in_bits<float>); | ||
| 135 | out.ref_dst_nt_nt = std::move(ref_dst_nt_nt_bf16); | ||
| 136 | |||
| 137 | out.clamp_nt_nt = {clamp_min_nt_nt, clamp_max_nt_nt}; | ||
| 138 | |||
| 139 | const auto ref_dst_nt_t = | ||
| 140 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>( | ||
| 141 | M, N, K, ref_lhs_quant.data(), lhs_qoutputs.scales.data(), lhs_qoutputs.zero_points.data(), 1, K, | ||
| 142 | ref_rhs_quant_t.data(), rhs_qoutputs.scales.data(), nullptr, 1, K, has_bias ? bias.data() : nullptr, | ||
| 143 | nullptr, nullptr, 1); | ||
| 144 | |||
| 145 | const auto [clamp_min_nt_t, clamp_max_nt_t] = find_clamp_range<float>(ref_dst_nt_t.data(), M * N, clamp_keep_ratio); | ||
| 146 | out.ref_rhs_qsi4_nt_t = std::move(ref_rhs_quant_t); | ||
| 147 | const auto ref_dst_nt_t_float = clamp<float>(ref_dst_nt_t.data(), M * N, clamp_min_nt_t, clamp_max_nt_t); | ||
| 148 | |||
| 149 | auto ref_dst_nt_t_bf16 = | ||
| 150 | cast<BFloat16<false>, float>(ref_dst_nt_t_float.data(), ref_dst_nt_t_float.size() * 8 / size_in_bits<float>); | ||
| 151 | out.ref_dst_nt_t = std::move(ref_dst_nt_t_bf16); | ||
| 152 | out.clamp_nt_t = {clamp_min_nt_t, clamp_max_nt_t}; | ||
| 153 | out.ref_lhs_bf16 = std::move(lhs); | ||
| 154 | out.ref_biases_buf = std::move(bias); | ||
| 155 | out.ref_rhs_scales = std::move(rhs_qoutputs.scales); | ||
| 156 | return out; | ||
| 157 | ✗ | } | |
| 158 | |||
| 159 | 3 | static const std::array<UkernelVariant<kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel>, 2> | |
| 160 |
0/4✗ Branch 0 not taken.
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 1 not taken.
|
5 | variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp = {{ |
| 161 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod), |
| 162 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", cpu_has_dotprod_and_bf16}, |
| 163 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
6 | {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm), |
| 164 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
3 | "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", cpu_has_i8mm_and_bf16}, |
| 165 | }}; | ||
| 166 | |||
| 167 | using MatMulClampTestPortionedParamsWithBias = std::tuple<size_t, MatMulShape, MatrixPortion, float, bool>; | ||
| 168 | class MatMulTest_bf16_qai8dxp_qsi4cxp : public ::testing::TestWithParam<MatMulClampTestPortionedParamsWithBias> {}; | ||
| 169 | |||
| 170 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5886 | TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_NxK) { |
| 171 | 21168 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio, has_bias] = GetParam(); | |
| 172 | 4704 | const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index); | |
| 173 | |||
| 174 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
|
2352 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 175 | ✗ | GTEST_SKIP() << "CPU features are not supported by current CPU"; | |
| 176 | } | ||
| 177 | |||
| 178 | 4704 | const size_t M = matmul_shape.m; | |
| 179 | 4704 | const size_t N = matmul_shape.n; | |
| 180 | 4704 | const size_t K = matmul_shape.k; | |
| 181 | |||
| 182 | 2352 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 183 | 2352 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 184 | 2352 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 185 | 2352 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 186 | |||
| 187 | 2352 | const auto lhs_format = DataFormat(DataType::BF16); | |
| 188 | 2352 | const auto rhs_format = DataFormat(DataType::FP32); | |
| 189 |
4/4✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 1176 times.
✓ Branch 2 taken 1176 times.
✓ Branch 3 taken 1176 times.
|
4704 | const auto bias_format = has_bias ? DataFormat(DataType::FP32) : DataFormat(DataType::UNKNOWN); |
| 190 | |||
| 191 | 2352 | auto m_step = ukernel_variant.interface.get_m_step(); | |
| 192 |
3/14✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2352 times.
|
2352 | ASSERT_TRUE(m_step % mr == 0); |
| 193 | |||
| 194 | 2352 | auto n_step = ukernel_variant.interface.get_n_step(); | |
| 195 |
3/14✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2352 times.
|
2352 | ASSERT_TRUE(n_step % nr == 0); |
| 196 | |||
| 197 | 4704 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 198 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2352 times.
|
2352 | if (rect.height() == 0 || rect.width() == 0) { |
| 199 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 200 | } | ||
| 201 | |||
| 202 | 4704 | const Bf16Qai8Qsi4CacheDataId testdata_id = {matmul_shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio}; | |
| 203 | 2352 | const Bf16Qai8Qsi4CacheData& testdata = getV<Bf16Qai8Qsi4CacheDataId, Bf16Qai8Qsi4CacheData>(testdata_id); | |
| 204 | |||
| 205 | 2352 | const auto& ref_lhs_bf16 = testdata.ref_lhs_bf16; | |
| 206 | 2352 | const auto& ref_rhs_qsi4 = testdata.ref_rhs_qsi4_nt_t; | |
| 207 | 2352 | const auto& ref_biases_buf = testdata.ref_biases_buf; | |
| 208 | 2352 | const auto& ref_rhs_scales = testdata.ref_rhs_scales; | |
| 209 | 2352 | const auto& ref_dst = testdata.ref_dst_nt_t; | |
| 210 | 7056 | auto [clamp_min, clamp_max] = testdata.clamp_nt_t; | |
| 211 | 2352 | const auto lhs_start_row = rect.start_row(); | |
| 212 | 2352 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); | |
| 213 | 2352 | Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size); | |
| 214 | |||
| 215 | 2352 | auto lhs_stride = K * sizeof(uint16_t); | |
| 216 | |||
| 217 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); |
| 218 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); |
| 219 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
| 220 | |||
| 221 |
4/16✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 2352 times.
✗ Branch 15 not taken.
|
2352 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
| 222 | |||
| 223 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | kai_run_lhs_quant_pack_qai8dxp_bf16_neon( |
| 224 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
|
2352 | rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, |
| 225 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | reinterpret_cast<uint8_t*>(imp_packed_lhs_buf.data()) + lhs_packed_offset); |
| 226 | |||
| 227 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
4704 | const auto ref_rhs_qsi4_padded = pad_row<Int4>( |
| 228 |
4/8✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2352 times.
✗ Branch 7 not taken.
|
2352 | ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); |
| 229 | |||
| 230 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr); |
| 231 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size); |
| 232 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | const auto rhs_start_row = rect.start_col(); |
| 233 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr); |
| 234 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); |
| 235 |
4/16✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
|
2352 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
| 236 | // Runs the RHS packing micro-kernel. | ||
| 237 | 2352 | kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{}; | |
| 238 | 2352 | params.lhs_zero_point = 1; | |
| 239 | 2352 | params.rhs_zero_point = 0; | |
| 240 | |||
| 241 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( |
| 242 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()), |
| 243 |
3/4✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 1176 times.
✓ Branch 2 taken 1176 times.
✗ Branch 3 not taken.
|
2352 | has_bias ? reinterpret_cast<const float*>(ref_biases_buf.data()) : nullptr, |
| 244 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
|
2352 | reinterpret_cast<const float*>(ref_rhs_scales.data()), reinterpret_cast<uint8_t*>(imp_packed_rhs_buf.data()), 0, |
| 245 | ¶ms); | ||
| 246 | |||
| 247 | 2352 | const auto dst_stride_row = N * sizeof(uint16_t); | |
| 248 | 2352 | const auto dst_stride_col = sizeof(uint16_t); | |
| 249 | 4704 | const auto dst_offset = | |
| 250 |
3/6✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
|
2352 | ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); |
| 251 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
|
2352 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; |
| 252 |
4/16✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
|
2352 | ASSERT_EQ(dst_offset, ref_dst_offset); |
| 253 | |||
| 254 | // Runs the GEMM micro-kernel. | ||
| 255 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
| 256 |
5/18✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2352 times.
|
2352 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
| 257 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | Buffer imp_dst_buf = Buffer(imp_dst_size); |
| 258 | |||
| 259 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
4704 | ukernel_variant.interface.run_matmul( |
| 260 |
3/6✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
|
2352 | rect.height(), rect.width(), K, reinterpret_cast<const uint8_t*>(imp_packed_lhs_buf.data()) + lhs_matmul_offset, |
| 261 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | reinterpret_cast<const uint8_t*>(imp_packed_rhs_buf.data()) + rhs_matmul_offset, |
| 262 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | reinterpret_cast<uint8_t*>(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min, |
| 263 | 2352 | clamp_max); | |
| 264 | |||
| 265 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
| 266 | // tested. | ||
| 267 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
| 268 | 2352 | DataFormat dst_format = DataFormat(DataType::BF16); | |
| 269 | 4704 | const auto success = | |
| 270 |
3/6✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
|
2352 | compare(reinterpret_cast<const uint8_t*>(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler); |
| 271 |
4/16✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
|
2352 | ASSERT_TRUE(success); |
| 272 | 2352 | } | |
| 273 | |||
| 274 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
5886 | TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_KxN) { |
| 275 | 21168 | const auto& [variant_index, matmul_shape, portion, clamp_keep_ratio, has_bias] = GetParam(); | |
| 276 | 4704 | const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index); | |
| 277 | |||
| 278 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
|
2352 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
| 279 | ✗ | GTEST_SKIP() << "CPU features are not supported by current CPU"; | |
| 280 | } | ||
| 281 | |||
| 282 | 4704 | const size_t M = matmul_shape.m; | |
| 283 | 4704 | const size_t N = matmul_shape.n; | |
| 284 | 4704 | const size_t K = matmul_shape.k; | |
| 285 | |||
| 286 | 2352 | const auto mr = ukernel_variant.interface.get_mr(); | |
| 287 | 2352 | const auto nr = ukernel_variant.interface.get_nr(); | |
| 288 | 2352 | const auto kr = ukernel_variant.interface.get_kr(); | |
| 289 | 2352 | const auto sr = ukernel_variant.interface.get_sr(); | |
| 290 | |||
| 291 | 2352 | const auto lhs_format = DataFormat(DataType::BF16); | |
| 292 | 2352 | const auto rhs_format = DataFormat(DataType::FP32); | |
| 293 |
4/4✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 1176 times.
✓ Branch 2 taken 1176 times.
✓ Branch 3 taken 1176 times.
|
4704 | const auto bias_format = has_bias ? DataFormat(DataType::FP32) : DataFormat(DataType::UNKNOWN); |
| 294 | |||
| 295 | // Generates input data. | ||
| 296 | 4704 | const Bf16Qai8Qsi4CacheDataId testdata_id = {matmul_shape, lhs_format, rhs_format, bias_format, clamp_keep_ratio}; | |
| 297 | 2352 | const Bf16Qai8Qsi4CacheData& testdata = getV<Bf16Qai8Qsi4CacheDataId, Bf16Qai8Qsi4CacheData>(testdata_id); | |
| 298 | |||
| 299 | 2352 | const auto& ref_lhs_bf16 = testdata.ref_lhs_bf16; | |
| 300 | 2352 | const auto& ref_rhs_qsi4 = testdata.ref_rhs_qsi4_nt_nt; | |
| 301 | 2352 | const auto& ref_biases_buf = testdata.ref_biases_buf; | |
| 302 | 2352 | const auto& ref_rhs_scales = testdata.ref_rhs_scales; | |
| 303 | 2352 | const auto& ref_dst = testdata.ref_dst_nt_nt; | |
| 304 | 7056 | auto [clamp_min, clamp_max] = testdata.clamp_nt_nt; | |
| 305 | 2352 | auto m_step = ukernel_variant.interface.get_m_step(); | |
| 306 |
3/14✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2352 times.
|
2352 | ASSERT_TRUE(m_step % mr == 0); |
| 307 | |||
| 308 | 2352 | auto n_step = ukernel_variant.interface.get_n_step(); | |
| 309 |
3/14✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 2352 times.
|
2352 | ASSERT_TRUE(n_step % nr == 0); |
| 310 | |||
| 311 | 4704 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
| 312 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2352 times.
|
2352 | if (rect.height() == 0 || rect.width() == 0) { |
| 313 | ✗ | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; | |
| 314 | } | ||
| 315 | |||
| 316 | 2352 | const auto lhs_start_row = rect.start_row(); | |
| 317 | 2352 | size_t lhs_stride = K * sizeof(uint16_t); | |
| 318 | |||
| 319 | // Runs the LHS packing micro-kernel. | ||
| 320 | 2352 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); | |
| 321 | 2352 | Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size); | |
| 322 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); |
| 323 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); |
| 324 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
| 325 |
4/16✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 2352 times.
✗ Branch 15 not taken.
|
2352 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
| 326 | |||
| 327 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | kai_run_lhs_quant_pack_qai8dxp_bf16_neon( |
| 328 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
|
2352 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, ref_lhs_bf16.data() + lhs_offset, lhs_stride, |
| 329 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | reinterpret_cast<uint8_t*>(imp_packed_lhs_buf.data()) + lhs_packed_offset); |
| 330 | |||
| 331 | // Runs the RHS packing micro-kernel. | ||
| 332 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
| 333 | // * Packs the RHS matrix. | ||
| 334 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
4704 | const auto ref_rhs_qsi4_padded = pad_row<Int4>( |
| 335 |
4/8✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2352 times.
✗ Branch 7 not taken.
|
2352 | ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); |
| 336 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr); |
| 337 | |||
| 338 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | const auto rhs_start_row = rect.start_col(); |
| 339 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr); |
| 340 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); |
| 341 |
4/16✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
|
2352 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
| 342 | |||
| 343 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size); |
| 344 | 2352 | kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{}; | |
| 345 | 2352 | params.lhs_zero_point = 1; | |
| 346 | 2352 | params.rhs_zero_point = 0; | |
| 347 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0( |
| 348 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()), |
| 349 |
3/4✓ Branch 0 taken 1176 times.
✓ Branch 1 taken 1176 times.
✓ Branch 2 taken 1176 times.
✗ Branch 3 not taken.
|
2352 | has_bias ? reinterpret_cast<const float*>(ref_biases_buf.data()) : nullptr, |
| 350 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
|
2352 | reinterpret_cast<const float*>(ref_rhs_scales.data()), imp_packed_rhs_buf.data(), 0, ¶ms); |
| 351 | |||
| 352 | 2352 | const auto dst_stride = N * sizeof(uint16_t); | |
| 353 |
3/6✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
|
2352 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
| 354 |
2/4✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
|
2352 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(uint16_t); |
| 355 |
4/16✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
|
2352 | ASSERT_EQ(dst_offset, ref_dst_offset); |
| 356 | |||
| 357 | // Runs the GEMM micro-kernel. | ||
| 358 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
| 359 |
5/18✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2352 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 2352 times.
|
2352 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
| 360 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | Buffer imp_dst_buf = Buffer(imp_dst_size); |
| 361 | |||
| 362 | 2352 | const auto dst_stride_row = N * sizeof(uint16_t); | |
| 363 | 2352 | const auto dst_stride_col = sizeof(uint16_t); | |
| 364 | |||
| 365 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
4704 | ukernel_variant.interface.run_matmul( |
| 366 |
3/6✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
|
2352 | rect.height(), rect.width(), K, reinterpret_cast<const uint8_t*>(imp_packed_lhs_buf.data()) + lhs_matmul_offset, |
| 367 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | reinterpret_cast<const uint8_t*>(imp_packed_rhs_buf.data()) + rhs_matmul_offset, |
| 368 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | reinterpret_cast<uint8_t*>(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min, |
| 369 | 2352 | clamp_max); | |
| 370 | |||
| 371 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
| 372 | // tested. | ||
| 373 |
1/2✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
|
2352 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
| 374 | 2352 | DataFormat dst_format = DataFormat(DataType::BF16); | |
| 375 | 4704 | const auto success = | |
| 376 |
3/6✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
|
2352 | compare(reinterpret_cast<const uint8_t*>(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler); |
| 377 |
4/16✓ Branch 0 taken 2352 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2352 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2352 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2352 times.
|
2352 | ASSERT_TRUE(success); |
| 378 | 2352 | } | |
| 379 | |||
| 380 |
34/120✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✓ Branch 24 taken 4 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✓ Branch 26 taken 4 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✓ Branch 28 taken 4 times.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✓ Branch 30 taken 4 times.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✓ Branch 32 taken 4 times.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✓ Branch 34 taken 4 times.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 4 times.
✓ Branch 36 taken 2352 times.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✓ Branch 38 taken 4704 times.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 59 not taken.
✗ Branch 60 not taken.
✗ Branch 60 not taken.
✗ Branch 61 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✗ Branch 63 not taken.
✗ Branch 64 not taken.
✗ Branch 65 not taken.
|
14121 | INSTANTIATE_TEST_SUITE_P( |
| 381 | MatMul, MatMulTest_bf16_qai8dxp_qsi4cxp, | ||
| 382 | testing::Combine( | ||
| 383 | testing::Range<size_t>(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.size()), | ||
| 384 | testing::Values( | ||
| 385 | MatMulShape{1, 2, 32}, // | ||
| 386 | MatMulShape{1, 3, 32}, // | ||
| 387 | MatMulShape{1, 4, 32}, // | ||
| 388 | MatMulShape{1, 5, 32}, // | ||
| 389 | MatMulShape{3, 3, 32}, // | ||
| 390 | MatMulShape{4, 4, 32}, // | ||
| 391 | MatMulShape{5, 5, 32}, // | ||
| 392 | MatMulShape{32, 64, 64}, // | ||
| 393 | MatMulShape{16, 32, 64}, // | ||
| 394 | MatMulShape{8, 32, 64}, // | ||
| 395 | MatMulShape{15, 32, 32}, // | ||
| 396 | MatMulShape{77, 99, 64}, // | ||
| 397 | MatMulShape{77, 99, 66}, // | ||
| 398 | MatMulShape{77, 99, 31}), | ||
| 399 | testing::Values( | ||
| 400 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 401 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
| 402 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
| 403 | MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle | ||
| 404 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
| 405 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
| 406 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
| 407 | ), | ||
| 408 | testing::ValuesIn(std::initializer_list<float>{1.0f, 0.9f, 0.5f}), // | ||
| 409 | testing::Bool()), | ||
| 410 | [](const auto& info) -> std::string { | ||
| 411 | const auto variant_idx = std::get<0>(info.param); | ||
| 412 | const auto& name = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp[variant_idx].name; | ||
| 413 | return test_description( | ||
| 414 | name, std::get<MatMulShape>(info.param), std::get<2>(info.param), std::get<4>(info.param), | ||
| 415 | std::get<3>(info.param)); | ||
| 416 | }); | ||
| 417 | |||
| 418 | } // namespace kai::test | ||
| 419 |