KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 71.4% 35 / 5 / 54
Functions: 30.2% 13 / 0 / 43
Branches: 17.0% 49 / 88 / 376

test/reference/binary_elementwise.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/binary_elementwise.hpp"
8
9 #include <algorithm>
10 #include <cstddef>
11 #include <cstdint>
12
13 #include "kai/kai_common.h"
14 #include "test/common/buffer.hpp"
15 #include "test/common/data_type.hpp"
16 #include "test/common/float16.hpp"
17 #include "test/common/int4.hpp"
18 #include "test/common/memory.hpp"
19
20 namespace kai::test {
21
22 namespace {
23
24 /// Binary element-wise operator.
25 enum class BinaryElementwiseOperator : uint32_t {
26 ADD, ///< Addition.
27 SUB, ///< Subtraction.
28 MUL, ///< Multiplication.
29 DIV, ///< Division.
30 };
31
32 /// Scalar binary element-wise function.
33 ///
34 /// @tparam op Binary element-wise operator to perform.
35 /// @tparam T Data type.
36 ///
37 /// @param[in] lhs LHS operand.
38 /// @param[in] rhs RHS operand.
39 ///
40 /// @return The result of the operation.
41 template <const BinaryElementwiseOperator op, typename T>
42 3312191 T scalar_binary_elementwise(T lhs, T rhs) {
43 if constexpr (op == BinaryElementwiseOperator::ADD) {
44 3069139 return lhs + rhs;
45 } else if constexpr (op == BinaryElementwiseOperator::SUB) {
46 60763 return lhs - rhs;
47 } else if constexpr (op == BinaryElementwiseOperator::MUL) {
48 182289 return lhs * rhs;
49 } else if constexpr (op == BinaryElementwiseOperator::DIV) {
50 return lhs / rhs;
51 } else {
52 KAI_ERROR("Unsupported binary element-wise operator!");
53 }
54 }
55
56 /// Binary element-wise function.
57 ///
58 /// @tparam op Binary element-wise operator to perform.
59 /// @tparam T Data type.
60 ///
61 /// @param[in] lhs LHS data buffer.
62 /// @param[in] rhs RHS data buffer.
63 /// @param[in] lhs_height LHS height.
64 /// @param[in] lhs_width LHS width.
65 /// @param[in] rhs_height RHS height.
66 /// @param[in] rhs_width RHS width.
67 ///
68 /// @return The result data buffer.
69 template <const BinaryElementwiseOperator op, typename T>
70 3685 Buffer binary_elementwise_any_op_type(
71 const void* lhs, const void* rhs, size_t lhs_height, size_t lhs_width, size_t rhs_height, size_t rhs_width) {
72 3685 const auto height = std::max(lhs_height, rhs_height);
73 3685 const auto width = std::max(lhs_width, rhs_width);
74
75 KAI_ASSUME_ALWAYS(width * size_in_bits<T> % 8 == 0);
76 3685 Buffer dst(height * width * size_in_bits<T> / 8);
77
78
8/32
✓ Branch 0 taken 521 times.
✓ Branch 1 taken 521 times.
✓ Branch 2 taken 1042 times.
✓ Branch 3 taken 1042 times.
✓ Branch 4 taken 521 times.
✓ Branch 5 taken 60763 times.
✓ Branch 6 taken 1601 times.
✓ Branch 7 taken 41031 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
107042 for (size_t y = 0; y < height; ++y) {
79
8/32
✓ Branch 0 taken 521 times.
✓ Branch 1 taken 60763 times.
✓ Branch 2 taken 1042 times.
✓ Branch 3 taken 121526 times.
✓ Branch 4 taken 60763 times.
✓ Branch 5 taken 60763 times.
✓ Branch 6 taken 3069139 times.
✓ Branch 7 taken 41031 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
3415548 for (size_t x = 0; x < width; ++x) {
80
6/32
✗ Branch 0 not taken.
✓ Branch 1 taken 60763 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 121526 times.
✓ Branch 4 taken 60736 times.
✓ Branch 5 taken 27 times.
✓ Branch 6 taken 3013888 times.
✓ Branch 7 taken 55251 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
3312191 const auto lhs_y = lhs_height > 1 ? y : 0;
81
7/32
✓ Branch 0 taken 60736 times.
✓ Branch 1 taken 27 times.
✓ Branch 2 taken 60736 times.
✓ Branch 3 taken 60790 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 60763 times.
✓ Branch 6 taken 3066481 times.
✓ Branch 7 taken 2658 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
3312191 const auto lhs_x = lhs_width > 1 ? x : 0;
82
1/14
✓ Branch 0 taken 3069139 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
3312191 const auto lhs_value = read_array<T>(lhs, lhs_y * lhs_width + lhs_x);
83
84
4/32
✗ Branch 0 not taken.
✓ Branch 1 taken 60763 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 121526 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 60763 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3069139 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
3312191 const auto rhs_y = rhs_height > 1 ? y : 0;
85
7/32
✓ Branch 0 taken 60736 times.
✓ Branch 1 taken 27 times.
✓ Branch 2 taken 60736 times.
✓ Branch 3 taken 60790 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 60763 times.
✓ Branch 6 taken 3066481 times.
✓ Branch 7 taken 2658 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
3312191 const auto rhs_x = rhs_width > 1 ? x : 0;
86
1/14
✗ Branch 0 not taken.
✓ Branch 1 taken 3069139 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
3312191 const auto rhs_value = read_array<T>(rhs, rhs_y * rhs_width + rhs_x);
87
88
4/32
✓ Branch 0 taken 60763 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 121526 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 60763 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3069139 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
3312191 const auto dst_value = scalar_binary_elementwise<op, T>(lhs_value, rhs_value);
89
2/16
✓ Branch 0 taken 3069139 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3069139 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
3312191 write_array<T>(dst.data(), y * width + x, dst_value);
90 3312191 }
91 103357 }
92
93 3685 return dst;
94 3685 }
95
96 template <const BinaryElementwiseOperator op>
97 1601 Buffer binary_elementwise_any_type(
98 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
99 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
100 KAI_ASSUME_ALWAYS(lhs_dt == rhs_dt);
101 KAI_ASSUME_ALWAYS(lhs_height == 1 || rhs_height == 1 || lhs_height == rhs_height);
102 KAI_ASSUME_ALWAYS(lhs_width == 1 || rhs_width == 1 || lhs_width == rhs_width);
103
104
1/20
✓ Branch 0 taken 1601 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
1601 switch (lhs_dt) {
105 case DataType::FP32:
106 1601 return binary_elementwise_any_op_type<op, float>(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
107
108 case DataType::FP16:
109 return binary_elementwise_any_op_type<op, Float16>(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
110
111 case DataType::I32:
112 return binary_elementwise_any_op_type<op, int32_t>(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
113
114 case DataType::QSU4:
115 return binary_elementwise_any_op_type<op, UInt4>(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
116
117 default:
118 KAI_ERROR("Unsupported data type!");
119 }
120 1601 }
121
122 } // namespace
123
124 1601 Buffer add(
125 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
126 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
127 1601 return binary_elementwise_any_type<BinaryElementwiseOperator::ADD>(
128 1601 lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width);
129 }
130
131 Buffer sub(
132 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
133 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
134 return binary_elementwise_any_type<BinaryElementwiseOperator::SUB>(
135 lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width);
136 }
137
138 template <typename T>
139 521 Buffer sub(
140 const void* lhs, size_t lhs_height, size_t lhs_width, //
141 const void* rhs, size_t rhs_height, size_t rhs_width) {
142 521 return binary_elementwise_any_op_type<BinaryElementwiseOperator::SUB, T>(
143 521 lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
144 }
145
146 template Buffer sub<int32_t>(
147 const void* lhs, size_t lhs_height, size_t lhs_width, //
148 const void* rhs, size_t rhs_height, size_t rhs_width);
149
150 Buffer mul(
151 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
152 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
153 return binary_elementwise_any_type<BinaryElementwiseOperator::MUL>(
154 lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width);
155 }
156
157 template <typename T>
158 1563 Buffer mul(
159 const void* lhs, size_t lhs_height, size_t lhs_width, //
160 const void* rhs, size_t rhs_height, size_t rhs_width) {
161 1563 return binary_elementwise_any_op_type<BinaryElementwiseOperator::MUL, T>(
162 1563 lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
163 }
164
165 template Buffer mul<float>(
166 const void* lhs, size_t lhs_height, size_t lhs_width, //
167 const void* rhs, size_t rhs_height, size_t rhs_width);
168
169 template Buffer mul<int32_t>(
170 const void* lhs, size_t lhs_height, size_t lhs_width, //
171 const void* rhs, size_t rhs_height, size_t rhs_width);
172
173 Buffer div(
174 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
175 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
176 return binary_elementwise_any_type<BinaryElementwiseOperator::DIV>(
177 lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width);
178 }
179
180 } // namespace kai::test
181