KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 100.0% 34 13 47
Functions: 100.0% 8 0 8
Branches: -% 0 26 26

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64)
8 #error This file must be compiled for AArch64, FEAT_SVE2.
9 #else // Architectural features check.
10 #include "kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <string.h>
15
16 #include "kai/kai_common.h"
17
18 enum {
19 NR = 2,
20 KR = 4,
21 MAX_N_STEP = NR * ((KAI_SME_VEC_LENGTH_MAX_BYTES / sizeof(uint8_t)) / KR),
22 };
23
24 typedef struct {
25 const void* bias_ptr;
26 const void* scale_ptr;
27 int32_t input_zero_point;
28 float scale_multiplier;
29 size_t width;
30 size_t height;
31 size_t in_stride;
32 size_t out_stride;
33 const void* in;
34 void* out;
35 const void* pad_row;
36 } KernelArgs;
37
38 static const size_t kai_num_bytes_input = sizeof(uint8_t);
39 static const size_t kai_num_bytes_output = sizeof(uint8_t);
40 static const size_t kai_num_bytes_bias = sizeof(int32_t);
41 static const size_t kai_num_bytes_scale = sizeof(float);
42
43 void kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(const KernelArgs* args_ptr);
44
45 9696 size_t kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(void) {
46 9696 return NR * kai_get_sme_vector_length_u8() / KR;
47 }
48
49 834 size_t kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) {
50 KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0);
51
52 834 return n_idx * kai_num_bytes_input;
53 }
54
55 834 size_t kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) {
56 834 return n_idx * kai_num_bytes_bias;
57 }
58
59 834 size_t kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx) {
60 834 return n_idx * kai_num_bytes_scale;
61 }
62
63 2676 size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t k) {
64 5352 return kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() *
65 2676 (kai_num_bytes_bias + kai_roundup(k, KR) * kai_num_bytes_output + kai_num_bytes_scale);
66 }
67
68 1842 size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n_idx, size_t k) {
69 KAI_ASSUME(n_idx % kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() == 0);
70
71 1842 const size_t block_idx = n_idx / kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme();
72 3684 return block_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(k);
73 1842 }
74
75 834 size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(size_t n, size_t k) {
76 834 const size_t n_nr_blocks = kai_roundup(n, kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme());
77 1668 return kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(n_nr_blocks, k);
78 834 }
79
80 834 void kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
81 size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride_row, const void* rhs,
82 const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes,
83 const struct kai_rhs_pack_qsi8cx_params* params) {
84 KAI_ASSUME(num_groups == 1);
85 KAI_ASSUME(nr == kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme());
86 KAI_ASSUME(kr == KR);
87 KAI_ASSUME(sr == 1);
88 KAI_ASSUME(rhs != NULL);
89 KAI_ASSUME(bias != NULL);
90 KAI_ASSUME(scale != NULL);
91 KAI_ASSUME(rhs_packed != NULL);
92 KAI_ASSUME(extra_bytes == 0);
93 KAI_ASSUME(params != NULL);
94
95 KAI_ASSERT(kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme() <= MAX_N_STEP);
96 static const uint8_t pad_row[MAX_N_STEP] = {0};
97
98 834 KernelArgs args;
99 834 args.bias_ptr = bias;
100 834 args.scale_ptr = scale;
101 834 args.height = k;
102 834 args.width = n;
103 834 args.in = rhs;
104 834 args.out = rhs_packed;
105 834 args.in_stride = rhs_stride_row;
106 834 args.out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(args.height);
107 834 args.input_zero_point = params->lhs_zero_point;
108 834 args.scale_multiplier = params->scale_multiplier;
109 834 args.pad_row = pad_row;
110
111 834 kai_kernel_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(&args);
112 834 }
113
114 #endif // Architectural features check.
115