KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 40 / 4 / 44
Functions: 100.0% 9 / 0 / 9
Branches: -% 0 / 8 / 8

kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10 #include "kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include "kai/kai_common.h"
16
17 typedef struct {
18 const void* A;
19 const void* B;
20 void* C;
21 uint64_t ldcb;
22 uint64_t M;
23 uint64_t N;
24 uint64_t K;
25 int32_t min;
26 int32_t max;
27 int32_t result_zero_point;
28 void* accumulator_buffer;
29 uint64_t flags;
30 } KernelArgs;
31
32 static const size_t kai_mr = 2;
33 static const size_t kai_nr = 2;
34 static const size_t kai_kr = 4;
35
36 void kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(KernelArgs* args);
37
38 // Returns a constant value specific to this kernel that's relative to vector length
39 9990 static size_t kai_get_kernel_vec_length_constant(void) {
40 9990 const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u8() / kai_kr;
41 19980 return kernel_vec_length_constant;
42 9990 }
43
44 3330 size_t kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) {
45 3330 return kai_mr * kai_get_kernel_vec_length_constant();
46 }
47
48 6660 size_t kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(void) {
49 6660 return kai_nr * kai_get_kernel_vec_length_constant();
50 }
51
52 1665 size_t kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(
53 size_t m_idx, size_t k_chunk_count, size_t k_chunk_length) {
54 KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0);
55 1665 return m_idx * k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t);
56 }
57
58 1665 static size_t kai_get_rhs_packed_stride_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(
59 size_t k_chunk_count, size_t k_chunk_length) {
60 3330 return kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() *
61 1665 (sizeof(int32_t) + k_chunk_count * kai_roundup(k_chunk_length, kai_kr) * sizeof(int8_t) + sizeof(float));
62 }
63
64 1665 size_t kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(
65 size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) {
66 KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0);
67 1665 const size_t block_idx = n_idx / kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa();
68 4995 return block_idx *
69 1665 kai_get_rhs_packed_stride_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(
70 1665 k_chunk_count, k_chunk_length);
71 1665 }
72
73 1665 size_t kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(
74 size_t m_idx, size_t n_idx, size_t dst_stride_row) {
75 KAI_ASSUME(m_idx % kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0);
76 KAI_ASSUME(n_idx % kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa() == 0);
77
78 1665 return m_idx * dst_stride_row + n_idx * sizeof(int8_t);
79 }
80
81 1665 size_t kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(size_t m, size_t n) {
82 1665 return m * n * sizeof(int8_t);
83 }
84
85 1666 void kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(
86 size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed,
87 void* dst, size_t dst_stride_row, const struct kai_matmul_requantize32_params* params) {
88 1666 KernelArgs args;
89
90 1666 args.A = lhs_packed;
91 1666 args.B = rhs_packed;
92 1666 args.C = dst;
93 1666 args.ldcb = dst_stride_row;
94 1666 args.M = m;
95 1666 args.N = n;
96 1666 args.K = k_chunk_count * kai_roundup(k_chunk_length, kai_kr);
97 1666 args.min = params->min_value;
98 1666 args.max = params->max_value;
99 1666 args.result_zero_point = params->output_zero_point;
100 1666 args.accumulator_buffer = NULL;
101 1666 args.flags = 0;
102
103 1666 kai_commit_za();
104
105 1666 kai_kernel_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa(&args);
106 1666 }
107
108 #endif // Architectural features check.
109