TinyChatEngine
Loading...
Searching...
No Matches
Fp32llamaForCausalLM.h
1#include "Fp32llamaDecoder.h"
2
4 Matrix3D<float> logits;
5 std::vector<Matrix3D<float>> past_keys, past_values;
6};
8 Matrix3D<int> input_ids;
9 Matrix3D<float> image_embed;
10 Matrix3D<int> second_input_ids;
11 std::vector<Matrix3D<float>> past_keys, past_values;
12 bool has_past_keys_values;
13 bool is_llava;
14
16 Fp32LlamaForCausalLM_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) {
17 has_past_keys_values = false;
18 is_llava = false;
19 }
20 Fp32LlamaForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
21 std::vector<Matrix3D<float>> past_values_)
22 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
23 has_past_keys_values = true;
24 is_llava = false;
25 }
26 Fp32LlamaForCausalLM_input(Matrix3D<int> input_ids_, Matrix3D<float> image_embed_, Matrix3D<int> second_input_ids_)
27 : input_ids(input_ids_), image_embed(image_embed_), second_input_ids(second_input_ids_) {
28 has_past_keys_values = false;
29 is_llava = true;
30 }
32 : input_ids(input_ids_), image_embed(image_embed_) {
33 has_past_keys_values = false;
34 is_llava = true;
35 }
36};
37
39 public:
40 Fp32LlamaForCausalLM(std::string param_path, const struct model_config config);
41
42 struct Fp32LlamaForCausalLM_output forward(const struct Fp32LlamaForCausalLM_input& input);
43
44 private:
45 Fp32llamaDecoder decoder;
46 Linear_FP lm_head;
47 std::string profile_name = "Fp32LlamaForCausalLM";
48 float* logits_output;
49 float* lm_head_weight;
50};
Definition Fp32llamaForCausalLM.h:38
Definition Fp32llamaDecoder.h:44
Definition linear.h:6
Definition common.h:34
Definition Fp32llamaForCausalLM.h:7
Definition Fp32llamaForCausalLM.h:3
Definition model.h:5