TinyChatEngine
Loading...
Searching...
No Matches
Int4OPTForCausalLM.h
1#include "Int4OPTDecoder.h"
2
4 Matrix3D<float> logits;
5 std::vector<Matrix3D<float>> past_keys, past_values;
6};
8 Matrix3D<int> input_ids;
9 std::vector<Matrix3D<float>> past_keys, past_values;
10 bool has_past_keys_values;
11
12 Int4OPTForCausalLM_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
13 Int4OPTForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
14 std::vector<Matrix3D<float>> past_values_)
15 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
16 has_past_keys_values = true;
17 }
18};
19
21 public:
22 Int4OPTForCausalLM(std::string param_path, const struct model_config config);
23 struct Int4OPTForCausalLM_output forward(const struct Int4OPTForCausalLM_input& input);
24
25 private:
26 Int4OPTDecoder decoder;
27 Linear_FP_int4 lm_head;
28 std::string profile_name = "Int4OPTForCausalLM";
29 float* logits_output;
30 uint8_t* lm_head_weight;
31};
Definition Int4OPTDecoder.h:26
Definition Int4OPTForCausalLM.h:20
Definition linear.h:27
Definition common.h:34
Definition Int4OPTForCausalLM.h:7
Definition Int4OPTForCausalLM.h:3
Definition model.h:5