KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f16_qsi8d32p_qai4c32p/kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.2% 55 11 67
Functions: 100.0% 16 0 16
Branches: 50.0% 1 22 24

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \
7 !defined(_M_ARM64)
8 #error "Dotprod extension and fp16 vector arithmetic required to compile this micro-kernel"
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 typedef struct {
19 void* dst;
20 const void* lhs_packed;
21 const void* rhs_packed;
22 const float* clamp_vals;
23 size_t dst_stride_row;
24 size_t m;
25 size_t n;
26 size_t num_blocks;
27 size_t num_subblocks;
28 } KernelArgs;
29
30 void kai_kernel_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(KernelArgs* args_ptr);
31
32 // Compute args
33 static const size_t kai_m_step = 1;
34 static const size_t kai_n_step = 4;
35 // Packing args
36 static const size_t kai_mr = 1;
37 static const size_t kai_nr = 4;
38 static const size_t kai_kr = 16;
39 static const size_t kai_sr = 2;
40 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
41 static const size_t kai_num_bytes_qvalue_lhs = 1;
42 static const size_t kai_num_bytes_multiplier_lhs = 4;
43 static const size_t kai_num_bytes_sum_lhs = 4;
44 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
45 // asymmetric))
46 static const size_t kai_num_bytes_recip_qvalue_rhs = 2;
47 static const size_t kai_num_bytes_multiplier_rhs = 4;
48 static const size_t kai_num_bytes_offset_rhs = 4;
49
50 // DST format args
51 static const size_t kai_num_bytes_dst_value = 2;
52 // Extra args
53 static const size_t kai_num_bytes_bias = 4;
54 static const size_t kai_bl = 32;
55
56 138 inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) {
57 138 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs;
58 }
59
60 138 inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) {
61 KAI_ASSUME((bl % kai_bl) == 0);
62 276 size_t num_bytes_per_block_rhs =
63 138 (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs;
64 276 return num_bytes_per_block_rhs;
65 138 }
66
67 415 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
68 KAI_ASSUME((bl % kai_bl) == 0);
69
70 415 return kai_roundup(k, bl) / bl;
71 }
72
73 138 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) {
74 138 return kai_mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl);
75 }
76
77 138 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) {
78 KAI_ASSUME((bl % kai_bl) == 0);
79
80 138 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
81 138 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
82
83 138 size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row);
84 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
85 138 rhs_packed_stride += kai_nr * kai_num_bytes_bias;
86
87 276 return rhs_packed_stride;
88 138 }
89
90 714 size_t kai_get_m_step_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(void) {
91 714 return kai_m_step;
92 }
93
94 714 size_t kai_get_n_step_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(void) {
95 714 return kai_n_step;
96 }
97
98 1036 size_t kai_get_mr_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(void) {
99 1036 return kai_mr;
100 }
101
102 1036 size_t kai_get_nr_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(void) {
103 1036 return kai_nr;
104 }
105
106 1036 size_t kai_get_kr_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(void) {
107 1036 return kai_kr;
108 }
109
110 1036 size_t kai_get_sr_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(void) {
111 1036 return kai_sr;
112 }
113
114 138 size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(
115 size_t m_idx, size_t k, size_t bl) {
116 KAI_ASSUME((m_idx % kai_m_step) == 0);
117
118 138 return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k, bl);
119 }
120
121 138 size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(
122 size_t n_idx, size_t k, size_t bl) {
123 KAI_ASSUME((k % bl) == 0);
124 KAI_ASSUME((n_idx % kai_n_step) == 0);
125
126 138 return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl);
127 }
128
129 138 size_t kai_get_dst_offset_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(
130 size_t m_idx, size_t n_idx, size_t dst_stride) {
131 KAI_ASSUME((m_idx % kai_m_step) == 0);
132 KAI_ASSUME((n_idx % kai_n_step) == 0);
133
134 138 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
135 }
136
137 138 size_t kai_get_dst_size_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(size_t m, size_t n) {
138 138 return m * n * kai_num_bytes_dst_value;
139 }
140
141 139 void kai_run_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(
142 size_t m, //
143 size_t n, //
144 size_t k, //
145 size_t bl, //
146 const void* restrict lhs_packed, //
147 const void* restrict rhs_packed, //
148 void* restrict dst, // NOLINT(readability-non-const-parameter)
149 size_t dst_stride_row, //
150 size_t dst_stride_col, //
151 float scalar_min, //
152 float scalar_max) {
153 KAI_ASSUME(dst_stride_col == sizeof(uint16_t));
154 KAI_ASSUME((k % bl) == 0);
155 KAI_ASSUME((bl % kai_bl) == 0);
156
157
1/2
✓ Branch 0 taken 139 times.
✗ Branch 1 not taken.
139 if (m == 0) {
158 return;
159 }
160 139 const size_t num_subblocks = bl / kai_bl;
161 139 const size_t num_blocks = kai_get_num_blocks_per_row(k, bl);
162 139 const float clamp_vals[2] = {scalar_min, scalar_max};
163
164 139 KernelArgs args;
165
166 139 args.dst = dst;
167 139 args.lhs_packed = lhs_packed;
168 139 args.rhs_packed = rhs_packed;
169 139 args.clamp_vals = clamp_vals;
170 139 args.dst_stride_row = dst_stride_row;
171 139 args.m = m;
172 139 args.n = n;
173 139 args.num_blocks = num_blocks;
174 139 args.num_subblocks = num_subblocks;
175
176 139 kai_kernel_matmul_clamp_f16_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod(&args);
177 139 }
178
179 #endif // Architectural features check.
180