KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.8% 45 11 57
Functions: 100.0% 16 0 16
Branches: 50.0% 1 22 24

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD)
7 #error "Dotprod extension required to compile this micro-kernel"
8 #else // Architectural features check.
9
10 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16
17 // Compute args
18 static const size_t kai_m_step = 1;
19 static const size_t kai_n_step = 4;
20 // Packing args
21 static const size_t kai_mr = 1;
22 static const size_t kai_nr = 4;
23 static const size_t kai_kr = 8;
24 static const size_t kai_sr = 2;
25 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
26 static const size_t kai_num_bytes_qvalue_lhs = 1;
27 static const size_t kai_num_bytes_multiplier_lhs = 2;
28 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
29 // asymmetric))
30 static const size_t kai_recip_num_bytes_qvalue_rhs = 2;
31 static const size_t kai_num_bytes_multiplier_rhs = 2;
32 // DST format args
33 static const size_t kai_num_bytes_dst_value = 4;
34 // Extra args
35 static const size_t kai_bl = 32;
36
37 26 inline static size_t kai_num_bytes_per_block_lhs(size_t bl) {
38 26 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs;
39 }
40
41 26 inline static size_t kai_num_bytes_per_block_rhs(size_t bl) {
42 KAI_ASSUME(bl == kai_bl);
43 26 size_t num_bytes_per_block_rhs = (bl / kai_recip_num_bytes_qvalue_rhs) + kai_num_bytes_multiplier_rhs;
44 52 return num_bytes_per_block_rhs;
45 26 }
46
47 56 inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
48 KAI_ASSUME(bl == kai_bl);
49 KAI_ASSUME((k % kai_bl) == 0);
50
51 56 return kai_roundup(k, bl) / bl;
52 }
53
54 26 inline static size_t kai_lhs_packed_stride(size_t k, size_t bl) {
55 26 return kai_mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block_lhs(bl);
56 }
57
58 26 inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) {
59 KAI_ASSUME(bl == kai_bl);
60 KAI_ASSUME((k % kai_bl) == 0);
61
62 26 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
63 26 const size_t num_bytes_per_block = kai_num_bytes_per_block_rhs(bl);
64
65 26 size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row);
66
67 52 return rhs_packed_stride;
68 26 }
69
70 52 size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(void) {
71 52 return kai_m_step;
72 }
73
74 52 size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(void) {
75 52 return kai_n_step;
76 }
77
78 48 size_t kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(void) {
79 48 return kai_mr;
80 }
81
82 48 size_t kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(void) {
83 48 return kai_nr;
84 }
85
86 72 size_t kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(void) {
87 72 return kai_kr;
88 }
89
90 48 size_t kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(void) {
91 48 return kai_sr;
92 }
93
94 26 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(
95 size_t m_idx, size_t k, size_t bl) {
96 KAI_ASSUME((m_idx % kai_m_step) == 0);
97
98 26 return (m_idx / kai_mr) * kai_lhs_packed_stride(k, bl);
99 }
100
101 26 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(
102 size_t n_idx, size_t k, size_t bl) {
103 KAI_ASSUME((n_idx % kai_n_step) == 0);
104
105 26 return (n_idx / kai_nr) * kai_rhs_packed_stride(k, bl);
106 }
107
108 3 size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(
109 size_t m_idx, size_t n_idx, size_t dst_stride) {
110 KAI_ASSUME((m_idx % kai_m_step) == 0);
111 KAI_ASSUME((n_idx % kai_n_step) == 0);
112
113 3 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
114 }
115
116 3 size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(size_t m, size_t n) {
117 3 return m * n * kai_num_bytes_dst_value;
118 }
119
120 4 void kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod(
121 size_t m, //
122 size_t n, //
123 size_t k, //
124 size_t bl, //
125 const void* restrict lhs_packed, //
126 const void* restrict rhs_packed, //
127 float* restrict dst, // NOLINT(readability-non-const-parameter)
128 size_t dst_stride_row, //
129 size_t dst_stride_col, //
130 float scalar_min, //
131 float scalar_max) {
132 KAI_ASSUME(m == 1);
133 KAI_ASSUME(dst_stride_col == sizeof(float));
134
135
1/2
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
4 if (m == 0) {
136 return;
137 }
138 4 size_t num_blocks = kai_num_blocks_per_row(k, bl);
139 4 KAI_UNUSED(scalar_min);
140 4 KAI_UNUSED(scalar_max);
141
142 8 __asm__ __volatile__(
143 "mov x26, #0x22\n"
144 "movi v30.16b, #0xf0\n"
145 "mov x25, %x[m]\n"
146 "mul x26, %x[num_blocks], x26\n"
147 "1:" // Row loop
148 "mov x24, %x[rhs_packed]\n"
149 "mov x23, %x[n]\n"
150 "add x22, %x[dst], %x[dst_stride_row]\n"
151 "2:" // Column loop
152 "mov x21, %x[lhs_packed]\n"
153 "movi v29.16b, #0x0\n"
154 "mov x20, %x[num_blocks]\n"
155 "3:" // Block loop
156 "ldr d16, [x24, #0x0]\n"
157 "ld1r { v28.8h }, [x21]\n"
158 "add x24, x24, #0x8\n"
159 "add x21, x21, #0x2\n"
160 "ldr q27, [x24, #0x0]\n"
161 "ldr q26, [x21, #0x0]\n"
162 "movi v25.4s, #0x0\n"
163 "sub x20, x20, #0x1\n"
164 "ldr q24, [x24, #0x10]\n"
165 "ldr q23, [x24, #0x20]\n"
166 "ldr q22, [x24, #0x30]\n"
167 "ldr q21, [x21, #0x10]\n"
168 "fcvtl v28.4s, v28.4h\n"
169 "fcvtl v20.4s, v16.4h\n"
170 "shl v19.16b, v27.16b, #0x4\n"
171 "and v27.16b, v27.16b, v30.16b\n"
172 "add x24, x24, #0x40\n"
173 "add x21, x21, #0x20\n"
174 "shl v18.16b, v24.16b, #0x4\n"
175 "shl v17.16b, v23.16b, #0x4\n"
176 "shl v16.16b, v22.16b, #0x4\n"
177 "and v24.16b, v24.16b, v30.16b\n"
178 ".inst 0x4f9ae279 // sdot v25.4s, v19.16b, v26.4b[0]\n"
179 "and v23.16b, v23.16b, v30.16b\n"
180 "and v22.16b, v22.16b, v30.16b\n"
181 "fmul v20.4s, v20.4s, v28.4s\n"
182 ".inst 0x4fbae259 // sdot v25.4s, v18.16b, v26.4b[1]\n"
183 ".inst 0x4f9aea39 // sdot v25.4s, v17.16b, v26.4b[2]\n"
184 ".inst 0x4fbaea19 // sdot v25.4s, v16.16b, v26.4b[3]\n"
185 ".inst 0x4f95e379 // sdot v25.4s, v27.16b, v21.4b[0]\n"
186 ".inst 0x4fb5e319 // sdot v25.4s, v24.16b, v21.4b[1]\n"
187 ".inst 0x4f95eaf9 // sdot v25.4s, v23.16b, v21.4b[2]\n"
188 ".inst 0x4fb5ead9 // sdot v25.4s, v22.16b, v21.4b[3]\n"
189 "scvtf v25.4s, v25.4s, #0x4\n"
190 "fmla v29.4s, v25.4s, v20.4s\n"
191 "cbnz x20, 3b\n"
192 "cmp x23, #0x4\n"
193 "blt 4f\n"
194 "str q29, [%x[dst], #0x0]\n"
195 "b 7f\n"
196 "4:" // Partial output
197 "mov x20, %x[dst]\n"
198 "tbz x23, #1, 5f\n"
199 "st1 { v29.d }[0], [x20], #0x8\n"
200 "tbz x23, #0, 6f\n"
201 "st1 { v29.s }[2], [x20]\n"
202 "b 6f\n"
203 "5:" // Output block 0: partial_1_0
204 "st1 { v29.s }[0], [x20]\n"
205 "6:" // Output block 0: Done
206 "7:" // Stores done
207 "subs x23, x23, #0x4\n"
208 "add %x[dst], %x[dst], #0x10\n"
209 "bgt 2b\n"
210 "subs x25, x25, #0x1\n"
211 "add %x[lhs_packed], %x[lhs_packed], x26\n"
212 "mov %x[dst], x22\n"
213 "bgt 1b\n"
214 : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed)
215 4 : [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), [num_blocks] "r"(num_blocks),
216 4 [rhs_packed] "r"(rhs_packed)
217 : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
218 "v29", "v30", "x20", "x21", "x22", "x23", "x24", "x25", "x26");
219 4 }
220
221 #endif // Architectural features check.
222