Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/dwconv.hpp" | ||
8 | |||
9 | #include <gtest/gtest.h> | ||
10 | |||
11 | #include <array> | ||
12 | #include <cstddef> | ||
13 | #include <iostream> | ||
14 | #include <string_view> | ||
15 | #include <tuple> | ||
16 | #include <unordered_map> | ||
17 | |||
18 | #include "kai/ukernels/dwconv/dwconv_f32_f32_f32p/kai_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla.h" | ||
19 | #include "kai/ukernels/dwconv/dwconv_f32_f32_f32p/kai_dwconv_clamp_f32_f32_f32p_interface.h" | ||
20 | #include "kai/ukernels/dwconv/pack/kai_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme.h" | ||
21 | #include "test/common/buffer.hpp" | ||
22 | #include "test/common/compare.hpp" | ||
23 | #include "test/common/cpu_info.hpp" | ||
24 | #include "test/common/matmul_test_common.hpp" | ||
25 | #include "test/common/matrix_portion.hpp" | ||
26 | #include "test/reference/clamp.hpp" | ||
27 | #include "test/reference/fill.hpp" | ||
28 | |||
29 | namespace kai::test { | ||
30 | |||
31 | namespace { | ||
32 | |||
33 | /// Interface for depthwise kernel. | ||
34 |
2/4✓ Branch 0 taken 302 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 302 times.
✗ Branch 3 not taken.
|
302 | struct DepthwisePlanarKernel { |
35 | std::function<size_t(size_t m, size_t n, size_t k)> get_dst_size; | ||
36 | std::function<size_t(size_t m, size_t n)> get_dst_offset; | ||
37 | std::function<size_t(void)> get_m_step; | ||
38 | std::function<void( | ||
39 | const void* inptr, const void* packed_rhs, void* outptr_start, size_t stride_in_row, size_t stride_in_col, | ||
40 | size_t dst_stride_row, size_t dst_stride_col, unsigned int valid_input_rows, unsigned int valid_out_rows, | ||
41 | unsigned int pad_left, unsigned int pad_top, float pad_value, float clamp_min, float clamp_max)> | ||
42 | conv; | ||
43 | }; | ||
44 | |||
45 | // Rhs packing micro-kernel. | ||
46 | 302 | struct RhsPackDepthwiseKernel { | |
47 | std::function<size_t(size_t fh, size_t fw, size_t nc)> get_rhs_packed_size; | ||
48 | std::function<void( | ||
49 | size_t filter_height, size_t filter_width, size_t height, size_t width, size_t num_channels, const void* rhs, | ||
50 | const void* bias, void* rhs_packed)> | ||
51 | pack; | ||
52 | }; | ||
53 | |||
54 | /// Description of a Depthwise kernel set | ||
55 |
1/2✓ Branch 0 taken 302 times.
✗ Branch 1 not taken.
|
302 | struct Depthwise { |
56 | std::string_view name; | ||
57 | std::function<bool(void)> is_supported; | ||
58 | std::pair<unsigned int, unsigned int> filter; | ||
59 | DataType data_type; | ||
60 | DataType acc_type; | ||
61 | RhsPackDepthwiseKernel rhs; | ||
62 | DepthwisePlanarKernel depthwise; | ||
63 | }; | ||
64 | |||
65 | /// Convenience types for testing. | ||
66 | using DepthwiseArray = std::array<Depthwise, 1>; | ||
67 | using DepthwiseParamsParams = std::tuple<Depthwise, MatMulShape, Padding2D, float>; | ||
68 | using DepthwisePlanarTest = testing::TestWithParam<DepthwiseParamsParams>; | ||
69 | |||
70 | /// Use interface for depthwise kernel | ||
71 | 1 | const kai_dwconv_clamp_f32_f32_f32p_planar_ukernel& get_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla() { | |
72 | static kai_dwconv_clamp_f32_f32_f32p_planar_ukernel ukernel; | ||
73 | 1 | ukernel.get_m_step = kai_get_m_step_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla; | |
74 | 1 | ukernel.get_dst_offset = kai_get_dst_offset_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla; | |
75 | 1 | ukernel.get_dst_size = kai_get_dst_size_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla; | |
76 | 1 | ukernel.run_dwconv = kai_run_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla; | |
77 | 1 | return ukernel; | |
78 | } | ||
79 | |||
80 | 1 | const DepthwiseArray& get_depthwise_methods() { | |
81 | // FP32 kernels with 3x3 filter. | ||
82 |
2/4✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
1 | static DepthwiseArray depthwise_methods{}; |
83 | 1 | depthwise_methods[0].name = "dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla"; | |
84 | 1 | depthwise_methods[0].rhs.get_rhs_packed_size = kai_rhs_get_dst_size_dwconv_pack_x32p1vlx1b_x32_x32_sme; | |
85 | 1 | depthwise_methods[0].rhs.pack = kai_run_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme; | |
86 | 1 | depthwise_methods[0].is_supported = cpu_has_sme2; | |
87 | 1 | depthwise_methods[0].filter = {3, 3}; | |
88 | |||
89 | 2 | const kai_dwconv_clamp_f32_f32_f32p_planar_ukernel& ukernel_f32 = | |
90 | 1 | get_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla(); | |
91 | 1 | depthwise_methods[0].data_type = DataType::FP32; | |
92 | 1 | depthwise_methods[0].acc_type = DataType::FP32; | |
93 | 1 | depthwise_methods[0].depthwise.get_m_step = ukernel_f32.get_m_step; | |
94 | 1 | depthwise_methods[0].depthwise.get_dst_size = ukernel_f32.get_dst_size; | |
95 | 1 | depthwise_methods[0].depthwise.get_dst_offset = ukernel_f32.get_dst_offset; | |
96 | 1 | depthwise_methods[0].depthwise.conv = ukernel_f32.run_dwconv; | |
97 | 1 | return depthwise_methods; | |
98 | 1 | } | |
99 | |||
100 | /// Test reference identification. | ||
101 | struct TestDataId { | ||
102 | using DT = std::underlying_type_t<DataType>; | ||
103 | MatMulShape in_shape; | ||
104 | MatMulShape rhs_shape; | ||
105 | Padding2D pad; | ||
106 | DataType dt; | ||
107 | DataType dt_acc; | ||
108 | float clamp_rate; | ||
109 | |||
110 | struct Hash { | ||
111 | 120 | size_t operator()(const TestDataId& test_id) const { | |
112 | 120 | return // | |
113 | 240 | (MatMulShape::Hash{}(test_id.in_shape) << 0) ^ // | |
114 | 240 | (MatMulShape::Hash{}(test_id.rhs_shape) << 1) ^ // | |
115 | 240 | (Padding2D::Hash{}(test_id.pad) << 2) ^ // | |
116 | 240 | (std::hash<DT>{}(static_cast<DT>(test_id.dt)) << 3) ^ // | |
117 | 240 | (std::hash<DT>{}(static_cast<DT>(test_id.dt_acc)) << 4) ^ // | |
118 | 120 | (std::hash<float>{}(test_id.clamp_rate) << 5); // | |
119 | } | ||
120 | }; | ||
121 | |||
122 | private: | ||
123 | ✗ | friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { | |
124 | ✗ | return // | |
125 | ✗ | lhs.in_shape == rhs.in_shape && // | |
126 | ✗ | lhs.rhs_shape == rhs.rhs_shape && // | |
127 | ✗ | lhs.pad == rhs.pad && // | |
128 | ✗ | lhs.dt == rhs.dt && // | |
129 | ✗ | lhs.dt_acc == rhs.dt_acc && // | |
130 | ✗ | lhs.clamp_rate == rhs.clamp_rate; // | |
131 | } | ||
132 | }; | ||
133 | |||
134 | /// Test reference data | ||
135 | struct TestData { | ||
136 | Buffer lhs; ///< LHS input matrix | ||
137 | Buffer rhs; ///< RHS input matrix | ||
138 | Buffer bias; ///< Bias vector | ||
139 | Buffer out; ///< Reference depthwise result | ||
140 | Buffer padding; ///< Padding buffer | ||
141 | Range<float> clamp_range; ///< Clamp range | ||
142 | }; | ||
143 | |||
144 | /// Generate reference data, caches it. | ||
145 | struct ReferenceGenerator { | ||
146 | /// Retrieve reference data for the provided test identification | ||
147 | 60 | static const TestData& get_test_reference(const TestDataId test_id, const MatMulShape& out_shape) { | |
148 |
3/4✓ Branch 0 taken 1 time.
✓ Branch 1 taken 59 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
60 | static std::unordered_map<TestDataId, TestData, TestDataId::Hash> m_data; |
149 |
2/5✗ Branch 0 not taken.
✓ Branch 1 taken 60 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
|
60 | if (const auto itr = m_data.find(test_id); itr != end(m_data)) { |
150 | ✗ | return itr->second; | |
151 | } | ||
152 | |||
153 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | return m_data[test_id] = generate_reference(test_id, out_shape); |
154 | 60 | } | |
155 | |||
156 | private: | ||
157 | /// Return incremented seed value | ||
158 | 180 | static size_t get_seed() { | |
159 | static size_t seed = 0; | ||
160 | 180 | return seed++; | |
161 | } | ||
162 | |||
163 | /// Generate reference data. | ||
164 | // NOTE : This block is currently FP32 specific - it is not datatype generic | ||
165 | 60 | static TestData generate_reference(const TestDataId& test_id, const MatMulShape& out_shape) { | |
166 | 960 | const auto& [in_shape, rhs_shape, pad, dt, acc_dt, clamp_rate] = test_id; | |
167 | |||
168 | // Generate random input data | ||
169 | 300 | Buffer lhs = fill_matrix_random(in_shape.m, in_shape.n * in_shape.k, DataFormat(dt), get_seed()); | |
170 |
6/12✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 60 times.
✗ Branch 11 not taken.
|
300 | Buffer rhs = fill_matrix_random(rhs_shape.m, rhs_shape.n * rhs_shape.k, DataFormat(dt), get_seed()); |
171 |
3/6✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
|
120 | Buffer bias = fill_matrix_random(1, out_shape.k, DataFormat(dt), get_seed()); |
172 | |||
173 | // Call reference function | ||
174 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
120 | Buffer out = depthwise_reference<float>( |
175 |
7/14✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 60 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 60 times.
✗ Branch 13 not taken.
|
300 | 1, in_shape.m, in_shape.n, in_shape.k, rhs_shape.m, rhs_shape.n, lhs.data(), rhs.data(), bias.data(), pad); |
176 | |||
177 | 120 | const auto [min, max] = | |
178 |
3/6✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
|
60 | find_clamp_range(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, 1.0F - clamp_rate); |
179 |
5/10✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
|
120 | Buffer out_clamped = clamp(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, min, max); |
180 | |||
181 | // Populate reference data | ||
182 | 60 | TestData test_reference; | |
183 | 60 | test_reference.lhs = std::move(lhs); | |
184 | 60 | test_reference.rhs = std::move(rhs); | |
185 | 60 | test_reference.bias = std::move(bias); | |
186 | 60 | test_reference.out = std::move(out_clamped); | |
187 | 180 | test_reference.clamp_range = {min, max}; | |
188 | 60 | return test_reference; | |
189 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | }; |
190 | }; | ||
191 | |||
192 | /// Perform RHS packing for depthwise | ||
193 | 60 | Buffer pack_rhs(const RhsPackDepthwiseKernel& kernel, const MatMulShape& shape, const TestData& reference) { | |
194 | // Calculate size, and allocate buffer | ||
195 | 60 | const size_t dst_size = kernel.get_rhs_packed_size(shape.m, shape.n, shape.k); | |
196 | 60 | Buffer dst(dst_size); | |
197 | |||
198 | // RHS Pack API is subject to change. | ||
199 |
4/8✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
|
60 | kernel.pack(shape.m, shape.n, shape.m, shape.n, shape.k, reference.rhs.data(), reference.bias.data(), dst.data()); |
200 | 60 | return dst; | |
201 | 60 | } | |
202 | |||
203 | /// Perform Depthwise Operation using main kernel. | ||
204 | 60 | Buffer dwconv( | |
205 | const DepthwisePlanarKernel& kernel, const Rect& portion, const MatMulShape& in_shape, const MatMulShape& out_shape, | ||
206 | const Padding2D pad, const TestData& reference, const Buffer& rhs_packed, Range<float> clamp_range, DataType type) { | ||
207 | 60 | const size_t dst_size = kernel.get_dst_size(out_shape.m, out_shape.n, out_shape.k); | |
208 | 60 | Buffer dst(dst_size); | |
209 | |||
210 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | const size_t dt_size_bytes = data_type_size_in_bits(type) / 8; |
211 | 60 | const size_t stride_in_row = in_shape.n * in_shape.k * dt_size_bytes; | |
212 | 60 | const size_t dst_stride_row = out_shape.n * out_shape.k * dt_size_bytes; | |
213 | 60 | const size_t stride_col = out_shape.k * dt_size_bytes; | |
214 | |||
215 | // Loop the following. M-Step rows are handled at a time. | ||
216 |
5/8✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1098 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1038 times.
✓ Branch 5 taken 60 times.
✓ Branch 6 taken 1038 times.
✗ Branch 7 not taken.
|
1098 | for (size_t out_row = portion.start_row(); out_row < portion.end_row(); out_row += kernel.get_m_step()) { |
217 | 1038 | const int start_in_row = out_row - pad.top; | |
218 |
2/2✓ Branch 0 taken 993 times.
✓ Branch 1 taken 45 times.
|
1038 | const size_t pad_top = (start_in_row < 0) ? (-start_in_row) : 0; |
219 |
2/2✓ Branch 0 taken 993 times.
✓ Branch 1 taken 45 times.
|
1038 | const size_t in_row = (start_in_row < 0) ? 0 : start_in_row; |
220 | |||
221 |
1/2✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
|
1038 | const size_t valid_input_rows = (in_row < in_shape.m) ? (in_shape.m - in_row) : 0; |
222 | 1038 | const size_t valid_out_rows = (out_shape.m - out_row); | |
223 | |||
224 |
1/2✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
|
2076 | kernel.conv( |
225 |
3/6✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1038 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1038 times.
✗ Branch 5 not taken.
|
1038 | reference.lhs.data() + (in_row * stride_in_row), rhs_packed.data(), dst.data() + (out_row * dst_stride_row), |
226 | 1038 | stride_in_row, stride_col, dst_stride_row, stride_col, valid_input_rows, valid_out_rows, pad.left, pad_top, | |
227 | 1038 | 0.f, clamp_range.min, clamp_range.max); | |
228 | 1038 | } | |
229 | |||
230 | 60 | return dst; | |
231 | 60 | } | |
232 | } // namespace | ||
233 | |||
234 | /// End-to-end test for depthwise kernels | ||
235 |
7/14✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
|
182 | TEST_P(DepthwisePlanarTest, Output) { |
236 | 960 | const auto& [method, in_shape, padding, clamp_rate] = GetParam(); | |
237 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | if (not method.is_supported()) { |
238 | ✗ | GTEST_SKIP() << "Unsupported CPU feature"; | |
239 | } | ||
240 | |||
241 | // Calculate Shapes. | ||
242 | 300 | const int out_height = (in_shape.m + padding.top + padding.bottom + 1 - method.filter.first); | |
243 | 300 | const int out_width = (in_shape.n + padding.left + padding.right + 1 - method.filter.second); | |
244 |
4/16✗ Branch 0 not taken.
✓ Branch 1 taken 60 times.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 60 times.
|
60 | ASSERT_TRUE(out_height > 0 && out_width > 0); |
245 | |||
246 | 120 | const size_t dt_size_bytes = data_type_size_in_bits(method.data_type) / 8; | |
247 | 240 | MatMulShape rhs_shape = {method.filter.first, method.filter.second, in_shape.k}; | |
248 | 120 | MatMulShape out_shape = {static_cast<size_t>(out_height), static_cast<size_t>(out_width), (in_shape.k)}; | |
249 | |||
250 | // 1. Calculate reference. | ||
251 | 120 | const TestData& test_data = ReferenceGenerator::get_test_reference( | |
252 | 360 | {in_shape, rhs_shape, padding, method.data_type, method.acc_type, clamp_rate}, out_shape); | |
253 | |||
254 | // 2. Pack RHS (Weights+Bias) | ||
255 | 120 | Buffer rhs_packed = pack_rhs(method.rhs, rhs_shape, test_data); | |
256 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | const MatrixPortion out_portion{0, 0, 1, 1}; |
257 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
120 | const Rect portion = out_portion.compute_portion( |
258 |
3/6✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
|
120 | out_shape.m, out_shape.n * out_shape.k, method.depthwise.get_m_step(), (rhs_packed.size() / dt_size_bytes)); |
259 | |||
260 | // 3. Run Depthwise Kernel. | ||
261 |
4/8✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
|
240 | Buffer out = dwconv( |
262 | 180 | method.depthwise, portion, in_shape, out_shape, padding, test_data, rhs_packed, test_data.clamp_range, | |
263 | 60 | method.data_type); | |
264 | |||
265 | // 4. Compare with reference result. | ||
266 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); |
267 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
120 | const auto success = compare( |
268 |
2/4✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
|
60 | out.data(), test_data.out.data(), DataType::FP32, out_shape.m, out_shape.n * out_shape.k, portion, handler); |
269 |
4/16✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 60 times.
|
60 | ASSERT_TRUE(success); |
270 | 60 | } | |
271 | |||
272 | /// Name generator for test case | ||
273 | 120 | [[maybe_unused]] static void PrintTo(const DepthwiseParamsParams& param, std::ostream* os) { | |
274 | 600 | const auto& [method, shape, padding, clamp_rate] = param; | |
275 | 240 | *os << method.name << "__"; | |
276 | 120 | PrintTo(shape, os); | |
277 | 120 | *os << "__"; | |
278 | 120 | PrintTo(padding, os); | |
279 | 120 | *os << "__"; | |
280 | 240 | *os << "__clamp_rate_" << static_cast<int>(clamp_rate * 100); | |
281 | 120 | } | |
282 | |||
283 | /// Test parameter listing | ||
284 |
11/32✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 60 times.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
|
62 | INSTANTIATE_TEST_SUITE_P( |
285 | Depthwise, DepthwisePlanarTest, | ||
286 | testing::Combine( | ||
287 | testing::ValuesIn(get_depthwise_methods()), // | ||
288 | testing::ValuesIn({ | ||
289 | // clang-format off | ||
290 | // IN_HEIGHT, IN_WIDTH, IN_CHANNELS | ||
291 | MatMulShape{ 4, 4, 1}, // | ||
292 | MatMulShape{ 8, 4, 16}, // | ||
293 | MatMulShape{ 96, 33, 37}, // | ||
294 | MatMulShape{ 99, 22, 51}, // | ||
295 | MatMulShape{ 127, 127, 127}, // | ||
296 | // clang-format on | ||
297 | }), | ||
298 | testing::ValuesIn({ | ||
299 | // clang-format off | ||
300 | // pad_left, pad_right, pad_top, pad_bottom | ||
301 | Padding2D{0, 0, 0, 0}, | ||
302 | Padding2D{0, 1, 0, 1}, | ||
303 | Padding2D{1, 1, 1, 1}, | ||
304 | Padding2D{5, 11, 7, 3}, | ||
305 | // clang-format on | ||
306 | }), | ||
307 | testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})), // | ||
308 | testing::PrintToStringParamName()); | ||
309 | |||
310 | } // namespace kai::test | ||
311 |