KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 79.1% 34 / 1 / 44
Functions: 80.0% 8 / 0 / 10
Branches: 50.0% 57 / 0 / 114

test/reference/cast.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/reference/cast.hpp"
8
9 #include <cstddef>
10 #include <cstdint>
11
12 #include "kai/kai_common.h"
13 #include "test/common/bfloat16.hpp"
14 #include "test/common/buffer.hpp"
15 #include "test/common/data_type.hpp"
16 #include "test/common/float16.hpp"
17 #include "test/common/memory.hpp"
18 #include "test/common/round.hpp"
19
20 namespace kai::test {
21
22 template <typename DstType, typename SrcType>
23 33292 Buffer cast(const void* src, size_t length) {
24 33292 Buffer dst(round_up_division(length * size_in_bits<DstType>, 8));
25
26
10/12
✓ Branch 0 taken 11512369 times.
✓ Branch 1 taken 15866 times.
✓ Branch 2 taken 1768584 times.
✓ Branch 3 taken 528 times.
✓ Branch 4 taken 20936666 times.
✓ Branch 5 taken 1722 times.
✓ Branch 6 taken 9659812 times.
✓ Branch 7 taken 14816 times.
✓ Branch 8 taken 2913180 times.
✓ Branch 9 taken 360 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
46823903 for (size_t i = 0; i < length; ++i) {
27
20/48
✓ Branch 0 taken 11512369 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 11512369 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11512369 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 11512369 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1768584 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1768584 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1768584 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1768584 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 20936666 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 20936666 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 20936666 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 20936666 times.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 25 taken 9659812 times.
✓ Branch 26 taken 9659812 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 9659812 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 9659812 times.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✓ Branch 33 taken 2913180 times.
✓ Branch 34 taken 2913180 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 2913180 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 2913180 times.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
46790611 write_array(dst.data(), i, static_cast<DstType>(read_array<SrcType>(src, i)));
28 46790611 }
29
30 33292 return dst;
31 33292 }
32
33 template <>
34 Buffer cast<BFloat16<false>, Float16>(const void* src, size_t length) {
35 Buffer dst(round_up_division(length * size_in_bits<BFloat16<>>, 8));
36
37 for (size_t i = 0; i < length; ++i) {
38 float interim = static_cast<float>(read_array<Float16>(src, i));
39 write_array(dst.data(), i, BFloat16<false>(interim));
40 }
41
42 return dst;
43 }
44
45 template <>
46 240 Buffer cast<BFloat16<true>, Float16>(const void* src, size_t length) {
47 240 Buffer dst(round_up_division(length * size_in_bits<BFloat16<>>, 8));
48
49
2/2
✓ Branch 0 taken 3561696 times.
✓ Branch 1 taken 240 times.
3561936 for (size_t i = 0; i < length; ++i) {
50
2/4
✓ Branch 0 taken 3561696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3561696 times.
✗ Branch 3 not taken.
3561696 float interim = static_cast<float>(read_array<Float16>(src, i));
51
3/6
✓ Branch 0 taken 3561696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3561696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3561696 times.
✗ Branch 5 not taken.
3561696 write_array(dst.data(), i, BFloat16<true>(interim));
52 3561696 }
53
54 240 return dst;
55 240 }
56
57 template Buffer cast<Float16, float>(const void* src, size_t length);
58 template Buffer cast<BFloat16<false>, float>(const void* src, size_t length);
59 template Buffer cast<BFloat16<true>, float>(const void* src, size_t length);
60 template Buffer cast<float, Float16>(const void* src, size_t length);
61 template Buffer cast<float, BFloat16<false>>(const void* src, size_t length);
62 template Buffer cast<float, BFloat16<true>>(const void* src, size_t length);
63
64 3198 Buffer cast(const void* src, kai::test::DataType src_dt, DataType dst_dt, size_t height, size_t width) {
65 3198 const auto length = height * width;
66
67
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3198 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
3198 if (src_dt == DataType::BF16 && dst_dt == DataType::FP32) {
68 return cast<float, BFloat16<>>(src, length);
69
4/4
✓ Branch 0 taken 1820 times.
✓ Branch 1 taken 1378 times.
✓ Branch 2 taken 1580 times.
✓ Branch 3 taken 240 times.
3198 } else if (src_dt == DataType::FP16 && dst_dt == DataType::BF16) {
70 240 return cast<BFloat16<>, Float16>(src, length);
71
4/4
✓ Branch 0 taken 1378 times.
✓ Branch 1 taken 1580 times.
✓ Branch 2 taken 790 times.
✓ Branch 3 taken 588 times.
2958 } else if (src_dt == DataType::FP32 && dst_dt == DataType::BF16) {
72 588 return cast<BFloat16<>, float>(src, length);
73
3/4
✓ Branch 0 taken 790 times.
✓ Branch 1 taken 1580 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 790 times.
2370 } else if (src_dt == DataType::FP32 && dst_dt == DataType::FP16) {
74 790 return cast<Float16, float>(src, length);
75
1/2
✓ Branch 0 taken 1580 times.
✗ Branch 1 not taken.
1580 } else if (src_dt == DataType::FP16 && dst_dt == DataType::FP32) {
76 1580 return cast<float, Float16>(src, length);
77 }
78 KAI_ERROR("Unsupported cast data type!");
79 3198 }
80
81 33960 Buffer cast_qsu4_qsi4(const void* src, size_t length) {
82 33960 Buffer dst(round_up_division(length, 2));
83
84
2/2
✓ Branch 0 taken 208407696 times.
✓ Branch 1 taken 33960 times.
208441656 for (size_t i = 0; i < length; ++i) {
85
5/10
✓ Branch 0 taken 208407696 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 208407696 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 208407696 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 208407696 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 208407696 times.
✗ Branch 9 not taken.
208407696 write_array(dst.data(), i, static_cast<UInt4>(static_cast<int32_t>(read_array<Int4>(src, i)) + 8));
86 208407696 }
87
88 33960 return dst;
89 33960 }
90
91 } // namespace kai::test
92