TinyChatEngine
Loading...
Searching...
No Matches
Fp32GPTBigCodeAttention.h
1#include <utility>
2
3#include "common.h"
4#include "operators.h"
5
7 Matrix3D<float> attn_output;
8 Matrix3D<float> attn_probs_reshaped;
9 std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value;
10};
12 Matrix3D<float> hidden_states;
13 Matrix3D<float> attention_mask;
14 Matrix3D<float> past_key, past_value;
15 bool has_past_key_value = false;
16 int layer_idx;
17
18 Fp32GPTBigCodeAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, int layer_idx_)
19 : hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {}
20
21 Fp32GPTBigCodeAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, Matrix3D<float> past_key_,
22 Matrix3D<float> past_value_, bool has_past_key_value_, int layer_idx_)
23 : hidden_states(hidden_states_),
24 attention_mask(attention_mask_),
25 past_key(past_key_),
26 past_value(past_value_),
27 has_past_key_value(has_past_key_value_),
28 layer_idx(layer_idx_) {}
29};
30
32 public:
33 Fp32GPTBigCodeAttention(std::string param_path, const struct model_config config);
35 static void initialized_memory(const struct model_config config);
36 struct Fp32GPTBigCodeAttention_output forward(const struct Fp32GPTBigCodeAttention_input &input);
37
38 private:
39 void unshape(Matrix3D<float> shaped, Matrix3D<float> unshape, int sqlen);
40 void shape_qkv(Matrix3D<float> unshape, Matrix3D<float> shaped_q, Matrix3D<float> shaped_k,
41 Matrix3D<float> shaped_v, int sqlen);
42 float scaling;
43 int embed_dim, num_heads, head_dim, kv_heads, kv_dim;
44 BMM_F32T qk_bmm, pv_bmm;
45 Linear_FP c_attn, c_proj;
46 std::string profile_name = "Fp32GPTBigCodeAttention";
47};
Definition BMM_F32T.h:3
Definition Fp32GPTBigCodeAttention.h:31
Definition linear.h:6
Definition common.h:34
Definition Fp32GPTBigCodeAttention.h:11
Definition Fp32GPTBigCodeAttention.h:6
Definition model.h:5