Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if !defined(__aarch64__) && !defined(_M_ARM64) | ||
8 | #error This file must be compiled for AArch64. | ||
9 | #else // Architectural features check. | ||
10 | |||
11 | #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" | ||
12 | |||
13 | #include <math.h> | ||
14 | #include <stddef.h> | ||
15 | #include <stdint.h> | ||
16 | |||
17 | #include "kai/kai_common.h" | ||
18 | |||
19 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
20 | |||
21 | 124 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
22 | 124 | return bl * sizeof(int8_t) + kai_num_bytes_multiplier; | |
23 | } | ||
24 | |||
25 | 150 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
26 | − | KAI_ASSERT((k % bl) == 0); | |
27 | 150 | return k / bl; | |
28 | } | ||
29 | |||
30 | 124 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { | |
31 | 124 | KAI_UNUSED(kr); | |
32 | |||
33 | 124 | return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); | |
34 | } | ||
35 | |||
36 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qsi8d32p_f32_neon(size_t mr) { | |
37 | ✗ | KAI_UNUSED(mr); | |
38 | ✗ | return 1; | |
39 | } | ||
40 | |||
41 | 26 | size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon(size_t m_idx, size_t lhs_stride) { | |
42 | 26 | return m_idx * lhs_stride; | |
43 | } | ||
44 | |||
45 | 72 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon( | |
46 | size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
47 | − | KAI_ASSUME((k % 2) == 0); | |
48 | − | KAI_ASSUME((k % kr) == 0); | |
49 | − | KAI_ASSUME((k % bl) == 0); | |
50 | − | KAI_ASSUME((m_idx % mr) == 0); | |
51 | |||
52 | 72 | KAI_UNUSED(sr); | |
53 | 72 | KAI_UNUSED(kr); | |
54 | |||
55 | // The scales are stored after all the mr packed quantized values | ||
56 | 72 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); | |
57 | } | ||
58 | |||
59 | 26 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon( | |
60 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
61 | − | KAI_ASSUME((k % 2) == 0); | |
62 | − | KAI_ASSUME((k % kr) == 0); | |
63 | − | KAI_ASSUME((k % bl) == 0); | |
64 | |||
65 | 26 | KAI_UNUSED(sr); | |
66 | 26 | KAI_UNUSED(kr); | |
67 | |||
68 | 26 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
69 | |||
70 | 52 | return (num_rows * kai_lhs_packed_stride(k, mr, kr, bl)); | |
71 | 26 | } | |
72 | |||
73 | 26 | void kai_run_lhs_quant_pack_qsi8d32p_f32_neon( | |
74 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, | ||
75 | size_t lhs_stride, void* lhs_packed) { | ||
76 | − | KAI_ASSUME((bl % kr) == 0); | |
77 | − | KAI_ASSUME((k % bl) == 0); | |
78 | − | KAI_ASSUME(kr == 4); | |
79 | − | KAI_ASSUME(bl == 32); | |
80 | 26 | KAI_UNUSED(sr); | |
81 | 26 | KAI_UNUSED(m_idx_start); | |
82 | 26 | KAI_UNUSED(lhs_stride); | |
83 | |||
84 |
1/2✓ Branch 0 taken 26 times.
✗ Branch 1 not taken.
|
26 | if (m == 0) { |
85 | ✗ | return; | |
86 | } | ||
87 | |||
88 | 26 | const size_t num_blocks = kai_num_blocks_per_row(k, bl); | |
89 | 26 | const size_t lhs_packed_stride = kai_lhs_packed_stride(k, mr, kr, bl); | |
90 | |||
91 | 26 | const float* lhs_ptr = lhs; | |
92 | 26 | int8_t* lhs_packed_start_ptr = lhs_packed; | |
93 | |||
94 |
2/2✓ Branch 0 taken 598 times.
✓ Branch 1 taken 26 times.
|
624 | for (size_t m_idx = 0; m_idx < m; m_idx++) { |
95 | 598 | int8_t* lhs_packed_ptr = lhs_packed_start_ptr; | |
96 | 1196 | uint16_t* lhs_packed_scales = | |
97 | 598 | (uint16_t*)(lhs_packed_ptr + lhs_packed_stride - ((mr * num_blocks) * kai_num_bytes_multiplier)); | |
98 | |||
99 | 598 | lhs_packed_ptr += (m_idx % mr) * kr; | |
100 | 598 | lhs_packed_scales += (m_idx % mr); | |
101 | |||
102 |
2/2✓ Branch 0 taken 1130 times.
✓ Branch 1 taken 598 times.
|
1728 | for (size_t block_idx = 0; block_idx < num_blocks; block_idx++) { |
103 | // Maximum absolute value of the block elements | ||
104 | 1130 | float amax = 0.0F; | |
105 | |||
106 |
2/2✓ Branch 0 taken 36160 times.
✓ Branch 1 taken 1130 times.
|
37290 | for (size_t bl_idx = 0; bl_idx < bl; bl_idx++) { |
107 |
2/2✓ Branch 0 taken 31502 times.
✓ Branch 1 taken 4658 times.
|
36160 | amax = KAI_MAX(amax, fabsf(lhs_ptr[bl_idx])); |
108 | 36160 | } | |
109 | |||
110 | 1130 | const float sf = amax / ((1 << 7) - 1); | |
111 | |||
112 |
1/2✓ Branch 0 taken 1130 times.
✗ Branch 1 not taken.
|
1130 | const float sf_inv = sf ? 1.0F / sf : 0.0F; |
113 | |||
114 |
2/2✓ Branch 0 taken 9040 times.
✓ Branch 1 taken 1130 times.
|
10170 | for (size_t bl_idx = 0; bl_idx < bl; bl_idx += kr) { |
115 |
2/2✓ Branch 0 taken 36160 times.
✓ Branch 1 taken 9040 times.
|
45200 | for (size_t kr_idx = 0; kr_idx < kr; ++kr_idx) { |
116 | 36160 | int32_t v0_s32 = (int32_t)(roundf(lhs_ptr[kr_idx] * sf_inv)); | |
117 | 36160 | lhs_packed_ptr[kr_idx] = (int8_t)v0_s32; | |
118 | 36160 | } | |
119 | 9040 | lhs_ptr += kr; | |
120 | 9040 | lhs_packed_ptr += mr * kr; | |
121 | 9040 | } | |
122 | |||
123 | // Num_blocks (rows) x Mr (cols) | ||
124 | 1130 | lhs_packed_scales[0] = kai_cast_f16_f32(sf); | |
125 | |||
126 | 1130 | lhs_packed_scales += mr; | |
127 | 1130 | } | |
128 |
2/2✓ Branch 0 taken 567 times.
✓ Branch 1 taken 31 times.
|
598 | if (((m_idx + 1) % mr) == 0) { |
129 | 31 | lhs_packed_start_ptr += lhs_packed_stride; | |
130 | 31 | } | |
131 | 598 | } | |
132 | 26 | } | |
133 | #endif // Architectural features check. | ||
134 |