TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Int4OPTDecoderLayer.h
1
#include "Int4OPTAttention.h"
2
#include "common.h"
3
#include "operators.h"
4
5
struct
Int4OPTDecoderLayer_output
{
6
Matrix3D<float>
hidden_states;
7
Matrix3D<float>
attentions;
8
std::pair<Matrix3D<float>,
Matrix3D<float>
> past_key_value;
9
10
Int4OPTDecoderLayer_output
(
Matrix3D<float>
hidden_states_,
Matrix3D<float>
attentions_,
11
std::pair<
Matrix3D<float>
,
Matrix3D<float>
> past_key_value_) {
12
hidden_states = hidden_states_;
13
attentions = attentions_;
14
past_key_value = past_key_value_;
15
};
16
};
17
struct
Int4OPTDecoderLayer_input
{
18
Matrix3D<float>
hidden_states;
19
Matrix3D<float>
attention_mask;
20
Matrix3D<float>
past_key, past_value;
21
bool
has_past_key_value =
false
;
22
23
Int4OPTDecoderLayer_input
(
Matrix3D<float>
&hidden_states_,
Matrix3D<float>
&attention_mask_) {
24
hidden_states = hidden_states_;
25
attention_mask = attention_mask_;
26
has_past_key_value =
false
;
27
}
28
29
Int4OPTDecoderLayer_input
(
Matrix3D<float>
&hidden_states_,
Matrix3D<float>
&attention_mask_,
30
Matrix3D<float>
past_key_,
Matrix3D<float>
past_value_) {
31
hidden_states = hidden_states_;
32
attention_mask = attention_mask_;
33
past_key = past_key_;
34
past_value = past_value_;
35
has_past_key_value =
true
;
36
}
37
};
38
39
class
Int4OPTDecoderLayer
{
40
public
:
41
Int4OPTDecoderLayer
(std::string param_path,
const
struct
model_config
config,
int
layer_idx);
42
struct
Int4OPTDecoderLayer_output
forward(const struct
Int4OPTDecoderLayer_input
&input);
43
44
int
embed_dim, num_attention_heads, hidden_dim, layer_idx;
45
LayerNorm
self_attn_layer_norm, final_layer_norm;
// from torch_int.nn
46
Linear_FP_int4
fc1, fc2;
47
Int4OPTAttention
attn;
48
std::string profile_name =
"Int4OPTDecoderLayer"
;
49
};
Int4OPTAttention
Definition
Int4OPTAttention.h:31
Int4OPTDecoderLayer
Definition
Int4OPTDecoderLayer.h:39
LayerNorm
Definition
LayerNorm.h:8
Linear_FP_int4
Definition
linear.h:27
Matrix3D
Definition
common.h:34
Int4OPTDecoderLayer_input
Definition
Int4OPTDecoderLayer.h:17
Int4OPTDecoderLayer_output
Definition
Int4OPTDecoderLayer.h:5
model_config
Definition
model.h:5
Generated by
1.11.0