TinyChatEngine
Loading...
Searching...
No Matches
Int4llamaForCausalLM.h
1#include "Int4llamaDecoder.h"
2
4 Matrix3D<float> logits;
5#ifdef QM_CUDA
6 std::vector<Matrix3D<float16_t>> past_keys, past_values;
7#else
8 std::vector<Matrix3D<float>> past_keys, past_values;
9#endif
10};
12 Matrix3D<int> input_ids;
13 Matrix3D<float> image_embed;
14 Matrix3D<int> second_input_ids;
15 bool has_past_keys_values;
16 bool is_llava;
17#ifdef QM_CUDA
18 std::vector<Matrix3D<float16_t>> past_keys, past_values;
19#else
20 std::vector<Matrix3D<float>> past_keys, past_values;
21#endif
22
24 Int4LlamaForCausalLM_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) {
25 has_past_keys_values = false;
26 is_llava = false;
27 }
28#ifdef QM_CUDA
29 Int4LlamaForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float16_t>> past_keys_,
30 std::vector<Matrix3D<float16_t>> past_values_)
31#else
32 Int4LlamaForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
33 std::vector<Matrix3D<float>> past_values_)
34#endif
35 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
36 has_past_keys_values = true;
37 is_llava = false;
38 }
39 Int4LlamaForCausalLM_input(Matrix3D<int> input_ids_, Matrix3D<float> image_embed_, Matrix3D<int> second_input_ids_)
40 : input_ids(input_ids_), image_embed(image_embed_), second_input_ids(second_input_ids_) {
41 has_past_keys_values = false;
42 is_llava = true;
43 }
45 : input_ids(input_ids_), image_embed(image_embed_) {
46 has_past_keys_values = false;
47 is_llava = true;
48 }
49};
50
52 public:
53 Int4LlamaForCausalLM(std::string param_path, const struct model_config config);
55 struct Int4LlamaForCausalLM_output forward(std::string param_path, const struct Int4LlamaForCausalLM_input& input);
56 float* logits_output = nullptr;
57#ifdef QM_CUDA
58 void free_cuda_memory();
59 int* lm_head_weight = nullptr;
60 float16_t* logits_output_half = nullptr;
61#else
62 uint8_t* lm_head_weight;
63#endif
64
65 private:
66 std::string profile_name = "Int4LlamaForCausalLM";
67 Int4llamaDecoder decoder;
68#ifdef QM_CUDA
69 Linear_half_int4 lm_head;
70#else
71 Linear_FP_int4 lm_head;
72#endif
73};
Definition Int4llamaForCausalLM.h:51
Definition Int4llamaDecoder.h:58
Definition linear.h:27
Definition common.h:34
Definition Int4llamaForCausalLM.h:11
Definition Int4llamaForCausalLM.h:3
Definition model.h:5