KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 98.0% 97 / 0 / 99
Functions: 96.0% 24 / 0 / 25
Branches: 90.0% 9 / 0 / 10

benchmark/dwconv/dwconv_runner.hpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #pragma once
8
9 #include <cfloat>
10 #include <cstddef>
11 #include <cstdint>
12 #include <limits>
13 #include <test/common/data_type.hpp>
14
15 #include "dwconv_interface.hpp"
16
17 namespace kai::benchmark {
18
19 using DataType = test::DataType;
20
21 6 inline size_t data_type_size_bytes(const DataType dt) {
22 6 return test::data_type_size_in_bits(dt) / 8;
23 }
24
25 // Base runner that abstracts common configuration and exposes a uniform API.
26 class DwConvRunner {
27 public:
28 1 DwConvRunner(const DwConvTraits& traits, DataType src_type, DataType dst_type) :
29 1 m_traits(traits), m_src_type(src_type), m_dst_type(dst_type) {
30 1 }
31
32 1 virtual ~DwConvRunner() = default;
33
34 1 void set_input_dims(size_t height, size_t width) {
35 1 m_input_height = height;
36 1 m_input_width = width;
37 1 }
38 1 void set_output_dims(size_t height, size_t width) {
39 1 m_output_height = height;
40 1 m_output_width = width;
41 1 }
42 1 void set_channels(size_t channels) {
43 1 m_num_channels = channels;
44 1 }
45 1 void set_padding(size_t top, size_t bottom, size_t left, size_t right) {
46 1 m_pad_top = top;
47 1 m_pad_bottom = bottom;
48 1 m_pad_left = left;
49 1 m_pad_right = right;
50 1 }
51 1 void set_clamp(float min_val, float max_val) {
52 1 m_clamp_min = min_val;
53 1 m_clamp_max = max_val;
54 1 }
55
56 // API to allow derived classes to stash RHS in the shape they need.
57 virtual void prepare(
58 const void* /* rhs_packed */, const void* /* weights */, const void* /* bias */, const void* /* qp */) {
59 // No-op
60 }
61
62 // Uniform run call from benchmark layer. Implements common tiling and delegates kernel call.
63 1 void run(const void* src, void* dst) {
64 1 const size_t m_step = traits().get_m_step();
65 1 const size_t filter_height = traits().get_filter_height();
66
67 1 const size_t in_stride_row = in_stride_row_bytes();
68 1 const size_t in_stride_col = in_stride_col_bytes();
69 1 const size_t dst_stride_row = dst_stride_row_bytes();
70 1 const size_t dst_stride_col = dst_stride_col_bytes();
71
72 1 const uint8_t* src_ptr = reinterpret_cast<const uint8_t*>(src);
73 1 uint8_t* dst_ptr = reinterpret_cast<uint8_t*>(dst);
74
75
2/2
✓ Branch 0 taken 1 time.
✓ Branch 1 taken 8 times.
9 for (size_t out_row = 0; out_row < output_height(); out_row += m_step) {
76
1/2
✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
8 const size_t valid_dst_rows = (out_row + m_step <= output_height()) ? m_step : (output_height() - out_row);
77
2/2
✓ Branch 0 taken 7 times.
✓ Branch 1 taken 1 time.
8 const size_t in_row = (out_row > pad_top()) ? (out_row - pad_top()) : 0;
78
2/2
✓ Branch 0 taken 7 times.
✓ Branch 1 taken 1 time.
8 const size_t valid_input_rows = (in_row + filter_height + m_step - 1 <= input_height())
79 7 ? (filter_height + m_step - 1)
80 1 : (input_height() - in_row);
81
82 8 const size_t src_offset = traits().get_src_offset(in_row, in_stride_row);
83 8 const size_t dst_offset = traits().get_dst_offset(out_row, dst_stride_row);
84
85
2/2
✓ Branch 0 taken 1 time.
✓ Branch 1 taken 7 times.
8 const size_t tile_pad_top = (out_row < pad_top()) ? (pad_top() - out_row) : 0;
86 8 const size_t tile_pad_left = pad_left();
87
88 8 call_kernel(
89 8 src_ptr + src_offset, dst_ptr + dst_offset, in_stride_row, in_stride_col, dst_stride_row,
90 8 dst_stride_col, valid_input_rows, valid_dst_rows, tile_pad_left, tile_pad_top);
91 8 }
92 1 }
93
94 protected:
95 // Derived classes implement the actual micro-kernel invocation for a tile
96 virtual void call_kernel(
97 const uint8_t* src_tile, uint8_t* dst_tile, size_t in_stride_row, size_t in_stride_col, size_t dst_stride_row,
98 size_t dst_stride_col, size_t valid_input_rows, size_t valid_dst_rows, size_t tile_pad_left,
99 size_t tile_pad_top) = 0;
100
101 // Helpers usable by derived classes
102 1 size_t in_stride_row_bytes() const {
103 1 return m_input_width * m_num_channels * data_type_size_bytes(m_src_type);
104 }
105 1 size_t in_stride_col_bytes() const {
106 1 return m_num_channels * data_type_size_bytes(m_src_type);
107 }
108 1 size_t dst_stride_row_bytes() const {
109 1 return m_output_width * m_num_channels * data_type_size_bytes(m_dst_type);
110 }
111 1 size_t dst_stride_col_bytes() const {
112 1 return m_num_channels * data_type_size_bytes(m_dst_type);
113 }
114
115 18 const DwConvTraits& traits() const {
116 18 return m_traits;
117 }
118 9 size_t input_height() const {
119 9 return m_input_height;
120 }
121 17 size_t output_height() const {
122 17 return m_output_height;
123 }
124 24 size_t pad_top() const {
125 24 return m_pad_top;
126 }
127 size_t pad_bottom() const {
128 return m_pad_bottom;
129 }
130 8 size_t pad_left() const {
131 8 return m_pad_left;
132 }
133 size_t pad_right() const {
134 return m_pad_right;
135 }
136 8 float clamp_min() const {
137 8 return m_clamp_min;
138 }
139 8 float clamp_max() const {
140 8 return m_clamp_max;
141 }
142
143 private:
144 DwConvTraits m_traits{};
145 DataType m_src_type{DataType::FP32};
146 DataType m_dst_type{DataType::FP32};
147
148 1 size_t m_input_height{0};
149 1 size_t m_input_width{0};
150 1 size_t m_output_height{0};
151 1 size_t m_output_width{0};
152 1 size_t m_num_channels{0};
153 1 size_t m_pad_top{0};
154 1 size_t m_pad_bottom{0};
155 1 size_t m_pad_left{0};
156 1 size_t m_pad_right{0};
157 1 float m_clamp_min{-std::numeric_limits<float>::infinity()};
158 1 float m_clamp_max{std::numeric_limits<float>::infinity()};
159 };
160
161 // Packed FP32 runner
162 class DwConvPackedFloatRunner : public DwConvRunner {
163 public:
164 2 DwConvPackedFloatRunner(
165 const DwConvPackedFloatInterface& iface, const DwConvTraits& traits, DataType src_type, DataType dst_type) :
166 2 DwConvRunner(traits, src_type, dst_type), m_iface(iface) {
167 2 }
168
169 1 void prepare(
170 const void* rhs_packed, const void* /* weights */, const void* /* bias */, const void* /* qp */) override {
171 1 m_rhs_packed = rhs_packed;
172 1 }
173
174 protected:
175 8 void call_kernel(
176 const uint8_t* src_tile, uint8_t* dst_tile, size_t in_stride_row, size_t in_stride_col, size_t dst_stride_row,
177 size_t dst_stride_col, size_t valid_input_rows, size_t valid_dst_rows, size_t tile_pad_left,
178 size_t tile_pad_top) override {
179 16 m_iface.run_dwconv(
180 8 src_tile, m_rhs_packed, dst_tile, in_stride_row, in_stride_col, dst_stride_row, dst_stride_col,
181 8 valid_input_rows, valid_dst_rows, tile_pad_left, tile_pad_top,
182 0.0f, // pad_value
183 8 clamp_min(), clamp_max());
184 8 }
185
186 private:
187 DwConvPackedFloatInterface m_iface{};
188 1 const void* m_rhs_packed{nullptr};
189 };
190
191 // Split FP32 runner
192 class DwConvSplitFloatRunner : public DwConvRunner {
193 public:
194 DwConvSplitFloatRunner(
195 const DwConvSplitFloatInterface& iface, const DwConvTraits& traits, DataType src_type, DataType dst_type) :
196 DwConvRunner(traits, src_type, dst_type), m_iface(iface) {
197 }
198
199 void prepare(const void* /* rhs_packed */, const void* weights, const void* bias, const void* /* qp */) override {
200 m_weights = static_cast<const float*>(weights);
201 m_bias = static_cast<const float*>(bias);
202 }
203
204 protected:
205 void call_kernel(
206 const uint8_t* src_tile, uint8_t* dst_tile, size_t in_stride_row, size_t in_stride_col, size_t dst_stride_row,
207 size_t dst_stride_col, size_t valid_input_rows, size_t valid_dst_rows, size_t tile_pad_left,
208 size_t tile_pad_top) override {
209 m_iface.run_dwconv(
210 src_tile, m_weights, m_bias, dst_tile, in_stride_row, in_stride_col, dst_stride_row, dst_stride_col,
211 valid_input_rows, valid_dst_rows, tile_pad_left, tile_pad_top, 0.0f, clamp_min(), clamp_max());
212 }
213
214 private:
215 DwConvSplitFloatInterface m_iface{};
216 const float* m_weights{nullptr};
217 const float* m_bias{nullptr};
218 };
219
220 } // namespace kai::benchmark
221