KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.9% 126 3 133
Functions: 85.7% 6 0 7
Branches: 89.5% 34 4 42

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_lhs_quant_pack_qai8dxp_f32.h"
7
8 #if defined(__aarch64__)
9 #include <arm_neon.h>
10 #endif
11 #include <float.h>
12 #include <math.h>
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_num_bytes_per_multiplier = sizeof(float);
19 static const size_t kai_num_bytes_per_offset = sizeof(int32_t);
20
21 17176 inline static size_t kai_k_roundedup(size_t k) {
22 // Round up k to be a multiple of 32.
23 17176 size_t kai_k_multiple_of = 32;
24 34352 return kai_roundup(k, kai_k_multiple_of);
25 17176 }
26
27 13236 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) {
28 13236 KAI_UNUSED(kr);
29 13236 KAI_UNUSED(sr);
30
31 13236 const size_t k_internal = kai_k_roundedup(k);
32
33 KAI_ASSERT((k_internal % 2) == 0);
34
35 26472 return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset);
36 13236 }
37
38 size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr) {
39 KAI_UNUSED(mr);
40 return 1;
41 }
42
43 3940 size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride) {
44 3940 return m_idx * lhs_stride;
45 }
46
47 5356 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
48 // It always points to the beginning of the row
49 5356 return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr);
50 }
51
52 3940 size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
53 3940 const size_t num_rows = kai_roundup(m, mr) / mr;
54
55 7880 return num_rows * kai_lhs_packed_stride(k, mr, kr, sr);
56 3940 }
57
58 3940 void kai_run_lhs_quant_pack_qai8dxp_f32(
59 size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs,
60 size_t lhs_stride, void* restrict lhs_packed) {
61 KAI_ASSERT((kr % sr) == 0);
62
63
1/2
✓ Branch 0 taken 3940 times.
✗ Branch 1 not taken.
3940 if (m == 0) {
64 return;
65 }
66
67 3940 const size_t num_rows = m;
68
69 3940 const float* src_ptr = lhs;
70
71 3940 const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr);
72 3940 const size_t k_internal = kai_k_roundedup(k);
73 3940 const int32_t k_block_len = (int32_t)(kr / sr);
74
75 3940 const int32_t num_blocks_k = (int32_t)(k / k_block_len);
76 3940 const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len);
77
78
2/2
✓ Branch 0 taken 80332 times.
✓ Branch 1 taken 3940 times.
84272 for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) {
79 80332 float max0 = -FLT_MAX;
80 80332 float min0 = FLT_MAX;
81
82 // Find min/max for each channel
83 80332 int32_t k_idx = 0;
84
85 #if defined(__aarch64__)
86 80332 float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX);
87 80332 float32x4_t vmin0 = vdupq_n_f32(FLT_MAX);
88
89
2/2
✓ Branch 0 taken 987704 times.
✓ Branch 1 taken 80332 times.
1068036 for (; k_idx <= ((int32_t)k - 8); k_idx += 8) {
90 987704 const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx);
91 987704 const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx);
92
93 // Calculate the max
94 987704 vmax0 = vmaxq_f32(src0_0, vmax0);
95 987704 vmax0 = vmaxq_f32(vmax0, src0_1);
96
97 // Calculate the min
98 987704 vmin0 = vminq_f32(src0_0, vmin0);
99 987704 vmin0 = vminq_f32(vmin0, src0_1);
100 987704 }
101 // Get the max/min
102 80332 max0 = vmaxvq_f32(vmax0);
103 80332 min0 = vminvq_f32(vmin0);
104 #endif
105
2/2
✓ Branch 0 taken 76600 times.
✓ Branch 1 taken 80332 times.
156932 for (; k_idx < (int32_t)k; ++k_idx) {
106 76600 const float src0_0 = *(src_ptr + (size_t)k_idx);
107 76600 max0 = fmaxf(src0_0, max0);
108 76600 min0 = fminf(src0_0, min0);
109 76600 }
110
111 // Maximum/minimum int8 values
112 80332 const float qmin = (float)INT8_MIN;
113 80332 const float qmax = (float)INT8_MAX;
114
115 80332 const float rmin0 = fminf(0.0F, min0);
116 80332 const float rmax0 = fmaxf(0.0F, max0);
117
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 80332 times.
80332 const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0);
118
119 // Reciprocal to quantize
120
1/2
✓ Branch 0 taken 80332 times.
✗ Branch 1 not taken.
80332 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
121
122 80332 const float descaled_min0 = rmin0 * scale0;
123 80332 const float descaled_max0 = rmax0 * scale0;
124
125 80332 const float zero_point_from_min_error0 = qmin + descaled_min0;
126 80332 const float zero_point_from_max_error0 = qmax + descaled_max0;
127
128 160664 float zero_point0 =
129
1/2
✓ Branch 0 taken 80332 times.
✗ Branch 1 not taken.
80332 zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0;
130
131 80332 zero_point0 = fmaxf(zero_point0, qmin);
132 80332 zero_point0 = fminf(zero_point0, qmax);
133
134 // Round to nearest integer
135 80332 const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0);
136
137 80332 const size_t dst_x = ((row_idx + m_idx_start) % mr);
138
139 80332 uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t));
140
141 // Quantize the channels
142 80332 int32_t block_idx = 0;
143
144 #if defined(__aarch64__)
145
2/2
✓ Branch 0 taken 45550 times.
✓ Branch 1 taken 34782 times.
80332 if (k_block_len == 8) {
146
2/2
✓ Branch 0 taken 600230 times.
✓ Branch 1 taken 45550 times.
645780 for (; block_idx < num_blocks_k; ++block_idx) {
147 // Clamp at the last valid k-index
148 600230 const int32_t k_idx_start = block_idx * k_block_len;
149
150 600230 const float32x4_t src_0 = vld1q_f32(src_ptr + k_idx_start);
151 600230 const float32x4_t src_1 = vld1q_f32(src_ptr + k_idx_start + 4);
152
153 // Scale the values
154 600230 float32x4_t v0_f32 = vmulq_n_f32(src_0, scale0);
155 600230 float32x4_t v1_f32 = vmulq_n_f32(src_1, scale0);
156 600230 int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32);
157 600230 int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32);
158
159 600230 int16x4_t v0_s16 = vqmovn_s32(v0_s32);
160 600230 int16x4_t v1_s16 = vqmovn_s32(v1_s32);
161 600230 int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16);
162
163 // Add zero points
164 600230 int16_t nzp_s16 = (int16_t)nudged_zero_point0;
165 600230 int16x8_t vnzp_s16 = vdupq_n_s16(nzp_s16);
166 600230 v_s16 = vaddq_s16(v_s16, vnzp_s16);
167 600230 v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN));
168 600230 v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX));
169
170 600230 int8x8_t v0_s8 = vqmovn_s16(v_s16);
171 600230 vst1_s8((int8_t*)(dst_ptr), v0_s8);
172 600230 dst_ptr += 8 * sizeof(int8_t);
173 600230 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
174 600230 }
175 45550 } else
176 #endif
177 {
178
2/2
✓ Branch 0 taken 781226 times.
✓ Branch 1 taken 34782 times.
816008 for (; block_idx < num_blocks_k; ++block_idx) {
179
2/2
✓ Branch 0 taken 3124904 times.
✓ Branch 1 taken 781226 times.
3906130 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
180 3124904 const int32_t k_idx_start = (block_idx * k_block_len) + k_block_idx;
181
182 3124904 const float src0_0 = *(src_ptr + k_idx_start);
183
184 // Scale the values
185 3124904 int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0));
186
187 3124904 v0_s32 = v0_s32 + nudged_zero_point0;
188
2/2
✓ Branch 0 taken 3118356 times.
✓ Branch 1 taken 6548 times.
3124904 v0_s32 = KAI_MAX(v0_s32, INT8_MIN);
189
2/2
✓ Branch 0 taken 3084798 times.
✓ Branch 1 taken 40106 times.
3124904 v0_s32 = KAI_MIN(v0_s32, INT8_MAX);
190
191 3124904 *((int8_t*)(dst_ptr)) = (int8_t)v0_s32;
192 3124904 dst_ptr += sizeof(int8_t);
193 3124904 }
194 781226 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
195 781226 }
196 }
197
198
2/2
✓ Branch 0 taken 117664 times.
✓ Branch 1 taken 80332 times.
197996 for (; block_idx < num_blocks_k_internal; ++block_idx) {
199 // left over k
200
2/2
✓ Branch 0 taken 618280 times.
✓ Branch 1 taken 117664 times.
735944 for (int32_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
201 // Clamp at the last valid k-index
202
2/2
✓ Branch 0 taken 34360 times.
✓ Branch 1 taken 583920 times.
618280 const size_t k_idx_start = KAI_MIN((size_t)((block_idx * k_block_len) + k_block_idx), k - 1);
203
204 618280 const float src0_0 = *(src_ptr + k_idx_start);
205
206 // Scale the values
207 618280 int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0));
208
209 618280 v0_s32 = v0_s32 + nudged_zero_point0;
210
2/2
✓ Branch 0 taken 618254 times.
✓ Branch 1 taken 26 times.
618280 v0_s32 = KAI_MAX(v0_s32, INT8_MIN);
211
2/2
✓ Branch 0 taken 605826 times.
✓ Branch 1 taken 12454 times.
618280 v0_s32 = KAI_MIN(v0_s32, INT8_MAX);
212
213 618280 *((int8_t*)(dst_ptr)) = (int8_t)v0_s32;
214 618280 dst_ptr += sizeof(int8_t);
215 618280 }
216 117664 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
217 117664 }
218
219 80332 dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t));
220
221 80332 dst_ptr += dst_x * kai_num_bytes_per_offset;
222
223 // LHS offset at the beginning of the row
224 80332 *((int32_t*)(dst_ptr)) = -nudged_zero_point0;
225
226 // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier
227 KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier);
228
229 80332 dst_ptr += mr * kai_num_bytes_per_offset;
230
231 // Store the scale quantization params
232 80332 *((float*)(dst_ptr)) = recip_scale0;
233
234 80332 src_ptr += (lhs_stride / sizeof(float));
235
236 // Move to the next row if we have interleaved all Mr rows
237
2/2
✓ Branch 0 taken 37912 times.
✓ Branch 1 taken 42420 times.
80332 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
238 42420 lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride);
239 42420 }
240 80332 }
241 3940 }
242