Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include <gtest/gtest.h> | ||
8 | |||
9 | #include <array> | ||
10 | #include <cstddef> | ||
11 | #include <cstdint> | ||
12 | #include <cstdlib> | ||
13 | #include <limits> | ||
14 | #include <string> | ||
15 | #include <tuple> | ||
16 | |||
17 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" | ||
18 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h" | ||
19 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" | ||
20 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot.h" | ||
21 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" | ||
22 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" | ||
23 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" | ||
24 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" | ||
25 | #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h" | ||
26 | #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" | ||
27 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" | ||
28 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h" | ||
29 | #include "test/common/buffer.hpp" | ||
30 | #include "test/common/cache.hpp" | ||
31 | #include "test/common/cpu_info.hpp" | ||
32 | #include "test/common/matmul_test_common.hpp" | ||
33 | #include "test/common/matrix_portion.hpp" | ||
34 | #include "test/common/memory.hpp" | ||
35 | #include "test/common/printer.hpp" | ||
36 | #include "test/common/test_suite.hpp" | ||
37 | #include "test/reference/fill.hpp" | ||
38 | #include "test/reference/matmul.hpp" | ||
39 | #include "test/reference/quantize.hpp" | ||
40 | #include "test/reference/transpose.hpp" | ||
41 | |||
42 | namespace kai::test { | ||
43 | using CacheDataId = std::tuple<MatMulShape, DataFormat, DataFormat, DataFormat>; | ||
44 | |||
45 | struct CacheData { | ||
46 | Buffer lhs; | ||
47 | Buffer rhs; | ||
48 | Buffer bias; | ||
49 | }; | ||
50 | |||
51 | template <> | ||
52 | 11 | CacheData ReferenceGenerator<CacheDataId, CacheData>::generate_reference(const CacheDataId& k) { | |
53 | 11 | MatMulShape shape = std::get<0>(k); | |
54 | 11 | DataFormat lhs_format = std::get<1>(k); | |
55 | 11 | DataFormat rhs_format = std::get<2>(k); | |
56 | 11 | DataFormat bias_format = std::get<3>(k); | |
57 | |||
58 | static size_t seed = 1; | ||
59 | 11 | Buffer lhs = fill_matrix_random(shape.m, shape.k, lhs_format, seed++); | |
60 |
1/2✓ Branch 0 taken 11 times.
✗ Branch 1 not taken.
|
11 | Buffer rhs = fill_matrix_random(shape.k, shape.n, rhs_format, seed++); |
61 |
1/2✓ Branch 0 taken 11 times.
✗ Branch 1 not taken.
|
11 | Buffer bias = fill_matrix_random(1, shape.n, bias_format, seed++); |
62 | |||
63 | 11 | CacheData test_reference; | |
64 | 11 | test_reference.lhs = std::move(lhs); | |
65 | 11 | test_reference.rhs = std::move(rhs); | |
66 | 11 | test_reference.bias = std::move(bias); | |
67 | |||
68 | 11 | return test_reference; | |
69 | 11 | } | |
70 | |||
71 | 1 | static const std::array<UkernelVariant<kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel>, 8> | |
72 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
9 | variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp = {{ |
73 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod), |
74 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod", cpu_has_dotprod}, |
75 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod), |
76 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod", cpu_has_dotprod}, |
77 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod), |
78 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod", cpu_has_dotprod}, |
79 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm), |
80 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm", cpu_has_i8mm}, |
81 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot), |
82 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme_dot", cpu_has_sme}, |
83 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa), |
84 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa", cpu_has_sme}, |
85 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot), |
86 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot", cpu_has_sme2}, |
87 |
1/2✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
|
2 | {UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa), |
88 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
|
1 | "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa", cpu_has_sme2}, |
89 | }}; | ||
90 | |||
91 | class MatMulTest_f32_qai8dxp_qsi8cxp : public ::testing::TestWithParam<MatMulTestPortionedParams> {}; | ||
92 | |||
93 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1850 | TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, Offset_RHS) { |
94 | 1848 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
95 | 1232 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index); | |
96 | |||
97 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
616 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
98 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
99 | } | ||
100 | |||
101 | 1232 | const size_t K = matmul_shape.k; | |
102 | 616 | const auto nr = ukernel_variant.interface.get_nr(); | |
103 | 616 | const auto kr = ukernel_variant.interface.get_kr(); | |
104 | 616 | const auto sr = ukernel_variant.interface.get_sr(); | |
105 | |||
106 | 616 | auto n_step = ukernel_variant.interface.get_n_step(); | |
107 | |||
108 | 616 | auto rhs_packed_offset_kxn = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(n_step, K, nr, kr, sr); | |
109 | 616 | auto rhs_packed_offset_nxk = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(n_step, K, nr, kr, sr); | |
110 | |||
111 |
3/14✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
|
616 | ASSERT_EQ(rhs_packed_offset_kxn, rhs_packed_offset_nxk); |
112 | |||
113 | 616 | auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(n_step, K); | |
114 |
3/14✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
|
616 | ASSERT_EQ(rhs_packed_offset_kxn, rhs_matmul_offset); |
115 | 616 | } | |
116 | |||
117 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1850 | TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, Offset_LHS) { |
118 | 1848 | const auto& [variant_index, matmul_shape, portion] = GetParam(); | |
119 | 1232 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index); | |
120 | |||
121 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
616 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
122 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
123 | } | ||
124 | |||
125 | 1232 | const size_t K = matmul_shape.k; | |
126 | 616 | const auto mr = ukernel_variant.interface.get_mr(); | |
127 | 616 | const auto kr = ukernel_variant.interface.get_kr(); | |
128 | 616 | const auto sr = ukernel_variant.interface.get_sr(); | |
129 | |||
130 | 616 | auto m_step = ukernel_variant.interface.get_m_step(); | |
131 | |||
132 | 616 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(m_step, K, mr, kr, sr); | |
133 | 616 | auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(m_step, K); | |
134 | |||
135 |
3/14✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 616 times.
|
616 | ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); |
136 | 616 | } | |
137 | |||
138 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1850 | TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_nxk_qsi8cx) { |
139 | 4312 | auto& [variant_index, matmul_shape, portion] = GetParam(); | |
140 | 1232 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index); | |
141 | |||
142 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
616 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
143 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
144 | } | ||
145 | |||
146 | 1232 | const size_t M = matmul_shape.m; | |
147 | 1232 | const size_t N = matmul_shape.n; | |
148 | 1232 | const size_t K = matmul_shape.k; | |
149 | |||
150 | 616 | const auto mr = ukernel_variant.interface.get_mr(); | |
151 | 616 | const auto nr = ukernel_variant.interface.get_nr(); | |
152 | 616 | const auto kr = ukernel_variant.interface.get_kr(); | |
153 | 616 | const auto sr = ukernel_variant.interface.get_sr(); | |
154 | |||
155 | // Generates input data. | ||
156 | 1232 | const CacheDataId id = { | |
157 | matmul_shape, // | ||
158 | 616 | DataFormat(DataType::FP32), // | |
159 | 616 | DataFormat(DataType::FP32), // | |
160 | 616 | DataFormat(DataType::FP32)}; | |
161 | 616 | const CacheData& test_data = getV<CacheDataId, CacheData>(id); | |
162 | |||
163 | // Runs the reference implementation. | ||
164 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
165 | // * Quantizes the RHS matrix using 8-bit symmetric quantization. | ||
166 | // * Performs GEMM. | ||
167 | 1848 | const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
168 | 616 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(test_data.lhs.data(), M, K, K); | |
169 | 3000 | const auto [ref_rhs_qsi8, ref_rhs_scales] = | |
170 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
616 | quantize_symmetric_per_block_dynamic<float, int8_t, float>(test_data.rhs.data(), N, K, K); |
171 | |||
172 |
1/2✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
|
1232 | const auto ref_dst = matmul_clamp_nt_t<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>( |
173 |
6/12✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 616 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 616 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 616 times.
✗ Branch 11 not taken.
|
1232 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi8.data(), |
174 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
616 | ref_rhs_scales.data(), nullptr, K, test_data.bias.data(), std::numeric_limits<float>::lowest(), |
175 | 616 | std::numeric_limits<float>::max()); | |
176 | |||
177 |
1/2✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
|
616 | auto m_step = ukernel_variant.interface.get_m_step(); |
178 |
4/16✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 616 times.
|
616 | ASSERT_TRUE(m_step % mr == 0); |
179 | |||
180 |
1/2✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
|
616 | auto n_step = ukernel_variant.interface.get_n_step(); |
181 |
4/16✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 616 times.
|
616 | ASSERT_TRUE(n_step % nr == 0); |
182 | |||
183 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
1232 | const auto rect = portion.compute_portion(M, N, m_step, n_step); |
184 |
5/8✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✓ Branch 3 taken 40 times.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 576 times.
|
616 | if (rect.height() == 0 || rect.width() == 0) { |
185 |
10/20✓ Branch 0 taken 40 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 40 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 40 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 40 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 40 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 40 times.
✗ Branch 19 not taken.
|
40 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
186 | } | ||
187 | |||
188 | // Runs the LHS packing micro-kernel. | ||
189 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); |
190 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
191 | |||
192 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto lhs_start_row = rect.start_row(); |
193 | 576 | size_t lhs_stride = K * sizeof(float); | |
194 | |||
195 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); |
196 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); |
197 | |||
198 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | kai_run_lhs_quant_pack_qai8dxp_f32( |
199 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | rect.height(), K, mr, kr, sr, 0, reinterpret_cast<const float*>(test_data.lhs.data() + lhs_offset), lhs_stride, |
200 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | imp_packed_lhs.data() + lhs_packed_offset); |
201 | |||
202 | // Runs the RHS packing micro-kernel. | ||
203 | // * Generates the 8-bit signed symmetric quantized input for the micro-kernel. | ||
204 | // * Packs the RHS matrix. | ||
205 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); |
206 | |||
207 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
208 | 576 | const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f}; | |
209 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon( |
210 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
1152 | 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()), |
211 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | reinterpret_cast<const float*>(test_data.bias.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()), |
212 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | imp_packed_rhs.data(), 0, ¶ms); |
213 | |||
214 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto packed_rhs_start_row = rect.start_col(); |
215 | 576 | auto rhs_packed_offset = | |
216 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | kai_get_rhs_packed_offset_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(packed_rhs_start_row, K, nr, kr, sr); |
217 | |||
218 | 576 | const auto dst_stride = N * sizeof(float); | |
219 |
3/6✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
|
576 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
220 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); |
221 |
4/16✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
|
576 | ASSERT_EQ(dst_offset, ref_dst_offset); |
222 | |||
223 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | const auto matmul_lhs_packed_offset = ukernel_variant.interface.get_lhs_packed_offset(rect.start_row(), K); |
224 |
4/16✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
|
576 | ASSERT_EQ(lhs_packed_offset, matmul_lhs_packed_offset); |
225 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | const auto matmul_rhs_packed_offset = ukernel_variant.interface.get_rhs_packed_offset(rect.start_col(), K); |
226 |
4/16✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
|
576 | ASSERT_EQ(rhs_packed_offset, matmul_rhs_packed_offset); |
227 | |||
228 | // Runs the GEMM micro-kernel. | ||
229 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
230 |
5/18✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 576 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 576 times.
|
576 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
231 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | Buffer imp_dst(imp_dst_size); |
232 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
1152 | ukernel_variant.interface.run_matmul( |
233 |
3/6✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
|
576 | rect.height(), rect.width(), K, imp_packed_lhs.data() + matmul_lhs_packed_offset, |
234 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | imp_packed_rhs.data() + matmul_rhs_packed_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
235 | 576 | N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | |
236 | |||
237 | // Compares the output of the micro-kernels against the output of the reference implementation. | ||
238 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 5812 times.
✓ Branch 2 taken 5236 times.
✓ Branch 3 taken 576 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 576 times.
|
5812 | for (size_t y = 0; y < rect.height(); ++y) { |
239 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 248408 times.
✓ Branch 2 taken 243172 times.
✓ Branch 3 taken 5236 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5236 times.
|
248408 | for (size_t x = 0; x < rect.width(); ++x) { |
240 | 486344 | const auto imp_value = | |
241 |
4/8✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 243172 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 243172 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 243172 times.
✗ Branch 7 not taken.
|
243172 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
242 | 486344 | const auto ref_value = | |
243 |
4/8✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 243172 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 243172 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 243172 times.
✗ Branch 7 not taken.
|
243172 | read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
244 |
1/2✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
|
243172 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); |
245 | |||
246 |
1/2✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
|
243172 | if (rel_error > 0.0001F) { |
247 | ✗ | ASSERT_EQ(imp_value, ref_value); | |
248 | ✗ | } | |
249 | 243172 | } | |
250 | 5236 | } | |
251 | 616 | } | |
252 | |||
253 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
1850 | TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_kxn_qsi8cx) { |
254 | 4312 | auto& [variant_index, matmul_shape, portion] = GetParam(); | |
255 | 1232 | const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_index); | |
256 | |||
257 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
616 | if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { |
258 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
259 | } | ||
260 | |||
261 | 1232 | const size_t M = matmul_shape.m; | |
262 | 1232 | const size_t N = matmul_shape.n; | |
263 | 1232 | const size_t K = matmul_shape.k; | |
264 | |||
265 | 616 | const auto mr = ukernel_variant.interface.get_mr(); | |
266 | 616 | const auto nr = ukernel_variant.interface.get_nr(); | |
267 | 616 | const auto kr = ukernel_variant.interface.get_kr(); | |
268 | 616 | const auto sr = ukernel_variant.interface.get_sr(); | |
269 | |||
270 | // Generates input data. | ||
271 | 1232 | const CacheDataId id = { | |
272 | matmul_shape, // | ||
273 | 616 | DataFormat(DataType::FP32), // | |
274 | 616 | DataFormat(DataType::FP32), // | |
275 | 616 | DataFormat(DataType::FP32)}; | |
276 | 616 | const CacheData& test_data = getV<CacheDataId, CacheData>(id); | |
277 | |||
278 | // Transposed(nxk) RHS dimensions | ||
279 | 616 | const size_t ref_rhs_qsi8_nxk_stride = K; | |
280 | |||
281 | // Non-Transposed(kxn) RHS dimensions | ||
282 | 616 | const size_t ref_rhs_qsi8_kxn_stride = N; | |
283 | 616 | const size_t ref_rhs_qsi8_kxn_size_bytes = K * ref_rhs_qsi8_kxn_stride; | |
284 | |||
285 | // Runs the reference implementation. | ||
286 | // * Quantizes the LHS matrix using 8-bit asymmetric quantization. | ||
287 | // * Quantizes the RHS matrix using 8-bit symmetric quantization. | ||
288 | // * Performs GEMM. | ||
289 | 1848 | const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = | |
290 | 616 | quantize_asymmetric_per_block_dynamic<float, int8_t, float, int32_t>(test_data.lhs.data(), M, K, K); | |
291 | 3040 | const auto [ref_rhs_qsi8_transposed, ref_rhs_scales] = | |
292 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
616 | quantize_symmetric_per_block_dynamic<float, int8_t, float>(test_data.rhs.data(), N, K, K); |
293 | |||
294 |
1/2✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
|
616 | const auto ref_rhs_qsi8 = transpose_with_padding<int8_t>( |
295 |
1/2✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
|
616 | ref_rhs_qsi8_transposed.data(), N, K, ref_rhs_qsi8_nxk_stride, ref_rhs_qsi8_kxn_stride, |
296 | 616 | ref_rhs_qsi8_kxn_size_bytes); | |
297 | |||
298 |
1/2✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
|
1232 | const auto ref_dst = matmul_clamp_nt_nt<int8_t, float, int32_t, int8_t, float, int32_t, float, int32_t, float>( |
299 |
5/10✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 616 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 616 times.
✗ Branch 9 not taken.
|
1232 | M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi8.data(), |
300 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
616 | ref_rhs_scales.data(), nullptr, K, test_data.bias.data(), std::numeric_limits<float>::lowest(), |
301 | 616 | std::numeric_limits<float>::max()); | |
302 | |||
303 |
1/2✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
|
616 | auto m_step = ukernel_variant.interface.get_m_step(); |
304 |
4/16✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 616 times.
|
616 | ASSERT_TRUE(m_step % mr == 0); |
305 | |||
306 |
1/2✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
|
616 | auto n_step = ukernel_variant.interface.get_n_step(); |
307 |
4/16✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 616 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 616 times.
|
616 | ASSERT_TRUE(n_step % nr == 0); |
308 | |||
309 |
2/4✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 616 times.
✗ Branch 3 not taken.
|
1232 | const auto rect = portion.compute_portion(M, N, m_step, n_step); |
310 |
5/8✓ Branch 0 taken 616 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✓ Branch 3 taken 40 times.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 576 times.
|
616 | if (rect.height() == 0 || rect.width() == 0) { |
311 |
10/20✓ Branch 0 taken 40 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 40 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 40 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 40 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 40 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 40 times.
✗ Branch 19 not taken.
|
40 | GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; |
312 | } | ||
313 | |||
314 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto lhs_start_row = rect.start_row(); |
315 | 576 | size_t const lhs_stride = K * sizeof(float); | |
316 | |||
317 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); |
318 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); |
319 | |||
320 | // Runs the LHS packing micro-kernel. | ||
321 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); |
322 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | Buffer imp_packed_lhs(imp_packed_lhs_size); |
323 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | kai_run_lhs_quant_pack_qai8dxp_f32( |
324 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | rect.height(), K, mr, kr, sr, 0, reinterpret_cast<const float*>(test_data.lhs.data() + lhs_offset), |
325 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | K * sizeof(float), imp_packed_lhs.data() + lhs_packed_offset); |
326 | |||
327 | // Runs the RHS packing micro-kernel. | ||
328 | // * Generates the 8-bit signed symmetric quantized input for the micro-kernel. | ||
329 | // * Packs the RHS matrix. | ||
330 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); |
331 | |||
332 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | Buffer imp_packed_rhs(imp_packed_rhs_size); |
333 | 576 | const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f}; | |
334 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon( |
335 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | 1, N, K, nr, kr, sr, reinterpret_cast<const int8_t*>(ref_rhs_qsi8.data()), |
336 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | reinterpret_cast<const float*>(test_data.bias.data()), reinterpret_cast<const float*>(ref_rhs_scales.data()), |
337 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | imp_packed_rhs.data(), 0, ¶ms); |
338 | |||
339 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto packed_rhs_start_row = rect.start_col(); |
340 | 576 | auto rhs_packed_offset = | |
341 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(packed_rhs_start_row, K, nr, kr, sr); |
342 | |||
343 | 576 | const auto dst_stride = N * sizeof(float); | |
344 |
3/6✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
|
576 | const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); |
345 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); |
346 |
4/16✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
|
576 | ASSERT_EQ(dst_offset, ref_dst_offset); |
347 | |||
348 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | const auto matmul_lhs_packed_offset = ukernel_variant.interface.get_lhs_packed_offset(rect.start_row(), K); |
349 |
4/16✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
|
576 | ASSERT_EQ(lhs_packed_offset, matmul_lhs_packed_offset); |
350 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | const auto matmul_rhs_packed_offset = ukernel_variant.interface.get_rhs_packed_offset(rect.start_col(), K); |
351 |
4/16✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 576 times.
|
576 | ASSERT_EQ(rhs_packed_offset, matmul_rhs_packed_offset); |
352 | |||
353 | // Runs the GEMM micro-kernel. | ||
354 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); |
355 |
5/18✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 576 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 576 times.
|
576 | ASSERT_EQ(imp_dst_size, ref_dst.size()); |
356 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
576 | Buffer imp_dst(imp_dst_size); |
357 |
1/2✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
|
1152 | ukernel_variant.interface.run_matmul( |
358 |
3/6✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 576 times.
✗ Branch 5 not taken.
|
576 | rect.height(), rect.width(), K, imp_packed_lhs.data() + matmul_lhs_packed_offset, |
359 |
2/4✓ Branch 0 taken 576 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 576 times.
✗ Branch 3 not taken.
|
576 | imp_packed_rhs.data() + matmul_rhs_packed_offset, reinterpret_cast<float*>(imp_dst.data() + dst_offset), |
360 | 576 | N * sizeof(float), sizeof(float), std::numeric_limits<float>::lowest(), std::numeric_limits<float>::max()); | |
361 | |||
362 | // Compares the output of the micro-kernels against the output of the reference implementation. | ||
363 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 5812 times.
✓ Branch 2 taken 5236 times.
✓ Branch 3 taken 576 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 576 times.
|
5812 | for (size_t y = 0; y < rect.height(); ++y) { |
364 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 248408 times.
✓ Branch 2 taken 243172 times.
✓ Branch 3 taken 5236 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5236 times.
|
248408 | for (size_t x = 0; x < rect.width(); ++x) { |
365 | 486344 | const auto imp_value = | |
366 |
4/8✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 243172 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 243172 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 243172 times.
✗ Branch 7 not taken.
|
243172 | read_array<float>(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
367 | 486344 | const auto ref_value = | |
368 |
4/8✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 243172 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 243172 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 243172 times.
✗ Branch 7 not taken.
|
243172 | read_array<float>(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); |
369 |
1/2✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
|
243172 | const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); |
370 | |||
371 |
1/2✓ Branch 0 taken 243172 times.
✗ Branch 1 not taken.
|
243172 | if (rel_error > 0.0001F) { |
372 | ✗ | ASSERT_EQ(imp_value, ref_value); | |
373 | ✗ | } | |
374 | 243172 | } | |
375 | 5236 | } | |
376 | 616 | } | |
377 | |||
378 |
18/58✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 4 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 4 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 4 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 2464 times.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✓ Branch 56 taken 2464 times.
✗ Branch 57 not taken.
|
4933 | INSTANTIATE_TEST_SUITE_P( |
379 | MatMul, MatMulTest_f32_qai8dxp_qsi8cxp, | ||
380 | testing::Combine( | ||
381 | testing::Range<size_t>(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.size()), | ||
382 | testing::Values( | ||
383 | MatMulShape{17, 33, 67}, // | ||
384 | MatMulShape{19, 35, 63}, // | ||
385 | MatMulShape{1, 27, 31}, // | ||
386 | MatMulShape{1, 65, 35}, // | ||
387 | MatMulShape{1, 64, 65}, // | ||
388 | MatMulShape{1, 63, 15}, // | ||
389 | MatMulShape{1, 130, 15}, // | ||
390 | MatMulShape{15, 65, 35}, // | ||
391 | MatMulShape{16, 64, 65}, // | ||
392 | MatMulShape{17, 63, 15}, // | ||
393 | MatMulShape{20, 130, 15}), | ||
394 | testing::Values( | ||
395 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
396 | MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. | ||
397 | MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. | ||
398 | MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle | ||
399 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
400 | MatrixPortion(0.75, 0, 1, 1), // Partial rows | ||
401 | MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle | ||
402 | )), | ||
403 | [](const auto& info) { | ||
404 | const auto variant_idx = std::get<0>(info.param); | ||
405 | const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi8cxp.at(variant_idx).name}; | ||
406 | const auto shape = std::get<MatMulShape>(info.param); | ||
407 | const auto portion = std::get<MatrixPortion>(info.param); | ||
408 | |||
409 | return test_description(name, shape, portion, true); | ||
410 | }); | ||
411 | |||
412 | } // namespace kai::test | ||
413 |