Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) | ||
8 | #error This file must be compiled for AArch64 and FEAT_FP16. | ||
9 | #else // Architectural features check. | ||
10 | |||
11 | #include "kai_lhs_quant_pack_qsi8d32pscalef32_f16_neon.h" | ||
12 | |||
13 | #include <arm_neon.h> | ||
14 | #include <float.h> | ||
15 | #include <math.h> | ||
16 | #include <stddef.h> | ||
17 | #include <stdint.h> | ||
18 | |||
19 | #include "kai/kai_common.h" | ||
20 | #define FLT16_MAX 65504.0F | ||
21 | #define FLT16_MIN (-65504.0F) | ||
22 | static const size_t kai_num_bytes_sum = sizeof(float); | ||
23 | static const size_t kai_num_bytes_multiplier = sizeof(float); | ||
24 | static const size_t kai_bl_multiple_of = 32; | ||
25 | |||
26 | 4805224 | inline static size_t kai_get_num_bytes_per_block(size_t bl) { | |
27 | 4805224 | return bl * sizeof(int8_t) + kai_num_bytes_multiplier + kai_num_bytes_sum; | |
28 | } | ||
29 | |||
30 | 4805224 | inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { | |
31 | − | KAI_ASSERT((k % bl) == 0); | |
32 | 4805224 | return k / bl; | |
33 | } | ||
34 | |||
35 | 4794768 | inline static size_t kai_get_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { | |
36 | 4794768 | KAI_UNUSED(kr); | |
37 | |||
38 | 4794768 | return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block(bl); | |
39 | } | ||
40 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qsi8d32pscalef32_f16_neon(size_t mr) { | |
41 | ✗ | KAI_UNUSED(mr); | |
42 | ✗ | return 1; | |
43 | } | ||
44 | 10904 | size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f16_neon(size_t m_idx, size_t lhs_stride) { | |
45 | 10904 | return m_idx * lhs_stride; | |
46 | } | ||
47 | 21808 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f16_neon( | |
48 | size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
49 | − | KAI_ASSUME((k % 2) == 0); | |
50 | − | KAI_ASSUME((k % kr) == 0); | |
51 | − | KAI_ASSUME((k % bl) == 0); | |
52 | 21808 | KAI_UNUSED(sr); | |
53 | 21808 | return (m_idx / mr) * kai_get_lhs_packed_stride(k, mr, kr, bl); | |
54 | } | ||
55 | 4762504 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f16_neon( | |
56 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
57 | − | KAI_ASSUME((k % 2) == 0); | |
58 | − | KAI_ASSUME((k % kr) == 0); | |
59 | − | KAI_ASSUME((k % bl) == 0); | |
60 | |||
61 | 4762504 | KAI_UNUSED(sr); | |
62 | |||
63 | 4762504 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
64 | |||
65 | 9525008 | return (num_rows * kai_get_lhs_packed_stride(k, mr, kr, bl)); | |
66 | 4762504 | } | |
67 | 10904 | void kai_run_lhs_quant_pack_qsi8d32pscalef32_f16_neon( | |
68 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, | ||
69 | size_t lhs_stride, void* lhs_packed) { | ||
70 | − | KAI_ASSERT((kr % sr) == 0); | |
71 | − | KAI_ASSUME((bl % kr) == 0); | |
72 | − | KAI_ASSUME((k % bl) == 0); | |
73 | − | KAI_ASSUME((bl % kai_bl_multiple_of) == 0); | |
74 | |||
75 |
2/2✓ Branch 0 taken 10456 times.
✓ Branch 1 taken 448 times.
|
10904 | if (m == 0) { |
76 | 448 | return; | |
77 | } | ||
78 | 10456 | const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, mr, kr, bl); | |
79 | 10456 | const size_t num_rows = m; | |
80 | 10456 | const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); | |
81 | 10456 | const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl); | |
82 | 10456 | const size_t mr_block_size = mr * num_bytes_per_block; | |
83 | |||
84 | 10456 | const int32_t k_block_len = (int32_t)(kr / sr); | |
85 | |||
86 |
2/2✓ Branch 0 taken 113000 times.
✓ Branch 1 taken 10456 times.
|
123456 | for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { |
87 | 113000 | const float16_t* row_src_ptr = (const float16_t*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); | |
88 | 113000 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
89 |
2/2✓ Branch 0 taken 163344 times.
✓ Branch 1 taken 113000 times.
|
276344 | for (size_t b = 0; b < num_blocks_per_row; ++b) { |
90 | 163344 | const float16_t* src_ptr = row_src_ptr + b * bl; | |
91 | 163344 | int8_t* dst_ptr = (int8_t*)lhs_packed + dst_x * k_block_len * sizeof(int8_t) + b * mr_block_size; | |
92 | 163344 | int8_t* param_ptr = (int8_t*)lhs_packed + b * mr_block_size + bl * mr + dst_x * kai_num_bytes_sum; | |
93 | // Find absmax for each block | ||
94 | 163344 | float16_t absmax = (float16_t)(-FLT16_MAX); | |
95 | 163344 | int32_t k_idx = 0; | |
96 | 163344 | float16x8_t vabsmax = vdupq_n_f16(-FLT16_MAX); | |
97 |
2/2✓ Branch 0 taken 850144 times.
✓ Branch 1 taken 163344 times.
|
1013488 | for (; k_idx < ((int32_t)bl); k_idx += 8) { |
98 | 850144 | const float16x8_t src = vabsq_f16(vld1q_f16(src_ptr + (size_t)k_idx)); | |
99 | 850144 | vabsmax = vmaxq_f16(vabsmax, src); | |
100 | 850144 | } | |
101 | // Get the absmax | ||
102 | 163344 | absmax = vmaxvq_f16(vabsmax); | |
103 | // Maximum/minimum int8 values | ||
104 | 163344 | const float qmax = (float)INT8_MAX; | |
105 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 163344 times.
|
163344 | const float scale0 = absmax == 0.0F ? 0.0F : qmax / absmax; |
106 | // Reciprocal to quantize | ||
107 |
1/2✓ Branch 0 taken 163344 times.
✗ Branch 1 not taken.
|
163344 | const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; |
108 | 163344 | int32_t qsum = 0; | |
109 | // Quantize the blocks | ||
110 |
2/2✓ Branch 0 taken 1494760 times.
✓ Branch 1 taken 163344 times.
|
1658104 | for (k_idx = 0; k_idx < (int32_t)bl; k_idx += k_block_len) { |
111 |
2/2✓ Branch 0 taken 6801152 times.
✓ Branch 1 taken 1494760 times.
|
8295912 | for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { |
112 | // Clamp at the last valid k-index | ||
113 |
2/2✓ Branch 0 taken 6737344 times.
✓ Branch 1 taken 63808 times.
|
6801152 | const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k - 1); |
114 | |||
115 | 6801152 | const float16_t src0_0 = *(src_ptr + k_idx_start); | |
116 | |||
117 | // Scale the values | ||
118 | 6801152 | int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); | |
119 | |||
120 |
1/2✓ Branch 0 taken 6801152 times.
✗ Branch 1 not taken.
|
6801152 | v0_s32 = KAI_MAX(v0_s32, INT8_MIN); |
121 |
2/2✓ Branch 0 taken 6717584 times.
✓ Branch 1 taken 83568 times.
|
6801152 | v0_s32 = KAI_MIN(v0_s32, INT8_MAX); |
122 | 6801152 | qsum += v0_s32; | |
123 | |||
124 | 6801152 | *(dst_ptr) = (int8_t)v0_s32; | |
125 | 6801152 | dst_ptr += sizeof(int8_t); | |
126 | 6801152 | } | |
127 | 1494760 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
128 | 1494760 | } | |
129 | 163344 | *((float*)(param_ptr)) = ((float)qsum) * recip_scale0; | |
130 | 163344 | param_ptr += mr * kai_num_bytes_sum; | |
131 | 163344 | *((float*)(param_ptr)) = recip_scale0; | |
132 | 163344 | } | |
133 | // Move to the next row if we have interleaved all Mr rows | ||
134 |
2/2✓ Branch 0 taken 62092 times.
✓ Branch 1 taken 50908 times.
|
113000 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
135 | 50908 | lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); | |
136 | 50908 | } | |
137 | 113000 | } | |
138 | 10904 | } | |
139 | #endif // Architectural features check. | ||
140 |