TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Int8OPTAttention.h
1
#include <utility>
2
3
#include "common.h"
4
#include "operators.h"
5
6
struct
Int8OPTAttention_output
{
7
Matrix3D<float>
attn_output;
8
Matrix3D<int8_t>
attn_probs_reshaped;
9
std::pair<Matrix3D<int8_t>,
Matrix3D<int8_t>
> past_key_value;
10
};
11
struct
Int8OPTAttention_input
{
12
Matrix3D<int8_t>
hidden_states;
13
Matrix3D<float>
attention_mask;
14
Matrix3D<int8_t>
past_key, past_value;
15
bool
has_past_key_value =
false
;
16
int
layer_idx;
17
18
Int8OPTAttention_input
(
Matrix3D<int8_t>
hidden_states_,
Matrix3D<float>
attention_mask_,
int
layer_idx_)
19
: hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {}
20
21
Int8OPTAttention_input
(
Matrix3D<int8_t>
hidden_states_,
Matrix3D<float>
attention_mask_,
Matrix3D<int8_t>
past_key_,
22
Matrix3D<int8_t>
past_value_,
bool
has_past_key_value_,
int
layer_idx_)
23
: hidden_states(hidden_states_),
24
attention_mask(attention_mask_),
25
past_key(past_key_),
26
past_value(past_value_),
27
has_past_key_value(has_past_key_value_),
28
layer_idx(layer_idx_) {}
29
};
30
31
class
Int8OPTAttention
{
32
public
:
33
Int8OPTAttention
(std::string param_path,
const
struct
model_config
config,
BMM_S8T_S8N_F32T
&qk_bmm,
34
BMM_S8T_S8N_S8T
&pv_bmm,
W8A8B8O8Linear
&k_proj,
W8A8B8O8Linear
&v_proj,
W8A8B8O8Linear
&q_proj,
35
W8A8BFP32OFP32Linear
&out_proj);
36
Int8OPTAttention
() {}
37
static
void
initialized_memory(
const
struct
model_config
config);
38
struct
Int8OPTAttention_output
forward(const struct
Int8OPTAttention_input
&input);
39
40
private
:
41
void
unshape(
Matrix3D<int8_t>
shaped,
Matrix3D<int8_t>
unshape,
int
sqlen);
42
void
shpae(
Matrix3D<int8_t>
unshape,
Matrix3D<int8_t>
shaped,
int
sqlen);
43
int
embed_dim, num_heads, head_dim;
44
BMM_S8T_S8N_F32T
qk_bmm;
45
BMM_S8T_S8N_S8T
pv_bmm;
46
W8A8B8O8Linear
k_proj, v_proj, q_proj;
47
W8A8BFP32OFP32Linear
out_proj;
48
std::string profile_name =
"Int8OPTAttention"
;
49
};
BMM_S8T_S8N_F32T
Definition
BMM_S8T_S8N_F32T.h:7
BMM_S8T_S8N_S8T
Definition
BMM_S8T_S8N_S8T.h:7
Int8OPTAttention
Definition
Int8OPTAttention.h:31
Matrix3D
Definition
common.h:34
W8A8B8O8Linear
Definition
W8A8B8O8Linear.h:10
W8A8BFP32OFP32Linear
Definition
W8A8BFP32OFP32Linear.h:9
Int8OPTAttention_input
Definition
Int8OPTAttention.h:11
Int8OPTAttention_output
Definition
Int8OPTAttention.h:6
model_config
Definition
model.h:5
Generated by
1.11.0