TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Int4GPTBigCodeDecoder.h
1
#include <cstdlib>
2
#include <string>
3
#include <vector>
4
5
#include "Int4GPTBigCodeDecoderLayer.h"
6
#include "common.h"
7
#include "operators.h"
8
9
struct
Int4GPTBigCodeDecoder_output
{
10
Matrix3D<float>
last_hidden_state;
11
std::vector<Matrix3D<float>> past_keys, past_values;
12
};
13
struct
Int4GPTBigCodeDecoder_input
{
14
Matrix3D<int>
input_ids;
15
std::vector<Matrix3D<float>> past_keys, past_values;
16
bool
has_past_keys_values;
17
18
Int4GPTBigCodeDecoder_input
(
Matrix3D<int>
input_ids_) : input_ids(input_ids_) { has_past_keys_values =
false
; }
19
Int4GPTBigCodeDecoder_input
(
Matrix3D<int>
input_ids_, std::vector<
Matrix3D<float>
> past_keys_,
20
std::vector<
Matrix3D<float>
> past_values_)
21
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
22
has_past_keys_values =
true
;
23
}
24
};
25
26
class
Int4GPTBigCodeDecoder
{
27
public
:
28
Int4GPTBigCodeDecoder
(std::string param_path,
const
struct
model_config
config);
29
Int4GPTBigCodeDecoder
(){};
30
Matrix3D<float>
prepare_decoder_attention_mask(
int
length,
int
past_length);
31
Matrix3D<float>
get_position_embed(
int
sql_length,
int
past_length);
32
struct
Int4GPTBigCodeDecoder_output
forward(const struct
Int4GPTBigCodeDecoder_input
& input);
33
Embedding
wte, wpe;
34
int
voc_size, embed_dim, padding_idx, hidden_dim, num_heads, max_position_embeddings;
35
std::vector<Int4GPTBigCodeDecoderLayer> layers;
36
LayerNorm
ln_f;
37
std::string profile_name =
"Int4GPTBigCodeDecoder"
;
38
39
private
:
40
float
* attention_mask_buf;
41
float
* pos_embeds_buf;
42
float
* last_hidden_states_buf;
43
float
* hidden_states_buf;
44
};
Embedding
Definition
Embedding.h:5
Int4GPTBigCodeDecoder
Definition
Int4GPTBigCodeDecoder.h:26
LayerNorm
Definition
LayerNorm.h:8
Matrix3D
Definition
common.h:34
Int4GPTBigCodeDecoder_input
Definition
Int4GPTBigCodeDecoder.h:13
Int4GPTBigCodeDecoder_output
Definition
Int4GPTBigCodeDecoder.h:9
model_config
Definition
model.h:5
Generated by
1.11.0