test/tests/dwconv_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/reference/dwconv.hpp" | ||
| 8 | |||
| 9 | #include <gtest/gtest.h> | ||
| 10 | |||
| 11 | #include <array> | ||
| 12 | #include <cstddef> | ||
| 13 | #include <iostream> | ||
| 14 | #include <string_view> | ||
| 15 | #include <tuple> | ||
| 16 | #include <unordered_map> | ||
| 17 | |||
| 18 | #include "kai/ukernels/dwconv/dwconv_f32_f32_f32p/kai_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla.h" | ||
| 19 | #include "kai/ukernels/dwconv/dwconv_f32_f32_f32p/kai_dwconv_clamp_f32_f32_f32p_interface.h" | ||
| 20 | #include "kai/ukernels/dwconv/pack/kai_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme.h" | ||
| 21 | #include "test/common/abi_checker.hpp" | ||
| 22 | #include "test/common/buffer.hpp" | ||
| 23 | #include "test/common/compare.hpp" | ||
| 24 | #include "test/common/cpu_info.hpp" | ||
| 25 | #include "test/common/matmul_test_common.hpp" | ||
| 26 | #include "test/common/matrix_portion.hpp" | ||
| 27 | #include "test/common/seed.hpp" | ||
| 28 | #include "test/reference/clamp.hpp" | ||
| 29 | #include "test/reference/fill.hpp" | ||
| 30 | |||
| 31 | namespace kai::test { | ||
| 32 | |||
| 33 | namespace { | ||
| 34 | |||
| 35 | /// Interface for depthwise kernel. | ||
| 36 |
2/4✓ Branch 0 taken 906 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 906 times.
✗ Branch 3 not taken.
|
906 | struct DepthwisePlanarKernel { |
| 37 | std::function<size_t(size_t m, size_t n, size_t k)> get_dst_size; | ||
| 38 | std::function<size_t(size_t m, size_t n)> get_dst_offset; | ||
| 39 | std::function<size_t(void)> get_m_step; | ||
| 40 | std::function<void( | ||
| 41 | const void* inptr, const void* packed_rhs, void* outptr_start, size_t stride_in_row, size_t stride_in_col, | ||
| 42 | size_t dst_stride_row, size_t dst_stride_col, unsigned int valid_input_rows, unsigned int valid_out_rows, | ||
| 43 | unsigned int pad_left, unsigned int pad_top, float pad_value, float clamp_min, float clamp_max)> | ||
| 44 | conv; | ||
| 45 | }; | ||
| 46 | |||
| 47 | // Rhs packing micro-kernel. | ||
| 48 | 906 | struct RhsPackDepthwiseKernel { | |
| 49 | std::function<size_t(size_t fh, size_t fw, size_t nc)> get_rhs_packed_size; | ||
| 50 | std::function<void( | ||
| 51 | size_t filter_height, size_t filter_width, size_t height, size_t width, size_t num_channels, const void* rhs, | ||
| 52 | const void* bias, void* rhs_packed)> | ||
| 53 | pack; | ||
| 54 | }; | ||
| 55 | |||
| 56 | /// Description of a Depthwise kernel set | ||
| 57 |
1/2✓ Branch 0 taken 906 times.
✗ Branch 1 not taken.
|
906 | struct Depthwise { |
| 58 | std::string_view name; | ||
| 59 | std::function<bool(void)> is_supported; | ||
| 60 | std::pair<unsigned int, unsigned int> filter; | ||
| 61 | DataType data_type; | ||
| 62 | DataType acc_type; | ||
| 63 | RhsPackDepthwiseKernel rhs; | ||
| 64 | DepthwisePlanarKernel depthwise; | ||
| 65 | }; | ||
| 66 | |||
| 67 | /// Convenience types for testing. | ||
| 68 | using DepthwiseArray = std::array<Depthwise, 1>; | ||
| 69 | using DepthwiseParamsParams = std::tuple<Depthwise, MatMulShape, Padding2D, float>; | ||
| 70 | using DepthwisePlanarTest = testing::TestWithParam<DepthwiseParamsParams>; | ||
| 71 | |||
| 72 | /// Use interface for depthwise kernel | ||
| 73 | 3 | const kai_dwconv_clamp_f32_f32_f32p_planar_ukernel& get_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla() { | |
| 74 | static kai_dwconv_clamp_f32_f32_f32p_planar_ukernel ukernel; | ||
| 75 | 3 | ukernel.get_m_step = kai_get_m_step_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla; | |
| 76 | 3 | ukernel.get_dst_offset = kai_get_dst_offset_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla; | |
| 77 | 3 | ukernel.get_dst_size = kai_get_dst_size_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla; | |
| 78 | 3 | ukernel.run_dwconv = kai_run_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla; | |
| 79 | 3 | return ukernel; | |
| 80 | } | ||
| 81 | |||
| 82 | 3 | const DepthwiseArray& get_depthwise_methods() { | |
| 83 | // FP32 kernels with 3x3 filter. | ||
| 84 |
4/8✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
|
3 | static DepthwiseArray depthwise_methods{}; |
| 85 | 3 | depthwise_methods[0].name = "dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla"; | |
| 86 | 3 | depthwise_methods[0].rhs.get_rhs_packed_size = kai_rhs_get_dst_size_dwconv_pack_x32p1vlx1b_x32_x32_sme; | |
| 87 | 3 | depthwise_methods[0].rhs.pack = kai_run_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme; | |
| 88 | 3 | depthwise_methods[0].is_supported = cpu_has_sme2; | |
| 89 | 3 | depthwise_methods[0].filter = {3, 3}; | |
| 90 | |||
| 91 | 6 | const kai_dwconv_clamp_f32_f32_f32p_planar_ukernel& ukernel_f32 = | |
| 92 | 3 | get_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla(); | |
| 93 | 3 | depthwise_methods[0].data_type = DataType::FP32; | |
| 94 | 3 | depthwise_methods[0].acc_type = DataType::FP32; | |
| 95 | 3 | depthwise_methods[0].depthwise.get_m_step = ukernel_f32.get_m_step; | |
| 96 | 3 | depthwise_methods[0].depthwise.get_dst_size = ukernel_f32.get_dst_size; | |
| 97 | 3 | depthwise_methods[0].depthwise.get_dst_offset = ukernel_f32.get_dst_offset; | |
| 98 | 3 | depthwise_methods[0].depthwise.conv = ukernel_f32.run_dwconv; | |
| 99 | 3 | return depthwise_methods; | |
| 100 | 3 | } | |
| 101 | |||
| 102 | /// Test reference identification. | ||
| 103 | struct TestDataId { | ||
| 104 | using DT = std::underlying_type_t<DataType>; | ||
| 105 | MatMulShape in_shape; | ||
| 106 | MatMulShape rhs_shape; | ||
| 107 | Padding2D pad; | ||
| 108 | DataType dt; | ||
| 109 | DataType dt_acc; | ||
| 110 | float clamp_keep_ratio; | ||
| 111 | |||
| 112 | struct Hash { | ||
| 113 | 180 | size_t operator()(const TestDataId& test_id) const { | |
| 114 | 180 | return // | |
| 115 | 360 | (MatMulShape::Hash{}(test_id.in_shape) << 0) ^ // | |
| 116 | 360 | (MatMulShape::Hash{}(test_id.rhs_shape) << 1) ^ // | |
| 117 | 360 | (Padding2D::Hash{}(test_id.pad) << 2) ^ // | |
| 118 | 360 | (std::hash<DT>{}(static_cast<DT>(test_id.dt)) << 3) ^ // | |
| 119 | 360 | (std::hash<DT>{}(static_cast<DT>(test_id.dt_acc)) << 4) ^ // | |
| 120 | 180 | (std::hash<float>{}(test_id.clamp_keep_ratio) << 5); // | |
| 121 | } | ||
| 122 | }; | ||
| 123 | |||
| 124 | private: | ||
| 125 | ✗ | friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { | |
| 126 | ✗ | return // | |
| 127 | ✗ | lhs.in_shape == rhs.in_shape && // | |
| 128 | ✗ | lhs.rhs_shape == rhs.rhs_shape && // | |
| 129 | ✗ | lhs.pad == rhs.pad && // | |
| 130 | ✗ | lhs.dt == rhs.dt && // | |
| 131 | ✗ | lhs.dt_acc == rhs.dt_acc && // | |
| 132 | ✗ | lhs.clamp_keep_ratio == rhs.clamp_keep_ratio; // | |
| 133 | } | ||
| 134 | }; | ||
| 135 | |||
| 136 | /// Test reference data | ||
| 137 | struct TestData { | ||
| 138 | Buffer lhs; ///< LHS input matrix | ||
| 139 | Buffer rhs; ///< RHS input matrix | ||
| 140 | Buffer bias; ///< Bias vector | ||
| 141 | Buffer out; ///< Reference depthwise result | ||
| 142 | Buffer padding; ///< Padding buffer | ||
| 143 | Range<float> clamp_range; ///< Clamp range | ||
| 144 | }; | ||
| 145 | |||
| 146 | /// Generate reference data, caches it. | ||
| 147 | struct ReferenceGenerator { | ||
| 148 | /// Retrieve reference data for the provided test identification | ||
| 149 | 60 | static const TestData& get_test_reference(const TestDataId test_id, const MatMulShape& out_shape) { | |
| 150 |
3/4✓ Branch 0 taken 1 time.
✓ Branch 1 taken 59 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
|
60 | static std::unordered_map<TestDataId, TestData, TestDataId::Hash> m_data; |
| 151 |
2/5✗ Branch 0 not taken.
✓ Branch 1 taken 60 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
|
60 | if (const auto itr = m_data.find(test_id); itr != end(m_data)) { |
| 152 | ✗ | return itr->second; | |
| 153 | } | ||
| 154 | |||
| 155 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | return m_data[test_id] = generate_reference(test_id, out_shape); |
| 156 | 60 | } | |
| 157 | |||
| 158 | private: | ||
| 159 | /// Generate reference data. | ||
| 160 | // NOTE : This block is currently FP32 specific - it is not datatype generic | ||
| 161 | 60 | static TestData generate_reference(const TestDataId& test_id, const MatMulShape& out_shape) { | |
| 162 | 960 | const auto& [in_shape, rhs_shape, pad, dt, acc_dt, clamp_keep_ratio] = test_id; | |
| 163 | |||
| 164 | // Stable key derived from the cache identifier. | ||
| 165 | 60 | const auto key_hash = static_cast<std::uint32_t>(TestDataId::Hash{}(test_id)); | |
| 166 |
2/4✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
|
60 | const auto key = std::string("dwconv_cache:") + std::to_string(key_hash); |
| 167 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | auto& feed = seed_stream(key); |
| 168 | |||
| 169 | // Generate random input data | ||
| 170 |
6/12✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 60 times.
✗ Branch 11 not taken.
|
300 | Buffer lhs = fill_matrix_random(in_shape.m, in_shape.n * in_shape.k, DataFormat(dt), feed()); |
| 171 |
6/12✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 60 times.
✗ Branch 11 not taken.
|
300 | Buffer rhs = fill_matrix_random(rhs_shape.m, rhs_shape.n * rhs_shape.k, DataFormat(dt), feed()); |
| 172 |
3/6✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
|
120 | Buffer bias = fill_matrix_random(1, out_shape.k, DataFormat(dt), feed()); |
| 173 | |||
| 174 | // Call reference function | ||
| 175 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
120 | Buffer out = depthwise_reference<float>( |
| 176 |
7/14✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 60 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 60 times.
✗ Branch 13 not taken.
|
300 | 1, in_shape.m, in_shape.n, in_shape.k, rhs_shape.m, rhs_shape.n, lhs.data(), rhs.data(), bias.data(), pad); |
| 177 | |||
| 178 | 120 | const auto [min, max] = | |
| 179 |
3/6✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
|
60 | find_clamp_range(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, clamp_keep_ratio); |
| 180 |
5/10✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
|
120 | Buffer out_clamped = clamp(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, min, max); |
| 181 | |||
| 182 | // Populate reference data | ||
| 183 | 60 | TestData test_reference; | |
| 184 | 60 | test_reference.lhs = std::move(lhs); | |
| 185 | 60 | test_reference.rhs = std::move(rhs); | |
| 186 | 60 | test_reference.bias = std::move(bias); | |
| 187 | 60 | test_reference.out = std::move(out_clamped); | |
| 188 | 180 | test_reference.clamp_range = {min, max}; | |
| 189 | 60 | return test_reference; | |
| 190 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | }; |
| 191 | }; | ||
| 192 | |||
| 193 | /// Perform RHS packing for depthwise | ||
| 194 | 60 | Buffer pack_rhs(const RhsPackDepthwiseKernel& kernel, const MatMulShape& shape, const TestData& reference) { | |
| 195 | // Calculate size, and allocate buffer | ||
| 196 | 60 | const size_t dst_size = kernel.get_rhs_packed_size(shape.m, shape.n, shape.k); | |
| 197 | 60 | Buffer dst(dst_size); | |
| 198 | |||
| 199 | // RHS Pack API is subject to change. | ||
| 200 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | abi_check( |
| 201 |
2/4✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
|
60 | kernel.pack, shape.m, shape.n, shape.m, shape.n, shape.k, reference.rhs.data(), reference.bias.data(), |
| 202 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | dst.data()); |
| 203 | 60 | return dst; | |
| 204 | 60 | } | |
| 205 | |||
| 206 | /// Perform Depthwise Operation using main kernel. | ||
| 207 | 60 | Buffer dwconv( | |
| 208 | const DepthwisePlanarKernel& kernel, const Rect& portion, const MatMulShape& in_shape, const MatMulShape& out_shape, | ||
| 209 | const Padding2D pad, const TestData& reference, const Buffer& rhs_packed, Range<float> clamp_range, DataType type) { | ||
| 210 | 60 | const size_t dst_size = kernel.get_dst_size(out_shape.m, out_shape.n, out_shape.k); | |
| 211 | 60 | Buffer dst(dst_size); | |
| 212 | |||
| 213 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | const size_t dt_size_bytes = data_type_size_in_bits(type) / 8; |
| 214 | 60 | const size_t stride_in_row = in_shape.n * in_shape.k * dt_size_bytes; | |
| 215 | 60 | const size_t dst_stride_row = out_shape.n * out_shape.k * dt_size_bytes; | |
| 216 | 60 | const size_t stride_col = out_shape.k * dt_size_bytes; | |
| 217 | |||
| 218 | // Loop the following. M-Step rows are handled at a time. | ||
| 219 |
5/8✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1098 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1038 times.
✓ Branch 5 taken 60 times.
✓ Branch 6 taken 1038 times.
✗ Branch 7 not taken.
|
1098 | for (size_t out_row = portion.start_row(); out_row < portion.end_row(); out_row += kernel.get_m_step()) { |
| 220 | 1038 | const int start_in_row = out_row - pad.top; | |
| 221 |
2/2✓ Branch 0 taken 993 times.
✓ Branch 1 taken 45 times.
|
1038 | const size_t pad_top = (start_in_row < 0) ? (-start_in_row) : 0; |
| 222 |
2/2✓ Branch 0 taken 993 times.
✓ Branch 1 taken 45 times.
|
1038 | const size_t in_row = (start_in_row < 0) ? 0 : start_in_row; |
| 223 | |||
| 224 |
1/2✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
|
1038 | const size_t valid_input_rows = (in_row < in_shape.m) ? (in_shape.m - in_row) : 0; |
| 225 | 1038 | const size_t valid_out_rows = (out_shape.m - out_row); | |
| 226 | |||
| 227 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1038 times.
|
1038 | abi_check( |
| 228 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1038 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1038 times.
|
1038 | kernel.conv, reference.lhs.data() + (in_row * stride_in_row), rhs_packed.data(), |
| 229 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1038 times.
|
1038 | dst.data() + (out_row * dst_stride_row), stride_in_row, stride_col, dst_stride_row, stride_col, |
| 230 | 1038 | valid_input_rows, valid_out_rows, pad.left, pad_top, 0.f, clamp_range.min, clamp_range.max); | |
| 231 | 1038 | } | |
| 232 | |||
| 233 | 60 | return dst; | |
| 234 | 60 | } | |
| 235 | } // namespace | ||
| 236 | |||
| 237 | /// End-to-end test for depthwise kernels | ||
| 238 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
306 | TEST_P(DepthwisePlanarTest, Output) { |
| 239 | 1020 | const auto& [method, in_shape, padding, clamp_keep_ratio] = GetParam(); | |
| 240 |
2/2✓ Branch 0 taken 60 times.
✓ Branch 1 taken 60 times.
|
120 | if (not method.is_supported()) { |
| 241 |
3/6✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
|
60 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 242 | } | ||
| 243 | |||
| 244 | // Calculate Shapes. | ||
| 245 | 300 | const int out_height = (in_shape.m + padding.top + padding.bottom + 1 - method.filter.first); | |
| 246 | 300 | const int out_width = (in_shape.n + padding.left + padding.right + 1 - method.filter.second); | |
| 247 |
4/16✗ Branch 0 not taken.
✓ Branch 1 taken 60 times.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 60 times.
|
60 | ASSERT_TRUE(out_height > 0 && out_width > 0); |
| 248 | |||
| 249 | 120 | const size_t dt_size_bytes = data_type_size_in_bits(method.data_type) / 8; | |
| 250 | 240 | MatMulShape rhs_shape = {method.filter.first, method.filter.second, in_shape.k}; | |
| 251 | 120 | MatMulShape out_shape = {static_cast<size_t>(out_height), static_cast<size_t>(out_width), (in_shape.k)}; | |
| 252 | |||
| 253 | // 1. Calculate reference. | ||
| 254 | 120 | const TestData& test_data = ReferenceGenerator::get_test_reference( | |
| 255 | 360 | {in_shape, rhs_shape, padding, method.data_type, method.acc_type, clamp_keep_ratio}, out_shape); | |
| 256 | |||
| 257 | // 2. Pack RHS (Weights+Bias) | ||
| 258 | 120 | Buffer rhs_packed = pack_rhs(method.rhs, rhs_shape, test_data); | |
| 259 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | const MatrixPortion out_portion{0, 0, 1, 1}; |
| 260 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
120 | const Rect portion = out_portion.compute_portion( |
| 261 |
3/6✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
|
120 | out_shape.m, out_shape.n * out_shape.k, method.depthwise.get_m_step(), (rhs_packed.size() / dt_size_bytes)); |
| 262 | |||
| 263 | // 3. Run Depthwise Kernel. | ||
| 264 |
4/8✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
|
240 | Buffer out = dwconv( |
| 265 | 180 | method.depthwise, portion, in_shape, out_shape, padding, test_data, rhs_packed, test_data.clamp_range, | |
| 266 | 60 | method.data_type); | |
| 267 | |||
| 268 | // 4. Compare with reference result. | ||
| 269 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
60 | DefaultMismatchHandler handler(0, 0.0001, 0, 0.001); |
| 270 |
1/2✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
|
120 | const auto success = compare( |
| 271 |
2/4✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
|
60 | out.data(), test_data.out.data(), DataType::FP32, out_shape.m, out_shape.n * out_shape.k, portion, handler); |
| 272 |
4/16✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 60 times.
|
60 | ASSERT_TRUE(success); |
| 273 | 120 | } | |
| 274 | |||
| 275 | /// Name generator for test case | ||
| 276 | 360 | [[maybe_unused]] static void PrintTo(const DepthwiseParamsParams& param, std::ostream* os) { | |
| 277 | 1800 | const auto& [method, shape, padding, clamp_keep_ratio] = param; | |
| 278 | 720 | *os << method.name << "__"; | |
| 279 | 360 | PrintTo(shape, os); | |
| 280 | 360 | *os << "__"; | |
| 281 | 360 | PrintTo(padding, os); | |
| 282 | 360 | *os << "__"; | |
| 283 | 720 | *os << "__clamp_keep_ratio_" << static_cast<int>(clamp_keep_ratio * 100); | |
| 284 | 360 | } | |
| 285 | |||
| 286 | /// Test parameter listing | ||
| 287 |
18/56✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✓ Branch 20 taken 60 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✓ Branch 22 taken 120 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
|
186 | INSTANTIATE_TEST_SUITE_P( |
| 288 | Depthwise, DepthwisePlanarTest, | ||
| 289 | testing::Combine( | ||
| 290 | testing::ValuesIn(get_depthwise_methods()), // | ||
| 291 | testing::ValuesIn({ | ||
| 292 | // clang-format off | ||
| 293 | // IN_HEIGHT, IN_WIDTH, IN_CHANNELS | ||
| 294 | MatMulShape{ 4, 4, 1}, // | ||
| 295 | MatMulShape{ 8, 4, 16}, // | ||
| 296 | MatMulShape{ 96, 33, 37}, // | ||
| 297 | MatMulShape{ 99, 22, 51}, // | ||
| 298 | MatMulShape{ 127, 127, 127}, // | ||
| 299 | // clang-format on | ||
| 300 | }), | ||
| 301 | testing::ValuesIn({ | ||
| 302 | // clang-format off | ||
| 303 | // pad_left, pad_right, pad_top, pad_bottom | ||
| 304 | Padding2D{0, 0, 0, 0}, | ||
| 305 | Padding2D{0, 1, 0, 1}, | ||
| 306 | Padding2D{1, 1, 1, 1}, | ||
| 307 | Padding2D{5, 11, 7, 3}, | ||
| 308 | // clang-format on | ||
| 309 | }), | ||
| 310 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 311 | testing::PrintToStringParamName()); | ||
| 312 | |||
| 313 | } // namespace kai::test | ||
| 314 |