TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Int4GPTBigCodeAttention.h
1
#include <utility>
2
3
#include "common.h"
4
#include "operators.h"
5
6
struct
Int4GPTBigCodeAttention_output
{
7
Matrix3D<float>
attn_output;
8
Matrix3D<float>
attn_probs_reshaped;
9
std::pair<Matrix3D<float>,
Matrix3D<float>
> past_key_value;
10
};
11
struct
Int4GPTBigCodeAttention_input
{
12
Matrix3D<float>
hidden_states;
13
Matrix3D<float>
attention_mask;
14
Matrix3D<float>
past_key, past_value;
15
bool
has_past_key_value =
false
;
16
int
layer_idx;
17
18
Int4GPTBigCodeAttention_input
(
Matrix3D<float>
hidden_states_,
Matrix3D<float>
attention_mask_,
int
layer_idx_)
19
: hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {}
20
21
Int4GPTBigCodeAttention_input
(
Matrix3D<float>
hidden_states_,
Matrix3D<float>
attention_mask_,
Matrix3D<float>
past_key_,
22
Matrix3D<float>
past_value_,
bool
has_past_key_value_,
int
layer_idx_)
23
: hidden_states(hidden_states_),
24
attention_mask(attention_mask_),
25
past_key(past_key_),
26
past_value(past_value_),
27
has_past_key_value(has_past_key_value_),
28
layer_idx(layer_idx_) {}
29
};
30
31
class
Int4GPTBigCodeAttention
{
32
public
:
33
Int4GPTBigCodeAttention
(std::string param_path,
const
struct
model_config
config);
34
Int4GPTBigCodeAttention
() {}
35
static
void
initialized_memory(
const
struct
model_config
config);
36
struct
Int4GPTBigCodeAttention_output
forward(const struct
Int4GPTBigCodeAttention_input
&input);
37
38
private
:
39
void
unshape(
Matrix3D<float>
shaped,
Matrix3D<float>
unshape,
int
sqlen);
40
void
shape_qkv(
Matrix3D<float>
unshape,
Matrix3D<float>
shaped_q,
Matrix3D<float>
shaped_k,
41
Matrix3D<float>
shaped_v,
int
sqlen);
42
int
embed_dim, num_heads, head_dim, kv_heads, kv_dim;
43
BMM_F32T
qk_bmm, pv_bmm;
44
Linear_FP_int4
c_attn, c_proj;
45
std::string profile_name =
"Int4GPTBigCodeAttention"
;
46
};
BMM_F32T
Definition
BMM_F32T.h:3
Int4GPTBigCodeAttention
Definition
Int4GPTBigCodeAttention.h:31
Linear_FP_int4
Definition
linear.h:27
Matrix3D
Definition
common.h:34
Int4GPTBigCodeAttention_input
Definition
Int4GPTBigCodeAttention.h:11
Int4GPTBigCodeAttention_output
Definition
Int4GPTBigCodeAttention.h:6
model_config
Definition
model.h:5
Generated by
1.11.0