KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 97.6% 41 / 7 / 49
Functions: 100.0% 14 / 0 / 14
Branches: 50.0% 1 / 14 / 16

kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if !defined(__ARM_FEATURE_DOTPROD)
8 #error "Dotprod extension required to compile this micro-kernel"
9 #else // Architectural features check.
10 #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h"
11
12 #include <arm_neon.h>
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 static const size_t kai_m_step = 1;
19 static const size_t kai_n_step = 4;
20 static const size_t kai_mr = 1;
21 static const size_t kai_nr = 4;
22 static const size_t kai_kr = 16;
23 static const size_t kai_sr = 2;
24 static const size_t kai_num_bytes_multiplier_lhs = sizeof(float);
25 static const size_t kai_num_bytes_multiplier_rhs = sizeof(float);
26 static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t);
27 static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t);
28 static const size_t kai_num_bytes_bias = sizeof(float);
29
30 3846 inline static size_t kai_k_roundedup(size_t k) {
31 // Since we pack a float and int32 value at the end of the row,
32 // we must make sure that k is a multiple of 4 for alignment
33 3846 size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4);
34 7692 return kai_roundup(k, kr_sr_roundedup4);
35 3846 }
36
37 1440 inline static size_t kai_lhs_packed_stride(size_t k) {
38 1440 const size_t k_internal = kai_k_roundedup(k);
39
40 KAI_ASSERT((k_internal % 2) == 0);
41
42 2880 return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs);
43 1440 }
44
45 1440 inline static size_t kai_rhs_packed_stride(size_t k) {
46 1440 const size_t k_internal = kai_k_roundedup(k);
47
48 KAI_ASSERT((k_internal % 2) == 0);
49
50 2880 return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias);
51 1440 }
52
53 1920 size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
54 1920 return kai_m_step;
55 }
56
57 1920 size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
58 1920 return kai_n_step;
59 }
60
61 1440 size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
62 1440 return kai_mr;
63 }
64
65 1440 size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
66 1440 return kai_nr;
67 }
68
69 1920 size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
70 1920 return kai_kr;
71 }
72
73 1920 size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) {
74 1920 return kai_sr;
75 }
76
77 1440 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) {
78 KAI_ASSERT((m_idx % kai_m_step) == 0);
79
80 1440 return (m_idx / kai_mr) * kai_lhs_packed_stride(k);
81 }
82
83 1440 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k) {
84 KAI_ASSERT((n_idx % kai_n_step) == 0);
85
86 1440 return (n_idx / kai_nr) * kai_rhs_packed_stride(k);
87 }
88
89 960 size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(
90 size_t m_idx, size_t n_idx, size_t dst_stride) {
91 KAI_ASSERT((m_idx % kai_m_step) == 0);
92 KAI_ASSERT((n_idx % kai_n_step) == 0);
93
94 960 return (n_idx * sizeof(float)) + m_idx * dst_stride;
95 }
96
97 960 size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n) {
98 960 return m * n * sizeof(float);
99 }
100
101 966 void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(
102 size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed,
103 float* restrict dst, // NOLINT(readability-non-const-parameter)
104 size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) {
105 KAI_ASSERT(dst_stride_col == sizeof(float));
106
107
1/2
✓ Branch 0 taken 966 times.
✗ Branch 1 not taken.
966 if (m == 0) {
108 return;
109 }
110
111 966 const size_t k_internal = kai_k_roundedup(k);
112
113 966 size_t num_blocks = k_internal / 32;
114
115 966 float clamp_vals[2] = {scalar_min, scalar_max};
116 1932 __asm__ __volatile__(
117 "mov x26, #0x20\n"
118 "mov x20, #0x8\n"
119 "movi v30.16b, #0xf0\n"
120 "mov x25, %x[m]\n"
121 "madd x26, %x[num_blocks], x26, x20\n"
122 "1:" // Row loop
123 "mov x24, %x[rhs_packed]\n"
124 "mov x23, %x[n]\n"
125 "add x22, %x[dst], %x[dst_stride_row]\n"
126 "2:" // Column loop
127 "mov x21, %x[lhs_packed]\n"
128 "movi v29.4s, #0x0\n"
129 "movi v28.4s, #0x0\n"
130 "mov x20, %x[num_blocks]\n"
131 "3:" // Sub block loop
132 "ldr q27, [x24, #0x0]\n"
133 "ldr q26, [x24, #0x10]\n"
134 "subs x20, x20, #0x1\n"
135 "ld1r { v25.2d }, [x21], #0x8\n"
136 "ldr q24, [x24, #0x20]\n"
137 "ldr q23, [x24, #0x30]\n"
138 "add x24, x24, #0x40\n"
139 "ld1r { v22.2d }, [x21], #0x8\n"
140 "ld1r { v21.2d }, [x21], #0x8\n"
141 "shl v20.16b, v27.16b, #0x4\n"
142 "shl v19.16b, v26.16b, #0x4\n"
143 "ld1r { v18.2d }, [x21], #0x8\n"
144 "shl v17.16b, v24.16b, #0x4\n"
145 "and v27.16b, v27.16b, v30.16b\n"
146 "shl v16.16b, v23.16b, #0x4\n"
147 "and v26.16b, v26.16b, v30.16b\n"
148 ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n"
149 ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n"
150 "and v24.16b, v24.16b, v30.16b\n"
151 "and v23.16b, v23.16b, v30.16b\n"
152 ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n"
153 ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n"
154 ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n"
155 ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n"
156 ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n"
157 ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n"
158 "bgt 3b\n"
159 "ldr q22, [x24, #0x0]\n"
160 "ld1r { v21.4s }, [x21]\n"
161 "addp v29.4s, v29.4s, v28.4s\n"
162 "add x21, x21, #0x4\n"
163 "ld1r { v20.4s }, [x21]\n"
164 "ldr q16, [x24, #0x10]\n"
165 "add x20, %x[clamp_vals], #0x4\n"
166 "cmp x23, #0x4\n"
167 "ldr q19, [x24, #0x20]\n"
168 "ld1r { v18.4s }, [%x[clamp_vals]]\n"
169 "add x24, x24, #0x30\n"
170 "ld1r { v17.4s }, [x20]\n"
171 "mla v29.4s, v22.4s, v21.s[0]\n"
172 "fmul v16.4s, v16.4s, v20.4s\n"
173 "scvtf v29.4s, v29.4s\n"
174 "fmul v16.4s, v29.4s, v16.4s\n"
175 "fadd v16.4s, v16.4s, v19.4s\n"
176 "fmax v16.4s, v16.4s, v18.4s\n"
177 "fmin v16.4s, v16.4s, v17.4s\n"
178 "blt 4f\n"
179 "str q16, [%x[dst], #0x0]\n"
180 "b 7f\n"
181 "4:" // Partial output
182 "mov x20, %x[dst]\n"
183 "tbz x23, #1, 5f\n"
184 "st1 { v16.d }[0], [x20], #0x8\n"
185 "tbz x23, #0, 6f\n"
186 "st1 { v16.s }[2], [x20]\n"
187 "b 6f\n"
188 "5:" // Output block 0: partial_1_0
189 "st1 { v16.s }[0], [x20]\n"
190 "6:" // Output block 0: Done
191 "7:" // Stores done
192 "subs x23, x23, #0x4\n"
193 "add %x[dst], %x[dst], #0x10\n"
194 "bgt 2b\n"
195 "subs x25, x25, #0x1\n"
196 "add %x[lhs_packed], %x[lhs_packed], x26\n"
197 "mov %x[dst], x22\n"
198 "bgt 1b\n"
199 : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed)
200 966 : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n),
201 966 [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed)
202 : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
203 "v29", "v30", "x20", "x21", "x22", "x23", "x24", "x25", "x26");
204 966 }
205
206 #endif // Architectural features check.
207