Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #include "test/reference/matmul_pack.hpp" | ||
8 | |||
9 | #include <cstddef> | ||
10 | |||
11 | #include "test/common/buffer.hpp" | ||
12 | #include "test/common/round.hpp" | ||
13 | #include "test/reference/binary_elementwise.hpp" | ||
14 | #include "test/reference/pack.hpp" | ||
15 | #include "test/reference/pad.hpp" | ||
16 | #include "test/reference/reduce.hpp" | ||
17 | #include "test/reference/reorder.hpp" | ||
18 | |||
19 | namespace kai::test { | ||
20 | |||
21 | template <typename Data, typename Scale, typename ZeroPoint> | ||
22 | 927 | Buffer matmul_pack_rhs_nxk_static_quantized( | |
23 | const void* data, const void* scales, Scale lhs_scale, Scale dst_scale, const void* biases, | ||
24 | ZeroPoint lhs_zero_point, size_t n, size_t k, size_t block_height, size_t block_width) { | ||
25 | // The RHS data matrix is reordered according to the blocking parameters. | ||
26 | 927 | const auto reordered_data = reorder_block<Data>(data, n, k, block_height, block_width); | |
27 | |||
28 | // The effective per-channel scale: | ||
29 | // final_scales[n_index] = lhs_scale * rhs_scales[n_index] / dst_scale. | ||
30 | 927 | const auto scale_multiplier = lhs_scale / dst_scale; | |
31 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | auto combined_scales = mul<Scale>(scales, 1, n, &scale_multiplier, 1, 1); |
32 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
1854 | combined_scales = pad_matrix<Scale>( |
33 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
927 | combined_scales.data(), 1, n, 0, 0, round_up_multiple(n, block_height) - n, 0, 0); // Pads with 0s. |
34 | |||
35 | // The effective per-channel biases: | ||
36 | // final_biases[n_index] = biases[n_index] - lhs_zero_point * sum(data[n_index, :]). | ||
37 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | const auto row_sum_reduced = reduce_add_x<Data, ZeroPoint>(data, n, k); |
38 | // Reduced across width earlier, so lhs width is now 1 | ||
39 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
927 | const auto row_sum_times_lhs_zp = mul<ZeroPoint>(row_sum_reduced.data(), n, 1, &lhs_zero_point, 1, 1); |
40 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
927 | auto combined_biases = sub<ZeroPoint>(biases, 1, n, row_sum_times_lhs_zp.data(), 1, n); |
41 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
1854 | combined_biases = pad_matrix<ZeroPoint>( |
42 |
2/4✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
|
927 | combined_biases.data(), 1, n, 0, 0, round_up_multiple(n, block_height) - n, 0, 0); // Pads with 0s. |
43 | |||
44 | // Packs the effective biases followed by the data block followed by the effective scales for the block. | ||
45 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
1854 | auto packed_rhs = pack_zero_points_data_scales_per_block<ZeroPoint, Data, Scale>( |
46 |
4/8✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 927 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 927 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 927 times.
✗ Branch 7 not taken.
|
927 | combined_biases.data(), reordered_data.data(), combined_scales.data(), round_up_division(n, block_height), |
47 |
1/2✓ Branch 0 taken 927 times.
✗ Branch 1 not taken.
|
927 | block_height, block_height * round_up_multiple(k, block_width), block_height); |
48 | |||
49 | 927 | return packed_rhs; | |
50 | 927 | } | |
51 | |||
52 | template Buffer matmul_pack_rhs_nxk_static_quantized<int8_t>( | ||
53 | const void* data, const void* scales, float lhs_scale, float dst_scale, const void* biases, int32_t lhs_zero_point, | ||
54 | size_t n, size_t k, size_t block_height, size_t block_width); | ||
55 | |||
56 | } // namespace kai::test | ||
57 |