KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 98.2% 55 11 67
Functions: 100.0% 16 0 16
Branches: 50.0% 1 22 24

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(_M_ARM64)
7 #error "I8mm extension required to compile this micro-kernel"
8 #else // Architectural features check.
9
10 #include "kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.h"
11
12 #include <stddef.h>
13
14 #include "kai/kai_common.h"
15
16 typedef struct {
17 float* dst;
18 const void* lhs_packed;
19 const void* rhs_packed;
20 const float* clamp_vals;
21 size_t dst_stride_row;
22 size_t m;
23 size_t n;
24 size_t num_blocks;
25 size_t num_subblocks;
26 } KernelArgs;
27
28 void kai_kernel_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(KernelArgs* args_ptr);
29
30 // Compute args
31 static const size_t kai_m_step = 8;
32 static const size_t kai_n_step = 4;
33 // Packing args
34 static const size_t kai_mr = 4;
35 static const size_t kai_nr = 4;
36 static const size_t kai_kr = 16;
37 static const size_t kai_sr = 2;
38 // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric))
39 static const size_t kai_num_bytes_qvalue_lhs = 1;
40 static const size_t kai_num_bytes_multiplier_lhs = 4;
41 static const size_t kai_num_bytes_sum_lhs = 4;
42 // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is
43 // asymmetric))
44 static const size_t kai_num_bytes_recip_qvalue_rhs = 2;
45 static const size_t kai_num_bytes_multiplier_rhs = 4;
46 static const size_t kai_num_bytes_offset_rhs = 4;
47
48 // DST format args
49 static const size_t kai_num_bytes_dst_value = 4;
50 // Extra args
51 static const size_t kai_num_bytes_bias = 4;
52 static const size_t kai_bl = 32;
53
54 516 inline static size_t kai_get_num_bytes_per_block_lhs(size_t bl) {
55 516 return (bl * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs;
56 }
57
58 516 inline static size_t kai_get_num_bytes_per_block_rhs(size_t bl) {
59 KAI_ASSUME((bl % kai_bl) == 0);
60 1032 size_t num_bytes_per_block_rhs =
61 516 (bl / kai_num_bytes_recip_qvalue_rhs) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs;
62 1032 return num_bytes_per_block_rhs;
63 516 }
64
65 1549 inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) {
66 KAI_ASSUME((bl % kai_bl) == 0);
67
68 1549 return kai_roundup(k, bl) / bl;
69 }
70
71 516 inline static size_t kai_get_lhs_packed_stride(size_t k, size_t bl) {
72 516 return kai_mr * kai_get_num_blocks_per_row(k, bl) * kai_get_num_bytes_per_block_lhs(bl);
73 }
74
75 516 inline static size_t kai_get_rhs_packed_stride(size_t k, size_t bl) {
76 KAI_ASSUME((bl % kai_bl) == 0);
77
78 516 const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl);
79 516 const size_t num_bytes_per_block = kai_get_num_bytes_per_block_rhs(bl);
80
81 516 size_t rhs_packed_stride = kai_nr * (num_bytes_per_block * num_blocks_per_row);
82 // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias
83 516 rhs_packed_stride += kai_nr * kai_num_bytes_bias;
84
85 1032 return rhs_packed_stride;
86 516 }
87
88 1036 size_t kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) {
89 1036 return kai_m_step;
90 }
91
92 1036 size_t kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) {
93 1036 return kai_n_step;
94 }
95
96 1036 size_t kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) {
97 1036 return kai_mr;
98 }
99
100 1036 size_t kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) {
101 1036 return kai_nr;
102 }
103
104 1036 size_t kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) {
105 1036 return kai_kr;
106 }
107
108 1036 size_t kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(void) {
109 1036 return kai_sr;
110 }
111
112 516 size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(
113 size_t m_idx, size_t k, size_t bl) {
114 KAI_ASSUME((m_idx % kai_m_step) == 0);
115
116 516 return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k, bl);
117 }
118
119 516 size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(
120 size_t n_idx, size_t k, size_t bl) {
121 KAI_ASSUME((k % bl) == 0);
122 KAI_ASSUME((n_idx % kai_n_step) == 0);
123
124 516 return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k, bl);
125 }
126
127 516 size_t kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(
128 size_t m_idx, size_t n_idx, size_t dst_stride) {
129 KAI_ASSUME((m_idx % kai_m_step) == 0);
130 KAI_ASSUME((n_idx % kai_n_step) == 0);
131
132 516 return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride;
133 }
134
135 516 size_t kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(size_t m, size_t n) {
136 516 return m * n * kai_num_bytes_dst_value;
137 }
138
139 517 void kai_run_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(
140 size_t m, //
141 size_t n, //
142 size_t k, //
143 size_t bl, //
144 const void* restrict lhs_packed, //
145 const void* restrict rhs_packed, //
146 float* restrict dst, // NOLINT(readability-non-const-parameter)
147 size_t dst_stride_row, //
148 size_t dst_stride_col, //
149 float scalar_min, //
150 float scalar_max) {
151 KAI_ASSUME(dst_stride_col == sizeof(float));
152 KAI_ASSUME((k % bl) == 0);
153 KAI_ASSUME((bl % kai_bl) == 0);
154
155
1/2
✓ Branch 0 taken 517 times.
✗ Branch 1 not taken.
517 if (m == 0) {
156 return;
157 }
158 517 const size_t num_subblocks = bl / kai_bl;
159 517 const size_t num_blocks = kai_get_num_blocks_per_row(k, bl);
160 517 const float clamp_vals[2] = {scalar_min, scalar_max};
161
162 517 KernelArgs args;
163
164 517 args.dst = dst;
165 517 args.lhs_packed = lhs_packed;
166 517 args.rhs_packed = rhs_packed;
167 517 args.clamp_vals = clamp_vals;
168 517 args.dst_stride_row = dst_stride_row;
169 517 args.m = m;
170 517 args.n = n;
171 517 args.num_blocks = num_blocks;
172 517 args.num_subblocks = num_subblocks;
173
174 517 kai_kernel_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm(&args);
175 517 }
176
177 #endif // Architectural features check.
178