TinyChatEngine
Loading...
Searching...
No Matches
Int4GPTBigCodeDecoderLayer.h
1#include "Int4GPTBigCodeAttention.h"
2#include "common.h"
3#include "operators.h"
4
6 Matrix3D<float> hidden_states;
7 Matrix3D<float> attentions;
8 std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value;
9
11 std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value_) {
12 hidden_states = hidden_states_;
13 attentions = attentions_;
14 past_key_value = past_key_value_;
15 };
16};
18 Matrix3D<float> hidden_states;
19 Matrix3D<float> attention_mask;
20 Matrix3D<float> past_key, past_value;
21 bool has_past_key_value = false;
22
23 Int4GPTBigCodeDecoderLayer_input(Matrix3D<float> &hidden_states_, Matrix3D<float> &attention_mask_) {
24 hidden_states = hidden_states_;
25 attention_mask = attention_mask_;
26 has_past_key_value = false;
27 }
28
30 Matrix3D<float> past_key_, Matrix3D<float> past_value_) {
31 hidden_states = hidden_states_;
32 attention_mask = attention_mask_;
33 past_key = past_key_;
34 past_value = past_value_;
35 has_past_key_value = true;
36 }
37};
38
40 public:
41 Int4GPTBigCodeDecoderLayer(std::string param_path, const struct model_config config, int layer_idx);
43
44 int embed_dim, num_attention_heads, hidden_dim, layer_idx;
45 LayerNorm ln_1, ln_2; // from torch_int.nn
46 Linear_FP_int4 fc1, fc2;
48 std::string profile_name = "Int4GPTBigCodeDecoderLayer";
49};
Definition Int4GPTBigCodeAttention.h:31
Definition Int4GPTBigCodeDecoderLayer.h:39
Definition LayerNorm.h:8
Definition linear.h:27
Definition common.h:34
Definition Int4GPTBigCodeDecoderLayer.h:17
Definition Int4GPTBigCodeDecoderLayer.h:5
Definition model.h:5