TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Fp32llamaAttention.h
1
#include <utility>
2
3
#include "common.h"
4
#include "operators.h"
5
6
struct
Fp32llamaAttention_output
{
7
Matrix3D<float>
attn_output;
8
Matrix3D<float>
attn_probs_reshaped;
9
std::pair<Matrix3D<float>,
Matrix3D<float>
> past_key_value;
10
};
11
struct
Fp32llamaAttention_input
{
12
Matrix3D<float>
hidden_states;
13
Matrix3D<float>
attention_mask;
14
Matrix3D<float>
past_key, past_value;
15
bool
has_past_key_value =
false
;
16
int
layer_idx;
17
18
Fp32llamaAttention_input
(
Matrix3D<float>
hidden_states_,
Matrix3D<float>
attention_mask_,
int
layer_idx_)
19
: hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {}
20
21
Fp32llamaAttention_input
(
Matrix3D<float>
hidden_states_,
Matrix3D<float>
attention_mask_,
Matrix3D<float>
past_key_,
22
Matrix3D<float>
past_value_,
bool
has_past_key_value_,
int
layer_idx_)
23
: hidden_states(hidden_states_),
24
attention_mask(attention_mask_),
25
past_key(past_key_),
26
past_value(past_value_),
27
has_past_key_value(has_past_key_value_),
28
layer_idx(layer_idx_) {}
29
};
30
31
class
Fp32llamaAttention
{
32
public
:
33
Fp32llamaAttention
(std::string param_path,
const
struct
model_config
config);
34
Fp32llamaAttention
() {}
35
static
void
initialized_memory(
const
struct
model_config
config);
36
struct
Fp32llamaAttention_output
forward(const struct
Fp32llamaAttention_input
&input);
37
38
private
:
39
void
unshape(
Matrix3D<float>
shaped,
Matrix3D<float>
unshape,
int
sqlen);
40
void
shape(
Matrix3D<float>
unshape,
Matrix3D<float>
shaped,
int
sqlen);
41
int
embed_dim, num_heads, head_dim;
42
Linear_FP
k_proj, v_proj, q_proj, o_proj;
43
RotaryPosEmb
rotary_pos_emb;
44
BMM_F32T
qk_bmm, pv_bmm;
45
std::string profile_name =
"Fp32llamaAttention"
;
46
};
BMM_F32T
Definition
BMM_F32T.h:3
Fp32llamaAttention
Definition
Fp32llamaAttention.h:31
Linear_FP
Definition
linear.h:6
Matrix3D
Definition
common.h:34
RotaryPosEmb
Definition
RotaryPosEmb.h:6
Fp32llamaAttention_input
Definition
Fp32llamaAttention.h:11
Fp32llamaAttention_output
Definition
Fp32llamaAttention.h:6
model_config
Definition
model.h:5
Generated by
1.11.0