kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c
| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | // | ||
| 2 | // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> | ||
| 3 | // | ||
| 4 | // SPDX-License-Identifier: Apache-2.0 | ||
| 5 | // | ||
| 6 | |||
| 7 | #if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) | ||
| 8 | #error This file must be compiled for AArch64, FEAT_BF16. | ||
| 9 | #else // Architectural features check. | ||
| 10 | |||
| 11 | #include "kai_lhs_quant_pack_bf16p1x4_f32_neon.h" | ||
| 12 | |||
| 13 | #include <arm_neon.h> | ||
| 14 | #include <stddef.h> | ||
| 15 | #include <stdint.h> | ||
| 16 | |||
| 17 | #include "kai/kai_common.h" | ||
| 18 | |||
| 19 | static const size_t kai_mr = 1; | ||
| 20 | static const size_t kai_kr = 4; | ||
| 21 | static const size_t kai_sr = 1; | ||
| 22 | |||
| 23 | ✗ | size_t kai_get_m_step_lhs_quant_pack_bf16p1x4_f32_neon(size_t mr) { | |
| 24 | − | KAI_ASSUME(mr == kai_mr); | |
| 25 | ✗ | return mr; | |
| 26 | } | ||
| 27 | |||
| 28 | 576 | size_t kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon(size_t m_idx, size_t lhs_stride) { | |
| 29 | − | KAI_ASSUME(m_idx % kai_mr == 0); | |
| 30 | 576 | return m_idx * lhs_stride; | |
| 31 | } | ||
| 32 | |||
| 33 | ✗ | size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p1x4_f32_neon( | |
| 34 | size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { | ||
| 35 | ✗ | KAI_UNUSED(sr); | |
| 36 | − | KAI_ASSUME(m_idx == 0); | |
| 37 | − | KAI_ASSUME(mr == kai_mr); | |
| 38 | − | KAI_ASSUME(kr == kai_kr); | |
| 39 | − | KAI_ASSUME(sr == kai_sr); | |
| 40 | |||
| 41 | ✗ | return m_idx * kai_roundup(k, kr) * sizeof(uint16_t); | |
| 42 | } | ||
| 43 | |||
| 44 | 576 | size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { | |
| 45 | 576 | KAI_UNUSED(sr); | |
| 46 | − | KAI_ASSUME(mr == kai_mr); | |
| 47 | − | KAI_ASSUME(kr == kai_kr); | |
| 48 | − | KAI_ASSUME(sr == kai_sr); | |
| 49 | |||
| 50 | 576 | return kai_roundup(m, mr) * kai_roundup(k, kr) * sizeof(uint16_t); | |
| 51 | } | ||
| 52 | |||
| 53 | 576 | void kai_run_lhs_quant_pack_bf16p1x4_f32_neon( | |
| 54 | size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||
| 55 | void* lhs_packed) { | ||
| 56 | 576 | KAI_UNUSED(sr); | |
| 57 | 576 | KAI_UNUSED(lhs_stride); | |
| 58 | |||
| 59 | − | KAI_ASSUME(m == 1); | |
| 60 | − | KAI_ASSUME(mr == kai_mr); | |
| 61 | − | KAI_ASSUME(kr == kai_kr); | |
| 62 | − | KAI_ASSUME(sr == kai_sr); | |
| 63 | |||
| 64 | − | KAI_ASSUME(lhs != NULL); | |
| 65 | − | KAI_ASSUME(lhs_packed != NULL); | |
| 66 | |||
| 67 | − | KAI_ASSUME(m_idx_start == 0); | |
| 68 | |||
| 69 | 576 | const float* lhs_ptr = lhs; | |
| 70 | 576 | uint16_t* lhs_packed_ptr = lhs_packed; | |
| 71 | |||
| 72 | // Unroll two 256-bit loops | ||
| 73 | 576 | size_t i = 0; | |
| 74 |
2/2✓ Branch 0 taken 11760 times.
✓ Branch 1 taken 576 times.
|
12336 | for (; i + 16 <= k; i += 16) { |
| 75 | 11760 | const float32x4x4_t val = vld1q_f32_x4(lhs_ptr); | |
| 76 | 11760 | bfloat16x8x2_t bf_val; | |
| 77 | |||
| 78 | 11760 | bf_val.val[0] = vcvtq_low_bf16_f32(val.val[0]); | |
| 79 | 11760 | bf_val.val[0] = vcvtq_high_bf16_f32(bf_val.val[0], val.val[1]); | |
| 80 | 11760 | bf_val.val[1] = vcvtq_low_bf16_f32(val.val[2]); | |
| 81 | 11760 | bf_val.val[1] = vcvtq_high_bf16_f32(bf_val.val[1], val.val[3]); | |
| 82 | 11760 | vst1q_bf16_x2((bfloat16_t*)(lhs_packed_ptr), bf_val); | |
| 83 | |||
| 84 | 11760 | lhs_ptr += 16; | |
| 85 | 11760 | lhs_packed_ptr += 16; | |
| 86 | 11760 | } | |
| 87 | |||
| 88 | // 1 load + 1 convert + 1 store | ||
| 89 |
2/2✓ Branch 0 taken 192 times.
✓ Branch 1 taken 576 times.
|
768 | for (; i + 8 <= k; i += 8) { |
| 90 | 192 | const float32x4x2_t f32_val = vld1q_f32_x2(lhs_ptr); | |
| 91 | 192 | bfloat16x8_t bf_val = vcvtq_low_bf16_f32(f32_val.val[0]); | |
| 92 | 192 | bf_val = vcvtq_high_bf16_f32(bf_val, f32_val.val[1]); | |
| 93 | 192 | vst1q_bf16((bfloat16_t*)(lhs_packed_ptr), bf_val); | |
| 94 | |||
| 95 | 192 | lhs_ptr += 8; | |
| 96 | 192 | lhs_packed_ptr += 8; | |
| 97 | 192 | } | |
| 98 | |||
| 99 |
2/2✓ Branch 0 taken 240 times.
✓ Branch 1 taken 576 times.
|
816 | for (; i + 4 <= k; i += 4) { |
| 100 | 240 | const float32x4_t f32_val = vld1q_f32(lhs_ptr); | |
| 101 | 240 | bfloat16x4_t bf_val = vcvt_bf16_f32(f32_val); | |
| 102 | 240 | vst1_bf16((bfloat16_t*)(lhs_packed_ptr), bf_val); | |
| 103 | |||
| 104 | 240 | lhs_ptr += 4; | |
| 105 | 240 | lhs_packed_ptr += 4; | |
| 106 | 240 | } | |
| 107 | |||
| 108 |
2/2✓ Branch 0 taken 816 times.
✓ Branch 1 taken 576 times.
|
1392 | for (; i < k; ++i) { |
| 109 | 816 | *lhs_packed_ptr = kai_cast_bf16_f32(*lhs_ptr); | |
| 110 | |||
| 111 | 816 | ++lhs_ptr; | |
| 112 | 816 | ++lhs_packed_ptr; | |
| 113 | 816 | } | |
| 114 | |||
| 115 | // Zero pad | ||
| 116 | 576 | const size_t rounded_up_k = kai_roundup(k, kr); | |
| 117 |
2/2✓ Branch 0 taken 912 times.
✓ Branch 1 taken 576 times.
|
1488 | for (; i < rounded_up_k; ++i) { |
| 118 | 912 | *lhs_packed_ptr = 0; | |
| 119 | 912 | ++lhs_packed_ptr; | |
| 120 | 912 | } | |
| 121 | 576 | } | |
| 122 | |||
| 123 | #endif // Architectural features check. | ||
| 124 |