TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Fp32CLIPEncoderLayer.h
1
#include "Fp32CLIPAttention.h"
2
#include "common.h"
3
#include "operators.h"
4
5
struct
Fp32CLIPEncoderLayer_output
{
6
Matrix3D<float>
hidden_states;
7
Matrix3D<float>
attentions;
8
std::pair<Matrix3D<float>,
Matrix3D<float>
> past_key_value;
9
10
Fp32CLIPEncoderLayer_output
(
Matrix3D<float>
hidden_states_,
Matrix3D<float>
attentions_,
11
std::pair<
Matrix3D<float>
,
Matrix3D<float>
> past_key_value_) {
12
hidden_states = hidden_states_;
13
attentions = attentions_;
14
past_key_value = past_key_value_;
15
};
16
};
17
struct
Fp32CLIPEncoderLayer_input
{
18
Matrix3D<float>
hidden_states;
19
Matrix3D<float>
attention_mask;
20
Matrix3D<float>
past_key, past_value;
21
bool
has_past_key_value =
false
;
22
23
Fp32CLIPEncoderLayer_input
(
Matrix3D<float>
&hidden_states_,
Matrix3D<float>
attention_mask_) {
24
hidden_states = hidden_states_;
25
attention_mask = attention_mask_;
26
has_past_key_value =
false
;
27
}
28
29
Fp32CLIPEncoderLayer_input
(
Matrix3D<float>
&hidden_states_,
Matrix3D<float>
attention_mask_,
30
Matrix3D<float>
past_key_,
Matrix3D<float>
past_value_) {
31
hidden_states = hidden_states_;
32
attention_mask = attention_mask_;
33
past_key = past_key_;
34
past_value = past_value_;
35
has_past_key_value =
true
;
36
}
37
};
38
39
class
Fp32CLIPEncoderLayer
{
40
public
:
41
Fp32CLIPEncoderLayer
(std::string param_path,
const
struct
model_config
config,
int
layer_idx);
42
struct
Fp32CLIPEncoderLayer_output
forward(const struct
Fp32CLIPEncoderLayer_input
&input);
43
44
int
embed_dim, num_attention_heads, hidden_dim, layer_idx;
45
LayerNorm
layer_norm1, layer_norm2;
46
Linear_FP
mlp_fc1, mlp_fc2;
47
Fp32CLIPAttention
attn;
48
std::string profile_name =
"Fp32CLIPEncoderLayer"
;
49
};
Fp32CLIPAttention
Definition
Fp32CLIPAttention.h:31
Fp32CLIPEncoderLayer
Definition
Fp32CLIPEncoderLayer.h:39
LayerNorm
Definition
LayerNorm.h:8
Linear_FP
Definition
linear.h:6
Matrix3D
Definition
common.h:34
Fp32CLIPEncoderLayer_input
Definition
Fp32CLIPEncoderLayer.h:17
Fp32CLIPEncoderLayer_output
Definition
Fp32CLIPEncoderLayer.h:5
model_config
Definition
model.h:5
Generated by
1.11.0