KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi4cxp/kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.1% 52 5 58
Functions: 100.0% 14 0 14
Branches: 50.0% 1 10 12

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \
7 !defined(_M_ARM64)
8 #error "Dotprod extension and fp16 vector arithmetic required to compile this micro-kernel"
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 typedef struct {
19 uint16_t* dst;
20 const void* lhs_packed;
21 const void* rhs_packed;
22 const float* clamp_vals;
23 size_t dst_stride_row;
24 size_t m;
25 size_t n;
26 size_t num_blocks;
27 } KernelArgs;
28
29 void kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(KernelArgs* args_ptr);
30
31 // Compute args
32 static const size_t kai_m_step = 1;
33 static const size_t kai_n_step = 4;
34 // Packing args
35 static const size_t kai_mr = 1;
36 static const size_t kai_nr = 4;
37 static const size_t kai_kr = 16;
38 static const size_t kai_sr = 2;
39 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
40 static const size_t kai_num_bytes_qvalue_lhs = 1;
41 static const size_t kai_num_bytes_multiplier_lhs = 4;
42 static const size_t kai_num_bytes_zp_lhs = 4;
43 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
44 // asymmetric))
45 static const size_t kai_num_bytes_recip_qvalue_rhs = 2;
46 static const size_t kai_num_bytes_multiplier_rhs = 4;
47 static const size_t kai_num_bytes_rsum_rhs = 4;
48 // DST format args
49 static const size_t kai_num_bytes_dst_value = 2;
50 // Extra args
51 static const size_t kai_num_bytes_bias = 4;
52 static const size_t kai_k_multiple_of = 32;
53
54 static const size_t kai_bl = 32;
55
56 175 inline static size_t kai_get_k_roundedup(size_t k) {
57 175 return kai_roundup(k, kai_k_multiple_of);
58 }
59
60 58 inline static size_t kai_get_lhs_packed_stride(size_t k) {
61 58 const size_t k_internal = kai_get_k_roundedup(k);
62 58 size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs);
63 // Since the LHS matrix is asymmetric with per-row quantization, we must include the
64 // the number of bytes to hold the zero point value
65 58 lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs;
66
67 116 return lhs_packed_stride;
68 58 }
69
70 58 inline static size_t kai_get_rhs_packed_stride(size_t k) {
71 58 const size_t k_internal = kai_get_k_roundedup(k);
72 58 size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs);
73
74 58 rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs;
75 // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include
76 // the number of bytes for the reduction sum
77 58 rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs;
78 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
79 58 rhs_packed_stride += kai_nr * kai_num_bytes_bias;
80
81 116 return rhs_packed_stride;
82 58 }
83
84 84 size_t kai_get_m_step_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) {
85 84 return kai_m_step;
86 }
87
88 84 size_t kai_get_n_step_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) {
89 84 return kai_n_step;
90 }
91
92 224 size_t kai_get_mr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) {
93 224 return kai_mr;
94 }
95
96 224 size_t kai_get_nr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) {
97 224 return kai_nr;
98 }
99
100 224 size_t kai_get_kr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) {
101 224 return kai_kr;
102 }
103
104 224 size_t kai_get_sr_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(void) {
105 224 return kai_sr;
106 }
107
108 58 size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(size_t m_idx, size_t k) {
109 KAI_ASSUME((m_idx % kai_m_step) == 0);
110
111 58 return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k);
112 }
113
114 58 size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(size_t n_idx, size_t k) {
115 KAI_ASSUME((n_idx % kai_n_step) == 0);
116
117 58 return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k);
118 }
119
120 58 size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(
121 size_t m_idx, size_t n_idx, size_t dst_stride) {
122 KAI_ASSUME((m_idx % kai_m_step) == 0);
123 KAI_ASSUME((n_idx % kai_n_step) == 0);
124
125 58 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
126 }
127
128 58 size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(size_t m, size_t n) {
129 58 return m * n * kai_num_bytes_dst_value;
130 }
131
132 59 void kai_run_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(
133 size_t m, //
134 size_t n, //
135 size_t k, //
136 const void* restrict lhs_packed, //
137 const void* restrict rhs_packed, //
138 void* restrict dst, // NOLINT(readability-non-const-parameter)
139 size_t dst_stride_row, //
140 size_t dst_stride_col, //
141 float scalar_min, //
142 float scalar_max) {
143 KAI_ASSUME(dst_stride_col == sizeof(uint16_t));
144
1/2
✓ Branch 0 taken 59 times.
✗ Branch 1 not taken.
59 if (m == 0) {
145 return;
146 }
147 59 const size_t k_internal = kai_get_k_roundedup(k);
148 59 size_t num_blocks = k_internal / kai_bl;
149 59 const float clamp_vals[2] = {scalar_min, scalar_max};
150
151 59 KernelArgs args;
152
153 59 args.dst = dst;
154 59 args.lhs_packed = lhs_packed;
155 59 args.rhs_packed = rhs_packed;
156 59 args.clamp_vals = clamp_vals;
157 59 args.dst_stride_row = dst_stride_row;
158 59 args.m = m;
159 59 args.n = n;
160 59 args.num_blocks = num_blocks;
161
162 59 kai_kernel_matmul_clamp_f16_qai8dxp1x8_qsi4cxp4x8_1x4_neon_dotprod(&args);
163 59 }
164
165 #endif // Architectural features check.
166