test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #include <gtest/gtest.h> | ||
| 8 | |||
| 9 | #include <array> | ||
| 10 | #include <cstddef> | ||
| 11 | #include <cstdint> | ||
| 12 | #include <cstdlib> | ||
| 13 | #include <functional> | ||
| 14 | #include <optional> | ||
| 15 | #include <sstream> | ||
| 16 | #include <string> | ||
| 17 | #include <string_view> | ||
| 18 | #include <tuple> | ||
| 19 | #include <unordered_map> | ||
| 20 | |||
| 21 | #include "kai/kai_common.h" | ||
| 22 | #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h" | ||
| 23 | #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" | ||
| 24 | #include "kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h" | ||
| 25 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" | ||
| 26 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp_interface.h" | ||
| 27 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa.h" | ||
| 28 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h" | ||
| 29 | #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_interface.h" | ||
| 30 | #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h" | ||
| 31 | #include "kai/ukernels/matmul/pack/kai_lhs_pack_x8p2vlx4_x8_sme.h" | ||
| 32 | #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" | ||
| 33 | #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h" | ||
| 34 | #include "test/common/abi_checker.hpp" | ||
| 35 | #include "test/common/buffer.hpp" | ||
| 36 | #include "test/common/cpu_info.hpp" | ||
| 37 | #include "test/common/matmul_test_common.hpp" | ||
| 38 | #include "test/common/matrix_portion.hpp" | ||
| 39 | #include "test/common/memory.hpp" | ||
| 40 | #include "test/common/rect.hpp" | ||
| 41 | #include "test/common/seed.hpp" | ||
| 42 | #include "test/common/sme.hpp" | ||
| 43 | #include "test/reference/binary_elementwise.hpp" | ||
| 44 | #include "test/reference/clamp.hpp" | ||
| 45 | #include "test/reference/fill.hpp" | ||
| 46 | #include "test/reference/matmul.hpp" | ||
| 47 | #include "test/reference/matmul_pack.hpp" | ||
| 48 | #include "test/reference/quantize.hpp" | ||
| 49 | #include "test/reference/reduce.hpp" | ||
| 50 | #include "test/reference/reorder.hpp" | ||
| 51 | #include "test/reference/transpose.hpp" | ||
| 52 | |||
| 53 | namespace kai::test { | ||
| 54 | |||
| 55 | // Ensure static linkage for all functionality local to this test file | ||
| 56 | namespace { | ||
| 57 | |||
| 58 | struct KChunk { | ||
| 59 | size_t count; | ||
| 60 | size_t length; | ||
| 61 | }; | ||
| 62 | |||
| 63 |
3/6✓ Branch 0 taken 10008 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 10008 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10008 times.
✗ Branch 5 not taken.
|
10008 | struct LhsPackKernel { |
| 64 | std::function<size_t(size_t mr)> get_m_step; | ||
| 65 | std::function<size_t(size_t m_idx, size_t lhs_stride)> get_lhs_offset; | ||
| 66 | std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)> get_packed_lhs_offset; | ||
| 67 | std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)> get_packed_lhs_size; | ||
| 68 | std::function<void( | ||
| 69 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
| 70 | void* lhs_packed)> | ||
| 71 | pack; | ||
| 72 | }; | ||
| 73 | |||
| 74 |
2/4✓ Branch 0 taken 49986 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 49986 times.
✗ Branch 3 not taken.
|
49986 | struct LhsPackIndirectKernel { |
| 75 | std::function<size_t()> get_m_step; | ||
| 76 | std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_packed_lhs_offset; | ||
| 77 | std::function<size_t(size_t m, size_t k_chunk_count, size_t k_chunk_length)> get_packed_lhs_size; | ||
| 78 | std::function<void( | ||
| 79 | size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, | ||
| 80 | const void* zero, void* packed_lhs)> | ||
| 81 | pack; | ||
| 82 | }; | ||
| 83 | |||
| 84 |
5/10✓ Branch 0 taken 12528 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12528 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12528 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 12528 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 12528 times.
✗ Branch 9 not taken.
|
12528 | struct RhsPackKernel { |
| 85 | std::function<size_t()> get_n_step; | ||
| 86 | std::function<size_t(size_t n_idx)> get_rhs_offset; | ||
| 87 | std::function<size_t(size_t n_idx)> get_bias_offset; | ||
| 88 | std::function<size_t(size_t n_idx)> get_scale_offset; | ||
| 89 | std::function<size_t(size_t n_idx, size_t k)> get_packed_rhs_offset; | ||
| 90 | std::function<size_t(size_t n, size_t k)> get_packed_rhs_size; | ||
| 91 | std::function<void( | ||
| 92 | size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||
| 93 | const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, | ||
| 94 | const struct kai_rhs_pack_qsi8cx_params* params)> | ||
| 95 | pack; | ||
| 96 | }; | ||
| 97 | |||
| 98 |
5/10✓ Branch 0 taken 49986 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 49986 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 49986 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 49986 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 49986 times.
✗ Branch 9 not taken.
|
49986 | struct RhsPackIndirectKernel { |
| 99 | std::function<size_t()> get_n_step; | ||
| 100 | std::function<size_t(size_t n_idx)> get_rhs_offset; | ||
| 101 | std::function<size_t(size_t n_idx)> get_bias_offset; | ||
| 102 | std::function<size_t(size_t n_idx)> get_scale_offset; | ||
| 103 | std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_packed_rhs_offset; | ||
| 104 | std::function<size_t(size_t n, size_t k_chunk_count, size_t k_chunk_length)> get_packed_rhs_size; | ||
| 105 | std::function<void( | ||
| 106 | size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride, const void* rhs, const void* bias, | ||
| 107 | const void* scale, void* rhs_packed, const kai_rhs_pack_qsi8cx_params* params)> | ||
| 108 | pack; | ||
| 109 | }; | ||
| 110 | |||
| 111 |
9/18✓ Branch 0 taken 12528 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12528 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12528 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 12528 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 12528 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12528 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 12528 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 12528 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12528 times.
✗ Branch 17 not taken.
|
12528 | struct MatMulKernel { |
| 112 | std::function<size_t(void)> get_m_step; | ||
| 113 | std::function<size_t(void)> get_n_step; | ||
| 114 | std::function<size_t(void)> get_mr; | ||
| 115 | std::function<size_t(void)> get_nr; | ||
| 116 | std::function<size_t(void)> get_kr; | ||
| 117 | std::function<size_t(void)> get_sr; | ||
| 118 | std::function<size_t(size_t m_idx, size_t k)> get_packed_lhs_offset; | ||
| 119 | std::function<size_t(size_t n_idx, size_t k)> get_packed_rhs_offset; | ||
| 120 | std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride)> get_dst_offset; | ||
| 121 | std::function<size_t(size_t m, size_t n)> get_dst_size; | ||
| 122 | std::function<void( | ||
| 123 | size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, | ||
| 124 | size_t dst_stride_col, const kai_matmul_requantize32_params* params)> | ||
| 125 | matmul; | ||
| 126 | }; | ||
| 127 | |||
| 128 |
5/10✓ Branch 0 taken 49986 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 49986 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 49986 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 49986 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 49986 times.
✗ Branch 9 not taken.
|
49986 | struct MatMulIndirectKernel { |
| 129 | std::function<size_t(void)> get_m_step; | ||
| 130 | std::function<size_t(void)> get_n_step; | ||
| 131 | std::function<size_t(size_t m_idx, size_t k_chunk_count, size_t k_chunk_length)> get_lhs_packed_offset; | ||
| 132 | std::function<size_t(size_t n_idx, size_t k_chunk_count, size_t k_chunk_length)> get_rhs_packed_offset; | ||
| 133 | std::function<size_t(size_t m_idx, size_t n_idx, size_t dst_stride_row)> get_dst_offset; | ||
| 134 | std::function<size_t(size_t m, size_t n)> get_dst_size; | ||
| 135 | std::function<void( | ||
| 136 | size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_lenght, const void* lhs_packed, const void* rhs_packed, | ||
| 137 | void* dst, size_t dst_stride_row, const kai_matmul_requantize32_params* params)> | ||
| 138 | imatmul; | ||
| 139 | }; | ||
| 140 | |||
| 141 | /// Make sure that interface matches for qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa | ||
| 142 | const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& | ||
| 143 | 3 | get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface() { | |
| 144 | static kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel ukernel; | ||
| 145 | |||
| 146 | 3 | ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 147 | 3 | ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 148 | 3 | ukernel.get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 149 | 3 | ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 150 | 3 | ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 151 | 3 | ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 152 | 3 | ukernel.get_lhs_packed_offset = | |
| 153 | kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | ||
| 154 | 3 | ukernel.get_rhs_packed_offset = | |
| 155 | kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | ||
| 156 | 3 | ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 157 | 3 | ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 158 | 3 | ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 159 | |||
| 160 | 3 | return ukernel; | |
| 161 | } | ||
| 162 | |||
| 163 | /// Make sure that interface matches for qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme_mopa | ||
| 164 | const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& | ||
| 165 | 3 | get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface() { | |
| 166 | static kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel ukernel; | ||
| 167 | |||
| 168 | 3 | ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 169 | 3 | ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 170 | 3 | ukernel.get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 171 | 3 | ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 172 | 3 | ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 173 | 3 | ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 174 | 3 | ukernel.get_lhs_packed_offset = | |
| 175 | kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | ||
| 176 | 3 | ukernel.get_rhs_packed_offset = | |
| 177 | kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | ||
| 178 | 3 | ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 179 | 3 | ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 180 | 3 | ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 181 | |||
| 182 | 3 | return ukernel; | |
| 183 | } | ||
| 184 | |||
| 185 | /// Make sure that interface matches for qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot | ||
| 186 | const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& | ||
| 187 | 3 | get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() { | |
| 188 | static kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; | ||
| 189 | |||
| 190 | 3 | ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 191 | 3 | ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 192 | 3 | ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 193 | 3 | ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 194 | 3 | ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 195 | 3 | ukernel.get_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 196 | 3 | ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 197 | 3 | ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 198 | 3 | ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 199 | 3 | ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; | |
| 200 | |||
| 201 | 3 | return ukernel; | |
| 202 | }; | ||
| 203 | |||
| 204 | /// Make sure that interface matches qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa | ||
| 205 | const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& | ||
| 206 | 3 | get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface() { | |
| 207 | static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; | ||
| 208 | |||
| 209 | 3 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 210 | 3 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 211 | 3 | ukernel.get_lhs_packed_offset = | |
| 212 | kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | ||
| 213 | 3 | ukernel.get_rhs_packed_offset = | |
| 214 | kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | ||
| 215 | 3 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 216 | 3 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 217 | 3 | ukernel.run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; | |
| 218 | |||
| 219 | 3 | return ukernel; | |
| 220 | }; | ||
| 221 | |||
| 222 | /// Make sure that interface matches qai8_qai8p2vlx4_qsi8cxps2vlx4b_2vlx2vl_sme_mopa | ||
| 223 | const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& | ||
| 224 | 3 | get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface() { | |
| 225 | static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; | ||
| 226 | |||
| 227 | 3 | ukernel.get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 228 | 3 | ukernel.get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 229 | 3 | ukernel.get_lhs_packed_offset = | |
| 230 | kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | ||
| 231 | 3 | ukernel.get_rhs_packed_offset = | |
| 232 | kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | ||
| 233 | 3 | ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 234 | 3 | ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 235 | 3 | ukernel.run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa; | |
| 236 | |||
| 237 | 3 | return ukernel; | |
| 238 | }; | ||
| 239 | |||
| 240 | 9 | const RhsPackKernel& get_rhs_pack() { | |
| 241 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
9 | static RhsPackKernel ukernel; |
| 242 | |||
| 243 | 9 | ukernel.get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 244 | 9 | ukernel.get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 245 | 9 | ukernel.get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 246 | 9 | ukernel.get_scale_offset = kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 247 | 9 | ukernel.get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 248 | 9 | ukernel.get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 249 | 9 | ukernel.pack = kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 250 | |||
| 251 | 9 | return ukernel; | |
| 252 | } | ||
| 253 | |||
| 254 | 6 | const LhsPackKernel& get_lhs_pack() { | |
| 255 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
6 | static LhsPackKernel ukernel; |
| 256 | |||
| 257 | 6 | ukernel.get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme; | |
| 258 | 6 | ukernel.get_lhs_offset = kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme; | |
| 259 | 6 | ukernel.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme; | |
| 260 | 6 | ukernel.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme; | |
| 261 | 6 | ukernel.pack = kai_run_lhs_pack_x8p2vlx4_x8_sme; | |
| 262 | |||
| 263 | 6 | return ukernel; | |
| 264 | } | ||
| 265 | |||
| 266 |
2/4✓ Branch 0 taken 12528 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12528 times.
✗ Branch 3 not taken.
|
12528 | struct MatMulVariant { |
| 267 | std::string_view name; ///< Test identification | ||
| 268 | MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr) | ||
| 269 | MatMulShape acc_step; ///< Accumulator shape for matmul (stepping) | ||
| 270 | |||
| 271 | std::function<bool(void)> is_supported; ///< HW support check | ||
| 272 | |||
| 273 | std::optional<LhsPackKernel> lhs_pack; ///< LHS packing micro-kernel interface | ||
| 274 | RhsPackKernel rhs_pack; ///< RHS packing micro-kernel interface | ||
| 275 | MatMulKernel matmul; ///< Matmul kernel interface | ||
| 276 | }; | ||
| 277 | |||
| 278 |
2/4✓ Branch 0 taken 49986 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 49986 times.
✗ Branch 3 not taken.
|
49986 | struct IndirectMatMulVariant { |
| 279 | std::string_view name; ///< Test identification | ||
| 280 | MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr) | ||
| 281 | MatMulShape acc_step; ///< Accumulator shape for matmul (stepping) | ||
| 282 | |||
| 283 | std::function<bool(void)> is_supported; ///< HW support check | ||
| 284 | |||
| 285 | LhsPackIndirectKernel lhs_pack; ///< LHS packing micro-kernel interface | ||
| 286 | RhsPackIndirectKernel rhs_pack; ///< RHS packing micro-kernel interface | ||
| 287 | MatMulIndirectKernel matmul; ///< Matmul kernel interface | ||
| 288 | }; | ||
| 289 | |||
| 290 | 3 | const auto& get_gemm_variants() { | |
| 291 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
3 | static std::array<MatMulVariant, 2> variants; |
| 292 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
6 | static const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& ukernel_sme2 = |
| 293 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface(); |
| 294 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
6 | static const kai_matmul_clamp_qai8_qai8p_qsi8cxpsb_ukernel& ukernel_sme = |
| 295 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
|
3 | get_matmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface(); |
| 296 | |||
| 297 | 3 | variants[0].name = "matmul_qai8_qai8p_qsi8cxp_sme"; | |
| 298 | 3 | variants[0].acc_pack.m = 2 * get_sme_vector_length<int32_t>(); | |
| 299 | 3 | variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
| 300 | 3 | variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
| 301 | 3 | variants[0].acc_step.m = 2 * get_sme_vector_length<int32_t>(); | |
| 302 | 3 | variants[0].acc_step.n = 2 * get_sme_vector_length<int32_t>(); | |
| 303 | 3 | variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
| 304 | 3 | variants[0].is_supported = cpu_has_sme; | |
| 305 | 3 | variants[0].lhs_pack = get_lhs_pack(); | |
| 306 | 3 | variants[0].rhs_pack = get_rhs_pack(); | |
| 307 | 3 | variants[0].matmul.get_m_step = ukernel_sme.get_m_step; | |
| 308 | 3 | variants[0].matmul.get_n_step = ukernel_sme.get_n_step; | |
| 309 | 3 | variants[0].matmul.get_mr = ukernel_sme.get_mr; | |
| 310 | 3 | variants[0].matmul.get_nr = ukernel_sme.get_nr; | |
| 311 | 3 | variants[0].matmul.get_kr = ukernel_sme.get_kr; | |
| 312 | 3 | variants[0].matmul.get_sr = ukernel_sme.get_sr; | |
| 313 | 3 | variants[0].matmul.get_packed_lhs_offset = ukernel_sme.get_lhs_packed_offset; | |
| 314 | 3 | variants[0].matmul.get_packed_rhs_offset = ukernel_sme.get_rhs_packed_offset; | |
| 315 | 3 | variants[0].matmul.get_dst_offset = ukernel_sme.get_dst_offset; | |
| 316 | 3 | variants[0].matmul.get_dst_size = ukernel_sme.get_dst_size; | |
| 317 | 3 | variants[0].matmul.matmul = ukernel_sme.run_matmul; | |
| 318 | |||
| 319 | 3 | variants[1].name = "matmul_qai8_qai8p_qsi8cxp_sme2"; | |
| 320 | 3 | variants[1].acc_pack.m = 2 * get_sme_vector_length<int32_t>(); | |
| 321 | 3 | variants[1].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
| 322 | 3 | variants[1].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
| 323 | 3 | variants[1].acc_step.m = 2 * get_sme_vector_length<int32_t>(); | |
| 324 | 3 | variants[1].acc_step.n = 2 * get_sme_vector_length<int32_t>(); | |
| 325 | 3 | variants[1].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
| 326 | 3 | variants[1].is_supported = cpu_has_sme2; | |
| 327 | 3 | variants[1].lhs_pack = get_lhs_pack(); | |
| 328 | 3 | variants[1].rhs_pack = get_rhs_pack(); | |
| 329 | 3 | variants[1].matmul.get_m_step = ukernel_sme2.get_m_step; | |
| 330 | 3 | variants[1].matmul.get_n_step = ukernel_sme2.get_n_step; | |
| 331 | 3 | variants[1].matmul.get_mr = ukernel_sme2.get_mr; | |
| 332 | 3 | variants[1].matmul.get_nr = ukernel_sme2.get_nr; | |
| 333 | 3 | variants[1].matmul.get_kr = ukernel_sme2.get_kr; | |
| 334 | 3 | variants[1].matmul.get_sr = ukernel_sme2.get_sr; | |
| 335 | 3 | variants[1].matmul.get_packed_lhs_offset = ukernel_sme2.get_lhs_packed_offset; | |
| 336 | 3 | variants[1].matmul.get_packed_rhs_offset = ukernel_sme2.get_rhs_packed_offset; | |
| 337 | 3 | variants[1].matmul.get_dst_offset = ukernel_sme2.get_dst_offset; | |
| 338 | 3 | variants[1].matmul.get_dst_size = ukernel_sme2.get_dst_size; | |
| 339 | 3 | variants[1].matmul.matmul = ukernel_sme2.run_matmul; | |
| 340 | |||
| 341 | 3 | return variants; | |
| 342 | ✗ | } | |
| 343 | |||
| 344 | 9 | const auto& get_indirect_gemm_variants() { | |
| 345 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
9 | static std::array<IndirectMatMulVariant, 2> variants; |
| 346 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
12 | static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel_sme = |
| 347 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxp2vlx4sb_2vlx2vl_sme_mopa_interface(); |
| 348 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
12 | static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel_sme2 = |
| 349 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
|
3 | get_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa_interface(); |
| 350 | |||
| 351 | 9 | variants[0].name = "imatmul_qai8_qai8p_qsi8cxp_sme"; | |
| 352 | 9 | variants[0].acc_pack.m = 2 * get_sme_vector_length<int32_t>(); | |
| 353 | 9 | variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
| 354 | 9 | variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
| 355 | 9 | variants[0].acc_step.m = 2 * get_sme_vector_length<int32_t>(); | |
| 356 | 9 | variants[0].acc_step.n = 2 * get_sme_vector_length<int32_t>(); | |
| 357 | 9 | variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
| 358 | 9 | variants[0].is_supported = cpu_has_sme; | |
| 359 | 9 | variants[0].lhs_pack.get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
| 360 | 9 | variants[0].lhs_pack.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
| 361 | 9 | variants[0].lhs_pack.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
| 362 | 9 | variants[0].lhs_pack.pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
| 363 | 9 | variants[0].rhs_pack.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 364 | 9 | variants[0].rhs_pack.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 365 | 9 | variants[0].rhs_pack.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 366 | 9 | variants[0].rhs_pack.get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 367 | 9 | variants[0].rhs_pack.get_packed_rhs_offset = | |
| 368 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | ||
| 369 | 9 | variants[0].rhs_pack.get_packed_rhs_size = | |
| 370 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | ||
| 371 | 9 | variants[0].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 372 | 9 | variants[0].matmul.get_m_step = ukernel_sme.get_m_step; | |
| 373 | 9 | variants[0].matmul.get_n_step = ukernel_sme.get_n_step; | |
| 374 | 9 | variants[0].matmul.get_lhs_packed_offset = ukernel_sme.get_lhs_packed_offset; | |
| 375 | 9 | variants[0].matmul.get_rhs_packed_offset = ukernel_sme.get_rhs_packed_offset; | |
| 376 | 9 | variants[0].matmul.get_dst_offset = ukernel_sme.get_dst_offset; | |
| 377 | 9 | variants[0].matmul.get_dst_size = ukernel_sme.get_dst_size; | |
| 378 | 9 | variants[0].matmul.imatmul = ukernel_sme.run_imatmul; | |
| 379 | |||
| 380 | 9 | variants[1].name = "imatmul_qai8_qai8p_qsi8cxp_sme2"; | |
| 381 | 9 | variants[1].acc_pack.m = 2 * get_sme_vector_length<int32_t>(); | |
| 382 | 9 | variants[1].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
| 383 | 9 | variants[1].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
| 384 | 9 | variants[1].acc_step.m = 2 * get_sme_vector_length<int32_t>(); | |
| 385 | 9 | variants[1].acc_step.n = 2 * get_sme_vector_length<int32_t>(); | |
| 386 | 9 | variants[1].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
| 387 | 9 | variants[1].is_supported = cpu_has_sme2; | |
| 388 | 9 | variants[1].lhs_pack.get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
| 389 | 9 | variants[1].lhs_pack.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
| 390 | 9 | variants[1].lhs_pack.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
| 391 | 9 | variants[1].lhs_pack.pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme; | |
| 392 | 9 | variants[1].rhs_pack.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 393 | 9 | variants[1].rhs_pack.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 394 | 9 | variants[1].rhs_pack.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 395 | 9 | variants[1].rhs_pack.get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 396 | 9 | variants[1].rhs_pack.get_packed_rhs_offset = | |
| 397 | kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | ||
| 398 | 9 | variants[1].rhs_pack.get_packed_rhs_size = | |
| 399 | kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | ||
| 400 | 9 | variants[1].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; | |
| 401 | 9 | variants[1].matmul.get_m_step = ukernel_sme2.get_m_step; | |
| 402 | 9 | variants[1].matmul.get_n_step = ukernel_sme2.get_n_step; | |
| 403 | 9 | variants[1].matmul.get_lhs_packed_offset = ukernel_sme2.get_lhs_packed_offset; | |
| 404 | 9 | variants[1].matmul.get_rhs_packed_offset = ukernel_sme2.get_rhs_packed_offset; | |
| 405 | 9 | variants[1].matmul.get_dst_offset = ukernel_sme2.get_dst_offset; | |
| 406 | 9 | variants[1].matmul.get_dst_size = ukernel_sme2.get_dst_size; | |
| 407 | 9 | variants[1].matmul.imatmul = ukernel_sme2.run_imatmul; | |
| 408 | |||
| 409 | 9 | return variants; | |
| 410 | ✗ | } | |
| 411 | |||
| 412 | 3 | const auto& get_gemv_variants() { | |
| 413 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
3 | static std::array<MatMulVariant, 1> variants; |
| 414 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
6 | static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel = |
| 415 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface(); |
| 416 | |||
| 417 | 3 | variants[0].name = "matmul_qai8_qai8_qsi8cxp"; | |
| 418 | 3 | variants[0].acc_pack.m = 1; | |
| 419 | 3 | variants[0].acc_pack.n = 2 * get_sme_vector_length<int32_t>(); | |
| 420 | 3 | variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); | |
| 421 | 3 | variants[0].acc_step.m = 1; | |
| 422 | 3 | variants[0].acc_step.n = 16 * get_sme_vector_length<int32_t>(); | |
| 423 | 3 | variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t); | |
| 424 | 3 | variants[0].is_supported = cpu_has_sme2; | |
| 425 | 3 | variants[0].lhs_pack = std::nullopt; | |
| 426 | 3 | variants[0].rhs_pack = get_rhs_pack(); | |
| 427 | 3 | variants[0].matmul.get_m_step = ukernel.get_m_step; | |
| 428 | 3 | variants[0].matmul.get_n_step = ukernel.get_n_step; | |
| 429 | 171 | variants[0].matmul.get_mr = []() -> size_t { return 1; }; | |
| 430 | 3 | variants[0].matmul.get_nr = ukernel.get_nr; | |
| 431 | 3 | variants[0].matmul.get_kr = ukernel.get_kr; | |
| 432 | 3 | variants[0].matmul.get_sr = ukernel.get_sr; | |
| 433 | 3 | variants[0].matmul.get_packed_lhs_offset = nullptr; | |
| 434 | 3 | variants[0].matmul.get_packed_rhs_offset = ukernel.get_rhs_packed_offset; | |
| 435 | 3 | variants[0].matmul.get_dst_offset = ukernel.get_dst_offset; | |
| 436 | 3 | variants[0].matmul.get_dst_size = ukernel.get_dst_size; | |
| 437 | 3 | variants[0].matmul.matmul = ukernel.run_matmul; | |
| 438 | |||
| 439 | 3 | return variants; | |
| 440 | ✗ | } | |
| 441 | |||
| 442 | /// Quantization parameters | ||
| 443 | struct Quant { | ||
| 444 | float scale; | ||
| 445 | int32_t zero_point; | ||
| 446 | }; | ||
| 447 | |||
| 448 | /// Reference test data | ||
| 449 | struct TestReference { | ||
| 450 | Range<int8_t> clamp; | ||
| 451 | |||
| 452 | Quant qa_lhs; | ||
| 453 | Quant qa_dst; | ||
| 454 | |||
| 455 | Buffer lhs_qai8; | ||
| 456 | Buffer lhs_qai8_scales; | ||
| 457 | Buffer lhs_qai8_zero_points; | ||
| 458 | Buffer lhs_qai8_indirect; | ||
| 459 | Buffer lhs_qai8_indirect_packed; | ||
| 460 | Buffer lhs_qai8_indirect_padding; | ||
| 461 | size_t lhs_qai8_indirect_offset; | ||
| 462 | |||
| 463 | Buffer rhs_qsi8; | ||
| 464 | Buffer rhs_scales; | ||
| 465 | |||
| 466 | Buffer bias_qsi32; | ||
| 467 | |||
| 468 | Buffer dst_qsi8_clamped; | ||
| 469 | |||
| 470 | Buffer packed_lhs; | ||
| 471 | Buffer packed_rhs; | ||
| 472 | }; | ||
| 473 | |||
| 474 | constexpr int8_t padding_value = 0; | ||
| 475 | |||
| 476 | // Functionality for hashing generated test data. | ||
| 477 | // This is particularly useful for portion testing | ||
| 478 | // which reuses the exact same data for all portions | ||
| 479 | struct TestDataId { | ||
| 480 | MatMulShape shape; | ||
| 481 | MatMulShape shape_pack; | ||
| 482 | size_t chunk_len; | ||
| 483 | bool pad_testing; | ||
| 484 | float clamp_keep_ratio; | ||
| 485 | |||
| 486 | struct Hash { | ||
| 487 | 4685 | size_t operator()(const TestDataId& id) const { | |
| 488 | 4685 | return // | |
| 489 | 9370 | (MatMulShape::Hash{}(id.shape) << 0) ^ // | |
| 490 | 9370 | (MatMulShape::Hash{}(id.shape_pack) << 1) ^ // | |
| 491 | 9370 | (std::hash<size_t>{}(id.chunk_len) << 2) ^ // | |
| 492 | 9370 | (std::hash<bool>{}(id.pad_testing) << 3) ^ // | |
| 493 | 4685 | (std::hash<float>{}(id.clamp_keep_ratio) << 4); | |
| 494 | } | ||
| 495 | }; | ||
| 496 | |||
| 497 | private: | ||
| 498 | 4489 | friend bool operator==(const TestDataId& lhs, const TestDataId& rhs) { | |
| 499 | 4489 | return // | |
| 500 |
2/2✓ Branch 0 taken 3659 times.
✓ Branch 1 taken 830 times.
|
4489 | lhs.shape == rhs.shape && // |
| 501 |
1/2✓ Branch 0 taken 3659 times.
✗ Branch 1 not taken.
|
3659 | lhs.shape_pack == rhs.shape_pack && // |
| 502 |
2/2✓ Branch 0 taken 3643 times.
✓ Branch 1 taken 16 times.
|
3659 | lhs.chunk_len == rhs.chunk_len && // |
| 503 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 3643 times.
|
3643 | lhs.pad_testing == rhs.pad_testing && // |
| 504 | 3643 | lhs.clamp_keep_ratio == rhs.clamp_keep_ratio; | |
| 505 | } | ||
| 506 | }; | ||
| 507 | |||
| 508 | // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 509 | 3 | std::unordered_map<TestDataId, TestReference, TestDataId::Hash> g_data; | |
| 510 | // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) | ||
| 511 | |||
| 512 | /// Generate test reference data | ||
| 513 | 4164 | const TestReference& get_test_reference(const TestDataId& test_data_id) { | |
| 514 | // ============================================================ | ||
| 515 | // Generates input and reference output data | ||
| 516 | // ============================================================ | ||
| 517 | |||
| 518 | // Attempt to find test data in cache | ||
| 519 | 4164 | const auto data_it = g_data.find(test_data_id); | |
| 520 |
2/2✓ Branch 0 taken 3643 times.
✓ Branch 1 taken 521 times.
|
4164 | if (data_it != g_data.end()) { |
| 521 | 3643 | return data_it->second; | |
| 522 | } | ||
| 523 | |||
| 524 | 5210 | const auto& [shape, pack_shape, k_chunk_len, pad_testing, clamp_keep_ratio] = test_data_id; | |
| 525 | |||
| 526 | // Seed the random generator. | ||
| 527 |
8/18✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 521 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 521 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 521 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 521 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 521 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
|
1563 | const auto key = std::string("Qai8Qai8Qsi8_cache:") + std::to_string(shape.m) + "x" + std::to_string(shape.n) + |
| 528 |
7/14✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 521 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 521 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 521 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 521 times.
✗ Branch 13 not taken.
|
1042 | "x" + std::to_string(shape.k) + ":" + std::to_string(clamp_keep_ratio); |
| 529 |
1/2✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
|
521 | auto& feed = seed_stream(key); |
| 530 | |||
| 531 | // Generates the input data in floating-point. | ||
| 532 |
4/8✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 521 times.
✗ Branch 7 not taken.
|
1563 | Buffer lhs_f32 = fill_random<float>(shape.m * shape.k, feed()); |
| 533 |
4/8✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 521 times.
✗ Branch 7 not taken.
|
1563 | const Buffer rhs_f32 = fill_random<float>(shape.k * shape.n, feed()); |
| 534 |
3/6✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 521 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 521 times.
✗ Branch 5 not taken.
|
1042 | const Buffer bias_f32 = fill_random<float>(shape.n, feed()); |
| 535 | |||
| 536 | // Quantizes the input data. | ||
| 537 | // * LHS: 8-bit asymmetric per-matrix quantization. | ||
| 538 | // * RHS: 8-bit symmetric per-channel quantization. | ||
| 539 | // * Bias: 32-bit symmetric per-channel quantization. | ||
| 540 | |||
| 541 |
1/2✓ Branch 0 taken 521 times.
✗ Branch 1 not taken.
|
521 | QuantizationInfo lhs_qinfo{}; |
| 542 | lhs_qinfo.quant_width = shape.m * shape.k; | ||
| 543 | lhs_qinfo.dst_type = DataType::QAI8; | ||
| 544 | lhs_qinfo.scale_type = DataType::FP32; | ||
| 545 | lhs_qinfo.zero_point_type = DataType::I32; | ||
| 546 | auto [lhs_ref_quant, lhs_qoutputs] = | ||
| 547 | quantize_dynamic(lhs_f32.data(), DataType::FP32, 1, shape.m * shape.k, lhs_qinfo); | ||
| 548 | const auto lhs_scale = read_array<float>(lhs_qoutputs.scales.data(), 0); | ||
| 549 | const auto lhs_zero_point = read_array<int32_t>(lhs_qoutputs.zero_points.data(), 0); | ||
| 550 | |||
| 551 | const size_t k_chunk_count = shape.k / k_chunk_len; | ||
| 552 | assert(k_chunk_count * k_chunk_len == shape.k); | ||
| 553 | |||
| 554 | // Setup an indirection buffer, where each "row" contains `k_chunk_count` | ||
| 555 | // pointers to chunks of length `k_chunk_len` in the input_buffer | ||
| 556 | Buffer lhs_qai8_indirect(shape.m * k_chunk_count * sizeof(void*)); | ||
| 557 | Buffer lhs_padding(k_chunk_len, padding_value); | ||
| 558 | auto* lhs_qai8_indirect_ptr = reinterpret_cast<uint8_t**>(lhs_qai8_indirect.data()); | ||
| 559 | for (size_t m_i = 0; m_i < shape.m; ++m_i) { | ||
| 560 | for (size_t k_chunk_idx = 0; k_chunk_idx < k_chunk_count; ++k_chunk_idx) { | ||
| 561 | const size_t idx = m_i * k_chunk_count + k_chunk_idx; | ||
| 562 | if (pad_testing and m_i == 0) { | ||
| 563 | // Push padding pointers for first row | ||
| 564 | lhs_qai8_indirect_ptr[idx] = reinterpret_cast<uint8_t*>(lhs_padding.data()); | ||
| 565 | } else { | ||
| 566 | uintptr_t offset = m_i * shape.k + k_chunk_idx * k_chunk_len; | ||
| 567 | lhs_qai8_indirect_ptr[idx] = reinterpret_cast<uint8_t*>(offset); | ||
| 568 | } | ||
| 569 | } | ||
| 570 | } | ||
| 571 | const auto indirection_base = reinterpret_cast<uintptr_t>(lhs_ref_quant.data()); | ||
| 572 | |||
| 573 | // Reorder indirection pointers to layout the packing micro-kernel expects | ||
| 574 | Buffer lhs_qai8_indirect_packed = reorder_block<const void*>( | ||
| 575 | reinterpret_cast<const void*>(lhs_qai8_indirect.data()), shape.m, k_chunk_count, pack_shape.m, 1); | ||
| 576 | |||
| 577 | // Transpose, then quantize symmetrically, then transpose back. This will give one | ||
| 578 | // quantization value for each column | ||
| 579 | const auto rhs_f32_t = transpose<float>(rhs_f32.data(), shape.k, shape.n); | ||
| 580 | |||
| 581 | QuantizationInfo rhs_qinfo{}; | ||
| 582 | rhs_qinfo.quant_width = shape.k; | ||
| 583 | rhs_qinfo.dst_type = DataType::QSI8; | ||
| 584 | rhs_qinfo.scale_type = DataType::FP32; | ||
| 585 | auto [rhs_ref_quant_t, rhs_qoutputs] = | ||
| 586 | quantize_dynamic(rhs_f32_t.data(), DataType::FP32, shape.n, shape.k, rhs_qinfo); | ||
| 587 | auto rhs_qsi8 = transpose<int8_t>(rhs_ref_quant_t.data(), shape.n, shape.k); | ||
| 588 | |||
| 589 | // Multiply all bias values with the LHS scale | ||
| 590 | const auto bias_scales = mul<float>(&lhs_scale, 1, 1, rhs_qoutputs.scales.data(), 1, shape.n); | ||
| 591 | // Calculate quantized bias values, by treating bias as column, and | ||
| 592 | // scale using RHS scales. This will scale each bias value indiviually | ||
| 593 | auto bias_qsi32 = | ||
| 594 | quantize_symmetric_per_block<float, int32_t, float>(bias_f32.data(), bias_scales.data(), shape.n, 1, 1); | ||
| 595 | |||
| 596 | // Runs the reference implementation of matmul to produce floating-point result. | ||
| 597 | const void* const* lhs_iptr = reinterpret_cast<const void* const*>(lhs_qai8_indirect.data()); | ||
| 598 | const auto ref_dst_f32 = | ||
| 599 | indirect_matmul_nt_t_quantized<int8_t, float, int32_t, int8_t, float, int32_t, int32_t, float, int32_t, float>( | ||
| 600 | shape.m, shape.n, k_chunk_count, k_chunk_len, // matmul shape | ||
| 601 | lhs_iptr, indirection_base, lhs_padding.data(), // LHS indirection, offset and padding | ||
| 602 | &lhs_scale, &lhs_zero_point, // LHS, scaling factor and zero point | ||
| 603 | shape.m, shape.k, // LHS quantization window shape | ||
| 604 | rhs_ref_quant_t.data(), rhs_qoutputs.scales.data(), nullptr, // RHS scaling factors | ||
| 605 | 1, shape.k, // RHS quantization window shape | ||
| 606 | bias_qsi32.data(), bias_scales.data(), nullptr, // Bias, scaling and zero points | ||
| 607 | 1 // Bias quantization window shape | ||
| 608 | ); | ||
| 609 | |||
| 610 | // Computes the output quantization information and clamping limits. | ||
| 611 | // | ||
| 612 | // To get a realistic value for the output quantization information and clamping limits | ||
| 613 | // and avoid uncontrolled saturation problem, these information will be calculated | ||
| 614 | // based on the reference floating-point output. | ||
| 615 | // | ||
| 616 | // The clamping limits will be slightly narrower than the actual range of the output | ||
| 617 | // so that a portion of the output will be clampped. | ||
| 618 | const auto [dst_scales, dst_zero_points] = | ||
| 619 | compute_asymmetric_per_block_quantization_info<float, int8_t, float, int32_t>( | ||
| 620 | ref_dst_f32.data(), 1, shape.m * shape.n, shape.m * shape.n); | ||
| 621 | const auto dst_scale = read_array<float>(dst_scales.data(), 0); | ||
| 622 | const auto dst_zero_point = read_array<int32_t>(dst_zero_points.data(), 0); | ||
| 623 | |||
| 624 | const auto ref_dst_f32_min = reduce_min<float>(ref_dst_f32.data(), shape.m * shape.n); | ||
| 625 | const auto ref_dst_f32_max = reduce_max<float>(ref_dst_f32.data(), shape.m * shape.n); | ||
| 626 | const auto ref_dst_f32_range = ref_dst_f32_max - ref_dst_f32_min; | ||
| 627 | |||
| 628 | const auto ref_dst_f32_clamp_min = ref_dst_f32_min + ref_dst_f32_range * (1.0F - clamp_keep_ratio) / 2; | ||
| 629 | const auto ref_dst_f32_clamp_max = ref_dst_f32_max - ref_dst_f32_range * (1.0F - clamp_keep_ratio) / 2; | ||
| 630 | const auto dst_qai8_clamp_min = | ||
| 631 | quantize_asymmetric<float, int8_t, int32_t>(ref_dst_f32_clamp_min, dst_scale, dst_zero_point); | ||
| 632 | const auto dst_qai8_clamp_max = | ||
| 633 | quantize_asymmetric<float, int8_t, int32_t>(ref_dst_f32_clamp_max, dst_scale, dst_zero_point); | ||
| 634 | |||
| 635 | // Clamps and quantizes the reference output matrix. | ||
| 636 | const auto ref_dst_f32_clamped = | ||
| 637 | clamp<float>(ref_dst_f32.data(), shape.m * shape.n, ref_dst_f32_clamp_min, ref_dst_f32_clamp_max); | ||
| 638 | auto ref_dst_qsi8_clamped = quantize_asymmetric_per_block<float, int8_t, float, int32_t>( | ||
| 639 | ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, // values, scales, zero point | ||
| 640 | 1, shape.m * shape.n, // data shape | ||
| 641 | shape.m * shape.n // quantization window width | ||
| 642 | ); | ||
| 643 | |||
| 644 | // Runs the reference implementation of the packing micro-kernels. | ||
| 645 | // | ||
| 646 | // The reference packing micro-kernels cannot be executed earlier | ||
| 647 | // because we need the reference floating-point output first to have | ||
| 648 | // the quantization information. | ||
| 649 | auto packed_lhs = reorder_block<int8_t>(lhs_ref_quant.data(), shape.m, shape.k, pack_shape.m, pack_shape.k); | ||
| 650 | auto packed_rhs = matmul_pack_rhs_nxk_static_quantized<int8_t, float, int32_t>( | ||
| 651 | rhs_ref_quant_t.data(), rhs_qoutputs.scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, | ||
| 652 | shape.n, shape.k, pack_shape.n, pack_shape.k); | ||
| 653 | |||
| 654 | TestReference& reference = g_data[test_data_id]; | ||
| 655 | reference.clamp.min = dst_qai8_clamp_min; | ||
| 656 | reference.clamp.max = dst_qai8_clamp_max; | ||
| 657 | reference.qa_lhs.scale = lhs_scale; | ||
| 658 | reference.qa_lhs.zero_point = lhs_zero_point; | ||
| 659 | reference.qa_dst.scale = dst_scale; | ||
| 660 | reference.qa_dst.zero_point = dst_zero_point; | ||
| 661 | reference.lhs_qai8 = std::move(lhs_ref_quant); | ||
| 662 | reference.lhs_qai8_scales = std::move(lhs_qoutputs.scales); | ||
| 663 | reference.lhs_qai8_zero_points = std::move(lhs_qoutputs.zero_points); | ||
| 664 | reference.lhs_qai8_indirect = std::move(lhs_qai8_indirect); | ||
| 665 | reference.lhs_qai8_indirect_packed = std::move(lhs_qai8_indirect_packed); | ||
| 666 | reference.lhs_qai8_indirect_padding = std::move(lhs_padding); | ||
| 667 | reference.lhs_qai8_indirect_offset = indirection_base; | ||
| 668 | reference.rhs_qsi8 = std::move(rhs_qsi8); | ||
| 669 | reference.rhs_scales = std::move(rhs_qoutputs.scales); | ||
| 670 | reference.bias_qsi32 = std::move(bias_qsi32); | ||
| 671 | reference.dst_qsi8_clamped = std::move(ref_dst_qsi8_clamped); | ||
| 672 | reference.packed_lhs = std::move(packed_lhs); | ||
| 673 | reference.packed_rhs = std::move(packed_rhs); | ||
| 674 | |||
| 675 | return reference; | ||
| 676 | ✗ | } | |
| 677 | |||
| 678 | /// Test LHS packing | ||
| 679 | 666 | void test_lhs_pack( | |
| 680 | const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { | ||
| 681 | − | KAI_ASSUME_ALWAYS(variant.lhs_pack.has_value()); | |
| 682 | |||
| 683 | 1332 | const auto imp_packed_lhs_size = | |
| 684 | 666 | variant.lhs_pack->get_packed_lhs_size(shape.m, shape.k, variant.acc_pack.m, variant.acc_pack.k, 1); | |
| 685 |
2/12✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 666 times.
|
666 | ASSERT_EQ(imp_packed_lhs_size, reference.packed_lhs.size()); |
| 686 | |||
| 687 | 666 | Buffer imp_packed_lhs(imp_packed_lhs_size, 0); | |
| 688 |
2/4✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 666 times.
✗ Branch 3 not taken.
|
666 | const auto imp_lhs_offset = variant.lhs_pack->get_lhs_offset(output_area.start_row(), shape.k * sizeof(int8_t)); |
| 689 |
1/2✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
|
666 | const auto imp_packed_lhs_offset = variant.lhs_pack->get_packed_lhs_offset( |
| 690 |
1/2✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
|
666 | output_area.start_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1); |
| 691 | |||
| 692 |
1/2✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
|
666 | abi_check( |
| 693 |
1/2✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
|
666 | variant.lhs_pack->pack, output_area.height(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1, 0, |
| 694 | 666 | reference.lhs_qai8.data() + imp_lhs_offset, shape.k * sizeof(int8_t), | |
| 695 | 666 | imp_packed_lhs.data() + imp_packed_lhs_offset); | |
| 696 | |||
| 697 |
3/4✓ Branch 0 taken 666 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 42 times.
✓ Branch 3 taken 624 times.
|
708 | const auto imp_packed_lhs_end_offset = output_area.end_row() < shape.m |
| 698 |
1/2✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
|
42 | ? variant.lhs_pack->get_packed_lhs_offset( |
| 699 |
1/2✓ Branch 0 taken 42 times.
✗ Branch 1 not taken.
|
42 | output_area.end_row(), shape.k, variant.acc_pack.m, variant.acc_pack.k, 1) |
| 700 | 624 | : imp_packed_lhs_size; | |
| 701 | |||
| 702 | 666 | const auto* imp_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_lhs.data()); | |
| 703 | 666 | const auto* ref_packed_lhs_ptr = reinterpret_cast<const uint8_t*>(reference.packed_lhs.data()); | |
| 704 | |||
| 705 |
4/6✓ Branch 0 taken 680346 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 679680 times.
✓ Branch 3 taken 666 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 666 times.
|
680346 | for (size_t i = 0; i < reference.packed_lhs.size(); ++i) { |
| 706 |
4/4✓ Branch 0 taken 651264 times.
✓ Branch 1 taken 28416 times.
✓ Branch 2 taken 42240 times.
✓ Branch 3 taken 609024 times.
|
679680 | if (i >= imp_packed_lhs_offset && i < imp_packed_lhs_end_offset) { |
| 707 |
3/14✗ Branch 0 not taken.
✓ Branch 1 taken 609024 times.
✓ Branch 2 taken 609024 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 609024 times.
|
609024 | ASSERT_EQ(imp_packed_lhs_ptr[i], ref_packed_lhs_ptr[i]); |
| 708 | 609024 | } else { | |
| 709 |
3/14✗ Branch 0 not taken.
✓ Branch 1 taken 70656 times.
✓ Branch 2 taken 70656 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 70656 times.
|
70656 | ASSERT_EQ(imp_packed_lhs_ptr[i], 0); |
| 710 | } | ||
| 711 | 679680 | } | |
| 712 | 666 | } | |
| 713 | |||
| 714 | /// Test RHS packing | ||
| 715 | 834 | void test_rhs_pack( | |
| 716 | const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { | ||
| 717 | 834 | const auto imp_packed_rhs_size = variant.rhs_pack.get_packed_rhs_size(shape.n, shape.k); | |
| 718 |
2/12✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 834 times.
|
834 | ASSERT_EQ(imp_packed_rhs_size, reference.packed_rhs.size()); |
| 719 | 834 | Buffer imp_packed_rhs(imp_packed_rhs_size, 0); | |
| 720 | |||
| 721 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const auto imp_rhs_offset = variant.rhs_pack.get_rhs_offset(output_area.start_col()); |
| 722 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const auto imp_bias_offset = variant.rhs_pack.get_bias_offset(output_area.start_col()); |
| 723 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const auto imp_scale_offset = variant.rhs_pack.get_scale_offset(output_area.start_col()); |
| 724 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const auto imp_packed_rhs_offset = variant.rhs_pack.get_packed_rhs_offset(output_area.start_col(), shape.k); |
| 725 | |||
| 726 | 834 | kai_rhs_pack_qsi8cx_params imp_pack_rhs_params{}; | |
| 727 | 834 | imp_pack_rhs_params.lhs_zero_point = reference.qa_lhs.zero_point; | |
| 728 | 834 | imp_pack_rhs_params.scale_multiplier = reference.qa_lhs.scale / reference.qa_dst.scale; | |
| 729 | |||
| 730 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
834 | abi_check( |
| 731 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
834 | variant.rhs_pack.pack, 1, output_area.width(), shape.k, variant.acc_pack.n, variant.acc_pack.k, 1, |
| 732 | 834 | shape.n * sizeof(int8_t), reference.rhs_qsi8.data() + imp_rhs_offset, | |
| 733 | 834 | reference.bias_qsi32.data() + imp_bias_offset, reference.rhs_scales.data() + imp_scale_offset, | |
| 734 | 834 | imp_packed_rhs.data() + imp_packed_rhs_offset, 0, &imp_pack_rhs_params); | |
| 735 | |||
| 736 |
3/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 150 times.
✓ Branch 3 taken 684 times.
|
984 | const auto imp_packed_rhs_end_offset = output_area.end_col() < shape.n |
| 737 |
2/4✓ Branch 0 taken 150 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 150 times.
✗ Branch 3 not taken.
|
150 | ? variant.rhs_pack.get_packed_rhs_offset(output_area.end_col(), shape.k) |
| 738 | 684 | : imp_packed_rhs_size; | |
| 739 | |||
| 740 | 834 | size_t mismatches = 0; | |
| 741 | 834 | const auto* imp_packed_rhs_ptr = reinterpret_cast<const uint8_t*>(imp_packed_rhs.data()); | |
| 742 | 834 | const auto* ref_packed_rhs_ptr = reinterpret_cast<const uint8_t*>(reference.packed_rhs.data()); | |
| 743 | |||
| 744 |
2/2✓ Branch 0 taken 3380736 times.
✓ Branch 1 taken 834 times.
|
3381570 | for (size_t i = 0; i < reference.packed_rhs.size(); ++i) { |
| 745 |
4/4✓ Branch 0 taken 2843136 times.
✓ Branch 1 taken 537600 times.
✓ Branch 2 taken 690816 times.
✓ Branch 3 taken 2152320 times.
|
3380736 | if (i >= imp_packed_rhs_offset && i < imp_packed_rhs_end_offset) { |
| 746 |
1/2✓ Branch 0 taken 2152320 times.
✗ Branch 1 not taken.
|
2152320 | if (imp_packed_rhs_ptr[i] != ref_packed_rhs_ptr[i]) { |
| 747 | ✗ | mismatches += 1; | |
| 748 | ✗ | } | |
| 749 | 2152320 | } else { | |
| 750 |
1/2✓ Branch 0 taken 1228416 times.
✗ Branch 1 not taken.
|
1228416 | if (imp_packed_rhs_ptr[i] != 0) { |
| 751 | ✗ | mismatches += 1; | |
| 752 | ✗ | } | |
| 753 | } | ||
| 754 | 3380736 | } | |
| 755 |
3/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
834 | ASSERT_EQ(mismatches, 0) << "There are an unexpected amount of mismatches in RHS packing"; |
| 756 | 834 | } | |
| 757 | |||
| 758 | 4164 | void compare_matmul_result( | |
| 759 | const MatMulShape& shape, const Rect& output_area, const Buffer& actual, const Buffer& reference) { | ||
| 760 | 4164 | size_t mismatches = 0; | |
| 761 | 4164 | bool printed_row = false; | |
| 762 | 4164 | std::ostringstream sstream; | |
| 763 |
2/2✓ Branch 0 taken 94884 times.
✓ Branch 1 taken 4164 times.
|
99048 | for (size_t m_i = 0; m_i < shape.m; ++m_i) { |
| 764 |
2/2✓ Branch 0 taken 7690212 times.
✓ Branch 1 taken 94884 times.
|
7785096 | for (size_t n_i = 0; n_i < shape.n; ++n_i) { |
| 765 | 7690212 | const auto i = m_i * shape.n + n_i; | |
| 766 |
6/8✓ Branch 0 taken 7690212 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6881252 times.
✓ Branch 3 taken 808960 times.
✓ Branch 4 taken 6881252 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6015836 times.
✓ Branch 7 taken 865416 times.
|
13706048 | const auto in_area = m_i >= output_area.start_row() && m_i < output_area.end_row() && |
| 767 |
4/6✓ Branch 0 taken 6015836 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5388508 times.
✓ Branch 3 taken 627328 times.
✓ Branch 4 taken 5388508 times.
✗ Branch 5 not taken.
|
6015836 | n_i >= output_area.start_col() && n_i < output_area.end_col(); |
| 768 | |||
| 769 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 7690212 times.
|
7690212 | const auto imp_value = read_array<int8_t>(actual.data(), i); |
| 770 |
3/4✓ Branch 0 taken 4807837 times.
✓ Branch 1 taken 2882375 times.
✓ Branch 2 taken 4807837 times.
✗ Branch 3 not taken.
|
7690212 | const auto ref_value = in_area ? read_array<int8_t>(reference.data(), i) : 0; |
| 771 | 7690212 | const auto error = std::abs(imp_value - ref_value); | |
| 772 | 7690212 | const auto threshold = in_area ? 1 : 0; | |
| 773 | 7690212 | const bool mismatch = error > threshold; | |
| 774 |
1/2✓ Branch 0 taken 7690212 times.
✗ Branch 1 not taken.
|
7690212 | if (mismatch) { |
| 775 | ✗ | if (not printed_row) { | |
| 776 | ✗ | sstream << " row=" << m_i << ", columns: "; | |
| 777 | ✗ | printed_row = true; | |
| 778 | ✗ | } | |
| 779 | ✗ | sstream << n_i << ", "; | |
| 780 | ✗ | } | |
| 781 | 7690212 | mismatches += static_cast<size_t>(mismatch); | |
| 782 | 7690212 | } | |
| 783 |
1/2✓ Branch 0 taken 94884 times.
✗ Branch 1 not taken.
|
94884 | if (printed_row) { |
| 784 | ✗ | sstream << "\n"; | |
| 785 | ✗ | } | |
| 786 | 94884 | printed_row = false; | |
| 787 | 94884 | } | |
| 788 |
3/20✓ Branch 0 taken 4164 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4164 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✓ Branch 19 taken 4164 times.
|
4164 | ASSERT_EQ(mismatches, 0) << "Mismatches between reference result and actual result:\n" << sstream.str(); |
| 789 | 4164 | } | |
| 790 | |||
| 791 | /// Test MatMul of GEMM/GEMV like kernel | ||
| 792 | 834 | void test_matmul( | |
| 793 | const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { | ||
| 794 | 834 | const auto imp_dst_size = variant.matmul.get_dst_size(shape.m, shape.n); | |
| 795 |
2/12✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 834 times.
|
834 | ASSERT_EQ(imp_dst_size, reference.dst_qsi8_clamped.size()); |
| 796 | |||
| 797 | 834 | Buffer imp_dst(imp_dst_size, 0); | |
| 798 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
3336 | const auto [imp_lhs_offset, lhs_data] = [&]() -> std::tuple<size_t, const Buffer&> { |
| 799 |
2/2✓ Branch 0 taken 666 times.
✓ Branch 1 taken 168 times.
|
834 | if (variant.lhs_pack.has_value()) { |
| 800 | 666 | return {variant.matmul.get_packed_lhs_offset(output_area.start_row(), shape.k), reference.packed_lhs}; | |
| 801 | } | ||
| 802 | 168 | return {output_area.start_row() * shape.k, reference.lhs_qai8}; | |
| 803 | 834 | }(); | |
| 804 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | const size_t imp_packed_rhs_offset = variant.matmul.get_packed_rhs_offset(output_area.start_col(), shape.k); |
| 805 | 1668 | const size_t imp_dst_offset = | |
| 806 |
3/6✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
|
834 | variant.matmul.get_dst_offset(output_area.start_row(), output_area.start_col(), shape.n * sizeof(int8_t)); |
| 807 |
5/18✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 834 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 834 times.
|
834 | ASSERT_EQ(imp_dst_offset, output_area.start_row() * shape.n + output_area.start_col()); |
| 808 | |||
| 809 | 834 | kai_matmul_requantize32_params imp_main_params{}; | |
| 810 | 834 | imp_main_params.min_value = reference.clamp.min; | |
| 811 | 834 | imp_main_params.max_value = reference.clamp.max; | |
| 812 | 834 | imp_main_params.output_zero_point = reference.qa_dst.zero_point; | |
| 813 | |||
| 814 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
834 | abi_check( |
| 815 |
2/4✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
|
834 | variant.matmul.matmul, output_area.height(), output_area.width(), shape.k, lhs_data.data() + imp_lhs_offset, |
| 816 | 834 | reference.packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t), | |
| 817 | 834 | sizeof(int8_t), &imp_main_params); | |
| 818 | |||
| 819 |
1/2✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
|
834 | compare_matmul_result(shape, output_area, imp_dst, reference.dst_qsi8_clamped); |
| 820 | 834 | } | |
| 821 | |||
| 822 | } // namespace | ||
| 823 | |||
| 824 | using MatMulQuantizedTest = testing::TestWithParam<std::tuple<MatMulVariant, MatMulShape, MatrixPortion, float>>; | ||
| 825 | using IndirectMatMulQuantizedTestParams = std::tuple<IndirectMatMulVariant, MatMulShape, size_t, MatrixPortion, float>; | ||
| 826 | using IndirectMatMulQuantizedTest = testing::TestWithParam<IndirectMatMulQuantizedTestParams>; | ||
| 827 | |||
| 828 | 2502 | static std::string test_description( | |
| 829 | const MatMulVariant& variant, // | ||
| 830 | const MatMulShape& shape, // | ||
| 831 | const MatrixPortion& portion, float clamp_keep_ratio) { | ||
| 832 | 2502 | std::ostringstream sstream; | |
| 833 | |||
| 834 |
2/4✓ Branch 0 taken 2502 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2502 times.
✗ Branch 3 not taken.
|
2502 | sstream << test_description(variant.name, shape, portion, true, clamp_keep_ratio); |
| 835 | |||
| 836 |
1/2✓ Branch 0 taken 2502 times.
✗ Branch 1 not taken.
|
2502 | return sstream.str(); |
| 837 | 2502 | }; | |
| 838 | |||
| 839 | 19980 | [[maybe_unused]] static void PrintTo(const IndirectMatMulQuantizedTestParams& param, std::ostream* os) { | |
| 840 | 119880 | const auto& [variant, shape, k_chunk_length, portion, clamp_keep_ratio] = param; | |
| 841 | |||
| 842 | 39960 | *os << variant.name << "__"; | |
| 843 | 19980 | PrintTo(shape, os); | |
| 844 | 39960 | *os << "__K_chunk_length_" << k_chunk_length; | |
| 845 | 39960 | *os << "__clamp_keep_ratio_" << static_cast<int>(clamp_keep_ratio * 100) << "__"; | |
| 846 | 19980 | PrintTo(portion, os); | |
| 847 | 19980 | }; | |
| 848 | |||
| 849 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
4176 | TEST_P(MatMulQuantizedTest, EndToEnd) { |
| 850 | 14178 | const auto& [variant, shape, output_portion, clamp_keep_ratio] = GetParam(); | |
| 851 | |||
| 852 |
2/2✓ Branch 0 taken 834 times.
✓ Branch 1 taken 834 times.
|
1668 | if (!variant.is_supported()) { |
| 853 |
3/6✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
|
834 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 854 | } | ||
| 855 | |||
| 856 | 4170 | TestDataId test_data_id{shape, variant.acc_pack, shape.k, false, clamp_keep_ratio}; | |
| 857 | 834 | const TestReference& reference = get_test_reference(test_data_id); | |
| 858 | |||
| 859 | // Check scheduling parameters | ||
| 860 | 1668 | const auto imp_mr = variant.matmul.get_mr(); | |
| 861 | 1668 | const auto imp_nr = variant.matmul.get_nr(); | |
| 862 | 1668 | const auto imp_kr = variant.matmul.get_kr(); | |
| 863 | 1668 | const auto imp_sr = variant.matmul.get_sr(); | |
| 864 | |||
| 865 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_mr, variant.acc_pack.m); |
| 866 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_nr, variant.acc_pack.n); |
| 867 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_kr, variant.acc_pack.k); |
| 868 |
3/14✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 834 times.
|
834 | ASSERT_EQ(imp_sr, 1); |
| 869 | |||
| 870 | // Check that stepping is a multiple of accumulation | ||
| 871 | 1668 | const auto imp_m_step = variant.matmul.get_m_step(); | |
| 872 | 1668 | const auto imp_n_step = variant.matmul.get_n_step(); | |
| 873 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_m_step, variant.acc_step.m); |
| 874 |
4/16✓ Branch 0 taken 834 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 834 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 834 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 834 times.
|
1668 | ASSERT_EQ(imp_n_step, variant.acc_step.n); |
| 875 | |||
| 876 | // Test kernels. Note that packing and actual stepping might not be the same | ||
| 877 | 4170 | const auto pack_portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_pack.m, variant.acc_pack.n); | |
| 878 | 834 | const auto matmul_portion = | |
| 879 | 3336 | output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); | |
| 880 |
2/2✓ Branch 0 taken 168 times.
✓ Branch 1 taken 666 times.
|
834 | if (variant.lhs_pack.has_value()) { |
| 881 | 666 | test_lhs_pack(shape, variant, pack_portion, reference); | |
| 882 | 666 | } | |
| 883 | 834 | test_rhs_pack(shape, variant, pack_portion, reference); | |
| 884 | 834 | test_matmul(shape, variant, matmul_portion, reference); | |
| 885 | 1668 | } | |
| 886 | |||
| 887 | namespace imatmul { | ||
| 888 | |||
| 889 | /// Perform LHS IMATMUL packing | ||
| 890 | 3330 | static Buffer lhs_pack( | |
| 891 | const LhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t m, | ||
| 892 | const KChunk& k_chunk) { | ||
| 893 | 6660 | const void* const* indirection_pointer = | |
| 894 | 3330 | reinterpret_cast<const void* const*>(reference.lhs_qai8_indirect_packed.data()); | |
| 895 | |||
| 896 | // Allocate buffer | ||
| 897 | 3330 | const size_t dst_size = variant.get_packed_lhs_size(m, k_chunk.count, k_chunk.length); | |
| 898 | 3330 | Buffer packed(dst_size); | |
| 899 | |||
| 900 | // Calculate offsets | ||
| 901 |
1/2✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
|
3330 | const size_t input_offset = portion.start_row() * k_chunk.count; |
| 902 |
2/4✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
|
3330 | const size_t dst_offset = variant.get_packed_lhs_offset(portion.start_row(), k_chunk.count, k_chunk.length); |
| 903 | |||
| 904 |
1/2✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
|
3330 | abi_check( |
| 905 | 3330 | variant.pack, // Kernel | |
| 906 |
1/2✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
|
3330 | portion.height(), k_chunk.count, k_chunk.length, // Dimensions |
| 907 | 3330 | indirection_pointer + input_offset, // Indirection input | |
| 908 | 3330 | reference.lhs_qai8_indirect_offset, // chunk offset | |
| 909 | 3330 | reference.lhs_qai8_indirect_padding.data(), // padding pointer | |
| 910 | 3330 | packed.data() + dst_offset); | |
| 911 | |||
| 912 | 3330 | return packed; | |
| 913 | 3330 | } | |
| 914 | |||
| 915 | /// Perform RHS IMATMUL packing | ||
| 916 | 3330 | static Buffer rhs_pack( | |
| 917 | const RhsPackIndirectKernel& variant, const Rect& portion, const TestReference& reference, size_t n, | ||
| 918 | const KChunk& k_chunk) { | ||
| 919 | // Allocate output buffer | ||
| 920 | 3330 | const size_t dst_size = variant.get_packed_rhs_size(n, k_chunk.count, k_chunk.length); | |
| 921 | 3330 | Buffer packed(dst_size); | |
| 922 | |||
| 923 | // Caluclate effective quantization parameters | ||
| 924 | 9990 | const kai_rhs_pack_qsi8cx_params quantization{ | |
| 925 | 3330 | reference.qa_lhs.zero_point, | |
| 926 | 3330 | reference.qa_lhs.scale / reference.qa_dst.scale, | |
| 927 | }; | ||
| 928 | |||
| 929 | // Calculate offsets | ||
| 930 |
2/4✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
|
3330 | const size_t rhs_offset = variant.get_rhs_offset(portion.start_col()); |
| 931 |
2/4✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
|
3330 | const size_t bias_offset = variant.get_bias_offset(portion.start_col()); |
| 932 |
2/4✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
|
3330 | const size_t scale_offset = variant.get_scale_offset(portion.start_col()); |
| 933 |
2/4✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
|
3330 | const size_t dst_offset = variant.get_packed_rhs_offset(portion.start_col(), k_chunk.count, k_chunk.length); |
| 934 | |||
| 935 | // Pack | ||
| 936 |
1/2✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
|
3330 | abi_check( |
| 937 | 3330 | variant.pack, // Kernel | |
| 938 |
1/2✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
|
3330 | portion.width(), k_chunk.count, k_chunk.length, // Dimensions |
| 939 | 3330 | n * sizeof(uint8_t), // Row stride | |
| 940 | 3330 | reference.rhs_qsi8.data() + rhs_offset, // RHS matrix | |
| 941 | 3330 | reference.bias_qsi32.data() + bias_offset, // Bias | |
| 942 | 3330 | reference.rhs_scales.data() + scale_offset, // Scales | |
| 943 | 3330 | packed.data() + dst_offset, // Output | |
| 944 | 3330 | &quantization); | |
| 945 | |||
| 946 | 3330 | return packed; | |
| 947 | 3330 | } | |
| 948 | |||
| 949 | /// Calculate the matmul result from IMATMUL kernels | ||
| 950 | 3330 | static Buffer matmul( | |
| 951 | const MatMulIndirectKernel& variant, const Rect& portion, const TestReference& reference, const Buffer& packed_lhs, | ||
| 952 | const Buffer& packed_rhs, const MatMulShape& shape, const KChunk& k_chunk) { | ||
| 953 | // Calculate portion offsets. | ||
| 954 | 3330 | size_t dst_offset = variant.get_dst_offset(portion.start_row(), portion.start_col(), shape.n); | |
| 955 | 3330 | size_t lhs_offset = variant.get_lhs_packed_offset(portion.start_row(), k_chunk.count, k_chunk.length); | |
| 956 | 3330 | size_t rhs_offset = variant.get_rhs_packed_offset(portion.start_col(), k_chunk.count, k_chunk.length); | |
| 957 | |||
| 958 | // Allocate output buffer | ||
| 959 | 3330 | const size_t dst_size = variant.get_dst_size(shape.m, shape.n); | |
| 960 | 3330 | Buffer dst(dst_size, 0); | |
| 961 | |||
| 962 | // Calculate geffective uantization parameters | ||
| 963 | 3330 | kai_matmul_requantize32_params requantization{}; | |
| 964 | 3330 | requantization.min_value = reference.clamp.min; | |
| 965 | 3330 | requantization.max_value = reference.clamp.max; | |
| 966 | 3330 | requantization.output_zero_point = reference.qa_dst.zero_point; | |
| 967 | |||
| 968 | // Call matmul kernel | ||
| 969 |
1/2✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
|
3330 | abi_check( |
| 970 | 3330 | variant.imatmul, // Kernel | |
| 971 |
2/4✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
|
3330 | portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions |
| 972 | 3330 | packed_lhs.data() + lhs_offset, // LHS | |
| 973 | 3330 | packed_rhs.data() + rhs_offset, // RHS | |
| 974 | 3330 | dst.data() + dst_offset, // DST | |
| 975 | 3330 | shape.n * sizeof(uint8_t), &requantization); | |
| 976 | |||
| 977 | 3330 | return dst; | |
| 978 | 3330 | } | |
| 979 | } // namespace imatmul | ||
| 980 | |||
| 981 |
8/16✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
|
16656 | TEST_P(IndirectMatMulQuantizedTest, EndToEnd) { |
| 982 | /* This is a bit special, as shape.k must be k_chunk_len * k_chunk_count | ||
| 983 | * so instead of inventing a new special kind of shape, simply multiply | ||
| 984 | * with `k_chunk_len` here */ | ||
| 985 | 39960 | const auto& [variant, shape_k_chunk, k_chunk_len, output_portion, clamp_keep_ratio] = GetParam(); | |
| 986 | 19980 | const KChunk k_chunk{shape_k_chunk.k, k_chunk_len}; | |
| 987 | 19980 | MatMulShape shape{shape_k_chunk.m, shape_k_chunk.n, k_chunk.count * k_chunk.length}; | |
| 988 | |||
| 989 |
2/2✓ Branch 0 taken 3330 times.
✓ Branch 1 taken 3330 times.
|
6660 | if (!variant.is_supported()) { |
| 990 |
3/6✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3330 times.
✗ Branch 5 not taken.
|
3330 | GTEST_SKIP() << "Unsupported CPU feature"; |
| 991 | } | ||
| 992 | |||
| 993 | // Toggle padding testst when LHS has more than one row | ||
| 994 | 9990 | TestDataId test_data_id{shape, variant.acc_pack, k_chunk.length, shape.m > 1, clamp_keep_ratio}; | |
| 995 | 3330 | const TestReference& reference = get_test_reference(test_data_id); | |
| 996 | 13320 | const Rect portion = output_portion.compute_portion(shape.m, shape.n, variant.acc_step.m, variant.acc_step.n); | |
| 997 | |||
| 998 | 6660 | Buffer packed_lhs = imatmul::lhs_pack(variant.lhs_pack, portion, reference, shape.m, k_chunk); | |
| 999 |
2/4✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
|
6660 | Buffer packed_rhs = imatmul::rhs_pack(variant.rhs_pack, portion, reference, shape.n, k_chunk); |
| 1000 |
2/4✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3330 times.
✗ Branch 3 not taken.
|
6660 | Buffer impl_result = imatmul::matmul(variant.matmul, portion, reference, packed_lhs, packed_rhs, shape, k_chunk); |
| 1001 |
1/2✓ Branch 0 taken 3330 times.
✗ Branch 1 not taken.
|
3330 | compare_matmul_result(shape, portion, impl_result, reference.dst_qsi8_clamped); |
| 1002 | 6660 | } | |
| 1003 | |||
| 1004 | static constexpr std::array shapes{ | ||
| 1005 | // clang-format off | ||
| 1006 | MatMulShape{ 1, 1, 1}, | ||
| 1007 | MatMulShape{ 1, 16, 4}, | ||
| 1008 | MatMulShape{ 1, 16, 16}, | ||
| 1009 | MatMulShape{ 1, 17, 4}, | ||
| 1010 | MatMulShape{ 1, 19, 24}, | ||
| 1011 | MatMulShape{ 1, 32, 4}, | ||
| 1012 | MatMulShape{ 1, 32, 32}, | ||
| 1013 | MatMulShape{ 1, 33,200}, | ||
| 1014 | MatMulShape{ 1, 49, 21}, | ||
| 1015 | MatMulShape{ 1, 64, 4}, | ||
| 1016 | MatMulShape{ 1, 65, 4}, | ||
| 1017 | MatMulShape{ 1, 300, 10}, | ||
| 1018 | MatMulShape{ 1, 512, 4}, | ||
| 1019 | MatMulShape{ 1, 1523, 10}, | ||
| 1020 | MatMulShape{ 2, 195, 50}, | ||
| 1021 | MatMulShape{ 3, 6, 6}, | ||
| 1022 | MatMulShape{ 3, 28, 25}, | ||
| 1023 | MatMulShape{ 3, 184,177}, | ||
| 1024 | MatMulShape{ 4, 16, 27}, | ||
| 1025 | MatMulShape{ 5, 136, 23}, | ||
| 1026 | MatMulShape{ 6, 18, 31}, | ||
| 1027 | MatMulShape{ 6, 28, 1}, | ||
| 1028 | MatMulShape{ 6, 29, 24}, | ||
| 1029 | MatMulShape{ 16, 16, 4}, | ||
| 1030 | MatMulShape{ 20, 30, 40}, | ||
| 1031 | MatMulShape{ 23, 1, 43}, | ||
| 1032 | MatMulShape{ 32, 14, 1}, | ||
| 1033 | MatMulShape{ 32, 16, 27}, | ||
| 1034 | MatMulShape{ 32, 32, 3}, | ||
| 1035 | MatMulShape{ 32, 32, 4}, | ||
| 1036 | MatMulShape{ 33, 29, 24}, | ||
| 1037 | MatMulShape{ 64, 64, 3}, | ||
| 1038 | MatMulShape{ 64, 64, 4}, | ||
| 1039 | MatMulShape{ 96, 96, 3}, | ||
| 1040 | MatMulShape{123, 85, 45}, | ||
| 1041 | MatMulShape{128, 128, 3}, | ||
| 1042 | MatMulShape{130, 130, 6}, | ||
| 1043 | // clang-format on | ||
| 1044 | }; | ||
| 1045 | |||
| 1046 | static constexpr std::array portions{ | ||
| 1047 | // clang-format off | ||
| 1048 | // (Start row , start col , height , width) | ||
| 1049 | MatrixPortion( 0 , 0 , 1 , 1) , // Full matrix. | ||
| 1050 | MatrixPortion( 0 , 0 , 1 , 0.5) , // Left half | ||
| 1051 | MatrixPortion( 0 , 0 , 0.5 , 1) , // Upper half | ||
| 1052 | MatrixPortion( 0 , 0.5 , 1 , 0.5) , // Right half | ||
| 1053 | MatrixPortion( 0.5 , 0 , 0.5 , 1) , // Bottom half | ||
| 1054 | MatrixPortion( 0.4 , 0.4 , 0.3 , 0.3) , // Center ninth | ||
| 1055 | // clang-format on | ||
| 1056 | }; | ||
| 1057 | |||
| 1058 |
18/56✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✓ Branch 20 taken 666 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✓ Branch 22 taken 1332 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
|
4002 | INSTANTIATE_TEST_SUITE_P( |
| 1059 | matmul_clamp_qai8_qai8p_qsi8cxp, MatMulQuantizedTest, | ||
| 1060 | testing::Combine( | ||
| 1061 | testing::ValuesIn(get_gemm_variants()), // | ||
| 1062 | testing::ValuesIn(shapes), // | ||
| 1063 | testing::ValuesIn({ | ||
| 1064 | // clang-format off | ||
| 1065 | MatrixPortion( 0, 0, 1, 1), // Full matrix. | ||
| 1066 | MatrixPortion( 0, 0, 0.25, 0.25), // Top-left corner. | ||
| 1067 | MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. | ||
| 1068 | // clang-format on | ||
| 1069 | }), | ||
| 1070 | testing::ValuesIn(std::initializer_list<float>{0.0F, 0.1F, 0.5F})), | ||
| 1071 | [](const auto& info) -> std::string { | ||
| 1072 | return test_description( | ||
| 1073 | std::get<MatMulVariant>(info.param), // | ||
| 1074 | std::get<MatMulShape>(info.param), // | ||
| 1075 | std::get<MatrixPortion>(info.param), // | ||
| 1076 | std::get<float>(info.param)); | ||
| 1077 | }); | ||
| 1078 | |||
| 1079 |
18/56✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✓ Branch 20 taken 168 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✓ Branch 22 taken 336 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
|
1014 | INSTANTIATE_TEST_SUITE_P( |
| 1080 | matmul_clamp_qai8_qai8_qsi8cxp, MatMulQuantizedTest, | ||
| 1081 | testing::Combine( | ||
| 1082 | testing::ValuesIn(get_gemv_variants()), | ||
| 1083 | testing::ValuesIn({ | ||
| 1084 | // clang-format off | ||
| 1085 | MatMulShape{ 1, 1, 1}, | ||
| 1086 | MatMulShape{ 1, 16, 4}, | ||
| 1087 | MatMulShape{ 1, 16, 16}, | ||
| 1088 | MatMulShape{ 1, 17, 4}, | ||
| 1089 | MatMulShape{ 1, 19, 24}, | ||
| 1090 | MatMulShape{ 1, 32, 4}, | ||
| 1091 | MatMulShape{ 1, 32, 32}, | ||
| 1092 | MatMulShape{ 1, 33,200}, | ||
| 1093 | MatMulShape{ 1, 49, 21}, | ||
| 1094 | MatMulShape{ 1, 64, 4}, | ||
| 1095 | MatMulShape{ 1, 65, 4}, | ||
| 1096 | MatMulShape{ 1, 300, 10}, | ||
| 1097 | MatMulShape{ 1, 512, 4}, | ||
| 1098 | MatMulShape{ 1, 1523, 10}, | ||
| 1099 | // clang-format on | ||
| 1100 | }), | ||
| 1101 | testing::ValuesIn({ | ||
| 1102 | // clang-format off | ||
| 1103 | MatrixPortion(0, 0, 1, 1), // Full matrix. | ||
| 1104 | MatrixPortion(0, .5, 1, .5), // Right half | ||
| 1105 | MatrixPortion(0, 0, 1, .5), // Left half | ||
| 1106 | MatrixPortion(0, .25, 1, .5) // Middle half | ||
| 1107 | // clang-format on | ||
| 1108 | }), | ||
| 1109 | // Clamp range | ||
| 1110 | testing::ValuesIn(std::initializer_list<float>({1.0f, 0.9f, 0.5f}))), // clamp_keep_ratio | ||
| 1111 | [](const auto& info) -> std::string { | ||
| 1112 | return test_description( | ||
| 1113 | std::get<MatMulVariant>(info.param), // | ||
| 1114 | std::get<MatMulShape>(info.param), // | ||
| 1115 | std::get<MatrixPortion>(info.param), // | ||
| 1116 | std::get<float>(info.param)); | ||
| 1117 | }); | ||
| 1118 | |||
| 1119 |
20/64✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 2664 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 5328 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
|
7998 | INSTANTIATE_TEST_SUITE_P( |
| 1120 | ShapesSmallKC, IndirectMatMulQuantizedTest, | ||
| 1121 | testing::Combine( | ||
| 1122 | testing::ValuesIn(get_indirect_gemm_variants()), // | ||
| 1123 | testing::ValuesIn(shapes), // | ||
| 1124 | // k_chunk_len | ||
| 1125 | testing::ValuesIn(std::initializer_list<size_t>{1, 2, 3, 4, 8, 11}), // | ||
| 1126 | testing::ValuesIn(portions), // | ||
| 1127 | // Clamp range | ||
| 1128 | testing::Values(0.1F)), | ||
| 1129 | testing::PrintToStringParamName()); | ||
| 1130 | |||
| 1131 |
20/64✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 22 taken 444 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 24 taken 888 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
|
1338 | INSTANTIATE_TEST_SUITE_P( |
| 1132 | ShapesKC32, IndirectMatMulQuantizedTest, | ||
| 1133 | testing::Combine( | ||
| 1134 | testing::ValuesIn(get_indirect_gemm_variants()), // | ||
| 1135 | testing::ValuesIn(shapes), // | ||
| 1136 | // k_chunk_len | ||
| 1137 | testing::ValuesIn(std::initializer_list<size_t>{32}), // | ||
| 1138 | testing::ValuesIn(portions), // | ||
| 1139 | // Clamp range | ||
| 1140 | testing::Values(0.1F)), | ||
| 1141 | testing::PrintToStringParamName()); | ||
| 1142 | |||
| 1143 |
22/72✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 10 taken 1 time.
✗ Branch 11 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 time.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 time.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 time.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 time.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 time.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 time.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✓ Branch 24 taken 222 times.
✗ Branch 25 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✓ Branch 26 taken 444 times.
✗ Branch 27 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
|
672 | INSTANTIATE_TEST_SUITE_P( |
| 1144 | Clamp, IndirectMatMulQuantizedTest, | ||
| 1145 | testing::Combine( | ||
| 1146 | testing::ValuesIn(get_indirect_gemm_variants()), // | ||
| 1147 | testing::ValuesIn(shapes), // | ||
| 1148 | // k_chunk_len | ||
| 1149 | testing::ValuesIn(std::initializer_list<size_t>{1}), // | ||
| 1150 | testing::Values(MatrixPortion(0, 0, 1, 1)), // | ||
| 1151 | // Clamp range | ||
| 1152 | testing::ValuesIn(std::initializer_list<float>{1.0f, 0.9f, 0.5f})), // clamp_keep_ratio | ||
| 1153 | testing::PrintToStringParamName()); | ||
| 1154 | |||
| 1155 | } // namespace kai::test | ||
| 1156 |