KleidiAI Coverage Report


Directory: ./
File: test/tests/dwconv_test.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 92.2% 118 0 128
Functions: 96.0% 24 0 25
Branches: 42.3% 96 0 227

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/dwconv.hpp"
8
9 #include <gtest/gtest.h>
10
11 #include <array>
12 #include <cstddef>
13 #include <iostream>
14 #include <string_view>
15 #include <tuple>
16 #include <unordered_map>
17
18 #include "kai/ukernels/dwconv/dwconv_f32_f32_f32p/kai_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla.h"
19 #include "kai/ukernels/dwconv/dwconv_f32_f32_f32p/kai_dwconv_clamp_f32_f32_f32p_interface.h"
20 #include "kai/ukernels/dwconv/pack/kai_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme.h"
21 #include "test/common/buffer.hpp"
22 #include "test/common/compare.hpp"
23 #include "test/common/cpu_info.hpp"
24 #include "test/common/matmul_test_common.hpp"
25 #include "test/common/matrix_portion.hpp"
26 #include "test/reference/clamp.hpp"
27 #include "test/reference/fill.hpp"
28
29 namespace kai::test {
30
31 namespace {
32
33 /// Interface for depthwise kernel.
34
2/4
✓ Branch 0 taken 302 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 302 times.
✗ Branch 3 not taken.
302 struct DepthwisePlanarKernel {
35 std::function<size_t(size_t m, size_t n, size_t k)> get_dst_size;
36 std::function<size_t(size_t m, size_t n)> get_dst_offset;
37 std::function<size_t(void)> get_m_step;
38 std::function<void(
39 const void* inptr, const void* packed_rhs, void* outptr_start, size_t stride_in_row, size_t stride_in_col,
40 size_t dst_stride_row, size_t dst_stride_col, unsigned int valid_input_rows, unsigned int valid_out_rows,
41 unsigned int pad_left, unsigned int pad_top, float pad_value, float clamp_min, float clamp_max)>
42 conv;
43 };
44
45 // Rhs packing micro-kernel.
46 302 struct RhsPackDepthwiseKernel {
47 std::function<size_t(size_t fh, size_t fw, size_t nc)> get_rhs_packed_size;
48 std::function<void(
49 size_t filter_height, size_t filter_width, size_t height, size_t width, size_t num_channels, const void* rhs,
50 const void* bias, void* rhs_packed)>
51 pack;
52 };
53
54 /// Description of a Depthwise kernel set
55
1/2
✓ Branch 0 taken 302 times.
✗ Branch 1 not taken.
302 struct Depthwise {
56 std::string_view name;
57 std::function<bool(void)> is_supported;
58 std::pair<unsigned int, unsigned int> filter;
59 DataType data_type;
60 DataType acc_type;
61 RhsPackDepthwiseKernel rhs;
62 DepthwisePlanarKernel depthwise;
63 };
64
65 /// Convenience types for testing.
66 using DepthwiseArray = std::array<Depthwise, 1>;
67 using DepthwiseParamsParams = std::tuple<Depthwise, MatMulShape, Padding2D, float>;
68 using DepthwisePlanarTest = testing::TestWithParam<DepthwiseParamsParams>;
69
70 /// Use interface for depthwise kernel
71 1 const kai_dwconv_clamp_f32_f32_f32p_planar_ukernel& get_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla() {
72 static kai_dwconv_clamp_f32_f32_f32p_planar_ukernel ukernel;
73 1 ukernel.get_m_step = kai_get_m_step_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla;
74 1 ukernel.get_dst_offset = kai_get_dst_offset_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla;
75 1 ukernel.get_dst_size = kai_get_dst_size_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla;
76 1 ukernel.run_dwconv = kai_run_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla;
77 1 return ukernel;
78 }
79
80 1 const DepthwiseArray& get_depthwise_methods() {
81 // FP32 kernels with 3x3 filter.
82
2/4
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
1 static DepthwiseArray depthwise_methods{};
83 1 depthwise_methods[0].name = "dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla";
84 1 depthwise_methods[0].rhs.get_rhs_packed_size = kai_rhs_get_dst_size_dwconv_pack_x32p1vlx1b_x32_x32_sme;
85 1 depthwise_methods[0].rhs.pack = kai_run_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme;
86 1 depthwise_methods[0].is_supported = cpu_has_sme2;
87 1 depthwise_methods[0].filter = {3, 3};
88
89 2 const kai_dwconv_clamp_f32_f32_f32p_planar_ukernel& ukernel_f32 =
90 1 get_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla();
91 1 depthwise_methods[0].data_type = DataType::FP32;
92 1 depthwise_methods[0].acc_type = DataType::FP32;
93 1 depthwise_methods[0].depthwise.get_m_step = ukernel_f32.get_m_step;
94 1 depthwise_methods[0].depthwise.get_dst_size = ukernel_f32.get_dst_size;
95 1 depthwise_methods[0].depthwise.get_dst_offset = ukernel_f32.get_dst_offset;
96 1 depthwise_methods[0].depthwise.conv = ukernel_f32.run_dwconv;
97 1 return depthwise_methods;
98 1 }
99
100 /// Test reference identification.
101 struct TestDataId {
102 using DT = std::underlying_type_t<DataType>;
103 MatMulShape in_shape;
104 MatMulShape rhs_shape;
105 Padding2D pad;
106 DataType dt;
107 DataType dt_acc;
108 float clamp_rate;
109
110 struct Hash {
111 120 size_t operator()(const TestDataId& test_id) const {
112 120 return //
113 240 (MatMulShape::Hash{}(test_id.in_shape) << 0) ^ //
114 240 (MatMulShape::Hash{}(test_id.rhs_shape) << 1) ^ //
115 240 (Padding2D::Hash{}(test_id.pad) << 2) ^ //
116 240 (std::hash<DT>{}(static_cast<DT>(test_id.dt)) << 3) ^ //
117 240 (std::hash<DT>{}(static_cast<DT>(test_id.dt_acc)) << 4) ^ //
118 120 (std::hash<float>{}(test_id.clamp_rate) << 5); //
119 }
120 };
121
122 private:
123 friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) {
124 return //
125 lhs.in_shape == rhs.in_shape && //
126 lhs.rhs_shape == rhs.rhs_shape && //
127 lhs.pad == rhs.pad && //
128 lhs.dt == rhs.dt && //
129 lhs.dt_acc == rhs.dt_acc && //
130 lhs.clamp_rate == rhs.clamp_rate; //
131 }
132 };
133
134 /// Test reference data
135 struct TestData {
136 Buffer lhs; ///< LHS input matrix
137 Buffer rhs; ///< RHS input matrix
138 Buffer bias; ///< Bias vector
139 Buffer out; ///< Reference depthwise result
140 Buffer padding; ///< Padding buffer
141 Range<float> clamp_range; ///< Clamp range
142 };
143
144 /// Generate reference data, caches it.
145 struct ReferenceGenerator {
146 /// Retrieve reference data for the provided test identification
147 60 static const TestData& get_test_reference(const TestDataId test_id, const MatMulShape& out_shape) {
148
3/4
✓ Branch 0 taken 1 time.
✓ Branch 1 taken 59 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 time.
60 static std::unordered_map<TestDataId, TestData, TestDataId::Hash> m_data;
149
2/5
✗ Branch 0 not taken.
✓ Branch 1 taken 60 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
60 if (const auto itr = m_data.find(test_id); itr != end(m_data)) {
150 return itr->second;
151 }
152
153
1/2
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
60 return m_data[test_id] = generate_reference(test_id, out_shape);
154 60 }
155
156 private:
157 /// Return incremented seed value
158 180 static size_t get_seed() {
159 static size_t seed = 0;
160 180 return seed++;
161 }
162
163 /// Generate reference data.
164 // NOTE : This block is currently FP32 specific - it is not datatype generic
165 60 static TestData generate_reference(const TestDataId& test_id, const MatMulShape& out_shape) {
166 960 const auto& [in_shape, rhs_shape, pad, dt, acc_dt, clamp_rate] = test_id;
167
168 // Generate random input data
169 300 Buffer lhs = fill_matrix_random(in_shape.m, in_shape.n * in_shape.k, DataFormat(dt), get_seed());
170
6/12
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 60 times.
✗ Branch 11 not taken.
300 Buffer rhs = fill_matrix_random(rhs_shape.m, rhs_shape.n * rhs_shape.k, DataFormat(dt), get_seed());
171
3/6
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
120 Buffer bias = fill_matrix_random(1, out_shape.k, DataFormat(dt), get_seed());
172
173 // Call reference function
174
1/2
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
120 Buffer out = depthwise_reference<float>(
175
7/14
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 60 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 60 times.
✗ Branch 13 not taken.
300 1, in_shape.m, in_shape.n, in_shape.k, rhs_shape.m, rhs_shape.n, lhs.data(), rhs.data(), bias.data(), pad);
176
177 120 const auto [min, max] =
178
3/6
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
60 find_clamp_range(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, 1.0F - clamp_rate);
179
5/10
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 60 times.
✗ Branch 9 not taken.
120 Buffer out_clamped = clamp(dt, out.data(), out_shape.m * out_shape.n * out_shape.k, min, max);
180
181 // Populate reference data
182 60 TestData test_reference;
183 60 test_reference.lhs = std::move(lhs);
184 60 test_reference.rhs = std::move(rhs);
185 60 test_reference.bias = std::move(bias);
186 60 test_reference.out = std::move(out_clamped);
187 180 test_reference.clamp_range = {min, max};
188 60 return test_reference;
189
1/2
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
60 };
190 };
191
192 /// Perform RHS packing for depthwise
193 60 Buffer pack_rhs(const RhsPackDepthwiseKernel& kernel, const MatMulShape& shape, const TestData& reference) {
194 // Calculate size, and allocate buffer
195 60 const size_t dst_size = kernel.get_rhs_packed_size(shape.m, shape.n, shape.k);
196 60 Buffer dst(dst_size);
197
198 // RHS Pack API is subject to change.
199
4/8
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
60 kernel.pack(shape.m, shape.n, shape.m, shape.n, shape.k, reference.rhs.data(), reference.bias.data(), dst.data());
200 60 return dst;
201 60 }
202
203 /// Perform Depthwise Operation using main kernel.
204 60 Buffer dwconv(
205 const DepthwisePlanarKernel& kernel, const Rect& portion, const MatMulShape& in_shape, const MatMulShape& out_shape,
206 const Padding2D pad, const TestData& reference, const Buffer& rhs_packed, Range<float> clamp_range, DataType type) {
207 60 const size_t dst_size = kernel.get_dst_size(out_shape.m, out_shape.n, out_shape.k);
208 60 Buffer dst(dst_size);
209
210
1/2
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
60 const size_t dt_size_bytes = data_type_size_in_bits(type) / 8;
211 60 const size_t stride_in_row = in_shape.n * in_shape.k * dt_size_bytes;
212 60 const size_t dst_stride_row = out_shape.n * out_shape.k * dt_size_bytes;
213 60 const size_t stride_col = out_shape.k * dt_size_bytes;
214
215 // Loop the following. M-Step rows are handled at a time.
216
5/8
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1098 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1038 times.
✓ Branch 5 taken 60 times.
✓ Branch 6 taken 1038 times.
✗ Branch 7 not taken.
1098 for (size_t out_row = portion.start_row(); out_row < portion.end_row(); out_row += kernel.get_m_step()) {
217 1038 const int start_in_row = out_row - pad.top;
218
2/2
✓ Branch 0 taken 993 times.
✓ Branch 1 taken 45 times.
1038 const size_t pad_top = (start_in_row < 0) ? (-start_in_row) : 0;
219
2/2
✓ Branch 0 taken 993 times.
✓ Branch 1 taken 45 times.
1038 const size_t in_row = (start_in_row < 0) ? 0 : start_in_row;
220
221
1/2
✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
1038 const size_t valid_input_rows = (in_row < in_shape.m) ? (in_shape.m - in_row) : 0;
222 1038 const size_t valid_out_rows = (out_shape.m - out_row);
223
224
1/2
✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
2076 kernel.conv(
225
3/6
✓ Branch 0 taken 1038 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1038 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1038 times.
✗ Branch 5 not taken.
1038 reference.lhs.data() + (in_row * stride_in_row), rhs_packed.data(), dst.data() + (out_row * dst_stride_row),
226 1038 stride_in_row, stride_col, dst_stride_row, stride_col, valid_input_rows, valid_out_rows, pad.left, pad_top,
227 1038 0.f, clamp_range.min, clamp_range.max);
228 1038 }
229
230 60 return dst;
231 60 }
232 } // namespace
233
234 /// End-to-end test for depthwise kernels
235
7/14
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
182 TEST_P(DepthwisePlanarTest, Output) {
236 960 const auto& [method, in_shape, padding, clamp_rate] = GetParam();
237
1/2
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
60 if (not method.is_supported()) {
238 GTEST_SKIP() << "Unsupported CPU feature";
239 }
240
241 // Calculate Shapes.
242 300 const int out_height = (in_shape.m + padding.top + padding.bottom + 1 - method.filter.first);
243 300 const int out_width = (in_shape.n + padding.left + padding.right + 1 - method.filter.second);
244
4/16
✗ Branch 0 not taken.
✓ Branch 1 taken 60 times.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 60 times.
60 ASSERT_TRUE(out_height > 0 && out_width > 0);
245
246 120 const size_t dt_size_bytes = data_type_size_in_bits(method.data_type) / 8;
247 240 MatMulShape rhs_shape = {method.filter.first, method.filter.second, in_shape.k};
248 120 MatMulShape out_shape = {static_cast<size_t>(out_height), static_cast<size_t>(out_width), (in_shape.k)};
249
250 // 1. Calculate reference.
251 120 const TestData& test_data = ReferenceGenerator::get_test_reference(
252 360 {in_shape, rhs_shape, padding, method.data_type, method.acc_type, clamp_rate}, out_shape);
253
254 // 2. Pack RHS (Weights+Bias)
255 120 Buffer rhs_packed = pack_rhs(method.rhs, rhs_shape, test_data);
256
1/2
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
60 const MatrixPortion out_portion{0, 0, 1, 1};
257
1/2
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
120 const Rect portion = out_portion.compute_portion(
258
3/6
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
120 out_shape.m, out_shape.n * out_shape.k, method.depthwise.get_m_step(), (rhs_packed.size() / dt_size_bytes));
259
260 // 3. Run Depthwise Kernel.
261
4/8
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 60 times.
✗ Branch 7 not taken.
240 Buffer out = dwconv(
262 180 method.depthwise, portion, in_shape, out_shape, padding, test_data, rhs_packed, test_data.clamp_range,
263 60 method.data_type);
264
265 // 4. Compare with reference result.
266
1/2
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
60 DefaultMismatchHandler handler(0, 0.0001, 0, 0.001);
267
1/2
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
120 const auto success = compare(
268
2/4
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
60 out.data(), test_data.out.data(), DataType::FP32, out_shape.m, out_shape.n * out_shape.k, portion, handler);
269
4/16
✓ Branch 0 taken 60 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 60 times.
60 ASSERT_TRUE(success);
270 60 }
271
272 /// Name generator for test case
273 120 [[maybe_unused]] static void PrintTo(const DepthwiseParamsParams& param, std::ostream* os) {
274 600 const auto& [method, shape, padding, clamp_rate] = param;
275 240 *os << method.name << "__";
276 120 PrintTo(shape, os);
277 120 *os << "__";
278 120 PrintTo(padding, os);
279 120 *os << "__";
280 240 *os << "__clamp_rate_" << static_cast<int>(clamp_rate * 100);
281 120 }
282
283 /// Test parameter listing
284
11/32
✓ Branch 0 taken 1 time.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 time.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 time.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 time.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 time.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✗ Branch 19 not taken.
✓ Branch 20 taken 60 times.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
62 INSTANTIATE_TEST_SUITE_P(
285 Depthwise, DepthwisePlanarTest,
286 testing::Combine(
287 testing::ValuesIn(get_depthwise_methods()), //
288 testing::ValuesIn({
289 // clang-format off
290 // IN_HEIGHT, IN_WIDTH, IN_CHANNELS
291 MatMulShape{ 4, 4, 1}, //
292 MatMulShape{ 8, 4, 16}, //
293 MatMulShape{ 96, 33, 37}, //
294 MatMulShape{ 99, 22, 51}, //
295 MatMulShape{ 127, 127, 127}, //
296 // clang-format on
297 }),
298 testing::ValuesIn({
299 // clang-format off
300 // pad_left, pad_right, pad_top, pad_bottom
301 Padding2D{0, 0, 0, 0},
302 Padding2D{0, 1, 0, 1},
303 Padding2D{1, 1, 1, 1},
304 Padding2D{5, 11, 7, 3},
305 // clang-format on
306 }),
307 testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})), //
308 testing::PrintToStringParamName());
309
310 } // namespace kai::test
311