KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 94.9% 37 5 44
Functions: 90.9% 10 0 11
Branches: -% 0 10 10

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include "kai/kai_common.h"
17
18 typedef struct {
19 int32_t c_offset;
20 int32_t maxval;
21 int32_t minval;
22 const void* A_ptr;
23 const void* B_ptr;
24 size_t N;
25 size_t K;
26 void* output_ptr;
27 uint64_t flags;
28 } KernelArgs;
29
30 void kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(KernelArgs* args_ptr);
31
32 static const size_t kai_m_step = 1;
33 static const size_t kai_nr = 2;
34 static const size_t kai_n_step = 16;
35 static const size_t kai_kr = 4;
36 static const size_t kai_sr = 1;
37
38 168 size_t kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) {
39 168 return kai_m_step;
40 }
41
42 840 size_t kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) {
43 840 return kai_n_step * kai_get_sme_vector_length_u8() / kai_kr;
44 }
45
46 168 size_t kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) {
47 168 return kai_nr * kai_get_sme_vector_length_u8() / kai_kr;
48 }
49
50 168 size_t kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) {
51 168 return kai_kr;
52 }
53
54 168 size_t kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(void) {
55 168 return kai_sr;
56 }
57
58 size_t kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m_idx, size_t k) {
59 KAI_ASSUME(m_idx == 0);
60
61 return m_idx * k;
62 }
63
64 168 static size_t kai_get_rhs_packed_stride_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t k) {
65 336 return kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot() *
66 168 (kai_roundup(k, kai_kr) * sizeof(int8_t) + sizeof(int32_t) + sizeof(int32_t));
67 }
68
69 168 size_t kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t n_idx, size_t k) {
70 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot() == 0);
71
72 168 const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot();
73 336 return block_idx * kai_get_rhs_packed_stride_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(k);
74 168 }
75
76 168 size_t kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(
77 size_t m_idx, size_t n_idx, size_t dst_stride) {
78 KAI_ASSUME(m_idx == 0);
79 KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot() == 0);
80
81 168 return (m_idx * dst_stride) + (n_idx * sizeof(int8_t));
82 }
83
84 168 size_t kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(size_t m, size_t n) {
85 168 return m * n * sizeof(int8_t);
86 }
87
88 169 void kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(
89 size_t m, size_t n, size_t k, const void* lhs, const void* rhs_packed, void* dst, size_t dst_stride_row,
90 size_t dst_stride_col, const struct kai_matmul_requantize32_params* params) {
91 169 KAI_UNUSED(dst_stride_row);
92 169 KAI_UNUSED(dst_stride_col);
93
94 KAI_ASSUME(m == 1);
95
96 169 uint64_t flags = 2;
97
98 169 KernelArgs args;
99 169 args.c_offset = params->output_zero_point;
100 169 args.maxval = params->max_value;
101 169 args.minval = params->min_value;
102 169 args.A_ptr = lhs;
103 169 args.B_ptr = rhs_packed;
104 169 args.N = n;
105 169 args.K = k;
106 169 args.output_ptr = dst;
107 169 args.flags = flags;
108
109 169 kai_kernel_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot(&args);
110 169 }
111
112 #endif // Architectural features check.
113