TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Fp32GPTBigCodeDecoder.h
1
#include <cstdlib>
2
#include <string>
3
#include <vector>
4
5
#include "Fp32GPTBigCodeDecoderLayer.h"
6
#include "common.h"
7
#include "operators.h"
8
9
struct
Fp32GPTBigCodeDecoder_output
{
10
Matrix3D<float>
last_hidden_state;
11
std::vector<Matrix3D<float>> past_keys, past_values;
12
};
13
struct
Fp32GPTBigCodeDecoder_input
{
14
Matrix3D<int>
input_ids;
15
std::vector<Matrix3D<float>> past_keys, past_values;
16
bool
has_past_keys_values;
17
18
Fp32GPTBigCodeDecoder_input
(
Matrix3D<int>
input_ids_) : input_ids(input_ids_) { has_past_keys_values =
false
; }
19
Fp32GPTBigCodeDecoder_input
(
Matrix3D<int>
input_ids_, std::vector<
Matrix3D<float>
> past_keys_,
20
std::vector<
Matrix3D<float>
> past_values_)
21
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
22
has_past_keys_values =
true
;
23
}
24
};
25
26
class
Fp32GPTBigCodeDecoder
{
27
public
:
28
Fp32GPTBigCodeDecoder
(std::string param_path,
const
struct
model_config
config);
29
Fp32GPTBigCodeDecoder
(){};
30
Matrix3D<float>
prepare_decoder_attention_mask(
int
length,
int
past_length);
31
Matrix3D<float>
get_position_embed(
int
sql_length,
int
past_length);
32
struct
Fp32GPTBigCodeDecoder_output
forward(const struct
Fp32GPTBigCodeDecoder_input
& input);
33
Embedding
wte, wpe;
34
int
voc_size, embed_dim, padding_idx, hidden_dim, num_heads, max_position_embeddings;
35
std::vector<Fp32GPTBigCodeDecoderLayer> layers;
36
LayerNorm
ln_f;
37
std::string profile_name =
"Fp32GPTBigCodeDecoder"
;
38
39
private
:
40
float
* attention_mask_buf;
41
float
* pos_embeds_buf;
42
float
* last_hidden_states_buf;
43
float
* hidden_states_buf;
44
};
Embedding
Definition
Embedding.h:5
Fp32GPTBigCodeDecoder
Definition
Fp32GPTBigCodeDecoder.h:26
LayerNorm
Definition
LayerNorm.h:8
Matrix3D
Definition
common.h:34
Fp32GPTBigCodeDecoder_input
Definition
Fp32GPTBigCodeDecoder.h:13
Fp32GPTBigCodeDecoder_output
Definition
Fp32GPTBigCodeDecoder.h:9
model_config
Definition
model.h:5
Generated by
1.11.0