Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <cstdlib> | ||
13 | #include <limits> | ||
14 | #include <sstream> | ||
15 | #include <string> | ||
16 | #include <tuple> | ||
17 | #include <vector> | ||
18 | |||
19 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h" | ||
20 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h" | ||
21 | #include "kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h" | ||
22 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h" | ||
23 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" | ||
24 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" | ||
25 | #include "test/common/bfloat16.hpp" | ||
26 | #include "test/common/buffer.hpp" | ||
27 | #include "test/common/compare.hpp" | ||
28 | #include "test/common/cpu_info.hpp" | ||
29 | #include "test/common/data_format.hpp" | ||
30 | #include "test/common/int4.hpp" | ||
31 | #include "test/common/matmul_test_common.hpp" | ||
32 | #include "test/common/matrix_portion.hpp" | ||
33 | #include "test/common/memory.hpp" | ||
34 | #include "test/common/round.hpp" | ||
35 | #include "test/common/test_suite.hpp" | ||
36 | #include "test/reference/cast.hpp" | ||
37 | #include "test/reference/clamp.hpp" | ||
38 | #include "test/reference/fill.hpp" | ||
39 | #include "test/reference/matmul.hpp" | ||
40 | #include "test/reference/pack.hpp" | ||
41 | #include "test/reference/pad.hpp" | ||
42 | #include "test/reference/quantize.hpp" | ||
43 | #include "test/reference/transpose.hpp" | ||
44 | |||
45 | // Using BFloat truncate implementation (BFloat16<false>) to match existing packing/inference | ||
46 | |||
47 | namespace kai::test { | ||
48 | |||
49 | static const std::array<UkernelVariant<kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel>, 2> | ||
50 | variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp = {{ | ||
51 | {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod), | ||
52 | "kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod", cpu_has_dotprod_and_bf16}, | ||
53 | {UKERNEL_MATMUL_VARIANT(clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm), | ||
54 | "kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm", cpu_has_i8mm_and_bf16}, | ||
55 | }}; | ||
56 | |||
57 | class MatMulTest_bf16_qai8dxp_qsi4cxp : public ::testing::TestWithParam<MatMulTestPortionedParamsWithBias> {}; | ||
58 | |||
59 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1178 | TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_NxK) { |
60 | 3468 | const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); | |
61 | 784 | const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index); | |
62 | |||
63 |
2/4✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
|
392 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
64 | ✗ | GTEST_SKIP() << "CPU features are not supported by current CPU"; | |
65 | } | ||
66 | |||
67 | 392 | const std::uint32_t seed = 0; | |
68 | |||
69 | 784 | const size_t M = matmul_shape.m; | |
70 | 784 | const size_t N = matmul_shape.n; | |
71 | 784 | const size_t K = matmul_shape.k; | |
72 | |||
73 | 392 | const auto mr = ukernel_variant.interface.get_mr(); | |
74 | 392 | const auto nr = ukernel_variant.interface.get_nr(); | |
75 | 392 | const auto kr = ukernel_variant.interface.get_kr(); | |
76 | 392 | const auto sr = ukernel_variant.interface.get_sr(); | |
77 | |||
78 | 392 | auto m_step = ukernel_variant.interface.get_m_step(); | |
79 |
3/14✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 392 times.
|
392 | ASSERT_TRUE(m_step % mr == 0); |
80 | |||
81 | 392 | auto n_step = ukernel_variant.interface.get_n_step(); | |
82 |
3/14✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 392 times.
|
392 | ASSERT_TRUE(n_step % nr == 0); |
83 | |||
84 | 784 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
85 |
4/4✓ Branch 0 taken 376 times.
✓ Branch 1 taken 16 times.
✓ Branch 2 taken 4 times.
✓ Branch 3 taken 372 times.
|
392 | if (rect.height() == 0 || rect.width() == 0) { |
86 |
9/18✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 20 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 20 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 20 times.
✗ Branch 17 not taken.
|
20 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
87 | } | ||
88 | |||
89 | // Generates input data. | ||
90 | 372 | const auto ref_lhs_bf16 = fill_random<BFloat16<false>>(M * K, seed + 0); | |
91 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto ref_rhs = fill_random<float>(N * K, seed + 1); |
92 | |||
93 | 372 | Buffer ref_biases_buf; | |
94 |
2/2✓ Branch 0 taken 186 times.
✓ Branch 1 taken 186 times.
|
372 | if (has_bias) { |
95 |
1/2✓ Branch 0 taken 186 times.
✗ Branch 1 not taken.
|
186 | ref_biases_buf = Buffer(fill_random<float>(N, seed + 2)); |
96 | 186 | } | |
97 | |||
98 | // For reference implementation, Casting BF16 input to FP32 type and FP32 output back to BFP16 because the matmul | ||
99 | // implementation works with FP32 accumulation and casts the result to BFP16 | ||
100 | 372 | const auto ref_lhs = | |
101 |
3/6✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
|
372 | cast<float, BFloat16<false>>(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits<BFloat16<false>>); |
102 | |||
103 | // Runs the reference implementation. | ||
104 | // * Quantizes the LHS matrix using 8-bit symmetric quantization. | ||
105 | // * Quantizes the RHS matrix using 8-bit asymmetric quantization. | ||
106 | // * Performs GEMM. | ||
107 | 1116 | const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
108 |
2/4✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
|
372 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K); |
109 | 1860 | const auto [ref_rhs_qsi4, ref_rhs_scales] = | |
110 |
2/4✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
|
372 | quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K); |
111 | 372 | const auto ref_dst_no_clamp = | |
112 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | matmul_nt_t_quantized<int8_t, float, int32_t, Int4, float, int32_t, float, float, int32_t, float>( |
113 |
4/8✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
|
744 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, |
114 |
7/10✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 186 times.
✓ Branch 5 taken 186 times.
✓ Branch 6 taken 186 times.
✓ Branch 7 taken 186 times.
✓ Branch 8 taken 186 times.
✗ Branch 9 not taken.
|
372 | ref_rhs_qsi4.data(), ref_rhs_scales.data(), nullptr, 1, K, has_bias ? ref_biases_buf.data() : nullptr, |
115 | nullptr, nullptr, 1); | ||
116 | |||
117 | // Clamps the reference output. | ||
118 | 372 | const auto clamp_ratio = 0.8F; | |
119 |
2/4✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
|
1116 | const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_no_clamp.data(), M * N, clamp_ratio); |
120 |
4/8✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
|
372 | const auto ref_dst_float = clamp<float>(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); |
121 | |||
122 | // Cast the reference output to BF16 | ||
123 |
3/6✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
|
372 | auto ref_dst = cast<BFloat16<false>, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>); |
124 | |||
125 | // Runs the LHS packing micro-kernel. | ||
126 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto lhs_start_row = rect.start_row(); |
127 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); |
128 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size); |
129 | |||
130 | 372 | auto lhs_stride = K * sizeof(uint16_t); | |
131 | |||
132 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); |
133 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); |
134 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
135 | |||
136 |
4/16✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 372 times.
✗ Branch 15 not taken.
|
372 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
137 | |||
138 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | kai_run_lhs_quant_pack_qai8dxp_bf16_neon( |
139 |
2/4✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
|
372 | rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_bf16.data() + lhs_offset, lhs_stride, |
140 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | reinterpret_cast<uint8_t*>(imp_packed_lhs_buf.data()) + lhs_packed_offset); |
141 | |||
142 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
744 | const auto ref_rhs_qsi4_padded = pad_row<Int4>( |
143 |
4/8✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
|
372 | ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); |
144 | |||
145 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr); |
146 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size); |
147 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto rhs_start_row = rect.start_col(); |
148 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr); |
149 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); |
150 |
4/16✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
|
372 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
151 | // Runs the RHS packing micro-kernel. | ||
152 | 372 | kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{}; | |
153 | 372 | params.lhs_zero_point = 1; | |
154 | 372 | params.rhs_zero_point = 0; | |
155 | |||
156 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( |
157 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()), |
158 |
3/4✓ Branch 0 taken 186 times.
✓ Branch 1 taken 186 times.
✓ Branch 2 taken 186 times.
✗ Branch 3 not taken.
|
372 | has_bias ? reinterpret_cast<const float*>(ref_biases_buf.data()) : nullptr, |
159 |
2/4✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
|
372 | reinterpret_cast<const float*>(ref_rhs_scales.data()), reinterpret_cast<uint8_t*>(imp_packed_rhs_buf.data()), 0, |
160 | ¶ms); | ||
161 | |||
162 | 372 | const auto dst_stride_row = N * sizeof(uint16_t); | |
163 | 372 | const auto dst_stride_col = sizeof(uint16_t); | |
164 | 744 | const auto dst_offset = | |
165 |
3/6✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
|
372 | ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); |
166 |
2/4✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
|
372 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; |
167 |
4/16✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
|
372 | ASSERT_EQ(dst_offset, ref_dst_offset); |
168 | |||
169 | // Runs the GEMM micro-kernel. | ||
170 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
171 |
5/18✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 372 times.
|
372 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
172 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | Buffer imp_dst_buf = Buffer(imp_dst_size); |
173 | |||
174 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
744 | ukernel_variant.interface.run_matmul( |
175 |
3/6✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
|
372 | rect.height(), rect.width(), K, reinterpret_cast<const uint8_t*>(imp_packed_lhs_buf.data()) + lhs_matmul_offset, |
176 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | reinterpret_cast<const uint8_t*>(imp_packed_rhs_buf.data()) + rhs_matmul_offset, |
177 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | reinterpret_cast<uint8_t*>(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min, |
178 | 372 | clamp_max); | |
179 | |||
180 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
181 | // tested. | ||
182 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
183 | 372 | DataFormat dst_format = DataFormat(DataType::BF16); | |
184 | 744 | const auto success = | |
185 |
3/6✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
|
372 | compare(reinterpret_cast<const uint8_t*>(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler); |
186 |
4/16✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
|
372 | ASSERT_TRUE(success); |
187 | 392 | } | |
188 | |||
189 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1178 | TEST_P(MatMulTest_bf16_qai8dxp_qsi4cxp, EndToEnd_RHS_KxN) { |
190 | 3116 | const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); | |
191 | 784 | const auto& ukernel_variant = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.at(variant_index); | |
192 | |||
193 |
2/4✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
|
392 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
194 | ✗ | GTEST_SKIP() << "CPU features are not supported by current CPU"; | |
195 | } | ||
196 | |||
197 | 392 | const uint32_t seed = 0; | |
198 | |||
199 | 784 | const size_t M = matmul_shape.m; | |
200 | 784 | const size_t N = matmul_shape.n; | |
201 | 784 | const size_t K = matmul_shape.k; | |
202 | |||
203 | 392 | const auto mr = ukernel_variant.interface.get_mr(); | |
204 | 392 | const auto nr = ukernel_variant.interface.get_nr(); | |
205 | 392 | const auto kr = ukernel_variant.interface.get_kr(); | |
206 | 392 | const auto sr = ukernel_variant.interface.get_sr(); | |
207 | |||
208 | // Generates input data. | ||
209 | 392 | const auto ref_lhs_bf16 = fill_random<BFloat16<false>>(M * K, seed + 0); | |
210 |
1/2✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
|
392 | const auto ref_rhs = fill_random<float>(N * K, seed + 1); |
211 | 392 | Buffer ref_biases_buf; | |
212 |
2/2✓ Branch 0 taken 196 times.
✓ Branch 1 taken 196 times.
|
392 | if (has_bias) { |
213 |
1/2✓ Branch 0 taken 196 times.
✗ Branch 1 not taken.
|
196 | ref_biases_buf = Buffer(fill_random<float>(N, seed + 2)); |
214 | 196 | } | |
215 | |||
216 | 392 | const auto ref_lhs = | |
217 |
3/6✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
|
392 | cast<float, BFloat16<false>>(ref_lhs_bf16.data(), ref_lhs_bf16.size() * 8 / size_in_bits<BFloat16<false>>); |
218 | |||
219 | // Transposed(nxk) RHS dimensions | ||
220 | 392 | const size_t ref_rhs_qsi4_nxk_stride = K; | |
221 | |||
222 | // Non-Transposed(kxn) RHS dimensions | ||
223 |
1/2✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
|
392 | const size_t ref_rhs_qsi4_kxn_stride = round_up_multiple(N, 2); |
224 |
1/2✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
|
392 | const size_t ref_rhs_qsi4_kxn_size_bytes = round_up_division(K * ref_rhs_qsi4_kxn_stride, 2); |
225 | |||
226 | // Runs the reference implementation. | ||
227 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
228 | // * Quantizes the RHS matrix using 4-bit symmetric quantization. | ||
229 | // * Performs GEMM. | ||
230 | 1176 | const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
231 |
2/4✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
|
392 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K); |
232 | 1940 | const auto [ref_rhs_qsi4_transposed, ref_rhs_scales] = | |
233 |
2/4✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
|
392 | quantize_symmetric_per_block_dynamic<float, Int4, float>(ref_rhs.data(), N, K, K); |
234 | |||
235 |
1/2✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
|
392 | const auto ref_rhs_qsi4 = transpose_with_padding<Int4>( |
236 |
1/2✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
|
392 | ref_rhs_qsi4_transposed.data(), N, K, ref_rhs_qsi4_nxk_stride, ref_rhs_qsi4_kxn_stride, |
237 | 392 | ref_rhs_qsi4_kxn_size_bytes); | |
238 | |||
239 | 392 | const auto ref_dst_fp32_clamp = | |
240 |
1/2✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
|
392 | matmul_clamp_nt_nt<int8_t, float, int32_t, Int4, float, int32_t, float, int32_t, float>( |
241 |
5/10✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 392 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 392 times.
✗ Branch 9 not taken.
|
784 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), |
242 |
6/8✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 196 times.
✓ Branch 3 taken 196 times.
✓ Branch 4 taken 196 times.
✓ Branch 5 taken 196 times.
✓ Branch 6 taken 196 times.
✗ Branch 7 not taken.
|
392 | ref_rhs_scales.data(), nullptr, K, has_bias ? ref_biases_buf.data() : nullptr, |
243 | 392 | std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | |
244 | |||
245 | // Clamps the reference output. | ||
246 | 392 | const auto clamp_ratio = 0.8F; | |
247 |
2/4✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
|
1136 | const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_fp32_clamp.data(), M * N, clamp_ratio); |
248 |
4/8✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 392 times.
✗ Branch 7 not taken.
|
392 | const auto ref_dst_float = clamp<float>(ref_dst_fp32_clamp.data(), M * N, clamp_min, clamp_max); |
249 | |||
250 | // Cast the reference output to BF16 | ||
251 |
3/6✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
|
392 | auto ref_dst = cast<BFloat16<false>, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>); |
252 | |||
253 |
1/2✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
|
392 | auto m_step = ukernel_variant.interface.get_m_step(); |
254 |
4/16✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 392 times.
|
392 | ASSERT_TRUE(m_step % mr == 0); |
255 | |||
256 |
1/2✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
|
392 | auto n_step = ukernel_variant.interface.get_n_step(); |
257 |
4/16✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 392 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 392 times.
|
392 | ASSERT_TRUE(n_step % nr == 0); |
258 | |||
259 |
2/4✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 392 times.
✗ Branch 3 not taken.
|
784 | const auto rect = portion.compute_portion(M, N, m_step, n_step); |
260 |
6/8✓ Branch 0 taken 392 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 376 times.
✓ Branch 3 taken 16 times.
✓ Branch 4 taken 376 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✓ Branch 7 taken 372 times.
|
392 | if (rect.height() == 0 || rect.width() == 0) { |
261 |
10/20✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 20 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 20 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 20 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 20 times.
✗ Branch 19 not taken.
|
20 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
262 | } | ||
263 | |||
264 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto lhs_start_row = rect.start_row(); |
265 | 372 | size_t lhs_stride = K * sizeof(uint16_t); | |
266 | |||
267 | // Runs the LHS packing micro-kernel. | ||
268 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon(M, K, mr, kr, sr); |
269 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | Buffer imp_packed_lhs_buf = Buffer(imp_packed_lhs_size); |
270 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, lhs_stride); |
271 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon(lhs_start_row, K, mr, kr, sr); |
272 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
273 |
4/16✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 372 times.
✗ Branch 15 not taken.
|
372 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
274 | |||
275 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | kai_run_lhs_quant_pack_qai8dxp_bf16_neon( |
276 |
2/4✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
|
372 | rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, ref_lhs_bf16.data() + lhs_offset, lhs_stride, |
277 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | reinterpret_cast<uint8_t*>(imp_packed_lhs_buf.data()) + lhs_packed_offset); |
278 | |||
279 | // Runs the RHS packing micro-kernel. | ||
280 | // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. | ||
281 | // * Packs the RHS matrix. | ||
282 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
744 | const auto ref_rhs_qsi4_padded = pad_row<Int4>( |
283 |
4/8✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
|
372 | ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); |
284 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(N, K, nr, kr, sr); |
285 | |||
286 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto rhs_start_row = rect.start_col(); |
287 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(rhs_start_row, K, nr, kr, sr); |
288 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); |
289 |
4/16✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
|
372 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
290 | |||
291 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | Buffer imp_packed_rhs_buf = Buffer(imp_packed_rhs_size); |
292 | 372 | kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{}; | |
293 | 372 | params.lhs_zero_point = 1; | |
294 | 372 | params.rhs_zero_point = 0; | |
295 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0( |
296 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | 1, N, K, nr, kr, sr, reinterpret_cast<const uint8_t*>(ref_rhs_qsi4_padded.data()), |
297 |
3/4✓ Branch 0 taken 186 times.
✓ Branch 1 taken 186 times.
✓ Branch 2 taken 186 times.
✗ Branch 3 not taken.
|
372 | has_bias ? reinterpret_cast<const float*>(ref_biases_buf.data()) : nullptr, |
298 |
2/4✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
|
372 | reinterpret_cast<const float*>(ref_rhs_scales.data()), imp_packed_rhs_buf.data(), 0, ¶ms); |
299 | |||
300 | 372 | const auto dst_stride = N * sizeof(uint16_t); | |
301 |
3/6✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
|
372 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
302 |
2/4✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
|
372 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(uint16_t); |
303 |
4/16✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
|
372 | ASSERT_EQ(dst_offset, ref_dst_offset); |
304 | |||
305 | // Runs the GEMM micro-kernel. | ||
306 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
307 |
5/18✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 372 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 372 times.
|
372 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
308 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | Buffer imp_dst_buf = Buffer(imp_dst_size); |
309 | |||
310 | 372 | const auto dst_stride_row = N * sizeof(uint16_t); | |
311 | 372 | const auto dst_stride_col = sizeof(uint16_t); | |
312 | |||
313 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
744 | ukernel_variant.interface.run_matmul( |
314 |
3/6✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
|
372 | rect.height(), rect.width(), K, reinterpret_cast<const uint8_t*>(imp_packed_lhs_buf.data()) + lhs_matmul_offset, |
315 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | reinterpret_cast<const uint8_t*>(imp_packed_rhs_buf.data()) + rhs_matmul_offset, |
316 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | reinterpret_cast<uint8_t*>(imp_dst_buf.data()) + dst_offset, dst_stride_row, dst_stride_col, clamp_min, |
317 | 372 | clamp_max); | |
318 | |||
319 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
320 | // tested. | ||
321 |
1/2✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
|
372 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
322 | 372 | DataFormat dst_format = DataFormat(DataType::BF16); | |
323 | 744 | const auto success = | |
324 |
3/6✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
|
372 | compare(reinterpret_cast<const uint8_t*>(imp_dst_buf.data()), ref_dst.data(), dst_format, M, N, rect, handler); |
325 |
4/16✓ Branch 0 taken 372 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 372 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 372 times.
|
372 | ASSERT_TRUE(success); |
326 | 392 | } | |
327 | |||
328 |
18/60✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 784 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
|
1571 | INSTANTIATE_TEST_SUITE_P( |
329 | MatMul, MatMulTest_bf16_qai8dxp_qsi4cxp, | ||
330 | testing::Combine( | ||
331 | testing::Range<size_t>(0, variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp.size()), | ||
332 | testing::Values( | ||
333 | MatMulShape{1, 2, 32}, // | ||
334 | MatMulShape{1, 3, 32}, // | ||
335 | MatMulShape{1, 4, 32}, // | ||
336 | MatMulShape{1, 5, 32}, // | ||
337 | MatMulShape{3, 3, 32}, // | ||
338 | MatMulShape{4, 4, 32}, // | ||
339 | MatMulShape{5, 5, 32}, // | ||
340 | MatMulShape{32, 64, 64}, // | ||
341 | MatMulShape{16, 32, 64}, // | ||
342 | MatMulShape{8, 32, 64}, // | ||
343 | MatMulShape{15, 32, 32}, // | ||
344 | MatMulShape{77, 99, 64}, // | ||
345 | MatMulShape{77, 99, 66}, // | ||
346 | MatMulShape{77, 99, 31}), | ||
347 | testing::Values( | ||
348 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
349 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
350 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
351 | MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle | ||
352 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
353 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
354 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
355 | ), | ||
356 | testing::Bool()), | ||
357 | [](const auto& info) -> std::string { | ||
358 | const auto variant_idx = std::get<0>(info.param); | ||
359 | const auto& name = variants_kai_matmul_clamp_bf16_qai8dxp_qsi4cxp[variant_idx].name; | ||
360 | return test_description( | ||
361 | name, std::get<MatMulShape>(info.param), std::get<2>(info.param), std::get<3>(info.param)); | ||
362 | }); | ||
363 | |||
364 | } // namespace kai::test | ||
365 |