Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | #include "kai_lhs_quant_pack_qsi8d32p_f32.h" | ||
7 | |||
8 | #include <math.h> | ||
9 | #include <stddef.h> | ||
10 | #include <stdint.h> | ||
11 | |||
12 | #include "kai/kai_common.h" | ||
13 | |||
14 | static const size_t kai_num_bytes_multiplier = sizeof(uint16_t); | ||
15 | |||
16 | 415 | inline static size_t kai_num_bytes_per_block(size_t bl) { | |
17 | 415 | return bl * sizeof(int8_t) + kai_num_bytes_multiplier; | |
18 | } | ||
19 | |||
20 | 415 | inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { | |
21 | − | KAI_ASSERT((k % bl) == 0); | |
22 | 415 | return k / bl; | |
23 | } | ||
24 | |||
25 | 340 | inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) { | |
26 | 340 | KAI_UNUSED(kr); | |
27 | 340 | return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl); | |
28 | } | ||
29 | |||
30 | ✗ | size_t kai_get_m_step_lhs_quant_pack_qsi8d32p_f32(size_t mr) { | |
31 | ✗ | KAI_UNUSED(mr); | |
32 | ✗ | return 1; | |
33 | } | ||
34 | |||
35 | 75 | size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32(size_t m_idx, size_t lhs_stride) { | |
36 | 75 | return m_idx * lhs_stride; | |
37 | } | ||
38 | |||
39 | 190 | size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32( | |
40 | size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
41 | − | KAI_ASSUME((k % 2) == 0); | |
42 | − | KAI_ASSUME((k % kr) == 0); | |
43 | − | KAI_ASSUME((k % bl) == 0); | |
44 | |||
45 | 190 | KAI_UNUSED(sr); | |
46 | 190 | KAI_UNUSED(kr); | |
47 | |||
48 | 190 | return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl); | |
49 | } | ||
50 | |||
51 | 75 | size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32( | |
52 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||
53 | − | KAI_ASSUME((k % 2) == 0); | |
54 | − | KAI_ASSUME((k % kr) == 0); | |
55 | − | KAI_ASSUME((k % bl) == 0); | |
56 | |||
57 | 75 | KAI_UNUSED(sr); | |
58 | 75 | KAI_UNUSED(kr); | |
59 | |||
60 | 75 | const size_t num_rows = kai_roundup(m, mr) / mr; | |
61 | |||
62 | 150 | return num_rows * kai_lhs_packed_stride(k, mr, kr, bl); | |
63 | 75 | } | |
64 | |||
65 | 75 | void kai_run_lhs_quant_pack_qsi8d32p_f32( | |
66 | size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, | ||
67 | size_t lhs_stride, void* lhs_packed) { | ||
68 |
1/2✓ Branch 0 taken 75 times.
✗ Branch 1 not taken.
|
75 | if (m == 0) { |
69 | ✗ | return; | |
70 | } | ||
71 | |||
72 | 75 | const size_t num_rows = m; | |
73 | 75 | const size_t k_block_len = kr / sr; | |
74 | 75 | const size_t lhs_packed_stride = kai_lhs_packed_stride(k, mr, kr, bl); | |
75 | 75 | const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); | |
76 | 75 | const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); | |
77 | |||
78 |
2/2✓ Branch 0 taken 1791 times.
✓ Branch 1 taken 75 times.
|
1866 | for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { |
79 | 1791 | const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride); | |
80 | |||
81 |
2/2✓ Branch 0 taken 3387 times.
✓ Branch 1 taken 1791 times.
|
5178 | for (size_t b = 0; b < num_blocks_per_row; ++b) { |
82 | 3387 | float abs_max = 0.0F; | |
83 | |||
84 | 3387 | const size_t dst_x = ((row_idx + m_idx_start) % mr); | |
85 | 3387 | int8_t* dst_ptr = (int8_t*)lhs_packed + (b * mr) * num_bytes_per_block; | |
86 | |||
87 |
2/2✓ Branch 0 taken 108384 times.
✓ Branch 1 taken 3387 times.
|
111771 | for (size_t idx_v = 0; idx_v < bl; ++idx_v) { |
88 | 108384 | const float val = src_ptr[idx_v]; | |
89 |
2/2✓ Branch 0 taken 94431 times.
✓ Branch 1 taken 13953 times.
|
108384 | abs_max = KAI_MAX(abs_max, fabsf(val)); |
90 | 108384 | } | |
91 | |||
92 | // Calculate scale and reciprocal | ||
93 | 3387 | const float scale = abs_max / ((1 << 7) - 1); | |
94 |
1/2✓ Branch 0 taken 3387 times.
✗ Branch 1 not taken.
|
3387 | const float rep_scale = scale ? 1.0F / scale : 0.0F; |
95 | |||
96 | 3387 | *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale); | |
97 | 3387 | dst_ptr += mr * kai_num_bytes_multiplier; | |
98 | |||
99 | 3387 | dst_ptr += dst_x * k_block_len * sizeof(int8_t); | |
100 | |||
101 | // Quantize and pack the block | ||
102 |
2/2✓ Branch 0 taken 18068 times.
✓ Branch 1 taken 3387 times.
|
21455 | for (size_t k_idx = 0; k_idx < bl; k_idx += k_block_len) { |
103 |
2/2✓ Branch 0 taken 108384 times.
✓ Branch 1 taken 18068 times.
|
126452 | for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
104 | // Clamp at the last valid k-index | ||
105 |
2/2✓ Branch 0 taken 108189 times.
✓ Branch 1 taken 195 times.
|
108384 | const size_t k_idx_start = KAI_MIN(k_idx + k_block_idx, k - 1); |
106 | |||
107 | 108384 | const float src0_0 = *(src_ptr + k_idx_start); | |
108 | |||
109 | // Scale the values | ||
110 | 108384 | int32_t v0_s32 = (int32_t)(roundf(src0_0 * rep_scale)); | |
111 | |||
112 | 108384 | *dst_ptr = (int8_t)v0_s32; | |
113 | 108384 | dst_ptr += sizeof(int8_t); | |
114 | 108384 | } | |
115 | 18068 | dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); | |
116 | 18068 | } | |
117 | |||
118 | 3387 | src_ptr += bl; | |
119 | 3387 | } | |
120 | // Move to the next row if we have interleaved all Mr rows | ||
121 |
2/2✓ Branch 0 taken 1353 times.
✓ Branch 1 taken 438 times.
|
1791 | if ((((row_idx + 1) + m_idx_start) % mr) == 0) { |
122 | 438 | lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride); | |
123 | 438 | } | |
124 | 1791 | } | |
125 | 75 | } | |
126 |