TinyChatEngine
Loading...
Searching...
No Matches
Int8OPTAttention.h
1#include <utility>
2
3#include "common.h"
4#include "operators.h"
5
7 Matrix3D<float> attn_output;
8 Matrix3D<int8_t> attn_probs_reshaped;
9 std::pair<Matrix3D<int8_t>, Matrix3D<int8_t>> past_key_value;
10};
12 Matrix3D<int8_t> hidden_states;
13 Matrix3D<float> attention_mask;
14 Matrix3D<int8_t> past_key, past_value;
15 bool has_past_key_value = false;
16 int layer_idx;
17
18 Int8OPTAttention_input(Matrix3D<int8_t> hidden_states_, Matrix3D<float> attention_mask_, int layer_idx_)
19 : hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {}
20
21 Int8OPTAttention_input(Matrix3D<int8_t> hidden_states_, Matrix3D<float> attention_mask_, Matrix3D<int8_t> past_key_,
22 Matrix3D<int8_t> past_value_, bool has_past_key_value_, int layer_idx_)
23 : hidden_states(hidden_states_),
24 attention_mask(attention_mask_),
25 past_key(past_key_),
26 past_value(past_value_),
27 has_past_key_value(has_past_key_value_),
28 layer_idx(layer_idx_) {}
29};
30
32 public:
33 Int8OPTAttention(std::string param_path, const struct model_config config, BMM_S8T_S8N_F32T &qk_bmm,
34 BMM_S8T_S8N_S8T &pv_bmm, W8A8B8O8Linear &k_proj, W8A8B8O8Linear &v_proj, W8A8B8O8Linear &q_proj,
35 W8A8BFP32OFP32Linear &out_proj);
37 static void initialized_memory(const struct model_config config);
38 struct Int8OPTAttention_output forward(const struct Int8OPTAttention_input &input);
39
40 private:
41 void unshape(Matrix3D<int8_t> shaped, Matrix3D<int8_t> unshape, int sqlen);
42 void shpae(Matrix3D<int8_t> unshape, Matrix3D<int8_t> shaped, int sqlen);
43 int embed_dim, num_heads, head_dim;
44 BMM_S8T_S8N_F32T qk_bmm;
45 BMM_S8T_S8N_S8T pv_bmm;
46 W8A8B8O8Linear k_proj, v_proj, q_proj;
47 W8A8BFP32OFP32Linear out_proj;
48 std::string profile_name = "Int8OPTAttention";
49};
Definition BMM_S8T_S8N_F32T.h:7
Definition BMM_S8T_S8N_S8T.h:7
Definition Int8OPTAttention.h:31
Definition common.h:34
Definition W8A8B8O8Linear.h:10
Definition W8A8BFP32OFP32Linear.h:9
Definition Int8OPTAttention.h:11
Definition Int8OPTAttention.h:6
Definition model.h:5