KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 90.6% 48 16 69
Functions: 60.0% 3 0 5
Branches: 100.0% 10 32 42

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
8 #error This file must be compiled for AArch64, FEAT_BF16.
9 #else // Architectural features check.
10
11 #include "kai_lhs_quant_pack_bf16p1x4_f32_neon.h"
12
13 #include <arm_neon.h>
14 #include <stddef.h>
15 #include <stdint.h>
16
17 #include "kai/kai_common.h"
18
19 static const size_t kai_mr = 1;
20 static const size_t kai_kr = 4;
21 static const size_t kai_sr = 1;
22
23 size_t kai_get_m_step_lhs_quant_pack_bf16p1x4_f32_neon(size_t mr) {
24 KAI_ASSUME(mr == kai_mr);
25 return mr;
26 }
27
28 92 size_t kai_get_lhs_offset_lhs_quant_pack_bf16p1x4_f32_neon(size_t m_idx, size_t lhs_stride) {
29 KAI_ASSUME(m_idx % kai_mr == 0);
30 92 return m_idx * lhs_stride;
31 }
32
33 size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p1x4_f32_neon(
34 size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
35 KAI_UNUSED(sr);
36 KAI_ASSUME(m_idx == 0);
37 KAI_ASSUME(mr == kai_mr);
38 KAI_ASSUME(kr == kai_kr);
39 KAI_ASSUME(sr == kai_sr);
40
41 return m_idx * kai_roundup(k, kr) * sizeof(uint16_t);
42 }
43
44 92 size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
45 92 KAI_UNUSED(sr);
46 KAI_ASSUME(mr == kai_mr);
47 KAI_ASSUME(kr == kai_kr);
48 KAI_ASSUME(sr == kai_sr);
49
50 92 return kai_roundup(m, mr) * kai_roundup(k, kr) * sizeof(uint16_t);
51 }
52
53 92 void kai_run_lhs_quant_pack_bf16p1x4_f32_neon(
54 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
55 void* lhs_packed) {
56 92 KAI_UNUSED(sr);
57 92 KAI_UNUSED(lhs_stride);
58
59 KAI_ASSUME(m == 1);
60 KAI_ASSUME(mr == kai_mr);
61 KAI_ASSUME(kr == kai_kr);
62 KAI_ASSUME(sr == kai_sr);
63
64 KAI_ASSUME(lhs != NULL);
65 KAI_ASSUME(lhs_packed != NULL);
66
67 KAI_ASSUME(m_idx_start == 0);
68
69 92 const float* lhs_ptr = lhs;
70 92 uint16_t* lhs_packed_ptr = lhs_packed;
71
72 // Unroll two 256-bit loops
73 92 size_t i = 0;
74
2/2
✓ Branch 0 taken 1834 times.
✓ Branch 1 taken 92 times.
1926 for (; i + 16 <= k; i += 16) {
75 1834 const float32x4x4_t val = vld1q_f32_x4(lhs_ptr);
76 1834 bfloat16x8x2_t bf_val;
77
78 1834 bf_val.val[0] = vcvtq_low_bf16_f32(val.val[0]);
79 1834 bf_val.val[0] = vcvtq_high_bf16_f32(bf_val.val[0], val.val[1]);
80 1834 bf_val.val[1] = vcvtq_low_bf16_f32(val.val[2]);
81 1834 bf_val.val[1] = vcvtq_high_bf16_f32(bf_val.val[1], val.val[3]);
82 1834 vst1q_bf16_x2((bfloat16_t*)(lhs_packed_ptr), bf_val);
83
84 1834 lhs_ptr += 16;
85 1834 lhs_packed_ptr += 16;
86 1834 }
87
88 // 1 load + 1 convert + 1 store
89
2/2
✓ Branch 0 taken 30 times.
✓ Branch 1 taken 92 times.
122 for (; i + 8 <= k; i += 8) {
90 30 const float32x4x2_t f32_val = vld1q_f32_x2(lhs_ptr);
91 30 bfloat16x8_t bf_val = vcvtq_low_bf16_f32(f32_val.val[0]);
92 30 bf_val = vcvtq_high_bf16_f32(bf_val, f32_val.val[1]);
93 30 vst1q_bf16((bfloat16_t*)(lhs_packed_ptr), bf_val);
94
95 30 lhs_ptr += 8;
96 30 lhs_packed_ptr += 8;
97 30 }
98
99
2/2
✓ Branch 0 taken 38 times.
✓ Branch 1 taken 92 times.
130 for (; i + 4 <= k; i += 4) {
100 38 const float32x4_t f32_val = vld1q_f32(lhs_ptr);
101 38 bfloat16x4_t bf_val = vcvt_bf16_f32(f32_val);
102 38 vst1_bf16((bfloat16_t*)(lhs_packed_ptr), bf_val);
103
104 38 lhs_ptr += 4;
105 38 lhs_packed_ptr += 4;
106 38 }
107
108
2/2
✓ Branch 0 taken 128 times.
✓ Branch 1 taken 92 times.
220 for (; i < k; ++i) {
109 128 *lhs_packed_ptr = kai_cast_bf16_f32(*lhs_ptr);
110
111 128 ++lhs_ptr;
112 128 ++lhs_packed_ptr;
113 128 }
114
115 // Zero pad
116 92 const size_t rounded_up_k = kai_roundup(k, kr);
117
2/2
✓ Branch 0 taken 144 times.
✓ Branch 1 taken 92 times.
236 for (; i < rounded_up_k; ++i) {
118 144 *lhs_packed_ptr = 0;
119 144 ++lhs_packed_ptr;
120 144 }
121 92 }
122
123 #endif // Architectural features check.
124