TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Fp32llamaDecoder.h
1
#include <cstdlib>
2
#include <string>
3
#include <vector>
4
5
#include "Fp32llamaDecoderLayer.h"
6
#include "common.h"
7
#include "operators.h"
8
9
struct
Fp32llamaDecoder_output
{
10
Matrix3D<float>
last_hidden_state;
11
std::vector<Matrix3D<float>> past_keys, past_values;
12
};
13
struct
Fp32llamaDecoder_input
{
14
Matrix3D<int>
input_ids;
15
Matrix3D<float>
image_embed;
16
Matrix3D<int>
second_input_ids;
17
std::vector<Matrix3D<float>> past_keys, past_values;
18
bool
has_past_keys_values;
19
bool
is_llava;
20
21
Fp32llamaDecoder_input
() {}
22
Fp32llamaDecoder_input
(
Matrix3D<int>
input_ids_) : input_ids(input_ids_) {
23
has_past_keys_values =
false
;
24
is_llava =
false
;
25
}
26
Fp32llamaDecoder_input
(
Matrix3D<int>
input_ids_, std::vector<
Matrix3D<float>
> past_keys_,
27
std::vector<
Matrix3D<float>
> past_values_)
28
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
29
has_past_keys_values =
true
;
30
is_llava =
false
;
31
}
32
Fp32llamaDecoder_input
(
Matrix3D<int>
input_ids_,
Matrix3D<float>
image_embed_,
Matrix3D<int>
second_input_ids_)
33
: input_ids(input_ids_), image_embed(image_embed_), second_input_ids(second_input_ids_) {
34
has_past_keys_values =
false
;
35
is_llava =
true
;
36
}
37
Fp32llamaDecoder_input
(
Matrix3D<int>
input_ids_,
Matrix3D<float>
image_embed_)
38
: input_ids(input_ids_), image_embed(image_embed_) {
39
has_past_keys_values =
false
;
40
is_llava =
true
;
41
}
42
};
43
44
class
Fp32llamaDecoder
{
45
public
:
46
Fp32llamaDecoder
(std::string param_path,
const
struct
model_config
config);
47
Fp32llamaDecoder
(){};
48
Matrix3D<float>
prepare_decoder_attention_mask(
int
length,
int
past_length);
49
struct
Fp32llamaDecoder_output
forward(const struct
Fp32llamaDecoder_input
& input);
50
Embedding
embed_tokens;
51
LlamaRMSNorm
norm;
52
float
rms_norm_eps;
53
int
voc_size, embed_dim, padding_idx, hidden_dim, num_heads;
54
std::vector<Fp32llamaDecoderLayer> layers;
55
std::string profile_name =
"Fp32llamaDecoder"
;
56
57
private
:
58
float
* attention_mask_buf;
59
float
* pos_embeds_buf;
60
float
* last_hidden_states_buf;
61
float
* hidden_states_buf;
62
float
* inputs_embeds_buf;
63
float
* first_input_ids_buf;
64
float
* image_embed_buf;
65
float
* second_input_ids_buf;
66
};
Embedding
Definition
Embedding.h:5
Fp32llamaDecoder
Definition
Fp32llamaDecoder.h:44
LlamaRMSNorm
Definition
LlamaRMSNorm.h:4
Matrix3D
Definition
common.h:34
Fp32llamaDecoder_input
Definition
Fp32llamaDecoder.h:13
Fp32llamaDecoder_output
Definition
Fp32llamaDecoder.h:9
model_config
Definition
model.h:5
Generated by
1.11.0