KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.6% 41 7 49
Functions: 100.0% 14 0 14
Branches: 50.0% 1 14 16

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__ARM_FEATURE_DOTPROD)
8 #error "Dotprod extension required to compile this micro-kernel"
9 #else // Architectural features check.
10 #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h"
11
12 #include <arm_neon.h>
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_m_step = 1;
19 static const size_t kai_n_step = 8;
20 static const size_t kai_mr = 1;
21 static const size_t kai_nr = 8;
22 static const size_t kai_kr = 16;
23 static const size_t kai_sr = 2;
24 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
25 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
26 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
27 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
28 static const size_t kai_num_bytes_bias = sizeof(float);
29
30 641 inline static size_t kai_k_roundedup(size_t k) {
31 // Since we pack a float and int32 value at the end of the row,
32 // we must make sure that k is a multiple of 4 for alignment
33 641 size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4);
34 1282 return kai_roundup(k, kr_sr_roundedup4);
35 641 }
36
37 240 inline static size_t kai_lhs_packed_stride(size_t k) {
38 240 const size_t k_internal = kai_k_roundedup(k);
39
40 KAI_ASSERT((k_internal % 2) == 0);
41
42 480 return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
43 240 }
44
45 240 inline static size_t kai_rhs_packed_stride(size_t k) {
46 240 const size_t k_internal = kai_k_roundedup(k);
47
48 KAI_ASSERT((k_internal % 2) == 0);
49
50 480 return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
51 240 }
52
53 320 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) {
54 320 return kai_m_step;
55 }
56
57 320 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) {
58 320 return kai_n_step;
59 }
60
61 240 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) {
62 240 return kai_mr;
63 }
64
65 240 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) {
66 240 return kai_nr;
67 }
68
69 320 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) {
70 320 return kai_kr;
71 }
72
73 320 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) {
74 320 return kai_sr;
75 }
76
77 240 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) {
78 KAI_ASSERT((m_idx % kai_m_step) == 0);
79
80 240 return (m_idx / kai_mr) * kai_lhs_packed_stride(k);
81 }
82
83 240 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k) {
84 KAI_ASSERT((n_idx % kai_n_step) == 0);
85
86 240 return (n_idx / kai_nr) * kai_rhs_packed_stride(k);
87 }
88
89 160 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(
90 size_t m_idx, size_t n_idx, size_t dst_stride) {
91 KAI_ASSERT((m_idx % kai_m_step) == 0);
92 KAI_ASSERT((n_idx % kai_n_step) == 0);
93
94 160 return (n_idx * sizeof(float)) + m_idx * dst_stride;
95 }
96
97 160 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n) {
98 160 return m * n * sizeof(float);
99 }
100
101 161 void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(
102 size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed,
103 float* restrict dst, // NOLINT(readability-non-const-parameter)
104 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
105 KAI_ASSERT(dst_stride_col == sizeof(float));
106
107
1/2
✓ Branch 0 taken 161 times.
✗ Branch 1 not taken.
161 if (m == 0) {
108 return;
109 }
110
111 161 const size_t k_internal = kai_k_roundedup(k);
112
113 161 size_t num_blocks = k_internal / 32;
114
115 161 float clamp_vals[2] = {scalar_min, scalar_max};
116 322 __asm__ __volatile__(
117 "mov x26, #0x20\n"
118 "mov x20, #0x8\n"
119 "movi v5.16b, #0xf0\n"
120 "mov x25, %x[m]\n"
121 "madd x26, %x[num_blocks], x26, x20\n"
122 "1:" // Row loop
123 "mov x24, %x[rhs_packed]\n"
124 "mov x23, %x[n]\n"
125 "add x22, %x[dst], %x[dst_stride_row]\n"
126 "2:" // Column loop
127 "mov x21, %x[lhs_packed]\n"
128 "movi v4.4s, #0x0\n"
129 "movi v3.4s, #0x0\n"
130 "mov x20, %x[num_blocks]\n"
131 "movi v2.4s, #0x0\n"
132 "movi v1.4s, #0x0\n"
133 "3:" // Sub block loop
134 "ldr q0, [x24, #0x0]\n"
135 "ldr q31, [x24, #0x10]\n"
136 "subs x20, x20, #0x1\n"
137 "ldr q30, [x24, #0x20]\n"
138 "ldr q29, [x24, #0x30]\n"
139 "ld1r { v28.2d }, [x21], #0x8\n"
140 "ldr q27, [x24, #0x40]\n"
141 "ldr q26, [x24, #0x50]\n"
142 "ldr q25, [x24, #0x60]\n"
143 "shl v24.16b, v0.16b, #0x4\n"
144 "shl v18.16b, v31.16b, #0x4\n"
145 "ldr q23, [x24, #0x70]\n"
146 "shl v17.16b, v30.16b, #0x4\n"
147 "shl v16.16b, v29.16b, #0x4\n"
148 "add x24, x24, #0x80\n"
149 "ld1r { v22.2d }, [x21], #0x8\n"
150 "shl v21.16b, v27.16b, #0x4\n"
151 "and v0.16b, v0.16b, v5.16b\n"
152 "ld1r { v20.2d }, [x21], #0x8\n"
153 "ld1r { v19.2d }, [x21], #0x8\n"
154 ".inst 0x4e9c9704 // sdot v4.4s, v24.16b, v28.16b\n"
155 ".inst 0x4e9c9643 // sdot v3.4s, v18.16b, v28.16b\n"
156 "shl v18.16b, v26.16b, #0x4\n"
157 ".inst 0x4e9c9622 // sdot v2.4s, v17.16b, v28.16b\n"
158 ".inst 0x4e9c9601 // sdot v1.4s, v16.16b, v28.16b\n"
159 "shl v17.16b, v25.16b, #0x4\n"
160 "shl v16.16b, v23.16b, #0x4\n"
161 "and v31.16b, v31.16b, v5.16b\n"
162 "and v30.16b, v30.16b, v5.16b\n"
163 "and v29.16b, v29.16b, v5.16b\n"
164 ".inst 0x4e9696a4 // sdot v4.4s, v21.16b, v22.16b\n"
165 ".inst 0x4e969643 // sdot v3.4s, v18.16b, v22.16b\n"
166 "and v27.16b, v27.16b, v5.16b\n"
167 ".inst 0x4e969622 // sdot v2.4s, v17.16b, v22.16b\n"
168 ".inst 0x4e969601 // sdot v1.4s, v16.16b, v22.16b\n"
169 "and v26.16b, v26.16b, v5.16b\n"
170 "and v25.16b, v25.16b, v5.16b\n"
171 "and v23.16b, v23.16b, v5.16b\n"
172 ".inst 0x4e949404 // sdot v4.4s, v0.16b, v20.16b\n"
173 ".inst 0x4e9497e3 // sdot v3.4s, v31.16b, v20.16b\n"
174 ".inst 0x4e9497c2 // sdot v2.4s, v30.16b, v20.16b\n"
175 ".inst 0x4e9497a1 // sdot v1.4s, v29.16b, v20.16b\n"
176 ".inst 0x4e939764 // sdot v4.4s, v27.16b, v19.16b\n"
177 ".inst 0x4e939743 // sdot v3.4s, v26.16b, v19.16b\n"
178 ".inst 0x4e939722 // sdot v2.4s, v25.16b, v19.16b\n"
179 ".inst 0x4e9396e1 // sdot v1.4s, v23.16b, v19.16b\n"
180 "bgt 3b\n"
181 "ldr q25, [x24, #0x0]\n"
182 "ldr q24, [x24, #0x10]\n"
183 "addp v4.4s, v4.4s, v3.4s\n"
184 "addp v2.4s, v2.4s, v1.4s\n"
185 "ld1r { v23.4s }, [x21]\n"
186 "ldr q22, [x24, #0x20]\n"
187 "add x21, x21, #0x4\n"
188 "add x20, %x[clamp_vals], #0x4\n"
189 "ld1r { v17.4s }, [x21]\n"
190 "ldr q16, [x24, #0x30]\n"
191 "cmp x23, #0x8\n"
192 "ldr q21, [x24, #0x40]\n"
193 "ldr q20, [x24, #0x50]\n"
194 "add x24, x24, #0x60\n"
195 "ld1r { v19.4s }, [%x[clamp_vals]]\n"
196 "ld1r { v18.4s }, [x20]\n"
197 "mla v4.4s, v25.4s, v23.s[0]\n"
198 "mla v2.4s, v24.4s, v23.s[0]\n"
199 "fmul v22.4s, v22.4s, v17.4s\n"
200 "fmul v16.4s, v16.4s, v17.4s\n"
201 "scvtf v4.4s, v4.4s\n"
202 "scvtf v2.4s, v2.4s\n"
203 "fmul v17.4s, v4.4s, v22.4s\n"
204 "fmul v16.4s, v2.4s, v16.4s\n"
205 "fadd v17.4s, v17.4s, v21.4s\n"
206 "fadd v16.4s, v16.4s, v20.4s\n"
207 "fmax v17.4s, v17.4s, v19.4s\n"
208 "fmax v16.4s, v16.4s, v19.4s\n"
209 "fmin v17.4s, v17.4s, v18.4s\n"
210 "fmin v16.4s, v16.4s, v18.4s\n"
211 "blt 4f\n"
212 "str q17, [%x[dst], #0x0]\n"
213 "str q16, [%x[dst], #0x10]\n"
214 "b 9f\n"
215 "4:" // Partial output
216 "mov x20, %x[dst]\n"
217 "tbz x23, #2, 6f\n"
218 "st1 { v17.4s }, [x20], #0x10\n"
219 "tbz x23, #1, 5f\n"
220 "st1 { v16.d }[0], [x20], #0x8\n"
221 "tbz x23, #0, 8f\n"
222 "st1 { v16.s }[2], [x20]\n"
223 "b 8f\n"
224 "5:" // Output block 0: partial_1_4
225 "tbz x23, #0, 8f\n"
226 "st1 { v16.s }[0], [x20]\n"
227 "b 8f\n"
228 "6:" // Output block 0: partial_2_0
229 "tbz x23, #1, 7f\n"
230 "st1 { v17.d }[0], [x20], #0x8\n"
231 "tbz x23, #0, 8f\n"
232 "st1 { v17.s }[2], [x20]\n"
233 "b 8f\n"
234 "7:" // Output block 0: partial_1_0
235 "st1 { v17.s }[0], [x20]\n"
236 "8:" // Output block 0: Done
237 "9:" // Stores done
238 "subs x23, x23, #0x8\n"
239 "add %x[dst], %x[dst], #0x20\n"
240 "bgt 2b\n"
241 "subs x25, x25, #0x1\n"
242 "add %x[lhs_packed], %x[lhs_packed], x26\n"
243 "mov %x[dst], x22\n"
244 "bgt 1b\n"
245 : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed)
246 161 : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n),
247 161 [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed)
248 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
249 "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26");
250 161 }
251
252 #endif // Architectural features check.
253