TinyChatEngine
Loading...
Searching...
No Matches
Int8OPTDecoder.h
1#include <cstdlib>
2#include <string>
3#include <vector>
4
5#include "Int8OPTDecoderLayer.h"
6#include "common.h"
7#include "operators.h"
8
10 Matrix3D<float> last_hidden_state;
11 std::vector<Matrix3D<int8_t>> past_keys, past_values;
12};
14 Matrix3D<int> input_ids;
15 std::vector<Matrix3D<int8_t>> past_keys, past_values;
16 bool has_past_keys_values;
17
18 Int8OPTDecoder_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
19 Int8OPTDecoder_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<int8_t>> past_keys_,
20 std::vector<Matrix3D<int8_t>> past_values_)
21 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
22 has_past_keys_values = true;
23 }
24};
25
27 public:
28 Int8OPTDecoder(std::string param_path, const struct model_config config);
30 Matrix3D<float> prepare_decoder_attention_mask(int length, int past_length);
31 Matrix3D<float> get_position_embed(int sql_length, int past_length);
32 struct Int8OPTDecoder_output forward(const struct Int8OPTDecoder_input& input);
33 Embedding embed_tokens, embed_positions;
34 int voc_size, embed_dim, padding_idx, hidden_dim, num_heads;
35 std::vector<Int8OPTDecoderLayer> layers;
36 LayerNorm final_layer_norm;
37 std::string profile_name = "Int8OPTDecoder";
38
39 private:
40 float* attention_mask_buf;
41 float* pos_embeds_buf;
42 float* last_hidden_states_buf;
43 float* hidden_states_buf;
44};
Definition Embedding.h:5
Definition Int8OPTDecoder.h:26
Definition LayerNorm.h:8
Definition common.h:34
Definition Int8OPTDecoder.h:13
Definition Int8OPTDecoder.h:9
Definition model.h:5