TinyChatEngine
Loading...
Searching...
No Matches
Int4llamaDecoderLayer.h
1#include "Int4llamaAttention.h"
2#include "common.h"
3#include "operators.h"
4
6#ifdef QM_CUDA
7 Matrix3D<float16_t> hidden_states;
8 Matrix3D<float16_t> attentions;
9 std::pair<Matrix3D<float16_t>, Matrix3D<float16_t>> past_key_value;
10
12 std::pair<Matrix3D<float16_t>, Matrix3D<float16_t>> past_key_value_) {
13#else
14 Matrix3D<float> hidden_states;
15 Matrix3D<float> attentions;
16 std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value;
17
19 std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value_) {
20#endif
21 hidden_states = hidden_states_;
22 attentions = attentions_;
23 past_key_value = past_key_value_;
24 };
25};
27 bool has_past_key_value = false;
28#ifdef QM_CUDA
29 Matrix3D<float16_t> hidden_states;
30 Matrix3D<float16_t> attention_mask;
31 Matrix3D<float16_t> past_key, past_value;
32
34#else
35 Matrix3D<float> hidden_states;
36 Matrix3D<float> attention_mask;
37 Matrix3D<float> past_key, past_value;
38
39 Int4llamaDecoderLayer_input(Matrix3D<float> &hidden_states_, Matrix3D<float> &attention_mask_) {
40#endif
41 hidden_states = hidden_states_;
42 attention_mask = attention_mask_;
43 has_past_key_value = false;
44 }
45
46#ifdef QM_CUDA
48 Matrix3D<float16_t> past_key_, Matrix3D<float16_t> past_value_){
49#else
50 Int4llamaDecoderLayer_input(Matrix3D<float> &hidden_states_, Matrix3D<float> &attention_mask_,
51 Matrix3D<float> past_key_, Matrix3D<float> past_value_) {
52#endif
53 hidden_states = hidden_states_;
54 attention_mask = attention_mask_;
55 past_key = past_key_;
56 past_value = past_value_;
57 has_past_key_value = true;
58}
59}
60;
61
63 public:
64 Int4llamaDecoderLayer(std::string param_path, const struct model_config config, int layer_idx);
66 struct Int4llamaDecoderLayer_output forward(std::string param_path, const struct Int4llamaDecoderLayer_input &input, int layer_idx);
67
68 std::string profile_name = "Int4llamaDecoderLayer";
69 int embed_dim, num_attention_heads, hidden_dim, layer_idx;
70 float rms_norm_eps;
72#ifdef QM_CUDA
73 void free_cuda_memory();
74 LlamaRMSNorm_cuda input_layernorm, post_attention_layernorm;
75 Linear_half_int4 gate_proj, down_proj, up_proj;
76
77#if !(DEC_SHARED_MEM)
78 int *gate_proj_weight = nullptr, *down_proj_weight = nullptr, *up_proj_weight = nullptr;
79#endif
80
81#else
82 LlamaRMSNorm input_layernorm, post_attention_layernorm; // from torch_int.nn
83 Linear_FP_int4 gate_proj, down_proj, up_proj;
84#endif
85 float *input_layernorm_weight_ptr = nullptr;
86 float *post_attention_layernorm_ptr = nullptr;
87};
Definition Int4llamaAttention.h:54
Definition Int4llamaDecoderLayer.h:62
Definition linear.h:27
Definition LlamaRMSNorm.h:4
Definition common.h:34
Definition Int4llamaDecoderLayer.h:26
Definition Int4llamaDecoderLayer.h:5
Definition model.h:5