TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Fp32CLIPAttention.h
1
#include <utility>
2
3
#include "common.h"
4
#include "operators.h"
5
6
struct
Fp32CLIPAttention_output
{
7
Matrix3D<float>
attn_output;
8
Matrix3D<float>
attn_probs_reshaped;
9
std::pair<Matrix3D<float>,
Matrix3D<float>
> past_key_value;
10
};
11
struct
Fp32CLIPAttention_input
{
12
Matrix3D<float>
hidden_states;
13
Matrix3D<float>
attention_mask;
14
Matrix3D<float>
past_key, past_value;
15
bool
has_past_key_value =
false
;
16
int
layer_idx;
17
18
Fp32CLIPAttention_input
(
Matrix3D<float>
hidden_states_,
Matrix3D<float>
attention_mask_,
int
layer_idx_)
19
: hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {}
20
21
Fp32CLIPAttention_input
(
Matrix3D<float>
hidden_states_,
Matrix3D<float>
attention_mask_,
Matrix3D<float>
past_key_,
22
Matrix3D<float>
past_value_,
bool
has_past_key_value_,
int
layer_idx_)
23
: hidden_states(hidden_states_),
24
attention_mask(attention_mask_),
25
past_key(past_key_),
26
past_value(past_value_),
27
has_past_key_value(has_past_key_value_),
28
layer_idx(layer_idx_) {}
29
};
30
31
class
Fp32CLIPAttention
{
32
public
:
33
Fp32CLIPAttention
(std::string param_path,
const
struct
model_config
config);
34
Fp32CLIPAttention
() {}
35
static
void
initialized_memory(
const
struct
model_config
config);
36
struct
Fp32CLIPAttention_output
forward(const struct
Fp32CLIPAttention_input
&input);
37
38
private
:
39
void
unshape(
Matrix3D<float>
shaped,
Matrix3D<float>
unshape,
int
sqlen);
40
void
shape(
Matrix3D<float>
unshape,
Matrix3D<float>
shaped,
int
sqlen);
41
// void shape_qkv(Matrix3D<float> unshape, Matrix3D<float> shaped_q, Matrix3D<float> shaped_k,
42
// Matrix3D<float> shaped_v, int sqlen);
43
int
embed_dim, num_heads, head_dim;
44
Linear_FP
k_proj, v_proj, q_proj, out_proj, qkv_proj;
45
BMM_F32T
qk_bmm, pv_bmm;
46
std::string profile_name =
"Fp32CLIPAttention"
;
47
};
BMM_F32T
Definition
BMM_F32T.h:3
Fp32CLIPAttention
Definition
Fp32CLIPAttention.h:31
Linear_FP
Definition
linear.h:6
Matrix3D
Definition
common.h:34
Fp32CLIPAttention_input
Definition
Fp32CLIPAttention.h:11
Fp32CLIPAttention_output
Definition
Fp32CLIPAttention.h:6
model_config
Definition
model.h:5
Generated by
1.11.0