KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 90.6% 48 / 16 / 69
Functions: 60.0% 3 / 0 / 5
Branches: 100.0% 10 / 32 / 42

kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
8 #error This file must be compiled for AArch64, FEAT_BF16.
9 #else // Architectural features check.
10
11 #include "kai_lhs_quant_pack_bf16p1x4_f32_neon.h"
12
13 #include <arm_neon.h>
14 #include <stddef.h>
15 #include <stdint.h>
16
17 #include "kai/kai_common.h"
18
19 static const size_t kai_mr = 1;
20 static const size_t kai_kr = 4;
21 static const size_t kai_sr = 1;
22
23 size_t kai_get_m_step_lhs_quant_pack_bf16p1x4_f32_neon(size_t mr) {
24 KAI_ASSUME(mr == kai_mr);
25 return mr;
26 }
27
28 576 size_t kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon(size_t m_idx, size_t lhs_stride) {
29 KAI_ASSUME(m_idx % kai_mr == 0);
30 576 return m_idx * lhs_stride;
31 }
32
33 size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p1x4_f32_neon(
34 size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
35 KAI_UNUSED(sr);
36 KAI_ASSUME(m_idx == 0);
37 KAI_ASSUME(mr == kai_mr);
38 KAI_ASSUME(kr == kai_kr);
39 KAI_ASSUME(sr == kai_sr);
40
41 return m_idx * kai_roundup(k, kr) * sizeof(uint16_t);
42 }
43
44 576 size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
45 576 KAI_UNUSED(sr);
46 KAI_ASSUME(mr == kai_mr);
47 KAI_ASSUME(kr == kai_kr);
48 KAI_ASSUME(sr == kai_sr);
49
50 576 return kai_roundup(m, mr) * kai_roundup(k, kr) * sizeof(uint16_t);
51 }
52
53 576 void kai_run_lhs_quant_pack_bf16p1x4_f32_neon(
54 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
55 void* lhs_packed) {
56 576 KAI_UNUSED(sr);
57 576 KAI_UNUSED(lhs_stride);
58
59 KAI_ASSUME(m == 1);
60 KAI_ASSUME(mr == kai_mr);
61 KAI_ASSUME(kr == kai_kr);
62 KAI_ASSUME(sr == kai_sr);
63
64 KAI_ASSUME(lhs != NULL);
65 KAI_ASSUME(lhs_packed != NULL);
66
67 KAI_ASSUME(m_idx_start == 0);
68
69 576 const float* lhs_ptr = lhs;
70 576 uint16_t* lhs_packed_ptr = lhs_packed;
71
72 // Unroll two 256-bit loops
73 576 size_t i = 0;
74
2/2
✓ Branch 0 taken 11760 times.
✓ Branch 1 taken 576 times.
12336 for (; i + 16 <= k; i += 16) {
75 11760 const float32x4x4_t val = vld1q_f32_x4(lhs_ptr);
76 11760 bfloat16x8x2_t bf_val;
77
78 11760 bf_val.val[0] = vcvtq_low_bf16_f32(val.val[0]);
79 11760 bf_val.val[0] = vcvtq_high_bf16_f32(bf_val.val[0], val.val[1]);
80 11760 bf_val.val[1] = vcvtq_low_bf16_f32(val.val[2]);
81 11760 bf_val.val[1] = vcvtq_high_bf16_f32(bf_val.val[1], val.val[3]);
82 11760 vst1q_bf16_x2((bfloat16_t*)(lhs_packed_ptr), bf_val);
83
84 11760 lhs_ptr += 16;
85 11760 lhs_packed_ptr += 16;
86 11760 }
87
88 // 1 load + 1 convert + 1 store
89
2/2
✓ Branch 0 taken 192 times.
✓ Branch 1 taken 576 times.
768 for (; i + 8 <= k; i += 8) {
90 192 const float32x4x2_t f32_val = vld1q_f32_x2(lhs_ptr);
91 192 bfloat16x8_t bf_val = vcvtq_low_bf16_f32(f32_val.val[0]);
92 192 bf_val = vcvtq_high_bf16_f32(bf_val, f32_val.val[1]);
93 192 vst1q_bf16((bfloat16_t*)(lhs_packed_ptr), bf_val);
94
95 192 lhs_ptr += 8;
96 192 lhs_packed_ptr += 8;
97 192 }
98
99
2/2
✓ Branch 0 taken 240 times.
✓ Branch 1 taken 576 times.
816 for (; i + 4 <= k; i += 4) {
100 240 const float32x4_t f32_val = vld1q_f32(lhs_ptr);
101 240 bfloat16x4_t bf_val = vcvt_bf16_f32(f32_val);
102 240 vst1_bf16((bfloat16_t*)(lhs_packed_ptr), bf_val);
103
104 240 lhs_ptr += 4;
105 240 lhs_packed_ptr += 4;
106 240 }
107
108
2/2
✓ Branch 0 taken 816 times.
✓ Branch 1 taken 576 times.
1392 for (; i < k; ++i) {
109 816 *lhs_packed_ptr = kai_cast_bf16_f32(*lhs_ptr);
110
111 816 ++lhs_ptr;
112 816 ++lhs_packed_ptr;
113 816 }
114
115 // Zero pad
116 576 const size_t rounded_up_k = kai_roundup(k, kr);
117
2/2
✓ Branch 0 taken 912 times.
✓ Branch 1 taken 576 times.
1488 for (; i < rounded_up_k; ++i) {
118 912 *lhs_packed_ptr = 0;
119 912 ++lhs_packed_ptr;
120 912 }
121 576 }
122
123 #endif // Architectural features check.
124