KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.8% 45 5 51
Functions: 100.0% 14 0 14
Branches: 50.0% 1 10 12

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD)
7 #error "Dotprod extension required to compile this micro-kernel"
8 #else // Architectural features check.
9
10 #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16
17 // Compute args
18 static const size_t kai_m_step = 1;
19 static const size_t kai_n_step = 4;
20 // Packing args
21 static const size_t kai_mr = 1;
22 static const size_t kai_nr = 4;
23 static const size_t kai_kr = 4;
24 static const size_t kai_sr = 1;
25 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
26 static const size_t kai_num_bytes_qvalue_lhs = 1;
27 static const size_t kai_num_bytes_multiplier_lhs = 4;
28 static const size_t kai_num_bytes_zp_lhs = 4;
29 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
30 // asymmetric))
31 static const size_t kai_num_bytes_qvalue_rhs = 1;
32 static const size_t kai_num_bytes_multiplier_rhs = 4;
33 static const size_t kai_num_bytes_rsum_rhs = 4;
34 // DST format args
35 static const size_t kai_num_bytes_dst_value = 4;
36 // Extra args
37 static const size_t kai_num_bytes_bias = 4;
38 static const size_t kai_k_multiple_of = 32;
39
40 557 inline static size_t kai_k_roundedup(size_t k) {
41 557 return kai_roundup(k, kai_k_multiple_of);
42 }
43
44 211 inline static size_t kai_lhs_packed_stride(size_t k) {
45 211 const size_t k_internal = kai_k_roundedup(k);
46 211 size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs);
47 // Since the LHS matrix is asymmetric with per-row quantization, we must include the
48 // the number of bytes to hold the zero point value
49 211 lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs;
50
51 422 return lhs_packed_stride;
52 211 }
53
54 211 inline static size_t kai_rhs_packed_stride(size_t k) {
55 211 const size_t k_internal = kai_k_roundedup(k);
56 211 size_t rhs_packed_stride = kai_nr * (k_internal * kai_num_bytes_qvalue_rhs);
57 211 rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs;
58 // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include
59 // the number of bytes for the reduction sum
60 211 rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs;
61 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
62 211 rhs_packed_stride += kai_nr * kai_num_bytes_bias;
63
64 422 return rhs_packed_stride;
65 211 }
66
67 231 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) {
68 231 return kai_m_step;
69 }
70
71 231 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) {
72 231 return kai_n_step;
73 }
74
75 231 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) {
76 231 return kai_mr;
77 }
78
79 231 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) {
80 231 return kai_nr;
81 }
82
83 308 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) {
84 308 return kai_kr;
85 }
86
87 308 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(void) {
88 308 return kai_sr;
89 }
90
91 211 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(size_t m_idx, size_t k) {
92 KAI_ASSUME((m_idx % kai_m_step) == 0);
93
94 211 return (m_idx / kai_mr) * kai_lhs_packed_stride(k);
95 }
96
97 211 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(size_t n_idx, size_t k) {
98 KAI_ASSUME((n_idx % kai_n_step) == 0);
99
100 211 return (n_idx / kai_nr) * kai_rhs_packed_stride(k);
101 }
102
103 134 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(
104 size_t m_idx, size_t n_idx, size_t dst_stride) {
105 KAI_ASSUME((m_idx % kai_m_step) == 0);
106 KAI_ASSUME((n_idx % kai_n_step) == 0);
107
108 134 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
109 }
110
111 134 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(size_t m, size_t n) {
112 134 return m * n * kai_num_bytes_dst_value;
113 }
114
115 135 void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod(
116 size_t m, //
117 size_t n, //
118 size_t k, //
119 const void* restrict lhs_packed, //
120 const void* restrict rhs_packed, //
121 float* restrict dst, // NOLINT(readability-non-const-parameter)
122 size_t dst_stride_row, //
123 size_t dst_stride_col, //
124 float scalar_min, //
125 float scalar_max) {
126 KAI_ASSUME(dst_stride_col == sizeof(float));
127
128
1/2
✓ Branch 0 taken 135 times.
✗ Branch 1 not taken.
135 if (m == 0) {
129 return;
130 }
131
132 135 const size_t k_internal = kai_k_roundedup(k);
133 135 const size_t num_blocks = k_internal / kai_k_multiple_of;
134 135 const float clamp_vals[2] = {scalar_min, scalar_max};
135
136 270 __asm__ __volatile__(
137 "mov x26, #0x20\n"
138 "mov x20, #0x8\n"
139 "mov x25, %x[m]\n"
140 "madd x26, %x[num_blocks], x26, x20\n"
141 "1:" // Row loop
142 "mov x24, %x[rhs_packed]\n"
143 "mov x23, %x[n]\n"
144 "add x22, %x[dst], %x[dst_stride_row]\n"
145 "2:" // Column loop
146 "mov x21, %x[lhs_packed]\n"
147 "movi v25.4s, #0x0\n"
148 "mov x20, %x[num_blocks]\n"
149 "3:" // Sub block loop
150 "ldr q16, [x24, #0x0]\n"
151 "ldr q24, [x21, #0x0]\n"
152 "subs x20, x20, #0x1\n"
153 "ldr q23, [x24, #0x10]\n"
154 "ldr q22, [x24, #0x20]\n"
155 "ldr q21, [x24, #0x30]\n"
156 "ldr q20, [x24, #0x40]\n"
157 "ldr q19, [x21, #0x10]\n"
158 "ldr q18, [x24, #0x50]\n"
159 ".inst 0x4f98e219 // sdot v25.4s, v16.16b, v24.4b[0]\n"
160 "add x21, x21, #0x20\n"
161 "ldr q17, [x24, #0x60]\n"
162 "ldr q16, [x24, #0x70]\n"
163 "add x24, x24, #0x80\n"
164 ".inst 0x4fb8e2f9 // sdot v25.4s, v23.16b, v24.4b[1]\n"
165 ".inst 0x4f98ead9 // sdot v25.4s, v22.16b, v24.4b[2]\n"
166 ".inst 0x4fb8eab9 // sdot v25.4s, v21.16b, v24.4b[3]\n"
167 ".inst 0x4f93e299 // sdot v25.4s, v20.16b, v19.4b[0]\n"
168 ".inst 0x4fb3e259 // sdot v25.4s, v18.16b, v19.4b[1]\n"
169 ".inst 0x4f93ea39 // sdot v25.4s, v17.16b, v19.4b[2]\n"
170 ".inst 0x4fb3ea19 // sdot v25.4s, v16.16b, v19.4b[3]\n"
171 "bgt 3b\n"
172 "ldr q22, [x24, #0x0]\n"
173 "ld1r { v21.4s }, [x21]\n"
174 "add x21, x21, #0x4\n"
175 "add x20, %x[clamp_vals], #0x4\n"
176 "ld1r { v20.4s }, [x21]\n"
177 "ldr q16, [x24, #0x10]\n"
178 "cmp x23, #0x4\n"
179 "ldr q19, [x24, #0x20]\n"
180 "ld1r { v18.4s }, [%x[clamp_vals]]\n"
181 "add x24, x24, #0x30\n"
182 "ld1r { v17.4s }, [x20]\n"
183 "mla v25.4s, v22.4s, v21.s[0]\n"
184 "fmul v16.4s, v16.4s, v20.4s\n"
185 "scvtf v25.4s, v25.4s\n"
186 "fmul v16.4s, v25.4s, v16.4s\n"
187 "fadd v16.4s, v16.4s, v19.4s\n"
188 "fmax v16.4s, v16.4s, v18.4s\n"
189 "fmin v16.4s, v16.4s, v17.4s\n"
190 "blt 4f\n"
191 "str q16, [%x[dst], #0x0]\n"
192 "b 7f\n"
193 "4:" // Partial output
194 "mov x20, %x[dst]\n"
195 "tbz x23, #1, 5f\n"
196 "st1 { v16.d }[0], [x20], #0x8\n"
197 "tbz x23, #0, 6f\n"
198 "st1 { v16.s }[2], [x20]\n"
199 "b 6f\n"
200 "5:" // Output block 0: partial_1_0
201 "st1 { v16.s }[0], [x20]\n"
202 "6:" // Output block 0: Done
203 "7:" // Stores done
204 "subs x23, x23, #0x4\n"
205 "add %x[dst], %x[dst], #0x10\n"
206 "bgt 2b\n"
207 "subs x25, x25, #0x1\n"
208 "add %x[lhs_packed], %x[lhs_packed], x26\n"
209 "mov %x[dst], x22\n"
210 "bgt 1b\n"
211 : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed)
212 135 : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n),
213 135 [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed)
214 : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "x20", "x21", "x22",
215 "x23", "x24", "x25", "x26");
216 135 }
217
218 #endif // Architectural features check.
219