KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 99.2% 128 / 3 / 132
Functions: 100.0% 7 / 0 / 7
Branches: 89.5% 34 / 4 / 42

kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_lhs_quant_pack_qai8dxp_f32.h"
7
8 #if defined(__aarch64__)
9 #include <arm_neon.h>
10 #endif
11 #include <float.h>
12 #include <math.h>
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_num_bytes_per_multiplier = sizeof(float);
19 static const size_t kai_num_bytes_per_offset = sizeof(int32_t);
20
21 84244 inline static size_t kai_k_roundedup(size_t k) {
22 // Round up k to be a multiple of 32.
23 84244 size_t kai_k_multiple_of = 32;
24 168488 return kai_roundup(k, kai_k_multiple_of);
25 84244 }
26
27 65550 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) {
28 65550 KAI_UNUSED(kr);
29 65550 KAI_UNUSED(sr);
30
31 65550 const size_t k_internal = kai_k_roundedup(k);
32
33 KAI_ASSERT((k_internal % 2) == 0);
34
35 131100 return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset);
36 65550 }
37
38 600 size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr) {
39 600 return mr;
40 }
41
42 18694 size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride) {
43 18694 return m_idx * lhs_stride;
44 }
45
46 28162 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
47 // It always points to the beginning of the row
48 28162 return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr);
49 }
50
51 18694 size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
52 18694 const size_t num_rows = kai_roundup(m, mr) / mr;
53
54 37388 return num_rows * kai_lhs_packed_stride(k, mr, kr, sr);
55 18694 }
56
57 18694 void kai_run_lhs_quant_pack_qai8dxp_f32(
58 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs,
59 size_t lhs_stride, void* restrict lhs_packed) {
60 KAI_ASSERT((kr % sr) == 0);
61
62
1/2
✓ Branch 0 taken 18694 times.
✗ Branch 1 not taken.
18694 if (m == 0) {
63 return;
64 }
65
66 18694 const size_t num_rows = m;
67
68 18694 const float* src_ptr = lhs;
69
70 18694 const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr);
71 18694 const size_t k_internal = kai_k_roundedup(k);
72 18694 const int32_t k_block_len = (int32_t)(kr / sr);
73
74 18694 const int32_t num_blocks_k = (int32_t)(k / k_block_len);
75 18694 const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len);
76
77
2/2
✓ Branch 0 taken 404124 times.
✓ Branch 1 taken 18694 times.
422818 for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) {
78 404124 float max0 = -FLT_MAX;
79 404124 float min0 = FLT_MAX;
80
81 // Find min/max for each channel
82 404124 int32_t k_idx = 0;
83
84 #if defined(__aarch64__)
85 404124 float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX);
86 404124 float32x4_t vmin0 = vdupq_n_f32(FLT_MAX);
87
88
2/2
✓ Branch 0 taken 5236988 times.
✓ Branch 1 taken 404124 times.
5641112 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
89 5236988 const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx);
90 5236988 const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx);
91
92 // Calculate the max
93 5236988 vmax0 = vmaxq_f32(src0_0, vmax0);
94 5236988 vmax0 = vmaxq_f32(vmax0, src0_1);
95
96 // Calculate the min
97 5236988 vmin0 = vminq_f32(src0_0, vmin0);
98 5236988 vmin0 = vminq_f32(vmin0, src0_1);
99 5236988 }
100 // Get the max/min
101 404124 max0 = vmaxvq_f32(vmax0);
102 404124 min0 = vminvq_f32(vmin0);
103 #endif
104
2/2
✓ Branch 0 taken 472972 times.
✓ Branch 1 taken 404124 times.
877096 for (; k_idx < (int32_t)k; ++k_idx) {
105 472972 const float src0_0 = *(src_ptr + (size_t)k_idx);
106 472972 max0 = fmaxf(src0_0, max0);
107 472972 min0 = fminf(src0_0, min0);
108 472972 }
109
110 // Maximum/minimum int8 values
111 404124 const float qmin = (float)INT8_MIN;
112 404124 const float qmax = (float)INT8_MAX;
113
114 404124 const float rmin0 = fminf(0.0F, min0);
115 404124 const float rmax0 = fmaxf(0.0F, max0);
116
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 404124 times.
404124 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
117
118 // Reciprocal to quantize
119
1/2
✓ Branch 0 taken 404124 times.
✗ Branch 1 not taken.
404124 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
120
121 404124 const float descaled_min0 = rmin0 * scale0;
122 404124 const float descaled_max0 = rmax0 * scale0;
123
124 404124 const float zero_point_from_min_error0 = qmin + descaled_min0;
125 404124 const float zero_point_from_max_error0 = qmax + descaled_max0;
126
127 808248 float zero_point0 =
128
1/2
✓ Branch 0 taken 404124 times.
✗ Branch 1 not taken.
404124 zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0;
129
130 404124 zero_point0 = fmaxf(zero_point0, qmin);
131 404124 zero_point0 = fminf(zero_point0, qmax);
132
133 // Round to nearest integer
134 404124 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
135
136 404124 const size_t dst_x = ((row_idx + m_idx_start) % mr);
137
138 404124 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t));
139
140 // Quantize the channels
141 404124 int32_t block_idx = 0;
142
143 #if defined(__aarch64__)
144
2/2
✓ Branch 0 taken 212602 times.
✓ Branch 1 taken 191522 times.
404124 if (k_block_len == 8) {
145
2/2
✓ Branch 0 taken 2622004 times.
✓ Branch 1 taken 212602 times.
2834606 for (; block_idx < num_blocks_k; ++block_idx) {
146 // Clamp at the last valid k-index
147 2622004 const int32_t k_idx_start = block_idx * k_block_len;
148
149 2622004 const float32x4_t src_0 = vld1q_f32(src_ptr + k_idx_start);
150 2622004 const float32x4_t src_1 = vld1q_f32(src_ptr + k_idx_start + 4);
151
152 // Scale the values
153 2622004 float32x4_t v0_f32 = vmulq_n_f32(src_0, scale0);
154 2622004 float32x4_t v1_f32 = vmulq_n_f32(src_1, scale0);
155 2622004 int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32);
156 2622004 int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32);
157
158 2622004 int16x4_t v0_s16 = vqmovn_s32(v0_s32);
159 2622004 int16x4_t v1_s16 = vqmovn_s32(v1_s32);
160 2622004 int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16);
161
162 // Add zero points
163 2622004 int16_t nzp_s16 = (int16_t)nudged_zero_point0;
164 2622004 int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16);
165 2622004 v_s16 = vaddq_s16(v_s16, vnzp_s16);
166 2622004 v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN));
167 2622004 v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX));
168
169 2622004 int8x8_t v0_s8 = vqmovn_s16(v_s16);
170 2622004 vst1_s8((int8_t*)(dst_ptr), v0_s8);
171 2622004 dst_ptr += 8 * sizeof(int8_t);
172 2622004 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
173 2622004 }
174 212602 } else
175 #endif
176 {
177
2/2
✓ Branch 0 taken 5271502 times.
✓ Branch 1 taken 191522 times.
5463024 for (; block_idx < num_blocks_k; ++block_idx) {
178
2/2
✓ Branch 0 taken 21086008 times.
✓ Branch 1 taken 5271502 times.
26357510 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
179 21086008 const int32_t k_idx_start = (block_idx * k_block_len) + k_block_idx;
180
181 21086008 const float src0_0 = *(src_ptr + k_idx_start);
182
183 // Scale the values
184 21086008 int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0));
185
186 21086008 v0_s32 = v0_s32 + nudged_zero_point0;
187
2/2
✓ Branch 0 taken 21045251 times.
✓ Branch 1 taken 40757 times.
21086008 v0_s32 = KAI_MAX(v0_s32, INT8_MIN);
188
2/2
✓ Branch 0 taken 20859772 times.
✓ Branch 1 taken 226236 times.
21086008 v0_s32 = KAI_MIN(v0_s32, INT8_MAX);
189
190 21086008 *((int8_t*)(dst_ptr)) = (int8_t)v0_s32;
191 21086008 dst_ptr += sizeof(int8_t);
192 21086008 }
193 5271502 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
194 5271502 }
195 }
196
197
2/2
✓ Branch 0 taken 704038 times.
✓ Branch 1 taken 404124 times.
1108162 for (; block_idx < num_blocks_k_internal; ++block_idx) {
198 // left over k
199
2/2
✓ Branch 0 taken 3704392 times.
✓ Branch 1 taken 704038 times.
4408430 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
200 // Clamp at the last valid k-index
201
2/2
✓ Branch 0 taken 201798 times.
✓ Branch 1 taken 3502594 times.
3704392 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
202
203 3704392 const float src0_0 = *(src_ptr + k_idx_start);
204
205 // Scale the values
206 3704392 int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0));
207
208 3704392 v0_s32 = v0_s32 + nudged_zero_point0;
209
2/2
✓ Branch 0 taken 3703101 times.
✓ Branch 1 taken 1291 times.
3704392 v0_s32 = KAI_MAX(v0_s32, INT8_MIN);
210
2/2
✓ Branch 0 taken 3595023 times.
✓ Branch 1 taken 109369 times.
3704392 v0_s32 = KAI_MIN(v0_s32, INT8_MAX);
211
212 3704392 *((int8_t*)(dst_ptr)) = (int8_t)v0_s32;
213 3704392 dst_ptr += sizeof(int8_t);
214 3704392 }
215 704038 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
216 704038 }
217
218 404124 dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
219
220 404124 dst_ptr += dst_x * kai_num_bytes_per_offset;
221
222 // LHS offset at the beginning of the row
223 404124 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
224
225 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier
226 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
227
228 404124 dst_ptr += mr * kai_num_bytes_per_offset;
229
230 // Store the scale quantization params
231 404124 *((float*)(dst_ptr)) = recip_scale0;
232
233 404124 src_ptr += (lhs_stride / sizeof(float));
234
235 // Move to the next row if we have interleaved all Mr rows
236
2/2
✓ Branch 0 taken 197228 times.
✓ Branch 1 taken 206896 times.
404124 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
237 206896 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
238 206896 }
239 404124 }
240 18694 }
241