Line | Branch | Exec | Source |
---|---|---|---|
1 | // | ||
2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
3 | // | ||
4 | // SPDX-License-Identifier: Apache-2.0 | ||
5 | // | ||
6 | |||
7 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) | ||
8 | #error This file must be compiled for AArch64, FEAT_BF16. | ||
9 | #else // Architectural features check. | ||
10 | |||
11 | #include "kai_lhs_quant_pack_bf16p1x4_f32_neon.h" | ||
12 | |||
13 | #include <arm_neon.h> | ||
14 | #include <stddef.h> | ||
15 | #include <stdint.h> | ||
16 | |||
17 | #include "kai/kai_common.h" | ||
18 | |||
19 | static const size_t kai_mr = 1; | ||
20 | static const size_t kai_kr = 4; | ||
21 | static const size_t kai_sr = 1; | ||
22 | |||
23 | ✗ | size_t kai_get_m_step_lhs_quant_pack_bf16p1x4_f32_neon(size_t mr) { | |
24 | − | KAI_ASSUME(mr == kai_mr); | |
25 | ✗ | return mr; | |
26 | } | ||
27 | |||
28 | 92 | size_t kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon(size_t m_idx, size_t lhs_stride) { | |
29 | − | KAI_ASSUME(m_idx % kai_mr == 0); | |
30 | 92 | return m_idx * lhs_stride; | |
31 | } | ||
32 | |||
33 | ✗ | size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p1x4_f32_neon( | |
34 | size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | ||
35 | ✗ | KAI_UNUSED(sr); | |
36 | − | KAI_ASSUME(m_idx == 0); | |
37 | − | KAI_ASSUME(mr == kai_mr); | |
38 | − | KAI_ASSUME(kr == kai_kr); | |
39 | − | KAI_ASSUME(sr == kai_sr); | |
40 | |||
41 | ✗ | return m_idx * kai_roundup(k, kr) * sizeof(uint16_t); | |
42 | } | ||
43 | |||
44 | 92 | size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
45 | 92 | KAI_UNUSED(sr); | |
46 | − | KAI_ASSUME(mr == kai_mr); | |
47 | − | KAI_ASSUME(kr == kai_kr); | |
48 | − | KAI_ASSUME(sr == kai_sr); | |
49 | |||
50 | 92 | return kai_roundup(m, mr) * kai_roundup(k, kr) * sizeof(uint16_t); | |
51 | } | ||
52 | |||
53 | 92 | void kai_run_lhs_quant_pack_bf16p1x4_f32_neon( | |
54 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
55 | void* lhs_packed) { | ||
56 | 92 | KAI_UNUSED(sr); | |
57 | 92 | KAI_UNUSED(lhs_stride); | |
58 | |||
59 | − | KAI_ASSUME(m == 1); | |
60 | − | KAI_ASSUME(mr == kai_mr); | |
61 | − | KAI_ASSUME(kr == kai_kr); | |
62 | − | KAI_ASSUME(sr == kai_sr); | |
63 | |||
64 | − | KAI_ASSUME(lhs != NULL); | |
65 | − | KAI_ASSUME(lhs_packed != NULL); | |
66 | |||
67 | − | KAI_ASSUME(m_idx_start == 0); | |
68 | |||
69 | 92 | const float* lhs_ptr = lhs; | |
70 | 92 | uint16_t* lhs_packed_ptr = lhs_packed; | |
71 | |||
72 | // Unroll two 256-bit loops | ||
73 | 92 | size_t i = 0; | |
74 |
2/2✓ Branch 0 taken 1834 times.
✓ Branch 1 taken 92 times.
|
1926 | for (; i + 16 <= k; i += 16) { |
75 | 1834 | const float32x4x4_t val = vld1q_f32_x4(lhs_ptr); | |
76 | 1834 | bfloat16x8x2_t bf_val; | |
77 | |||
78 | 1834 | bf_val.val[0] = vcvtq_low_bf16_f32(val.val[0]); | |
79 | 1834 | bf_val.val[0] = vcvtq_high_bf16_f32(bf_val.val[0], val.val[1]); | |
80 | 1834 | bf_val.val[1] = vcvtq_low_bf16_f32(val.val[2]); | |
81 | 1834 | bf_val.val[1] = vcvtq_high_bf16_f32(bf_val.val[1], val.val[3]); | |
82 | 1834 | vst1q_bf16_x2((bfloat16_t*)(lhs_packed_ptr), bf_val); | |
83 | |||
84 | 1834 | lhs_ptr += 16; | |
85 | 1834 | lhs_packed_ptr += 16; | |
86 | 1834 | } | |
87 | |||
88 | // 1 load + 1 convert + 1 store | ||
89 |
2/2✓ Branch 0 taken 30 times.
✓ Branch 1 taken 92 times.
|
122 | for (; i + 8 <= k; i += 8) { |
90 | 30 | const float32x4x2_t f32_val = vld1q_f32_x2(lhs_ptr); | |
91 | 30 | bfloat16x8_t bf_val = vcvtq_low_bf16_f32(f32_val.val[0]); | |
92 | 30 | bf_val = vcvtq_high_bf16_f32(bf_val, f32_val.val[1]); | |
93 | 30 | vst1q_bf16((bfloat16_t*)(lhs_packed_ptr), bf_val); | |
94 | |||
95 | 30 | lhs_ptr += 8; | |
96 | 30 | lhs_packed_ptr += 8; | |
97 | 30 | } | |
98 | |||
99 |
2/2✓ Branch 0 taken 38 times.
✓ Branch 1 taken 92 times.
|
130 | for (; i + 4 <= k; i += 4) { |
100 | 38 | const float32x4_t f32_val = vld1q_f32(lhs_ptr); | |
101 | 38 | bfloat16x4_t bf_val = vcvt_bf16_f32(f32_val); | |
102 | 38 | vst1_bf16((bfloat16_t*)(lhs_packed_ptr), bf_val); | |
103 | |||
104 | 38 | lhs_ptr += 4; | |
105 | 38 | lhs_packed_ptr += 4; | |
106 | 38 | } | |
107 | |||
108 |
2/2✓ Branch 0 taken 128 times.
✓ Branch 1 taken 92 times.
|
220 | for (; i < k; ++i) { |
109 | 128 | *lhs_packed_ptr = kai_cast_bf16_f32(*lhs_ptr); | |
110 | |||
111 | 128 | ++lhs_ptr; | |
112 | 128 | ++lhs_packed_ptr; | |
113 | 128 | } | |
114 | |||
115 | // Zero pad | ||
116 | 92 | const size_t rounded_up_k = kai_roundup(k, kr); | |
117 |
2/2✓ Branch 0 taken 144 times.
✓ Branch 1 taken 92 times.
|
236 | for (; i < rounded_up_k; ++i) { |
118 | 144 | *lhs_packed_ptr = 0; | |
119 | 144 | ++lhs_packed_ptr; | |
120 | 144 | } | |
121 | 92 | } | |
122 | |||
123 | #endif // Architectural features check. | ||
124 |