KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.8% 45 5 51
Functions: 100.0% 14 0 14
Branches: 50.0% 1 10 12

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD)
7 #error "Dotprod extension required to compile this micro-kernel"
8 #else // Architectural features check.
9
10 #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16
17 // Compute args
18 static const size_t kai_m_step = 1;
19 static const size_t kai_n_step = 4;
20 // Packing args
21 static const size_t kai_mr = 1;
22 static const size_t kai_nr = 4;
23 static const size_t kai_kr = 8;
24 static const size_t kai_sr = 2;
25 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
26 static const size_t kai_num_bytes_qvalue_lhs = 1;
27 static const size_t kai_num_bytes_multiplier_lhs = 4;
28 static const size_t kai_num_bytes_zp_lhs = 4;
29 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
30 // asymmetric))
31 static const size_t kai_num_bytes_recip_qvalue_rhs = 2;
32 static const size_t kai_num_bytes_multiplier_rhs = 4;
33 static const size_t kai_num_bytes_rsum_rhs = 4;
34 // DST format args
35 static const size_t kai_num_bytes_dst_value = 4;
36 // Extra args
37 static const size_t kai_num_bytes_bias = 4;
38 static const size_t kai_k_multiple_of = 32;
39 static const size_t kai_bl = 32;
40
41 641 inline static size_t kai_get_k_roundedup(size_t k) {
42 641 return kai_roundup(k, kai_k_multiple_of);
43 }
44
45 240 inline static size_t kai_get_lhs_packed_stride(size_t k) {
46 240 const size_t k_internal = kai_get_k_roundedup(k);
47 240 size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs);
48 // Since the LHS matrix is asymmetric with per-row quantization, we must include
49 // the number of bytes to hold the zero point value
50 240 lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs;
51
52 480 return lhs_packed_stride;
53 240 }
54
55 240 inline static size_t kai_get_rhs_packed_stride(size_t k) {
56 240 const size_t k_internal = kai_get_k_roundedup(k);
57 240 size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs);
58
59 240 rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs;
60 // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include
61 // the number of bytes for the reduction sum
62 240 rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs;
63 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
64 240 rhs_packed_stride += kai_nr * kai_num_bytes_bias;
65
66 480 return rhs_packed_stride;
67 240 }
68
69 320 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) {
70 320 return kai_m_step;
71 }
72
73 320 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) {
74 320 return kai_n_step;
75 }
76
77 240 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) {
78 240 return kai_mr;
79 }
80
81 240 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) {
82 240 return kai_nr;
83 }
84
85 320 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) {
86 320 return kai_kr;
87 }
88
89 320 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(void) {
90 320 return kai_sr;
91 }
92
93 240 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(size_t m_idx, size_t k) {
94 KAI_ASSUME((m_idx % kai_m_step) == 0);
95
96 240 return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k);
97 }
98
99 240 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(size_t n_idx, size_t k) {
100 KAI_ASSUME((n_idx % kai_n_step) == 0);
101
102 240 return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k);
103 }
104
105 160 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(
106 size_t m_idx, size_t n_idx, size_t dst_stride) {
107 KAI_ASSUME((m_idx % kai_m_step) == 0);
108 KAI_ASSUME((n_idx % kai_n_step) == 0);
109
110 160 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
111 }
112
113 160 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(size_t m, size_t n) {
114 160 return m * n * kai_num_bytes_dst_value;
115 }
116
117 161 void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod(
118 size_t m, //
119 size_t n, //
120 size_t k, //
121 const void* restrict lhs_packed, //
122 const void* restrict rhs_packed, //
123 float* restrict dst, // NOLINT(readability-non-const-parameter)
124 size_t dst_stride_row, //
125 size_t dst_stride_col, //
126 float scalar_min, //
127 float scalar_max) {
128 KAI_ASSUME(dst_stride_col == sizeof(float));
129
130
1/2
✓ Branch 0 taken 161 times.
✗ Branch 1 not taken.
161 if (m == 0) {
131 return;
132 }
133
134 161 const size_t k_internal = kai_get_k_roundedup(k);
135 161 size_t num_blocks = k_internal / kai_bl;
136 161 float clamp_vals[2] = {scalar_min, scalar_max};
137
138 322 __asm__ __volatile__(
139 "mov x26, #0x20\n"
140 "mov x20, #0x8\n"
141 "movi v27.16b, #0xf0\n"
142 "mov x25, %x[m]\n"
143 "madd x26, %x[num_blocks], x26, x20\n"
144 "1:" // Row loop
145 "mov x24, %x[rhs_packed]\n"
146 "mov x23, %x[n]\n"
147 "add x22, %x[dst], %x[dst_stride_row]\n"
148 "2:" // Column loop
149 "mov x21, %x[lhs_packed]\n"
150 "movi v26.4s, #0x0\n"
151 "mov x20, %x[num_blocks]\n"
152 "3:" // Sub block loop
153 "ldr q25, [x24, #0x0]\n"
154 "ldr q24, [x21, #0x0]\n"
155 "subs x20, x20, #0x1\n"
156 "ldr q23, [x24, #0x10]\n"
157 "ldr q22, [x24, #0x20]\n"
158 "ldr q21, [x24, #0x30]\n"
159 "ldr q20, [x21, #0x10]\n"
160 "add x24, x24, #0x40\n"
161 "add x21, x21, #0x20\n"
162 "shl v19.16b, v25.16b, #0x4\n"
163 "and v25.16b, v25.16b, v27.16b\n"
164 "shl v18.16b, v23.16b, #0x4\n"
165 "shl v17.16b, v22.16b, #0x4\n"
166 "shl v16.16b, v21.16b, #0x4\n"
167 "and v23.16b, v23.16b, v27.16b\n"
168 ".inst 0x4f98e27a // sdot v26.4s, v19.16b, v24.4b[0]\n"
169 "and v22.16b, v22.16b, v27.16b\n"
170 "and v21.16b, v21.16b, v27.16b\n"
171 ".inst 0x4fb8e25a // sdot v26.4s, v18.16b, v24.4b[1]\n"
172 ".inst 0x4f98ea3a // sdot v26.4s, v17.16b, v24.4b[2]\n"
173 ".inst 0x4fb8ea1a // sdot v26.4s, v16.16b, v24.4b[3]\n"
174 ".inst 0x4f94e33a // sdot v26.4s, v25.16b, v20.4b[0]\n"
175 ".inst 0x4fb4e2fa // sdot v26.4s, v23.16b, v20.4b[1]\n"
176 ".inst 0x4f94eada // sdot v26.4s, v22.16b, v20.4b[2]\n"
177 ".inst 0x4fb4eaba // sdot v26.4s, v21.16b, v20.4b[3]\n"
178 "bgt 3b\n"
179 "ldr q22, [x24, #0x0]\n"
180 "ld1r { v21.4s }, [x21]\n"
181 "add x21, x21, #0x4\n"
182 "add x20, %x[clamp_vals], #0x4\n"
183 "ld1r { v20.4s }, [x21]\n"
184 "ldr q16, [x24, #0x10]\n"
185 "cmp x23, #0x4\n"
186 "ldr q19, [x24, #0x20]\n"
187 "ld1r { v18.4s }, [%x[clamp_vals]]\n"
188 "add x24, x24, #0x30\n"
189 "ld1r { v17.4s }, [x20]\n"
190 "mla v26.4s, v22.4s, v21.s[0]\n"
191 "fmul v16.4s, v16.4s, v20.4s\n"
192 "scvtf v26.4s, v26.4s\n"
193 "fmul v16.4s, v26.4s, v16.4s\n"
194 "fadd v16.4s, v16.4s, v19.4s\n"
195 "fmax v16.4s, v16.4s, v18.4s\n"
196 "fmin v16.4s, v16.4s, v17.4s\n"
197 "blt 4f\n"
198 "str q16, [%x[dst], #0x0]\n"
199 "b 7f\n"
200 "4:" // Partial output
201 "mov x20, %x[dst]\n"
202 "tbz x23, #1, 5f\n"
203 "st1 { v16.d }[0], [x20], #0x8\n"
204 "tbz x23, #0, 6f\n"
205 "st1 { v16.s }[2], [x20]\n"
206 "b 6f\n"
207 "5:" // Output block 0: partial_1_0
208 "st1 { v16.s }[0], [x20]\n"
209 "6:" // Output block 0: Done
210 "7:" // Stores done
211 "subs x23, x23, #0x4\n"
212 "add %x[dst], %x[dst], #0x10\n"
213 "bgt 2b\n"
214 "subs x25, x25, #0x1\n"
215 "add %x[lhs_packed], %x[lhs_packed], x26\n"
216 "mov %x[dst], x22\n"
217 "bgt 1b\n"
218 : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed)
219 161 : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n),
220 161 [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed)
221 : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "x20",
222 "x21", "x22", "x23", "x24", "x25", "x26");
223 161 }
224
225 #endif // Architectural features check.
226