KleidiAI Coverage Report


Directory: ./
Coverage: low: ≥ 0% medium: ≥ 75.0% high: ≥ 90.0%
Coverage Exec / Excl / Total
Lines: 95.1% 58 / 7 / 68
Functions: 87.5% 7 / 0 / 8
Branches: 90.0% 18 / 14 / 34

kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
7
8 #include <math.h>
9 #include <stddef.h>
10 #include <stdint.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_multiplier = sizeof(uint16_t);
15
16 3724 inline static size_t kai_num_bytes_per_block(size_t bl) {
17 3724 return bl * sizeof(int8_t) + kai_num_bytes_multiplier;
18 }
19
20 3724 inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
21 KAI_ASSERT((k % bl) == 0);
22 3724 return k / bl;
23 }
24
25 3032 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) {
26 3032 KAI_UNUSED(kr);
27 3032 return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl);
28 }
29
30 size_t kai_get_m_step_lhs_quant_pack_qsi8d32p_f32(size_t mr) {
31 return mr;
32 }
33
34 692 size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32(size_t m_idx, size_t lhs_stride) {
35 692 return m_idx * lhs_stride;
36 }
37
38 1648 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32(
39 size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
40 KAI_ASSUME((k % 2) == 0);
41 KAI_ASSUME((k % kr) == 0);
42 KAI_ASSUME((k % bl) == 0);
43
44 1648 KAI_UNUSED(sr);
45 1648 KAI_UNUSED(kr);
46
47 1648 return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl);
48 }
49
50 692 size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32(
51 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
52 KAI_ASSUME((k % 2) == 0);
53 KAI_ASSUME((k % kr) == 0);
54 KAI_ASSUME((k % bl) == 0);
55
56 692 KAI_UNUSED(sr);
57 692 KAI_UNUSED(kr);
58
59 692 const size_t num_rows = kai_roundup(m, mr) / mr;
60
61 1384 return num_rows * kai_lhs_packed_stride(k, mr, kr, bl);
62 692 }
63
64 692 void kai_run_lhs_quant_pack_qsi8d32p_f32(
65 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
66 size_t lhs_stride, void* lhs_packed) {
67
1/2
✓ Branch 0 taken 692 times.
✗ Branch 1 not taken.
692 if (m == 0) {
68 return;
69 }
70
71 692 const size_t num_rows = m;
72 692 const size_t k_block_len = kr / sr;
73 692 const size_t lhs_packed_stride = kai_lhs_packed_stride(k, mr, kr, bl);
74 692 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
75 692 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
76
77
2/2
✓ Branch 0 taken 12076 times.
✓ Branch 1 taken 692 times.
12768 for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) {
78 12076 const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride);
79
80
2/2
✓ Branch 0 taken 23360 times.
✓ Branch 1 taken 12076 times.
35436 for (size_t b = 0; b < num_blocks_per_row; ++b) {
81 23360 float abs_max = 0.0F;
82
83 23360 const size_t dst_x = ((row_idx + m_idx_start) % mr);
84 23360 int8_t* dst_ptr = (int8_t*)lhs_packed + (b * mr) * num_bytes_per_block;
85
86
2/2
✓ Branch 0 taken 747520 times.
✓ Branch 1 taken 23360 times.
770880 for (size_t idx_v = 0; idx_v < bl; ++idx_v) {
87 747520 const float val = src_ptr[idx_v];
88
2/2
✓ Branch 0 taken 651584 times.
✓ Branch 1 taken 95936 times.
747520 abs_max = KAI_MAX(abs_max, fabsf(val));
89 747520 }
90
91 // Calculate scale and reciprocal
92 23360 const float scale = abs_max / ((1 << 7) - 1);
93
1/2
✓ Branch 0 taken 23360 times.
✗ Branch 1 not taken.
23360 const float rep_scale = scale ? 1.0F / scale : 0.0F;
94
95 23360 *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale);
96 23360 dst_ptr += mr * kai_num_bytes_multiplier;
97
98 23360 dst_ptr += dst_x * k_block_len * sizeof(int8_t);
99
100 // Quantize and pack the block
101
2/2
✓ Branch 0 taken 102768 times.
✓ Branch 1 taken 23360 times.
126128 for (size_t k_idx = 0; k_idx < bl; k_idx += k_block_len) {
102
2/2
✓ Branch 0 taken 747520 times.
✓ Branch 1 taken 102768 times.
850288 for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
103 // Clamp at the last valid k-index
104
2/2
✓ Branch 0 taken 746728 times.
✓ Branch 1 taken 792 times.
747520 const size_t k_idx_start = KAI_MIN(k_idx + k_block_idx, k - 1);
105
106 747520 const float src0_0 = *(src_ptr + k_idx_start);
107
108 // Scale the values
109 747520 int32_t v0_s32 = (int32_t)(roundf(src0_0 * rep_scale));
110
111 747520 *dst_ptr = (int8_t)v0_s32;
112 747520 dst_ptr += sizeof(int8_t);
113 747520 }
114 102768 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
115 102768 }
116
117 23360 src_ptr += bl;
118 23360 }
119 // Move to the next row if we have interleaved all Mr rows
120
2/2
✓ Branch 0 taken 9040 times.
✓ Branch 1 taken 3036 times.
12076 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
121 3036 lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride);
122 3036 }
123 12076 }
124 692 }
125