KleidiAI Coverage Report


Directory: ./
File: kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
Date: 2025-10-20 13:18:31
Coverage Exec Excl Total
Lines: 93.5% 58 7 69
Functions: 87.5% 7 0 8
Branches: 90.0% 18 14 34

Line Branch Exec Source
1 //
2 // SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
7
8 #include <math.h>
9 #include <stddef.h>
10 #include <stdint.h>
11
12 #include "kai/kai_common.h"
13
14 static const size_t kai_num_bytes_multiplier = sizeof(uint16_t);
15
16 415 inline static size_t kai_num_bytes_per_block(size_t bl) {
17 415 return bl * sizeof(int8_t) + kai_num_bytes_multiplier;
18 }
19
20 415 inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
21 KAI_ASSERT((k % bl) == 0);
22 415 return k / bl;
23 }
24
25 340 inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t bl) {
26 340 KAI_UNUSED(kr);
27 340 return mr * kai_num_blocks_per_row(k, bl) * kai_num_bytes_per_block(bl);
28 }
29
30 size_t kai_get_m_step_lhs_quant_pack_qsi8d32p_f32(size_t mr) {
31 KAI_UNUSED(mr);
32 return 1;
33 }
34
35 75 size_t kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32(size_t m_idx, size_t lhs_stride) {
36 75 return m_idx * lhs_stride;
37 }
38
39 190 size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32(
40 size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
41 KAI_ASSUME((k % 2) == 0);
42 KAI_ASSUME((k % kr) == 0);
43 KAI_ASSUME((k % bl) == 0);
44
45 190 KAI_UNUSED(sr);
46 190 KAI_UNUSED(kr);
47
48 190 return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, bl);
49 }
50
51 75 size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32(
52 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
53 KAI_ASSUME((k % 2) == 0);
54 KAI_ASSUME((k % kr) == 0);
55 KAI_ASSUME((k % bl) == 0);
56
57 75 KAI_UNUSED(sr);
58 75 KAI_UNUSED(kr);
59
60 75 const size_t num_rows = kai_roundup(m, mr) / mr;
61
62 150 return num_rows * kai_lhs_packed_stride(k, mr, kr, bl);
63 75 }
64
65 75 void kai_run_lhs_quant_pack_qsi8d32p_f32(
66 size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
67 size_t lhs_stride, void* lhs_packed) {
68
1/2
✓ Branch 0 taken 75 times.
✗ Branch 1 not taken.
75 if (m == 0) {
69 return;
70 }
71
72 75 const size_t num_rows = m;
73 75 const size_t k_block_len = kr / sr;
74 75 const size_t lhs_packed_stride = kai_lhs_packed_stride(k, mr, kr, bl);
75 75 const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
76 75 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
77
78
2/2
✓ Branch 0 taken 1791 times.
✓ Branch 1 taken 75 times.
1866 for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) {
79 1791 const float* src_ptr = (const float*)((const uint8_t*)lhs + (row_idx + m_idx_start) * lhs_stride);
80
81
2/2
✓ Branch 0 taken 3387 times.
✓ Branch 1 taken 1791 times.
5178 for (size_t b = 0; b < num_blocks_per_row; ++b) {
82 3387 float abs_max = 0.0F;
83
84 3387 const size_t dst_x = ((row_idx + m_idx_start) % mr);
85 3387 int8_t* dst_ptr = (int8_t*)lhs_packed + (b * mr) * num_bytes_per_block;
86
87
2/2
✓ Branch 0 taken 108384 times.
✓ Branch 1 taken 3387 times.
111771 for (size_t idx_v = 0; idx_v < bl; ++idx_v) {
88 108384 const float val = src_ptr[idx_v];
89
2/2
✓ Branch 0 taken 94431 times.
✓ Branch 1 taken 13953 times.
108384 abs_max = KAI_MAX(abs_max, fabsf(val));
90 108384 }
91
92 // Calculate scale and reciprocal
93 3387 const float scale = abs_max / ((1 << 7) - 1);
94
1/2
✓ Branch 0 taken 3387 times.
✗ Branch 1 not taken.
3387 const float rep_scale = scale ? 1.0F / scale : 0.0F;
95
96 3387 *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale);
97 3387 dst_ptr += mr * kai_num_bytes_multiplier;
98
99 3387 dst_ptr += dst_x * k_block_len * sizeof(int8_t);
100
101 // Quantize and pack the block
102
2/2
✓ Branch 0 taken 18068 times.
✓ Branch 1 taken 3387 times.
21455 for (size_t k_idx = 0; k_idx < bl; k_idx += k_block_len) {
103
2/2
✓ Branch 0 taken 108384 times.
✓ Branch 1 taken 18068 times.
126452 for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
104 // Clamp at the last valid k-index
105
2/2
✓ Branch 0 taken 108189 times.
✓ Branch 1 taken 195 times.
108384 const size_t k_idx_start = KAI_MIN(k_idx + k_block_idx, k - 1);
106
107 108384 const float src0_0 = *(src_ptr + k_idx_start);
108
109 // Scale the values
110 108384 int32_t v0_s32 = (int32_t)(roundf(src0_0 * rep_scale));
111
112 108384 *dst_ptr = (int8_t)v0_s32;
113 108384 dst_ptr += sizeof(int8_t);
114 108384 }
115 18068 dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t);
116 18068 }
117
118 3387 src_ptr += bl;
119 3387 }
120 // Move to the next row if we have interleaved all Mr rows
121
2/2
✓ Branch 0 taken 1353 times.
✓ Branch 1 taken 438 times.
1791 if ((((row_idx + 1) + m_idx_start) % mr) == 0) {
122 438 lhs_packed = (void*)((int8_t*)lhs_packed + lhs_packed_stride);
123 438 }
124 1791 }
125 75 }
126