test/nextgen/quantization/asymm_linear_quantizer.hpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #pragma once | ||
| 8 | |||
| 9 | #include <cstddef> | ||
| 10 | |||
| 11 | #include "test/common/buffer.hpp" | ||
| 12 | #include "test/common/data_type.hpp" | ||
| 13 | #include "test/nextgen/functions/round.hpp" | ||
| 14 | #include "test/nextgen/harness/tensor.hpp" | ||
| 15 | #include "test/nextgen/quantization/quantizer.hpp" | ||
| 16 | |||
| 17 | namespace kai::test { | ||
| 18 | |||
| 19 | /// Asymmetric linear quantizer. | ||
| 20 | class AsymmLinearQuantizer : public Quantizer { | ||
| 21 | public: | ||
| 22 | /// Creates a new asymmetric linear quantizer. | ||
| 23 | /// | ||
| 24 | /// @param[in] qdata_dtype The quantized data type. | ||
| 25 | /// @param[in] qscale_dtype The quantization scale data type. | ||
| 26 | /// @param[in] qzp_dtype The quantization zero-point data type. | ||
| 27 | /// @param[in] qdata_round_mode The rounding mode to calculate quantized data. | ||
| 28 | /// @param[in] qzp_round_mode The rounding mode to calculate quantization zero-point. | ||
| 29 | /// @param[in] block_height The quantization block height (0 if it's full height). | ||
| 30 | /// @param[in] block_width The quantization block width (0 if it's full width). | ||
| 31 | 16 | AsymmLinearQuantizer( | |
| 32 | DataType qdata_dtype, DataType qscale_dtype, DataType qzp_dtype, RoundMode qdata_round_mode, | ||
| 33 | RoundMode qzp_round_mode, size_t block_height, size_t block_width) : | ||
| 34 | 12 | m_qdata_dtype(qdata_dtype), | |
| 35 | 12 | m_qscale_dtype(qscale_dtype), | |
| 36 | 12 | m_qzp_dtype(qzp_dtype), | |
| 37 | 12 | m_qdata_round_mode(qdata_round_mode), | |
| 38 | 12 | m_qzp_round_mode(qzp_round_mode), | |
| 39 | 12 | m_block_height(block_height), | |
| 40 | 28 | m_block_width(block_width) { | |
| 41 | 16 | } | |
| 42 | |||
| 43 | void dynamic_quantize( | ||
| 44 | DataType fp_dtype, Span<const size_t> shape, Span<const std::byte> fp_data, Tensor& qdata, Tensor& qscale, | ||
| 45 | Tensor& qzp) const override; | ||
| 46 | [[nodiscard]] Buffer dequantize( | ||
| 47 | DataType fp_dtype, Span<const size_t> shape, Span<const std::byte> qdata, Span<const std::byte> qscale, | ||
| 48 | Span<const std::byte> qzp) const override; | ||
| 49 | |||
| 50 | private: | ||
| 51 | DataType m_qdata_dtype; | ||
| 52 | DataType m_qscale_dtype; | ||
| 53 | DataType m_qzp_dtype; | ||
| 54 | |||
| 55 | RoundMode m_qdata_round_mode; | ||
| 56 | RoundMode m_qzp_round_mode; | ||
| 57 | |||
| 58 | size_t m_block_height; | ||
| 59 | size_t m_block_width; | ||
| 60 | }; | ||
| 61 | |||
| 62 | } // namespace kai::test | ||
| 63 |