KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 96.8% 90 13 106
Functions: 87.5% 7 0 8
Branches: 81.8% 18 28 50

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__aarch64__) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64.
9 #else // Architectural features check.
10
11 #include "kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.h"
12
13 #include <arm_neon.h>
14 #include <float.h>
15 #include <stddef.h>
16 #include <stdint.h>
17
18 #include "kai/kai_common.h"
19
20 static const size_t kai_num_bytes_sum = sizeof(float);
21 static const size_t kai_num_bytes_multiplier = sizeof(float);
22 static const size_t kai_bl_multiple_of = 32;
23
24 4794320 inline static size_t kai_get_num_bytes_per_block(size_t bl) {
25 4794320 return bl * sizeof(int8_t) + kai_num_bytes_multiplier + kai_num_bytes_sum;
26 }
27
28 4794320 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
29 KAI_ASSERT((k % bl) == 0);
30 4794320 return k / bl;
31 }
32
33 4783864 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) {
34 4783864 KAI_UNUSED(kr);
35
36 4783864 return mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block(bl);
37 }
38
39 size_t kai_get_m_step_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t mr) {
40 KAI_UNUSED(mr);
41 return 1;
42 }
43
44 10904 size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(size_t m_idx, size_t lhs_stride) {
45 10904 return m_idx * lhs_stride;
46 }
47
48 10904 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32pscalef32_f32_neon(
49 size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
50 KAI_ASSUME((k % 2) == 0);
51 KAI_ASSUME((k % kr) == 0);
52 KAI_ASSUME((k % bl) == 0);
53 KAI_ASSUME((m_idx % mr) == 0);
54
55 10904 KAI_UNUSED(sr);
56 10904 return (m_idx / mr) * kai_get_lhs_packed_stride(k, mr, kr, bl);
57 }
58
59 4762504 size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32pscalef32_f32_neon(
60 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
61 KAI_ASSUME((k % 2) == 0);
62 KAI_ASSUME((k % kr) == 0);
63 KAI_ASSUME((k % bl) == 0);
64
65 4762504 KAI_UNUSED(sr);
66
67 4762504 const size_t num_rows = kai_roundup(m, mr) / mr;
68
69 9525008 return (num_rows * kai_get_lhs_packed_stride(k, mr, kr, bl));
70 4762504 }
71 10904 void kai_run_lhs_quant_pack_qsi8d32pscalef32_f32_neon(
72 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
73 size_t lhs_stride, void* lhs_packed) {
74 KAI_ASSERT((kr % sr) == 0);
75 KAI_ASSUME((bl % kr) == 0);
76 KAI_ASSUME((k % bl) == 0);
77 KAI_ASSUME((bl % kai_bl_multiple_of) == 0);
78 KAI_ASSERT(((kr / sr) == 8) || ((kr / sr) == 4));
79
80
2/2
✓ Branch 0 taken 10456 times.
✓ Branch 1 taken 448 times.
10904 if (m == 0) {
81 448 return;
82 }
83 10456 const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, mr, kr, bl);
84 10456 const size_t num_rows = m;
85 10456 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
86 10456 const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl);
87 10456 const size_t mr_block_size = mr * num_bytes_per_block;
88
89 10456 const int32_t k_block_len = (int32_t)(kr / sr);
90
91
2/2
✓ Branch 0 taken 113000 times.
✓ Branch 1 taken 10456 times.
123456 for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) {
92 113000 const float* row_src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride);
93 113000 const size_t dst_idx = ((row_idx + m_idx_start) % mr);
94
95
2/2
✓ Branch 0 taken 163344 times.
✓ Branch 1 taken 113000 times.
276344 for (size_t blk_idx = 0; blk_idx < num_blocks_per_row; ++blk_idx) {
96 163344 const float* src_ptr = row_src_ptr + blk_idx * bl;
97 163344 int8_t* dst_ptr = (int8_t*)lhs_packed + dst_idx * k_block_len * sizeof(int8_t) + blk_idx * mr_block_size;
98 163344 int8_t* param_ptr = (int8_t*)lhs_packed + blk_idx * mr_block_size + bl * mr + dst_idx * kai_num_bytes_sum;
99
100 // Find absmax for each block
101 163344 float absmax = -FLT_MAX;
102 163344 int32_t k_idx = 0;
103 163344 float32x4_t vabsmax = vdupq_n_f32(-FLT_MAX);
104
2/2
✓ Branch 0 taken 850144 times.
✓ Branch 1 taken 163344 times.
1013488 for (; k_idx < ((int32_t)bl); k_idx += 8) {
105 850144 const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx);
106 850144 const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx);
107 // Calculate the max
108 850144 vabsmax = vmaxq_f32(vabsq_f32(src0_0), vmaxq_f32(vabsmax, vabsq_f32(src0_1)));
109 850144 }
110 // Get the absmax
111 163344 absmax = vmaxvq_f32(vabsmax);
112
113 // Maximum/minimum int8 values
114 163344 const float qmax = (float)INT8_MAX;
115
116 // Get the scale and reciprocal to quantize
117
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 163344 times.
163344 const float scale0 = absmax == 0.0F ? 0.0F : qmax / absmax;
118
1/2
✓ Branch 0 taken 163344 times.
✗ Branch 1 not taken.
163344 const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F;
119
120 163344 int32_t qsum = 0;
121 // Quantize the blocks
122
2/2
✓ Branch 0 taken 1494760 times.
✓ Branch 1 taken 163344 times.
1658104 for (k_idx = 0; k_idx <= (int32_t)bl - k_block_len; k_idx += k_block_len) {
123 // Clamp at the last valid k-index
124
1/2
✓ Branch 0 taken 1494760 times.
✗ Branch 1 not taken.
1494760 const size_t k_idx_start = KAI_MIN((size_t)k_idx, k - 1);
125
2/2
✓ Branch 0 taken 205528 times.
✓ Branch 1 taken 1289232 times.
1494760 if (k_block_len == 8) {
126 205528 const float32x4_t vsrc_0 = vld1q_f32(src_ptr + k_idx_start);
127 205528 const float32x4_t vsrc_1 = vld1q_f32(src_ptr + k_idx_start + 4);
128
129 // Scale the values
130 205528 float32x4_t v0_f32 = vmulq_n_f32(vsrc_0, scale0);
131 205528 float32x4_t v1_f32 = vmulq_n_f32(vsrc_1, scale0);
132
133 205528 int16x4_t v0_s16 = vqmovn_s32(vcvtnq_s32_f32(v0_f32));
134 205528 int16x4_t v1_s16 = vqmovn_s32(vcvtnq_s32_f32(v1_f32));
135 205528 int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16);
136
137 205528 v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN));
138 205528 v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX));
139
140 // Update the sum
141 205528 qsum += vaddvq_s16(v_s16);
142
143 205528 int8x8_t v0_s8 = vqmovn_s16(v_s16);
144 205528 vst1_s8(dst_ptr, v0_s8);
145 205528 dst_ptr += 8 * sizeof(int8_t);
146
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1289232 times.
1494760 } else if (k_block_len == 4) {
147 1289232 const float32x2_t vsrc_0 = vld1_f32(src_ptr + k_idx_start);
148 1289232 const float32x2_t vsrc_1 = vld1_f32(src_ptr + k_idx_start + 2);
149
150 // Scale the values
151 1289232 float32x2_t v0_f32 = vmul_n_f32(vsrc_0, scale0);
152 1289232 float32x2_t v1_f32 = vmul_n_f32(vsrc_1, scale0);
153
154 1289232 int32x2_t v0_s32 = vcvtn_s32_f32(v0_f32);
155 1289232 int32x2_t v1_s32 = vcvtn_s32_f32(v1_f32);
156 1289232 int16x4_t v_s16 = vqmovn_s32(vcombine_s32(v0_s32, v1_s32));
157
158 1289232 v_s16 = vmax_s16(v_s16, vdup_n_s16(INT8_MIN));
159 1289232 v_s16 = vmin_s16(v_s16, vdup_n_s16(INT8_MAX));
160
161 // Update the sum
162 1289232 qsum += vaddv_s16(v_s16);
163
164 1289232 dst_ptr[0] = vqmovnh_s16(vget_lane_s16(v_s16, 0));
165 1289232 dst_ptr[1] = vqmovnh_s16(vget_lane_s16(v_s16, 1));
166 1289232 dst_ptr[2] = vqmovnh_s16(vget_lane_s16(v_s16, 2));
167 1289232 dst_ptr[3] = vqmovnh_s16(vget_lane_s16(v_s16, 3));
168 1289232 dst_ptr += 4 * sizeof(int8_t);
169 1289232 }
170 1494760 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
171 1494760 }
172 163344 *((float*)(param_ptr)) = ((float)qsum) * recip_scale0;
173 163344 param_ptr += mr * kai_num_bytes_sum;
174 163344 *((float*)(param_ptr)) = recip_scale0;
175 163344 }
176 // Move to the next row if we have interleaved all Mr rows
177
2/2
✓ Branch 0 taken 62092 times.
✓ Branch 1 taken 50908 times.
113000 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
178 50908 lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride);
179 50908 }
180 113000 }
181 10904 }
182 #endif // Architectural features check.
183