TinyChatEngine
Loading...
Searching...
No Matches
Fp32llamaDecoder.h
1#include <cstdlib>
2#include <string>
3#include <vector>
4
5#include "Fp32llamaDecoderLayer.h"
6#include "common.h"
7#include "operators.h"
8
10 Matrix3D<float> last_hidden_state;
11 std::vector<Matrix3D<float>> past_keys, past_values;
12};
14 Matrix3D<int> input_ids;
15 Matrix3D<float> image_embed;
16 Matrix3D<int> second_input_ids;
17 std::vector<Matrix3D<float>> past_keys, past_values;
18 bool has_past_keys_values;
19 bool is_llava;
20
22 Fp32llamaDecoder_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) {
23 has_past_keys_values = false;
24 is_llava = false;
25 }
26 Fp32llamaDecoder_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
27 std::vector<Matrix3D<float>> past_values_)
28 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
29 has_past_keys_values = true;
30 is_llava = false;
31 }
32 Fp32llamaDecoder_input(Matrix3D<int> input_ids_, Matrix3D<float> image_embed_, Matrix3D<int> second_input_ids_)
33 : input_ids(input_ids_), image_embed(image_embed_), second_input_ids(second_input_ids_) {
34 has_past_keys_values = false;
35 is_llava = true;
36 }
38 : input_ids(input_ids_), image_embed(image_embed_) {
39 has_past_keys_values = false;
40 is_llava = true;
41 }
42};
43
45 public:
46 Fp32llamaDecoder(std::string param_path, const struct model_config config);
48 Matrix3D<float> prepare_decoder_attention_mask(int length, int past_length);
49 struct Fp32llamaDecoder_output forward(const struct Fp32llamaDecoder_input& input);
50 Embedding embed_tokens;
51 LlamaRMSNorm norm;
52 float rms_norm_eps;
53 int voc_size, embed_dim, padding_idx, hidden_dim, num_heads;
54 std::vector<Fp32llamaDecoderLayer> layers;
55 std::string profile_name = "Fp32llamaDecoder";
56
57 private:
58 float* attention_mask_buf;
59 float* pos_embeds_buf;
60 float* last_hidden_states_buf;
61 float* hidden_states_buf;
62 float* inputs_embeds_buf;
63 float* first_input_ids_buf;
64 float* image_embed_buf;
65 float* second_input_ids_buf;
66};
Definition Embedding.h:5
Definition Fp32llamaDecoder.h:44
Definition LlamaRMSNorm.h:4
Definition common.h:34
Definition Fp32llamaDecoder.h:13
Definition Fp32llamaDecoder.h:9
Definition model.h:5