TinyChatEngine
Loading...
Searching...
No Matches
Int4llamaAttention.h
1#include <utility>
2
3#include "common.h"
4#include "operators.h"
5
7#ifdef QM_CUDA
8 Matrix3D<float16_t> attn_output;
9 Matrix3D<float16_t> attn_probs_reshaped;
10 std::pair<Matrix3D<float16_t>, Matrix3D<float16_t>> past_key_value;
11#else
12 Matrix3D<float> attn_output;
13 Matrix3D<float> attn_probs_reshaped;
14 std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value;
15#endif
16};
17
19 bool has_past_key_value = false;
20 int layer_idx;
21#ifdef QM_CUDA
22 Matrix3D<float16_t> hidden_states;
23 Matrix3D<float16_t> attention_mask;
24 Matrix3D<float16_t> past_key, past_value;
25
26 Int4llamaAttention_input(Matrix3D<float16_t> hidden_states_, Matrix3D<float16_t> attention_mask_, int layer_idx_)
27#else
28 Matrix3D<float> hidden_states;
29 Matrix3D<float> attention_mask;
30 Matrix3D<float> past_key, past_value;
31
32 Int4llamaAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, int layer_idx_)
33#endif
34 : hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {
35 }
36
37#ifdef QM_CUDA
39 Matrix3D<float16_t> past_key_, Matrix3D<float16_t> past_value_, bool has_past_key_value_,
40 int layer_idx_)
41#else
42 Int4llamaAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, Matrix3D<float> past_key_,
43 Matrix3D<float> past_value_, bool has_past_key_value_, int layer_idx_)
44#endif
45 : hidden_states(hidden_states_),
46 attention_mask(attention_mask_),
47 past_key(past_key_),
48 past_value(past_value_),
49 has_past_key_value(has_past_key_value_),
50 layer_idx(layer_idx_) {
51 }
52};
53
55 public:
56 Int4llamaAttention(std::string param_path, const struct model_config config, int layer_idx);
58 static void initialized_memory(const struct model_config config);
59 struct Int4llamaAttention_output forward(std::string param_path, const struct Int4llamaAttention_input &input);
60
61#if !(DEC_SHARED_MEM)
62 int *q_weight = nullptr, *k_weight = nullptr, *v_weight = nullptr, *o_weight = nullptr, *qkv_weight = nullptr;
63#endif
64
65#ifdef QM_CUDA
66 void free_cuda_memory();
67 half *cos_buf = nullptr, *sin_buf = nullptr;
68#else
69 float *cos_buf = nullptr, *sin_buf = nullptr;
70#endif
71
72 private:
73 std::string profile_name = "Int4llamaAttention";
74 int embed_dim, num_heads, num_kv_heads, head_dim;
75#ifdef QM_CUDA
76 Linear_half_int4 o_proj, qkv_proj;
77 RotaryPosEmb_cuda rotary_pos_emb;
78 BMM_F16T qk_bmm, pv_bmm;
79 int max_sqlen;
80#else
81 Linear_FP_int4 k_proj, v_proj, q_proj, o_proj, qkv_proj;
82 RotaryPosEmb rotary_pos_emb;
83 BMM_F32T qk_bmm, pv_bmm;
84 void unshape(Matrix3D<float> shaped, Matrix3D<float> unshape, int num_heads, int head_dim, int sqlen);
85 void shape(Matrix3D<float> unshape, Matrix3D<float> shaped, int num_heads, int head_dim, int sqlen);
86 void shape_qkv(Matrix3D<float> unshape, Matrix3D<float> shaped_q, Matrix3D<float> shaped_k,
87 Matrix3D<float> shaped_v, int sqlen);
88 void repeat(Matrix3D<float> input, Matrix3D<float> output, int num_heads, int num_kv_heads, int sqlen, int head_dim);
89
90#endif
91};
Definition BMM_F32T.h:3
Definition Int4llamaAttention.h:54
Definition linear.h:27
Definition common.h:34
Definition RotaryPosEmb.h:6
Definition Int4llamaAttention.h:18
Definition Int4llamaAttention.h:6
Definition model.h:5