KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 96.7% 89 / 13 / 105
Functions: 87.5% 7 / 0 / 8
Branches: 77.3% 17 / 28 / 50

kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10
11 #include "kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h"
12
13 #include <arm_neon.h>
14 #include <float.h>
15 #include <stddef.h>
16 #include <stdint.h>
17
18 #include "kai/kai_common.h"
19
20 static const size_t kai_num_bytes_sum = sizeof(float);
21 static const size_t kai_num_bytes_multiplier = sizeof(float);
22 static const size_t kai_bl_multiple_of = 32;
23
24 20857032 inline static size_t kai_get_num_bytes_per_block(size_t bl) {
25 20857032 return bl * sizeof(int8_t) + kai_num_bytes_multiplier + kai_num_bytes_sum;
26 }
27
28 20857032 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
29 KAI_ASSERT((k % bl) == 0);
30 20857032 return k / bl;
31 }
32
33 20806884 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) {
34 20806884 KAI_UNUSED(kr);
35
36 20806884 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block(bl);
37 }
38
39 size_t kai_get_m_step_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t mr) {
40 return mr;
41 }
42
43 50148 size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t m_idx, size_t lhs_stride) {
44 50148 return m_idx * lhs_stride;
45 }
46
47 50148 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(
48 size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
49 KAI_ASSUME((k % 2) == 0);
50 KAI_ASSUME((k % kr) == 0);
51 KAI_ASSUME((k % bl) == 0);
52 KAI_ASSUME((m_idx % mr) == 0);
53
54 50148 KAI_UNUSED(sr);
55 50148 return (m_idx / mr) * kai_get_lhs_packed_stride(k, mr, kr, bl);
56 }
57
58 20706588 size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(
59 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
60 KAI_ASSUME((k % 2) == 0);
61 KAI_ASSUME((k % kr) == 0);
62 KAI_ASSUME((k % bl) == 0);
63
64 20706588 KAI_UNUSED(sr);
65
66 20706588 const size_t num_rows = kai_roundup(m, mr) / mr;
67
68 41413176 return (num_rows * kai_get_lhs_packed_stride(k, mr, kr, bl));
69 20706588 }
70 50148 void kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon(
71 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
72 size_t lhs_stride, void* lhs_packed) {
73 KAI_ASSERT((kr % sr) == 0);
74 KAI_ASSUME((bl % kr) == 0);
75 KAI_ASSUME((k % bl) == 0);
76 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
77 KAI_ASSERT(((kr / sr) == 8) || ((kr / sr) == 4));
78
79
1/2
✓ Branch 0 taken 50148 times.
✗ Branch 1 not taken.
50148 if (m == 0) {
80 return;
81 }
82 50148 const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, mr, kr, bl);
83 50148 const size_t num_rows = m;
84 50148 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
85 50148 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl);
86 50148 const size_t mr_block_size = mr * num_bytes_per_block;
87
88 50148 const int32_t k_block_len = (int32_t)(kr / sr);
89
90
2/2
✓ Branch 0 taken 507276 times.
✓ Branch 1 taken 50148 times.
557424 for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) {
91 507276 const float* row_src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride);
92 507276 const size_t dst_idx = ((row_idx + m_idx_start) % mr);
93
94
2/2
✓ Branch 0 taken 733824 times.
✓ Branch 1 taken 507276 times.
1241100 for (size_t blk_idx = 0; blk_idx < num_blocks_per_row; ++blk_idx) {
95 733824 const float* src_ptr = row_src_ptr + blk_idx * bl;
96 733824 int8_t* dst_ptr = (int8_t*)lhs_packed + dst_idx * k_block_len * sizeof(int8_t) + blk_idx * mr_block_size;
97 733824 int8_t* param_ptr = (int8_t*)lhs_packed + blk_idx * mr_block_size + bl * mr + dst_idx * kai_num_bytes_sum;
98
99 // Find absmax for each block
100 733824 float absmax = -FLT_MAX;
101 733824 int32_t k_idx = 0;
102 733824 float32x4_t vabsmax = vdupq_n_f32(-FLT_MAX);
103
2/2
✓ Branch 0 taken 3817296 times.
✓ Branch 1 taken 733824 times.
4551120 for (; k_idx < ((int32_t)bl); k_idx += 8) {
104 3817296 const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx);
105 3817296 const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx);
106 // Calculate the max
107 3817296 vabsmax = vmaxq_f32(vabsq_f32(src0_0), vmaxq_f32(vabsmax, vabsq_f32(src0_1)));
108 3817296 }
109 // Get the absmax
110 733824 absmax = vmaxvq_f32(vabsmax);
111
112 // Maximum/minimum int8 values
113 733824 const float qmax = (float)INT8_MAX;
114
115 // Get the scale and reciprocal to quantize
116
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 733824 times.
733824 const float scale0 = absmax == 0.0F ? 0.0F : qmax / absmax;
117
1/2
✓ Branch 0 taken 733824 times.
✗ Branch 1 not taken.
733824 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
118
119 733824 int32_t qsum = 0;
120 // Quantize the blocks
121
2/2
✓ Branch 0 taken 6390192 times.
✓ Branch 1 taken 733824 times.
7124016 for (k_idx = 0; k_idx <= (int32_t)bl - k_block_len; k_idx += k_block_len) {
122 // Clamp at the last valid k-index
123
1/2
✓ Branch 0 taken 6390192 times.
✗ Branch 1 not taken.
6390192 const size_t k_idx_start = KAI_MIN((size_t)k_idx, k - 1);
124
2/2
✓ Branch 0 taken 1244400 times.
✓ Branch 1 taken 5145792 times.
6390192 if (k_block_len == 8) {
125 1244400 const float32x4_t vsrc_0 = vld1q_f32(src_ptr + k_idx_start);
126 1244400 const float32x4_t vsrc_1 = vld1q_f32(src_ptr + k_idx_start + 4);
127
128 // Scale the values
129 1244400 float32x4_t v0_f32 = vmulq_n_f32(vsrc_0, scale0);
130 1244400 float32x4_t v1_f32 = vmulq_n_f32(vsrc_1, scale0);
131
132 1244400 int16x4_t v0_s16 = vqmovn_s32(vcvtnq_s32_f32(v0_f32));
133 1244400 int16x4_t v1_s16 = vqmovn_s32(vcvtnq_s32_f32(v1_f32));
134 1244400 int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16);
135
136 1244400 v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN));
137 1244400 v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX));
138
139 // Update the sum
140 1244400 qsum += vaddvq_s16(v_s16);
141
142 1244400 int8x8_t v0_s8 = vqmovn_s16(v_s16);
143 1244400 vst1_s8(dst_ptr, v0_s8);
144 1244400 dst_ptr += 8 * sizeof(int8_t);
145
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 5145792 times.
6390192 } else if (k_block_len == 4) {
146 5145792 const float32x2_t vsrc_0 = vld1_f32(src_ptr + k_idx_start);
147 5145792 const float32x2_t vsrc_1 = vld1_f32(src_ptr + k_idx_start + 2);
148
149 // Scale the values
150 5145792 float32x2_t v0_f32 = vmul_n_f32(vsrc_0, scale0);
151 5145792 float32x2_t v1_f32 = vmul_n_f32(vsrc_1, scale0);
152
153 5145792 int32x2_t v0_s32 = vcvtn_s32_f32(v0_f32);
154 5145792 int32x2_t v1_s32 = vcvtn_s32_f32(v1_f32);
155 5145792 int16x4_t v_s16 = vqmovn_s32(vcombine_s32(v0_s32, v1_s32));
156
157 5145792 v_s16 = vmax_s16(v_s16, vdup_n_s16(INT8_MIN));
158 5145792 v_s16 = vmin_s16(v_s16, vdup_n_s16(INT8_MAX));
159
160 // Update the sum
161 5145792 qsum += vaddv_s16(v_s16);
162
163 5145792 dst_ptr[0] = vqmovnh_s16(vget_lane_s16(v_s16, 0));
164 5145792 dst_ptr[1] = vqmovnh_s16(vget_lane_s16(v_s16, 1));
165 5145792 dst_ptr[2] = vqmovnh_s16(vget_lane_s16(v_s16, 2));
166 5145792 dst_ptr[3] = vqmovnh_s16(vget_lane_s16(v_s16, 3));
167 5145792 dst_ptr += 4 * sizeof(int8_t);
168 5145792 }
169 6390192 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
170 6390192 }
171 733824 *((float*)(param_ptr)) = ((float)qsum) * recip_scale0;
172 733824 param_ptr += mr * kai_num_bytes_sum;
173 733824 *((float*)(param_ptr)) = recip_scale0;
174 733824 }
175 // Move to the next row if we have interleaved all Mr rows
176
2/2
✓ Branch 0 taken 264708 times.
✓ Branch 1 taken 242568 times.
507276 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
177 242568 lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride);
178 242568 }
179 507276 }
180 50148 }
181 #endif // Architectural features check.
182