KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f16_qai8dxp_qsi8cxp/kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.1% 51 5 57
Functions: 100.0% 14 0 14
Branches: 50.0% 1 10 12

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__ARM_FEATURE_MATMUL_INT8) || !defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
7 #error "I8mm extension and fp16 vector arithmetic required to compile this micro-kernel"
8 #else // Architectural features check.
9
10 #include "kai_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16
17 typedef struct {
18 uint16_t* dst;
19 const void* lhs_packed;
20 const void* rhs_packed;
21 const float* clamp_vals;
22 size_t dst_stride_row;
23 size_t m;
24 size_t n;
25 size_t num_blocks;
26 } KernelArgs;
27
28 void kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(KernelArgs* args_ptr);
29
30 // Compute args
31 static const size_t kai_m_step = 16;
32 static const size_t kai_n_step = 4;
33 // Packing args
34 static const size_t kai_mr = 4;
35 static const size_t kai_nr = 4;
36 static const size_t kai_kr = 8;
37 static const size_t kai_sr = 1;
38 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
39 static const size_t kai_num_bytes_qvalue_lhs = 1;
40 static const size_t kai_num_bytes_multiplier_lhs = 4;
41 static const size_t kai_num_bytes_zp_lhs = 4;
42 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
43 // asymmetric))
44 static const size_t kai_num_bytes_qvalue_rhs = 1;
45 static const size_t kai_num_bytes_multiplier_rhs = 4;
46 static const size_t kai_num_bytes_rsum_rhs = 4;
47 // DST format args
48 static const size_t kai_num_bytes_dst_value = 2;
49 // Extra args
50 static const size_t kai_num_bytes_bias = 4;
51 static const size_t kai_k_multiple_of = 32;
52
53 667 inline static size_t kai_get_k_roundedup(size_t k) {
54 667 return kai_roundup(k, kai_k_multiple_of);
55 }
56
57 222 inline static size_t kai_get_lhs_packed_stride(size_t k) {
58 222 const size_t k_internal = kai_get_k_roundedup(k);
59 222 size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs);
60 // Since the LHS matrix is asymmetric with per-row quantization, we must include the
61 // the number of bytes to hold the zero point value
62 222 lhs_packed_stride += kai_mr * kai_num_bytes_zp_lhs;
63
64 444 return lhs_packed_stride;
65 222 }
66
67 222 inline static size_t kai_get_rhs_packed_stride(size_t k) {
68 222 const size_t k_internal = kai_get_k_roundedup(k);
69 222 size_t rhs_packed_stride = kai_nr * (k_internal * kai_num_bytes_qvalue_rhs);
70
71 222 rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs;
72 // Since the LHS matrix is quantized asymmetric with per-row quantization, we also include
73 // the number of bytes for the reduction sum
74 222 rhs_packed_stride += kai_nr * kai_num_bytes_rsum_rhs;
75 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
76 222 rhs_packed_stride += kai_nr * kai_num_bytes_bias;
77
78 444 return rhs_packed_stride;
79 222 }
80
81 224 size_t kai_get_m_step_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) {
82 224 return kai_m_step;
83 }
84
85 224 size_t kai_get_n_step_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) {
86 224 return kai_n_step;
87 }
88
89 224 size_t kai_get_mr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) {
90 224 return kai_mr;
91 }
92
93 224 size_t kai_get_nr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) {
94 224 return kai_nr;
95 }
96
97 224 size_t kai_get_kr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) {
98 224 return kai_kr;
99 }
100
101 224 size_t kai_get_sr_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(void) {
102 224 return kai_sr;
103 }
104
105 222 size_t kai_get_lhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(size_t m_idx, size_t k) {
106 KAI_ASSUME((m_idx % kai_m_step) == 0);
107
108 222 return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k);
109 }
110
111 222 size_t kai_get_rhs_packed_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(size_t n_idx, size_t k) {
112 KAI_ASSUME((n_idx % kai_n_step) == 0);
113
114 222 return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k);
115 }
116
117 222 size_t kai_get_dst_offset_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(
118 size_t m_idx, size_t n_idx, size_t dst_stride) {
119 KAI_ASSUME((m_idx % kai_m_step) == 0);
120 KAI_ASSUME((n_idx % kai_n_step) == 0);
121
122 222 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
123 }
124
125 222 size_t kai_get_dst_size_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(size_t m, size_t n) {
126 222 return m * n * kai_num_bytes_dst_value;
127 }
128
129 223 void kai_run_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(
130 size_t m, //
131 size_t n, //
132 size_t k, //
133 const void* restrict lhs_packed, //
134 const void* restrict rhs_packed, //
135 void* restrict dst, // NOLINT(readability-non-const-parameter)
136 size_t dst_stride_row, //
137 size_t dst_stride_col, //
138 float scalar_min, //
139 float scalar_max) {
140 KAI_ASSUME(dst_stride_col == sizeof(uint16_t));
141
142
1/2
✓ Branch 0 taken 223 times.
✗ Branch 1 not taken.
223 if (m == 0) {
143 return;
144 }
145 223 const size_t num_blocks = kai_get_k_roundedup(k) / kai_k_multiple_of;
146 223 const float clamp_vals[2] = {scalar_min, scalar_max};
147
148 223 KernelArgs args;
149
150 223 args.dst = dst;
151 223 args.lhs_packed = lhs_packed;
152 223 args.rhs_packed = rhs_packed;
153 223 args.clamp_vals = clamp_vals;
154 223 args.dst_stride_row = dst_stride_row;
155 223 args.m = m;
156 223 args.n = n;
157 223 args.num_blocks = num_blocks;
158
159 223 kai_kernel_matmul_clamp_f16_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm(&args);
160 223 }
161
162 #endif // Architectural features check.
163