KleidiAI Coverage Report


Directory: ./
File: test/reference/binary_elementwise.cpp
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 73.5% 36 5 54
Functions: 34.9% 15 0 43
Branches: 22.6% 65 88 376

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/binary_elementwise.hpp"
8
9 #include <algorithm>
10 #include <cstddef>
11 #include <cstdint>
12
13 #include "kai/kai_common.h"
14 #include "test/common/buffer.hpp"
15 #include "test/common/data_type.hpp"
16 #include "test/common/float16.hpp"
17 #include "test/common/int4.hpp"
18 #include "test/common/memory.hpp"
19
20 namespace kai::test {
21
22 namespace {
23
24 /// Binary element-wise operator.
25 enum class BinaryElementwiseOperator : uint32_t {
26 ADD, ///< Addition.
27 SUB, ///< Subtraction.
28 MUL, ///< Multiplication.
29 DIV, ///< Division.
30 };
31
32 /// Scalar binary element-wise function.
33 ///
34 /// @tparam op Binary element-wise operator to perform.
35 /// @tparam T Data type.
36 ///
37 /// @param[in] lhs LHS operand.
38 /// @param[in] rhs RHS operand.
39 ///
40 /// @return The result of the operation.
41 template <const BinaryElementwiseOperator op, typename T>
42 6216353 T scalar_binary_elementwise(T lhs, T rhs) {
43 if constexpr (op == BinaryElementwiseOperator::ADD) {
44 5794841 return lhs + rhs;
45 } else if constexpr (op == BinaryElementwiseOperator::SUB) {
46 105378 return lhs - rhs;
47 } else if constexpr (op == BinaryElementwiseOperator::MUL) {
48 316134 return lhs * rhs;
49 } else if constexpr (op == BinaryElementwiseOperator::DIV) {
50 return lhs / rhs;
51 } else {
52 KAI_ERROR("Unsupported binary element-wise operator!");
53 }
54 }
55
56 /// Binary element-wise function.
57 ///
58 /// @tparam op Binary element-wise operator to perform.
59 /// @tparam T Data type.
60 ///
61 /// @param[in] lhs LHS data buffer.
62 /// @param[in] rhs RHS data buffer.
63 /// @param[in] lhs_height LHS height.
64 /// @param[in] lhs_width LHS width.
65 /// @param[in] rhs_height RHS height.
66 /// @param[in] rhs_width RHS width.
67 ///
68 /// @return The result data buffer.
69 template <const BinaryElementwiseOperator op, typename T>
70 6086 Buffer binary_elementwise_any_op_type(
71 const void* lhs, const void* rhs, size_t lhs_height, size_t lhs_width, size_t rhs_height, size_t rhs_width) {
72 6086 const auto height = std::max(lhs_height, rhs_height);
73 6086 const auto width = std::max(lhs_width, rhs_width);
74
75 KAI_ASSUME(width * size_in_bits<T> % 8 == 0);
76 6086 Buffer dst(height * width * size_in_bits<T> / 8);
77
78
10/32
✓ Branch 0 taken 927 times.
✓ Branch 1 taken 927 times.
✓ Branch 2 taken 1854 times.
✓ Branch 3 taken 1854 times.
✓ Branch 4 taken 927 times.
✓ Branch 5 taken 105378 times.
✓ Branch 6 taken 1191 times.
✓ Branch 7 taken 36473 times.
✓ Branch 8 taken 1187 times.
✓ Branch 9 taken 36766 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
187484 for (size_t y = 0; y < height; ++y) {
79
10/32
✓ Branch 0 taken 927 times.
✓ Branch 1 taken 105378 times.
✓ Branch 2 taken 1854 times.
✓ Branch 3 taken 210756 times.
✓ Branch 4 taken 105378 times.
✓ Branch 5 taken 105378 times.
✓ Branch 6 taken 2892060 times.
✓ Branch 7 taken 36473 times.
✓ Branch 8 taken 2902781 times.
✓ Branch 9 taken 36766 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
6397751 for (size_t x = 0; x < width; ++x) {
80
8/32
✗ Branch 0 not taken.
✓ Branch 1 taken 105378 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 210756 times.
✓ Branch 4 taken 105330 times.
✓ Branch 5 taken 48 times.
✓ Branch 6 taken 2872507 times.
✓ Branch 7 taken 19553 times.
✓ Branch 8 taken 2886286 times.
✓ Branch 9 taken 16495 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
6216353 const auto lhs_y = lhs_height > 1 ? y : 0;
81
9/32
✓ Branch 0 taken 105330 times.
✓ Branch 1 taken 48 times.
✓ Branch 2 taken 105330 times.
✓ Branch 3 taken 105426 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 105378 times.
✓ Branch 6 taken 2891162 times.
✓ Branch 7 taken 898 times.
✓ Branch 8 taken 2901811 times.
✓ Branch 9 taken 970 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
6216353 const auto lhs_x = lhs_width > 1 ? x : 0;
82
2/14
✓ Branch 0 taken 2892060 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2902781 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
6216353 const auto lhs_value = read_array<T>(lhs, lhs_y * lhs_width + lhs_x);
83
84
5/32
✗ Branch 0 not taken.
✓ Branch 1 taken 105378 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 210756 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 105378 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2892060 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2902781 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
6216353 const auto rhs_y = rhs_height > 1 ? y : 0;
85
9/32
✓ Branch 0 taken 105330 times.
✓ Branch 1 taken 48 times.
✓ Branch 2 taken 105330 times.
✓ Branch 3 taken 105426 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 105378 times.
✓ Branch 6 taken 2891162 times.
✓ Branch 7 taken 898 times.
✓ Branch 8 taken 2901811 times.
✓ Branch 9 taken 970 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
6216353 const auto rhs_x = rhs_width > 1 ? x : 0;
86
2/14
✗ Branch 0 not taken.
✓ Branch 1 taken 2892060 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2902781 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
6216353 const auto rhs_value = read_array<T>(rhs, rhs_y * rhs_width + rhs_x);
87
88
5/32
✓ Branch 0 taken 105378 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 210756 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 105378 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2892060 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2902781 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
6216353 const auto dst_value = scalar_binary_elementwise<op, T>(lhs_value, rhs_value);
89
3/16
✓ Branch 0 taken 2892060 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2892060 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2902781 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
6216353 write_array<T>(dst.data(), y * width + x, dst_value);
90 6216353 }
91 181398 }
92
93 6086 return dst;
94 6086 }
95
96 template <const BinaryElementwiseOperator op>
97 2378 Buffer binary_elementwise_any_type(
98 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
99 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
100 KAI_ASSUME(lhs_dt == rhs_dt);
101 KAI_ASSUME(lhs_height == 1 || rhs_height == 1 || lhs_height == rhs_height);
102 KAI_ASSUME(lhs_width == 1 || rhs_width == 1 || lhs_width == rhs_width);
103
104
2/20
✓ Branch 0 taken 1191 times.
✓ Branch 1 taken 1187 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
2378 switch (lhs_dt) {
105 case DataType::FP32:
106 1191 return binary_elementwise_any_op_type<op, float>(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
107
108 case DataType::FP16:
109 1187 return binary_elementwise_any_op_type<op, Float16>(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
110
111 case DataType::I32:
112 return binary_elementwise_any_op_type<op, int32_t>(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
113
114 case DataType::QSU4:
115 return binary_elementwise_any_op_type<op, UInt4>(lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
116
117 default:
118 KAI_ERROR("Unsupported data type!");
119 }
120 2378 }
121
122 } // namespace
123
124 2378 Buffer add(
125 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
126 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
127 2378 return binary_elementwise_any_type<BinaryElementwiseOperator::ADD>(
128 2378 lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width);
129 }
130
131 Buffer sub(
132 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
133 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
134 return binary_elementwise_any_type<BinaryElementwiseOperator::SUB>(
135 lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width);
136 }
137
138 template <typename T>
139 927 Buffer sub(
140 const void* lhs, size_t lhs_height, size_t lhs_width, //
141 const void* rhs, size_t rhs_height, size_t rhs_width) {
142 927 return binary_elementwise_any_op_type<BinaryElementwiseOperator::SUB, T>(
143 927 lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
144 }
145
146 template Buffer sub<int32_t>(
147 const void* lhs, size_t lhs_height, size_t lhs_width, //
148 const void* rhs, size_t rhs_height, size_t rhs_width);
149
150 Buffer mul(
151 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
152 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
153 return binary_elementwise_any_type<BinaryElementwiseOperator::MUL>(
154 lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width);
155 }
156
157 template <typename T>
158 2781 Buffer mul(
159 const void* lhs, size_t lhs_height, size_t lhs_width, //
160 const void* rhs, size_t rhs_height, size_t rhs_width) {
161 2781 return binary_elementwise_any_op_type<BinaryElementwiseOperator::MUL, T>(
162 2781 lhs, rhs, lhs_height, lhs_width, rhs_height, rhs_width);
163 }
164
165 template Buffer mul<float>(
166 const void* lhs, size_t lhs_height, size_t lhs_width, //
167 const void* rhs, size_t rhs_height, size_t rhs_width);
168
169 template Buffer mul<int32_t>(
170 const void* lhs, size_t lhs_height, size_t lhs_width, //
171 const void* rhs, size_t rhs_height, size_t rhs_width);
172
173 Buffer div(
174 const void* lhs, DataType lhs_dt, size_t lhs_height, size_t lhs_width, //
175 const void* rhs, DataType rhs_dt, size_t rhs_height, size_t rhs_width) {
176 return binary_elementwise_any_type<BinaryElementwiseOperator::DIV>(
177 lhs, lhs_dt, lhs_height, lhs_width, rhs, rhs_dt, rhs_height, rhs_width);
178 }
179
180 } // namespace kai::test
181