KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 97.6% 41 7 49
Functions: 100.0% 14 0 14
Branches: 50.0% 1 14 16

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__ARM_FEATURE_DOTPROD)
8 #error "Dotprod extension required to compile this micro-kernel"
9 #else // Architectural features check.
10 #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h"
11
12 #include <arm_neon.h>
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_m_step = 1;
19 static const size_t kai_n_step = 4;
20 static const size_t kai_mr = 1;
21 static const size_t kai_nr = 4;
22 static const size_t kai_kr = 16;
23 static const size_t kai_sr = 2;
24 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
25 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
26 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
27 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
28 static const size_t kai_num_bytes_bias = sizeof(float);
29
30 641 inline static size_t kai_k_roundedup(size_t k) {
31 // Since we pack a float and int32 value at the end of the row,
32 // we must make sure that k is a multiple of 4 for alignment
33 641 size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4);
34 1282 return kai_roundup(k, kr_sr_roundedup4);
35 641 }
36
37 240 inline static size_t kai_lhs_packed_stride(size_t k) {
38 240 const size_t k_internal = kai_k_roundedup(k);
39
40 KAI_ASSERT((k_internal % 2) == 0);
41
42 480 return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
43 240 }
44
45 240 inline static size_t kai_rhs_packed_stride(size_t k) {
46 240 const size_t k_internal = kai_k_roundedup(k);
47
48 KAI_ASSERT((k_internal % 2) == 0);
49
50 480 return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
51 240 }
52
53 320 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
54 320 return kai_m_step;
55 }
56
57 320 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
58 320 return kai_n_step;
59 }
60
61 240 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
62 240 return kai_mr;
63 }
64
65 240 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
66 240 return kai_nr;
67 }
68
69 320 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
70 320 return kai_kr;
71 }
72
73 320 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
74 320 return kai_sr;
75 }
76
77 240 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) {
78 KAI_ASSERT((m_idx % kai_m_step) == 0);
79
80 240 return (m_idx / kai_mr) * kai_lhs_packed_stride(k);
81 }
82
83 240 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k) {
84 KAI_ASSERT((n_idx % kai_n_step) == 0);
85
86 240 return (n_idx / kai_nr) * kai_rhs_packed_stride(k);
87 }
88
89 160 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(
90 size_t m_idx, size_t n_idx, size_t dst_stride) {
91 KAI_ASSERT((m_idx % kai_m_step) == 0);
92 KAI_ASSERT((n_idx % kai_n_step) == 0);
93
94 160 return (n_idx * sizeof(float)) + m_idx * dst_stride;
95 }
96
97 160 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n) {
98 160 return m * n * sizeof(float);
99 }
100
101 161 void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(
102 size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed,
103 float* restrict dst, // NOLINT(readability-non-const-parameter)
104 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
105 KAI_ASSERT(dst_stride_col == sizeof(float));
106
107
1/2
✓ Branch 0 taken 161 times.
✗ Branch 1 not taken.
161 if (m == 0) {
108 return;
109 }
110
111 161 const size_t k_internal = kai_k_roundedup(k);
112
113 161 size_t num_blocks = k_internal / 32;
114
115 161 float clamp_vals[2] = {scalar_min, scalar_max};
116 322 __asm__ __volatile__(
117 "mov x26, #0x20\n"
118 "mov x20, #0x8\n"
119 "movi v30.16b, #0xf0\n"
120 "mov x25, %x[m]\n"
121 "madd x26, %x[num_blocks], x26, x20\n"
122 "1:" // Row loop
123 "mov x24, %x[rhs_packed]\n"
124 "mov x23, %x[n]\n"
125 "add x22, %x[dst], %x[dst_stride_row]\n"
126 "2:" // Column loop
127 "mov x21, %x[lhs_packed]\n"
128 "movi v29.4s, #0x0\n"
129 "movi v28.4s, #0x0\n"
130 "mov x20, %x[num_blocks]\n"
131 "3:" // Sub block loop
132 "ldr q27, [x24, #0x0]\n"
133 "ldr q26, [x24, #0x10]\n"
134 "subs x20, x20, #0x1\n"
135 "ld1r { v25.2d }, [x21], #0x8\n"
136 "ldr q24, [x24, #0x20]\n"
137 "ldr q23, [x24, #0x30]\n"
138 "add x24, x24, #0x40\n"
139 "ld1r { v22.2d }, [x21], #0x8\n"
140 "ld1r { v21.2d }, [x21], #0x8\n"
141 "shl v20.16b, v27.16b, #0x4\n"
142 "shl v19.16b, v26.16b, #0x4\n"
143 "ld1r { v18.2d }, [x21], #0x8\n"
144 "shl v17.16b, v24.16b, #0x4\n"
145 "and v27.16b, v27.16b, v30.16b\n"
146 "shl v16.16b, v23.16b, #0x4\n"
147 "and v26.16b, v26.16b, v30.16b\n"
148 ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n"
149 ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n"
150 "and v24.16b, v24.16b, v30.16b\n"
151 "and v23.16b, v23.16b, v30.16b\n"
152 ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n"
153 ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n"
154 ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n"
155 ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n"
156 ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n"
157 ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n"
158 "bgt 3b\n"
159 "ldr q22, [x24, #0x0]\n"
160 "ld1r { v21.4s }, [x21]\n"
161 "addp v29.4s, v29.4s, v28.4s\n"
162 "add x21, x21, #0x4\n"
163 "ld1r { v20.4s }, [x21]\n"
164 "ldr q16, [x24, #0x10]\n"
165 "add x20, %x[clamp_vals], #0x4\n"
166 "cmp x23, #0x4\n"
167 "ldr q19, [x24, #0x20]\n"
168 "ld1r { v18.4s }, [%x[clamp_vals]]\n"
169 "add x24, x24, #0x30\n"
170 "ld1r { v17.4s }, [x20]\n"
171 "mla v29.4s, v22.4s, v21.s[0]\n"
172 "fmul v16.4s, v16.4s, v20.4s\n"
173 "scvtf v29.4s, v29.4s\n"
174 "fmul v16.4s, v29.4s, v16.4s\n"
175 "fadd v16.4s, v16.4s, v19.4s\n"
176 "fmax v16.4s, v16.4s, v18.4s\n"
177 "fmin v16.4s, v16.4s, v17.4s\n"
178 "blt 4f\n"
179 "str q16, [%x[dst], #0x0]\n"
180 "b 7f\n"
181 "4:" // Partial output
182 "mov x20, %x[dst]\n"
183 "tbz x23, #1, 5f\n"
184 "st1 { v16.d }[0], [x20], #0x8\n"
185 "tbz x23, #0, 6f\n"
186 "st1 { v16.s }[2], [x20]\n"
187 "b 6f\n"
188 "5:" // Output block 0: partial_1_0
189 "st1 { v16.s }[0], [x20]\n"
190 "6:" // Output block 0: Done
191 "7:" // Stores done
192 "subs x23, x23, #0x4\n"
193 "add %x[dst], %x[dst], #0x10\n"
194 "bgt 2b\n"
195 "subs x25, x25, #0x1\n"
196 "add %x[lhs_packed], %x[lhs_packed], x26\n"
197 "mov %x[dst], x22\n"
198 "bgt 1b\n"
199 : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed)
200 161 : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n),
201 161 [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed)
202 : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
203 "v29", "v30", "x20", "x21", "x22", "x23", "x24", "x25", "x26");
204 161 }
205
206 #endif // Architectural features check.
207