KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 96.4% 27 / 0 / 28
Functions: 100.0% 3 / 0 / 3
Branches: 60.5% 23 / 0 / 38

test/nextgen/reference/dequantize.cpp
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include "test/nextgen/reference/dequantize.hpp"
8
9 #include <cstddef>
10 #include <cstdint>
11 #include <tuple>
12 #include <type_traits>
13
14 #include "test/common/assert.hpp"
15 #include "test/common/buffer.hpp"
16 #include "test/common/data_type.hpp"
17 #include "test/common/int4.hpp"
18 #include "test/common/memory.hpp"
19 #include "test/common/numeric_limits.hpp"
20 #include "test/common/round.hpp"
21 #include "test/common/span.hpp"
22 #include "test/common/type_traits.hpp"
23
24 namespace kai::test {
25
26 namespace {
27
28 template <typename FpData, typename QData, typename QScale, typename QZp>
29 400 [[nodiscard]] Buffer dequantize_linear(
30 size_t height, size_t width, size_t block_height, size_t block_width, Span<const std::byte> qdata,
31 Span<const std::byte> qscale, Span<const std::byte> qzp) {
32 400 Buffer fp_data(height * round_up_division(width * size_in_bits<FpData>, 8));
33
34
2/4
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
400 const size_t quant_width = round_up_division(width, block_width);
35
36
4/4
✓ Branch 0 taken 200 times.
✓ Branch 1 taken 15452 times.
✓ Branch 2 taken 200 times.
✓ Branch 3 taken 15923 times.
31775 for (size_t row = 0; row < height; ++row) {
37
4/4
✓ Branch 0 taken 1274204 times.
✓ Branch 1 taken 15452 times.
✓ Branch 2 taken 1272221 times.
✓ Branch 3 taken 15923 times.
2577800 for (size_t col = 0; col < width; ++col) {
38
2/4
✓ Branch 0 taken 1274204 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1272221 times.
✗ Branch 3 not taken.
2546425 const QData qdata_value = read_2d<QData>(qdata, width, row, col);
39
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1272221 times.
2546425 FpData fp_value = static_cast<FpData>(qdata_value);
40
41 if constexpr (!std::is_same_v<QZp, void>) {
42
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1274204 times.
1274204 const QZp qzp_value = read_2d<QZp>(qzp, quant_width, row / block_height, col / block_width);
43 1274204 fp_value -= static_cast<FpData>(qzp_value);
44 1274204 } else if constexpr (is_unsigned<QData>) {
45 static_assert(size_in_bits<QData> <= 64);
46 1272221 constexpr FpData zp_value = static_cast<FpData>(static_cast<uint64_t>(1) << (size_in_bits<QData> - 1));
47 1272221 fp_value -= zp_value;
48 1272221 }
49
50
2/4
✓ Branch 0 taken 1274204 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1272221 times.
✗ Branch 3 not taken.
2546425 const QScale qscale_value = read_2d<QScale>(qscale, quant_width, row / block_height, col / block_width);
51 2546425 fp_value *= static_cast<FpData>(qscale_value);
52
53
4/8
✓ Branch 0 taken 1274204 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1274204 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1272221 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1272221 times.
✗ Branch 7 not taken.
2546425 write_2d<FpData>(fp_data.view(), width, row, col, fp_value);
54 2546425 }
55 31375 }
56
57 400 return fp_data;
58 400 }
59
60 } // namespace
61
62 400 DequantizeLinearFn make_dequantize_linear(
63 DataType fp_dtype, DataType qdata_dtype, DataType qscale_dtype, DataType qzp_dtype) {
64 400 const auto dtypes = std::make_tuple(fp_dtype, qdata_dtype, qscale_dtype, qzp_dtype);
65
66
2/2
✓ Branch 0 taken 200 times.
✓ Branch 1 taken 200 times.
400 if (dtypes == std::make_tuple(DataType::FP32, DataType::I8, DataType::FP32, DataType::I32)) {
67 200 return dequantize_linear<float, int8_t, float, int32_t>;
68 }
69
70
1/2
✓ Branch 0 taken 200 times.
✗ Branch 1 not taken.
200 if (dtypes == std::make_tuple(DataType::FP32, DataType::U4, DataType::FP32, DataType::UNKNOWN)) {
71 200 return dequantize_linear<float, UInt4, float, void>;
72 }
73
74 KAI_TEST_ERROR("Not implemented.");
75 400 }
76
77 } // namespace kai::test
78