Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <cstdlib> | ||
13 | #include <limits> | ||
14 | #include <sstream> | ||
15 | #include <string> | ||
16 | #include <tuple> | ||
17 | |||
18 | #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" | ||
19 | #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" | ||
20 | #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" | ||
21 | #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" | ||
22 | #include "kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp_qsi8cxp_interface.h" | ||
23 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f16_neon.h" | ||
24 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" | ||
25 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h" | ||
26 | #include "test/common/buffer.hpp" | ||
27 | #include "test/common/compare.hpp" | ||
28 | #include "test/common/cpu_info.hpp" | ||
29 | #include "test/common/data_format.hpp" | ||
30 | #include "test/common/float16.hpp" | ||
31 | #include "test/common/matmul_test_common.hpp" | ||
32 | #include "test/common/matrix_portion.hpp" | ||
33 | #include "test/common/memory.hpp" | ||
34 | #include "test/common/round.hpp" | ||
35 | #include "test/common/test_suite.hpp" | ||
36 | #include "test/reference/cast.hpp" | ||
37 | #include "test/reference/clamp.hpp" | ||
38 | #include "test/reference/fill.hpp" | ||
39 | #include "test/reference/matmul.hpp" | ||
40 | #include "test/reference/pack.hpp" | ||
41 | #include "test/reference/pad.hpp" | ||
42 | #include "test/reference/quantize.hpp" | ||
43 | |||
44 | namespace kai::test { | ||
45 | |||
46 | static const std::array<UkernelVariant<kai_matmul_clamp_f16_qai8dxp_qsi8cxp_ukernel>, 4> | ||
47 | variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp = {{ | ||
48 | {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod), | ||
49 | "kai_matmul_clamp_f16_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", cpu_has_dotprod_and_fp16}, | ||
50 | {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod), | ||
51 | "kai_matmul_clamp_f16_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", cpu_has_dotprod_and_fp16}, | ||
52 | {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod), | ||
53 | "kai_matmul_clamp_f16_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod_and_fp16}, | ||
54 | {UKERNEL_MATMUL_VARIANT(clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm), | ||
55 | "kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", cpu_has_i8mm_and_fp16}, | ||
56 | }}; | ||
57 | |||
58 | class MatMulTest_f16_qai8dxp_qsi8cxp : public ::testing::TestWithParam<MatMulTestPortionedParamsWithBias> {}; | ||
59 | |||
60 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
2690 | TEST_P(MatMulTest_f16_qai8dxp_qsi8cxp, EndToEnd) { |
61 | 6776 | const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); | |
62 | 1792 | const auto& ukernel_variant = variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp.at(variant_index); | |
63 | |||
64 |
2/4✓ Branch 0 taken 896 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 896 times.
✗ Branch 3 not taken.
|
896 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
65 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
66 | } | ||
67 | |||
68 | 896 | const std::uint32_t seed = 0; | |
69 | |||
70 | 1792 | const size_t M = matmul_shape.m; | |
71 | 1792 | const size_t N = matmul_shape.n; | |
72 | 1792 | const size_t K = matmul_shape.k; | |
73 | |||
74 | 896 | const auto mr = ukernel_variant.interface.get_mr(); | |
75 | 896 | const auto nr = ukernel_variant.interface.get_nr(); | |
76 | 896 | const auto kr = ukernel_variant.interface.get_kr(); | |
77 | 896 | const auto sr = ukernel_variant.interface.get_sr(); | |
78 | |||
79 |
4/4✓ Branch 0 taken 448 times.
✓ Branch 1 taken 448 times.
✓ Branch 2 taken 168 times.
✓ Branch 3 taken 280 times.
|
896 | if (mr == 1 && M > 1) { |
80 |
3/6✓ Branch 0 taken 280 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 280 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 280 times.
✗ Branch 5 not taken.
|
280 | GTEST_SKIP() << "Kernel does not support M != 1"; |
81 | } | ||
82 | |||
83 | 616 | auto m_step = ukernel_variant.interface.get_m_step(); | |
84 |
3/14✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
|
616 | ASSERT_TRUE(m_step % mr == 0); |
85 | |||
86 | 616 | auto n_step = ukernel_variant.interface.get_n_step(); | |
87 |
3/14✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
|
616 | ASSERT_TRUE(n_step % nr == 0); |
88 | |||
89 | 1232 | const auto rect = portion.compute_portion(M, N, m_step, n_step); | |
90 |
4/4✓ Branch 0 taken 568 times.
✓ Branch 1 taken 48 times.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 560 times.
|
616 | if (rect.height() == 0 || rect.width() == 0) { |
91 |
9/18✓ Branch 0 taken 56 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 56 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 56 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 56 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 56 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 56 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 56 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 56 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 56 times.
✗ Branch 17 not taken.
|
56 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
92 | } | ||
93 | |||
94 | // Generates input data. | ||
95 | 560 | const auto ref_lhs_f16 = fill_random<Float16>(M * K, seed + 0); | |
96 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | const auto ref_rhs = fill_random<float>(N * K, seed + 1); |
97 | 560 | Buffer ref_biases; | |
98 | |||
99 |
2/2✓ Branch 0 taken 280 times.
✓ Branch 1 taken 280 times.
|
560 | if (has_bias) { |
100 |
1/2✓ Branch 0 taken 280 times.
✗ Branch 1 not taken.
|
280 | ref_biases = fill_random<float>(N, seed + 2); |
101 | 280 | } | |
102 | // For reference implementation, Casting FP16 input to FP32 type and FP32 output back to FP16 because the matmul | ||
103 | // implementation works with FP32 accumulation and casts the result to FP16 | ||
104 |
3/6✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
|
560 | const auto ref_lhs = cast<float, Float16>(ref_lhs_f16.data(), ref_lhs_f16.size() * 8 / size_in_bits<Float16>); |
105 | |||
106 | // Runs the reference implementation. | ||
107 | // * Quantizes the LHS matrix using 8-bit symmetric quantization. | ||
108 | // * Quantizes the RHS matrix using 8-bit asymmetric quantization. | ||
109 | // * Performs GEMM. | ||
110 | 1680 | const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
111 |
2/4✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
|
560 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(ref_lhs.data(), M, K, K); |
112 | 2800 | const auto [ref_rhs_qsi8, ref_rhs_scales] = | |
113 |
2/4✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
|
560 | quantize_symmetric_per_block_dynamic<float, int8_t, float>(ref_rhs.data(), N, K, K); |
114 | |||
115 | 560 | const auto ref_dst_no_clamp = | |
116 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, float, float, int32_t, float>( |
117 |
4/8✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 560 times.
✗ Branch 7 not taken.
|
1120 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), 1, K, |
118 |
7/10✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 280 times.
✓ Branch 5 taken 280 times.
✓ Branch 6 taken 280 times.
✓ Branch 7 taken 280 times.
✓ Branch 8 taken 280 times.
✗ Branch 9 not taken.
|
560 | ref_rhs_qsi8.data(), ref_rhs_scales.data(), nullptr, 1, K, has_bias ? ref_biases.data() : nullptr, nullptr, |
119 | nullptr, 1); | ||
120 | |||
121 | // Clamps the reference output. | ||
122 | 560 | const auto clamp_ratio = 0.8F; | |
123 |
2/4✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
|
1680 | const auto [clamp_min, clamp_max] = find_clamp_range<float>(ref_dst_no_clamp.data(), M * N, clamp_ratio); |
124 |
4/8✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 560 times.
✗ Branch 7 not taken.
|
560 | const auto ref_dst_float = clamp<float>(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); |
125 | |||
126 | // Cast the reference output to F16 | ||
127 |
3/6✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
|
560 | auto ref_dst = cast<Float16, float>(ref_dst_float.data(), ref_dst_float.size() * 8 / size_in_bits<float>); |
128 | |||
129 | // Runs the LHS packing micro-kernel. | ||
130 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | const auto lhs_start_row = rect.start_row(); |
131 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f16_neon(M, K, mr, kr, sr); |
132 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
133 | |||
134 | 560 | auto lhs_stride = K * sizeof(uint16_t); | |
135 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, lhs_stride); |
136 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f16_neon(lhs_start_row, K, mr, kr, sr); |
137 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); |
138 | |||
139 |
4/16✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 560 times.
✗ Branch 15 not taken.
|
560 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
140 | |||
141 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | kai_run_lhs_quant_pack_qai8dxp_f16_neon( |
142 |
2/4✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
|
560 | rect.height() /* m */, K, mr, kr, sr, 0, ref_lhs_f16.data() + lhs_offset, lhs_stride, |
143 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | imp_packed_lhs.data() + lhs_packed_offset); |
144 | |||
145 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); |
146 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
147 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | const auto rhs_start_row = rect.start_col(); |
148 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | auto rhs_packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(rhs_start_row, K, nr, kr, sr); |
149 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); |
150 |
4/16✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 560 times.
|
560 | ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); |
151 | |||
152 | // Runs the RHS packirng micro-kernel. | ||
153 | 560 | const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f}; | |
154 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon( |
155 |
2/4✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
|
1120 | 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()), |
156 |
3/4✓ Branch 0 taken 280 times.
✓ Branch 1 taken 280 times.
✓ Branch 2 taken 280 times.
✗ Branch 3 not taken.
|
560 | has_bias ? reinterpret_cast<const float*>(ref_biases.data()) : nullptr, |
157 |
2/4✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
|
560 | reinterpret_cast<const float*>(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); |
158 | |||
159 | 560 | const auto dst_stride_row = N * sizeof(uint16_t); | |
160 | 560 | const auto dst_stride_col = sizeof(uint16_t); | |
161 | 1120 | const auto dst_offset = | |
162 |
3/6✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
|
560 | ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); |
163 |
2/4✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
|
560 | const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; |
164 |
4/16✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 560 times.
|
560 | ASSERT_EQ(dst_offset, ref_dst_offset); |
165 | |||
166 | // Runs the GEMM micro-kernel. | ||
167 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
168 |
5/18✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 560 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 560 times.
|
560 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
169 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | Buffer imp_dst(imp_dst_size); |
170 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
1120 | ukernel_variant.interface.run_matmul( |
171 |
3/6✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
|
560 | rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, |
172 |
2/4✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
|
560 | imp_packed_rhs.data() + rhs_matmul_offset, imp_dst.data() + dst_offset, dst_stride_row, dst_stride_col, |
173 | 1120 | clamp_min, clamp_max); | |
174 | |||
175 | // Compares the output of the micro-kernels against the output of the reference implementation for the portion | ||
176 | // tested. | ||
177 |
1/2✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
|
560 | DefaultMismatchHandler handler(0, 0.02, 0, 0.05); |
178 | 560 | DataFormat dst_format = DataFormat(DataType::FP16); | |
179 |
3/6✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
|
560 | const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); |
180 |
4/16✓ Branch 0 taken 560 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 560 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 560 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 560 times.
|
560 | ASSERT_TRUE(success); |
181 | 896 | } | |
182 |
19/62✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 time.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 time.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 time.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 time.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 time.
✗ Branch 33 not taken.
✓ Branch 34 taken 896 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✓ Branch 60 taken 896 times.
✗ Branch 61 not taken.
|
1794 | INSTANTIATE_TEST_SUITE_P( |
183 | MatMul, MatMulTest_f16_qai8dxp_qsi8cxp, | ||
184 | testing::Combine( | ||
185 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp.size()), | ||
186 | testing::Values( | ||
187 | MatMulShape{1, 2, 32}, // | ||
188 | MatMulShape{1, 3, 32}, // | ||
189 | MatMulShape{1, 4, 32}, // | ||
190 | MatMulShape{1, 5, 31}, // | ||
191 | MatMulShape{3, 3, 32}, // | ||
192 | MatMulShape{4, 4, 32}, // | ||
193 | MatMulShape{5, 5, 31}, // | ||
194 | MatMulShape{16, 32, 64}, // | ||
195 | MatMulShape{16, 32, 36}, // | ||
196 | MatMulShape{15, 35, 65}, // | ||
197 | MatMulShape{8, 32, 64}, // | ||
198 | MatMulShape{15, 31, 45}, // | ||
199 | MatMulShape{1, 35, 65}, // | ||
200 | MatMulShape{1, 128, 32}, // | ||
201 | MatMulShape{64, 128, 32}, // | ||
202 | MatMulShape{77, 99, 64}), | ||
203 | testing::Values( | ||
204 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
205 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
206 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
207 | MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle | ||
208 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
209 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
210 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
211 | ), | ||
212 | testing::Bool()), | ||
213 | [](const auto& info) { | ||
214 | const auto variant_idx = std::get<0>(info.param); | ||
215 | const std::string name{variants_kai_matmul_clamp_f16_qai8dxp_qsi8cxp.at(variant_idx).name}; | ||
216 | const auto shape = std::get<MatMulShape>(info.param); | ||
217 | const auto portion = std::get<2>(info.param); | ||
218 | const auto has_bias = std::get<3>(info.param); | ||
219 | |||
220 | return test_description(name, shape, portion, has_bias); | ||
221 | }); | ||
222 | |||
223 | } // namespace kai::test | ||
224 |