TinyChatEngine
Loading...
Searching...
No Matches
Fp32GPTBigCodeDecoder.h
1#include <cstdlib>
2#include <string>
3#include <vector>
4
5#include "Fp32GPTBigCodeDecoderLayer.h"
6#include "common.h"
7#include "operators.h"
8
10 Matrix3D<float> last_hidden_state;
11 std::vector<Matrix3D<float>> past_keys, past_values;
12};
14 Matrix3D<int> input_ids;
15 std::vector<Matrix3D<float>> past_keys, past_values;
16 bool has_past_keys_values;
17
18 Fp32GPTBigCodeDecoder_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
19 Fp32GPTBigCodeDecoder_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
20 std::vector<Matrix3D<float>> past_values_)
21 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
22 has_past_keys_values = true;
23 }
24};
25
27 public:
28 Fp32GPTBigCodeDecoder(std::string param_path, const struct model_config config);
30 Matrix3D<float> prepare_decoder_attention_mask(int length, int past_length);
31 Matrix3D<float> get_position_embed(int sql_length, int past_length);
32 struct Fp32GPTBigCodeDecoder_output forward(const struct Fp32GPTBigCodeDecoder_input& input);
33 Embedding wte, wpe;
34 int voc_size, embed_dim, padding_idx, hidden_dim, num_heads, max_position_embeddings;
35 std::vector<Fp32GPTBigCodeDecoderLayer> layers;
36 LayerNorm ln_f;
37 std::string profile_name = "Fp32GPTBigCodeDecoder";
38
39 private:
40 float* attention_mask_buf;
41 float* pos_embeds_buf;
42 float* last_hidden_states_buf;
43 float* hidden_states_buf;
44};
Definition Embedding.h:5
Definition Fp32GPTBigCodeDecoder.h:26
Definition LayerNorm.h:8
Definition common.h:34
Definition Fp32GPTBigCodeDecoder.h:13
Definition Fp32GPTBigCodeDecoder.h:9
Definition model.h:5