TinyChatEngine
Loading...
Searching...
No Matches
Int8OPTDecoderLayer.h
1#include "Int8OPTAttention.h"
2#include "common.h"
3#include "operators.h"
4
6 Matrix3D<float> hidden_states;
7 Matrix3D<int8_t> attentions;
8 std::pair<Matrix3D<int8_t>, Matrix3D<int8_t>> past_key_value;
9
11 std::pair<Matrix3D<int8_t>, Matrix3D<int8_t>> past_key_value_) {
12 hidden_states = hidden_states_;
13 attentions = attentions_;
14 past_key_value = past_key_value_;
15 };
16};
18 Matrix3D<float> hidden_states;
19 Matrix3D<float> attention_mask;
20 Matrix3D<int8_t> past_key, past_value;
21 bool has_past_key_value = false;
22
23 Int8OPTDecoderLayer_input(Matrix3D<float> &hidden_states_, Matrix3D<float> &attention_mask_) {
24 hidden_states = hidden_states_;
25 attention_mask = attention_mask_;
26 has_past_key_value = false;
27 }
28
29 Int8OPTDecoderLayer_input(Matrix3D<float> &hidden_states_, Matrix3D<float> &attention_mask_,
30 Matrix3D<int8_t> past_key_, Matrix3D<int8_t> past_value_) {
31 hidden_states = hidden_states_;
32 attention_mask = attention_mask_;
33 past_key = past_key_;
34 past_value = past_value_;
35 has_past_key_value = true;
36 }
37};
38
40 public:
41 Int8OPTDecoderLayer(std::string param_path, const struct model_config config, int layer_idx,
42 LayerNormQ self_attn_layer_norm, LayerNormQ final_layer_norm, W8A8B8O8LinearReLU fc1,
44 W8A8B8O8Linear k_proj, W8A8B8O8Linear v_proj, W8A8B8O8Linear q_proj,
45 W8A8BFP32OFP32Linear out_proj);
46 struct Int8OPTDecoderLayer_output forward(const struct Int8OPTDecoderLayer_input &input);
47
48 int embed_dim, num_attention_heads, hidden_dim, layer_idx;
49 LayerNormQ self_attn_layer_norm, final_layer_norm; // from torch_int.nn
53 std::string profile_name = "Int8OPTDecoderLayer";
54};
Definition BMM_S8T_S8N_F32T.h:7
Definition BMM_S8T_S8N_S8T.h:7
Definition Int8OPTAttention.h:31
Definition Int8OPTDecoderLayer.h:39
Definition LayerNormQ.h:8
Definition common.h:34
Definition W8A8B8O8LinearReLU.h:10
Definition W8A8B8O8Linear.h:10
Definition W8A8BFP32OFP32Linear.h:9
Definition Int8OPTDecoderLayer.h:17
Definition Int8OPTDecoderLayer.h:5
Definition model.h:5