TinyChatEngine
Loading...
Searching...
No Matches
matmul.h
1#ifndef MATMUL_H
2#define MATMUL_H
3#include <stdint.h>
4#ifdef _WIN32
5#define NOMINMAX
6#include <winsock2.h>
7#else
8#include <sys/time.h>
9#endif
10
11#include "half.hpp" // Third-party header
12typedef half_float::half naive_float16_t;
13
14#ifdef QM_CUDA
15#include <cuda.h>
16#include <cuda_fp16.h>
17#include <cuda_runtime.h>
18typedef half float16_t;
19#elif defined(__ARM_NEON)
20typedef __fp16 float16_t;
21#elif defined(__x86_64__)
22// x86_64 does not natively support fp16, so we use `half_float` library to support fp16 through software-based
23// solution.
24typedef half_float::half float16_t;
25#else
26// Unsupported platform (we only support CUDA, Arm, and x86_64). Using uint16_t as float16_t.
27typedef uint16_t float16_t;
28#endif
29
30#ifdef QM_ARM
31#ifdef __ARM_FEATURE_DOTPROD
32#include <arm_neon.h>
33// Native implementation using vdotq_s32 when available
34static inline int32x4_t my_vdotq_s32(int32x4_t accum, int8x16_t a, int8x16_t b) { return vdotq_s32(accum, a, b); }
35
36#else
37#include <arm_neon.h>
38// Fallback implementation when vdotq_s32 is not available
39static inline int32x4_t my_vdotq_s32(int32x4_t accum, int8x16_t a, int8x16_t b) {
40 // Multiply and widen results to 16-bit integers
41 int16x8_t result_low = vmull_s8(vget_low_s8(a), vget_low_s8(b));
42 int16x8_t result_high = vmull_s8(vget_high_s8(a), vget_high_s8(b));
43
44 // Sum pairs of 16-bit values and accumulate into 32-bit integers
45 return vaddq_s32(accum, vaddq_s32(vaddl_s16(vget_low_s16(result_low), vget_high_s16(result_low)),
46 vaddl_s16(vget_low_s16(result_high), vget_high_s16(result_high))));
47}
48#endif
49#endif
50
51// Data structures
53 float scale;
54 bool per_channel = false;
55 int32_t zero_point;
56 int8_t q_min = -128, q_max = 127;
57};
58
59struct matrix {
60 int row;
61 int column;
62 float *data_ptr;
63 float16_t *half_data_ptr;
64 naive_float16_t *fp16_data_ptr;
65 int32_t *int32_data_ptr;
66 int8_t *int8_data_ptr;
67 uint8_t *uint8_data_ptr;
68 uint8_t *int4_data_ptr;
69 struct quantization_params qparams;
70 int length() { return row * column; }
71};
72
74 int blk_size;
75 int num_thread = 8;
76};
77
79 struct matrix A, B, C, bias;
81 float alpha, beta;
82 float16_t half_alpha;
83 // for int4
84 float *scales, *offset, *zero_point;
85 float16_t *half_scales;
86 naive_float16_t *fp16_scales;
87 int *int32_zero_point;
88 int block_size;
89 // for int8 activation
90 float *A_scales;
91 int8_t A_zero_point;
92};
93
95 const struct matrix *A;
96 const struct matrix *B;
97 const struct matrix *C;
98 const struct matmul_params *params;
99 int start_i, end_i, blk_size;
100};
101
102#ifndef MAX
103#define MAX(A, B) ((A) > (B) ? (A) : (B))
104#endif
105#ifndef MIN
106#define MIN(A, B) ((A) < (B) ? (A) : (B))
107#endif
108
109namespace matmul {
111 public:
112 void mat_mul_transposed(const struct matmul_params *params);
113 void mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params);
114 void mat_mul_accelerator_transposed_fastover_column_bias(const struct matmul_params *params);
115 void mat_mul_accelerator_untransposed_fastover_column(const struct matmul_params *params);
116 // int8
117 void naive_mat_mul_int8(const struct matmul_params *params);
118 void mat_mul_accelerator_int8_fast_32unroll_over_column(const struct matmul_params *params);
119 void mat_mul_accelerator_int8_fast_2x2_32unroll(const struct matmul_params *params);
120 void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias(const struct matmul_params *params);
121 void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_batch(const struct matmul_params *params);
122 void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32(const struct matmul_params *params);
123 void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32_batch(const struct matmul_params *params);
124 void mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32(const struct matmul_params *params);
125 void mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32_over_column(const struct matmul_params *params);
126 // void mat_mul_accelerator_int8_fast_2x2_omp(const struct matmul_params *params);
127 // int4
128 void mat_mul_accelerator_int4_fast(const struct matmul_params *params);
129 void mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params *params);
130 void mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params *params);
131 void gemv_accelerator_int8_int4_fast_no_offset(struct matmul_params *params);
132 void gemm_accelerator_int8_int4_fast_no_offset(struct matmul_params *params);
133 void gemm_accelerator_int8_int4_fast_no_offset_v2(struct matmul_params *params);
134 void cblas_gemm_accelerator_no_offset(struct matmul_params *params);
135 void naive_mat_mul_int4(const struct matmul_params *params);
136 void naive_mat_mul_int4_with_offset(const struct matmul_params *params);
137 // cuda
138 void naive_mat_mul_fp16_int4(const struct matmul_params *params);
139 // void naive_mat_mul_fp16_int4_gemv(const struct matmul_params *params);
140 void mat_mul_cuda(const struct matmul_params *params);
142 void gemm_forward_cuda(const struct matmul_params *params, int split_k_iters);
143 void gemm_forward_cuda_8splits(const struct matmul_params *params, float16_t *split_8_buffer);
144 void gemm_forward_cuda_half(const struct matmul_params *params, int split_k_iters);
145 void gemm_forward_cuda_half_test(const struct matmul_params *params, int split_k_iters);
147 void gemv_forward_cuda(const struct matmul_params *params);
148
149 private:
150 float interval_to_us(struct timeval *start, struct timeval *end);
151 void CHECK_MATRICES(const struct matrix *A, const struct matrix *B, const struct matrix *C);
152 void CHECK_MATRICES_int4weight(const struct matrix *A, const struct matrix *B, const struct matrix *C);
153};
154} // namespace matmul
155
156#endif
Definition matmul.h:110
Definition matmul.h:78
Definition matmul.h:59
Definition Generate.h:48
Definition matmul.h:73
Definition matmul.h:52
Definition matmul.h:94