TinyChatEngine
Loading...
Searching...
No Matches
Int4llamaDecoder.h
1#include <cstdlib>
2#include <string>
3#include <vector>
4
5#include "Int4llamaDecoderLayer.h"
6#include "common.h"
7#include "operators.h"
8
10#ifdef QM_CUDA
11 Matrix3D<float16_t> last_hidden_state;
12 std::vector<Matrix3D<float16_t>> past_keys, past_values;
13#else
14 Matrix3D<float> last_hidden_state;
15 std::vector<Matrix3D<float>> past_keys, past_values;
16#endif
17};
19 Matrix3D<int> input_ids;
20 Matrix3D<float> image_embed;
21 Matrix3D<int> second_input_ids;
22 bool has_past_keys_values;
23 bool is_llava;
24#ifdef QM_CUDA
25 std::vector<Matrix3D<float16_t>> past_keys, past_values;
26#else
27 std::vector<Matrix3D<float>> past_keys, past_values;
28#endif
29
31 Int4llamaDecoder_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) {
32 has_past_keys_values = false;
33 is_llava = false;
34 }
35#ifdef QM_CUDA
36 Int4llamaDecoder_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float16_t>> past_keys_,
37 std::vector<Matrix3D<float16_t>> past_values_)
38#else
39 Int4llamaDecoder_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
40 std::vector<Matrix3D<float>> past_values_)
41#endif
42 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
43 has_past_keys_values = true;
44 is_llava = false;
45 }
46 Int4llamaDecoder_input(Matrix3D<int> input_ids_, Matrix3D<float> image_embed_, Matrix3D<int> second_input_ids_)
47 : input_ids(input_ids_), image_embed(image_embed_), second_input_ids(second_input_ids_) {
48 has_past_keys_values = false;
49 is_llava = true;
50 }
52 : input_ids(input_ids_), image_embed(image_embed_) {
53 has_past_keys_values = false;
54 is_llava = true;
55 }
56};
57
59 public:
60 Int4llamaDecoder(std::string param_path, const struct model_config config);
62 Matrix3D<float> prepare_decoder_attention_mask(int length, int past_length);
63 struct Int4llamaDecoder_output forward(std::string param_path, const struct Int4llamaDecoder_input& input);
64 int voc_size, embed_dim, padding_idx, hidden_dim, num_heads;
65 float rms_norm_eps;
66 std::vector<Int4llamaDecoderLayer> layers;
67 std::string profile_name = "Int4llamaDecoder";
68#ifdef QM_CUDA
69 void free_cuda_memory();
70 Embedding embed_tokens;
71 LlamaRMSNorm_cuda norm;
72
73 float16_t* attention_mask_buf = nullptr;
74 float16_t* last_hidden_states_buf = nullptr;
75 float* hidden_states_buf = nullptr;
76 float16_t* hidden_states_half_buf = nullptr;
77#else
78 Embedding embed_tokens;
79 LlamaRMSNorm norm;
80
81 float* attention_mask_buf;
82 float* pos_embeds_buf;
83 float* last_hidden_states_buf;
84 float* hidden_states_buf;
85 float* inputs_embeds_buf;
86 float* first_input_ids_buf;
87 float* image_embed_buf;
88 float* second_input_ids_buf;
89#endif
90 float* norm_weight_ptr = nullptr;
91};
Definition Embedding.h:5
Definition Int4llamaDecoder.h:58
Definition LlamaRMSNorm.h:4
Definition common.h:34
Definition Int4llamaDecoder.h:18
Definition Int4llamaDecoder.h:9
Definition model.h:5