KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 95.7% 66 / 11 / 80
Functions: 87.5% 7 / 0 / 8
Branches: 83.3% 20 / 22 / 46

kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f16_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) || !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
8 #error This file must be compiled for AArch64 and FEAT_FP16.
9 #else // Architectural features check.
10
11 #include "kai_lhs_quant_pack_qsi8d32pscalef32_f16_neon.h"
12
13 #include <arm_neon.h>
14 #include <float.h>
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include "kai/kai_common.h"
20 #define FLT16_MAX 65504.0F
21 #define FLT16_MIN (-65504.0F)
22 static const size_t kai_num_bytes_sum = sizeof(float);
23 static const size_t kai_num_bytes_multiplier = sizeof(float);
24 static const size_t kai_bl_multiple_of = 32;
25
26 20907180 inline static size_t kai_get_num_bytes_per_block(size_t bl) {
27 20907180 return bl * sizeof(int8_t) + kai_num_bytes_multiplier + kai_num_bytes_sum;
28 }
29
30 20907180 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
31 KAI_ASSERT((k % bl) == 0);
32 20907180 return k / bl;
33 }
34
35 20857032 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) {
36 20857032 KAI_UNUSED(kr);
37
38 20857032 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block(bl);
39 }
40 size_t kai_get_m_step_lhs_quant_pack_qsi8d32pscalef32_f16_neon(size_t mr) {
41 return mr;
42 }
43 50148 size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f16_neon(size_t m_idx, size_t lhs_stride) {
44 50148 return m_idx * lhs_stride;
45 }
46 100296 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f16_neon(
47 size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
48 KAI_ASSUME((k % 2) == 0);
49 KAI_ASSUME((k % kr) == 0);
50 KAI_ASSUME((k % bl) == 0);
51 100296 KAI_UNUSED(sr);
52 100296 return (m_idx / mr) * kai_get_lhs_packed_stride(k, mr, kr, bl);
53 }
54 20706588 size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f16_neon(
55 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
56 KAI_ASSUME((k % 2) == 0);
57 KAI_ASSUME((k % kr) == 0);
58 KAI_ASSUME((k % bl) == 0);
59
60 20706588 KAI_UNUSED(sr);
61
62 20706588 const size_t num_rows = kai_roundup(m, mr) / mr;
63
64 41413176 return (num_rows * kai_get_lhs_packed_stride(k, mr, kr, bl));
65 20706588 }
66 50148 void kai_run_lhs_quant_pack_qsi8d32pscalef32_f16_neon(
67 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs,
68 size_t lhs_stride, void* lhs_packed) {
69 KAI_ASSERT((kr % sr) == 0);
70 KAI_ASSUME((bl % kr) == 0);
71 KAI_ASSUME((k % bl) == 0);
72 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
73
74
1/2
✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
50148 if (m == 0) {
75 return;
76 }
77 50148 const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, mr, kr, bl);
78 50148 const size_t num_rows = m;
79 50148 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
80 50148 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl);
81 50148 const size_t mr_block_size = mr * num_bytes_per_block;
82
83 50148 const int32_t k_block_len = (int32_t)(kr / sr);
84
85
2/2
✓ Branch 0 taken 507276 times.
✓ Branch 1 taken 50148 times.
557424 for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) {
86 507276 const float16_t* row_src_ptr = (const float16_t*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride);
87 507276 const size_t dst_x = ((row_idx + m_idx_start) % mr);
88
2/2
✓ Branch 0 taken 733824 times.
✓ Branch 1 taken 507276 times.
1241100 for (size_t b = 0; b < num_blocks_per_row; ++b) {
89 733824 const float16_t* src_ptr = row_src_ptr + b * bl;
90 733824 int8_t* dst_ptr = (int8_t*)lhs_packed + dst_x * k_block_len * sizeof(int8_t) + b * mr_block_size;
91 733824 int8_t* param_ptr = (int8_t*)lhs_packed + b * mr_block_size + bl * mr + dst_x * kai_num_bytes_sum;
92 // Find absmax for each block
93 733824 float16_t absmax = (float16_t)(-FLT16_MAX);
94 733824 int32_t k_idx = 0;
95 733824 float16x8_t vabsmax = vdupq_n_f16(-FLT16_MAX);
96
2/2
✓ Branch 0 taken 3817296 times.
✓ Branch 1 taken 733824 times.
4551120 for (; k_idx < ((int32_t)bl); k_idx += 8) {
97 3817296 const float16x8_t src = vabsq_f16(vld1q_f16(src_ptr + (size_t)k_idx));
98 3817296 vabsmax = vmaxq_f16(vabsmax, src);
99 3817296 }
100 // Get the absmax
101 733824 absmax = vmaxvq_f16(vabsmax);
102 // Maximum/minimum int8 values
103 733824 const float qmax = (float)INT8_MAX;
104
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 733824 times.
733824 const float scale0 = absmax == 0.0F ? 0.0F : qmax / absmax;
105 // Reciprocal to quantize
106
1/2
✓ Branch 0 taken 733824 times.
✗ Branch 1 not taken.
733824 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
107 733824 int32_t qsum = 0;
108 // Quantize the blocks
109
2/2
✓ Branch 0 taken 6390192 times.
✓ Branch 1 taken 733824 times.
7124016 for (k_idx = 0; k_idx < (int32_t)bl; k_idx += k_block_len) {
110
2/2
✓ Branch 0 taken 30538368 times.
✓ Branch 1 taken 6390192 times.
36928560 for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) {
111 // Clamp at the last valid k-index
112
2/2
✓ Branch 0 taken 30251592 times.
✓ Branch 1 taken 286776 times.
30538368 const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k - 1);
113
114 30538368 const float16_t src0_0 = *(src_ptr + k_idx_start);
115
116 // Scale the values
117 30538368 int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0));
118
119
1/2
✓ Branch 0 taken 30538368 times.
✗ Branch 1 not taken.
30538368 v0_s32 = KAI_MAX(v0_s32, INT8_MIN);
120
2/2
✓ Branch 0 taken 29686833 times.
✓ Branch 1 taken 851535 times.
30538368 v0_s32 = KAI_MIN(v0_s32, INT8_MAX);
121 30538368 qsum += v0_s32;
122
123 30538368 *(dst_ptr) = (int8_t)v0_s32;
124 30538368 dst_ptr += sizeof(int8_t);
125 30538368 }
126 6390192 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
127 6390192 }
128 733824 *((float*)(param_ptr)) = ((float)qsum) * recip_scale0;
129 733824 param_ptr += mr * kai_num_bytes_sum;
130 733824 *((float*)(param_ptr)) = recip_scale0;
131 733824 }
132 // Move to the next row if we have interleaved all Mr rows
133
2/2
✓ Branch 0 taken 264708 times.
✓ Branch 1 taken 242568 times.
507276 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
134 242568 lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride);
135 242568 }
136 507276 }
137 50148 }
138 #endif // Architectural features check.
139