KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 100.0% 39 / 8 / 47
Functions: 100.0% 8 / 0 / 8
Branches: -% 0 / 16 / 16

kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10
11 #include "kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h"
12
13 #include <stddef.h>
14 #include <stdint.h>
15 #include <string.h>
16
17 #include "kai/kai_common.h"
18
19 enum {
20 NR = 2,
21 KR = 4,
22 MAX_N_STEP = NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint8_t)) / KR),
23 };
24
25 typedef struct {
26 const void* bias_ptr;
27 const void* scale_ptr;
28 int32_t input_zero_point;
29 float scale_multiplier;
30 size_t width;
31 size_t height;
32 size_t k_chunk_count;
33 size_t in_stride;
34 size_t out_stride;
35 const void* in;
36 void* out;
37 const void* pad_row;
38 } KernelArgs;
39
40 static const size_t kai_num_bytes_input = sizeof(uint8_t);
41 static const size_t kai_num_bytes_output = sizeof(uint8_t);
42 static const size_t kai_num_bytes_bias = sizeof(int32_t);
43 static const size_t kai_num_bytes_scale = sizeof(float);
44
45 void kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(const KernelArgs* args_ptr);
46
47 33300 size_t kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) {
48 33300 return NR * kai_get_sme_vector_length_u8() / KR;
49 }
50
51 3330 size_t kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) {
52 KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0);
53
54 3330 return n_idx * kai_num_bytes_input;
55 }
56
57 3330 size_t kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) {
58 3330 return n_idx * kai_num_bytes_bias;
59 }
60
61 3330 size_t kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) {
62 3330 return n_idx * kai_num_bytes_scale;
63 }
64
65 9990 size_t kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
66 size_t k_chunk_count, size_t k_chunk_length) {
67 19980 return kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() *
68 9990 (kai_num_bytes_bias + k_chunk_count * kai_roundup(k_chunk_length, KR) * kai_num_bytes_output +
69 kai_num_bytes_scale);
70 }
71
72 6660 size_t kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
73 size_t n_idx, size_t k_chunk_count, size_t k_chunk_length) {
74 KAI_ASSUME(n_idx % kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0);
75
76 6660 const size_t block_idx = n_idx / kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme();
77 19980 return block_idx *
78 6660 kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length);
79 6660 }
80
81 3330 size_t kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
82 size_t n, size_t k_chunk_count, size_t k_chunk_length) {
83 3330 const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme());
84 6660 return kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
85 3330 n_nr_blocks, k_chunk_count, k_chunk_length);
86 3330 }
87
88 3330 void kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
89 size_t n, size_t k_chunk_count, size_t k_chunk_length, size_t rhs_stride_row, const void* rhs, const void* bias,
90 const void* scale, void* rhs_packed, const struct kai_rhs_pack_qsi8cx_params* params) {
91 KAI_ASSUME(rhs != NULL);
92 KAI_ASSUME(bias != NULL);
93 KAI_ASSUME(scale != NULL);
94 KAI_ASSUME(rhs_packed != NULL);
95 KAI_ASSUME(params != NULL);
96
97 KAI_ASSERT(kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP);
98 static const uint8_t pad_row[MAX_N_STEP] = {0};
99
100 3330 KernelArgs args;
101 3330 args.bias_ptr = bias;
102 3330 args.scale_ptr = scale;
103 3330 args.height = k_chunk_length;
104 3330 args.width = n;
105 3330 args.in = rhs;
106 3330 args.out = rhs_packed;
107 3330 args.k_chunk_count = k_chunk_count;
108 3330 args.in_stride = rhs_stride_row;
109 3330 args.out_stride =
110 3330 kai_get_rhs_packed_stride_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k_chunk_count, k_chunk_length);
111 3330 args.input_zero_point = params->lhs_zero_point;
112 3330 args.scale_multiplier = params->scale_multiplier;
113 3330 args.pad_row = pad_row;
114
115 3330 kai_commit_za();
116
117 3330 kai_kernel_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(&args);
118 3330 }
119
120 #endif // Architectural features check.
121