test/reference/matmul_pack.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include "test/reference/matmul_pack.hpp" | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | |||
| 11 | #include "test/common/buffer.hpp" | ||
| 12 | #include "test/common/round.hpp" | ||
| 13 | #include "test/reference/binary_elementwise.hpp" | ||
| 14 | #include "test/reference/pack.hpp" | ||
| 15 | #include "test/reference/pad.hpp" | ||
| 16 | #include "test/reference/reduce.hpp" | ||
| 17 | #include "test/reference/reorder.hpp" | ||
| 18 | |||
| 19 | namespace kai::test { | ||
| 20 | |||
| 21 | template <typename Data, typename Scale, typename ZeroPoint> | ||
| 22 | 521 | Buffer matmul_pack_rhs_nxk_static_quantized( | |
| 23 | const void* data, const void* scales, Scale lhs_scale, Scale dst_scale, const void* biases, | ||
| 24 | ZeroPoint lhs_zero_point, size_t n, size_t k, size_t block_height, size_t block_width) { | ||
| 25 | // The RHS data matrix is reordered according to the blocking parameters. | ||
| 26 | 521 | const auto reordered_data = reorder_block<Data>(data, n, k, block_height, block_width); | |
| 27 | |||
| 28 | // The effective per-channel scale: | ||
| 29 | // final_scales[n_index] = lhs_scale * rhs_scales[n_index] / dst_scale. | ||
| 30 | 521 | const auto scale_multiplier = lhs_scale / dst_scale; | |
| 31 |
1/2✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
|
521 | auto combined_scales = mul<Scale>(scales, 1, n, &scale_multiplier, 1, 1); |
| 32 |
1/2✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
|
1042 | combined_scales = pad_matrix<Scale>( |
| 33 |
2/4✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
|
521 | combined_scales.data(), 1, n, 0, 0, round_up_multiple(n, block_height) - n, 0, 0); // Pads with 0s. |
| 34 | |||
| 35 | // The effective per-channel biases: | ||
| 36 | // final_biases[n_index] = biases[n_index] - lhs_zero_point * sum(data[n_index, :]). | ||
| 37 |
1/2✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
|
521 | const auto row_sum_reduced = reduce_add_x<Data, ZeroPoint>(data, n, k); |
| 38 | // Reduced across width earlier, so lhs width is now 1 | ||
| 39 |
2/4✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
|
521 | const auto row_sum_times_lhs_zp = mul<ZeroPoint>(row_sum_reduced.data(), n, 1, &lhs_zero_point, 1, 1); |
| 40 |
2/4✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
|
521 | auto combined_biases = sub<ZeroPoint>(biases, 1, n, row_sum_times_lhs_zp.data(), 1, n); |
| 41 |
1/2✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
|
1042 | combined_biases = pad_matrix<ZeroPoint>( |
| 42 |
2/4✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
|
521 | combined_biases.data(), 1, n, 0, 0, round_up_multiple(n, block_height) - n, 0, 0); // Pads with 0s. |
| 43 | |||
| 44 | // Packs the effective biases followed by the data block followed by the effective scales for the block. | ||
| 45 |
1/2✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
|
1042 | auto packed_rhs = pack_zero_points_data_scales_per_block<ZeroPoint, Data, Scale>( |
| 46 |
4/8✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 521 times.
✗ Branch 7 not taken.
|
521 | combined_biases.data(), reordered_data.data(), combined_scales.data(), round_up_division(n, block_height), |
| 47 |
1/2✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
|
521 | block_height, block_height * round_up_multiple(k, block_width), block_height); |
| 48 | |||
| 49 | 521 | return packed_rhs; | |
| 50 | 521 | } | |
| 51 | |||
| 52 | template Buffer matmul_pack_rhs_nxk_static_quantized<int8_t>( | ||
| 53 | const void* data, const void* scales, float lhs_scale, float dst_scale, const void* biases, int32_t lhs_zero_point, | ||
| 54 | size_t n, size_t k, size_t block_height, size_t block_width); | ||
| 55 | |||
| 56 | } // namespace kai::test | ||
| 57 |