TinyChatEngine
Loading...
Searching...
No Matches
OPTForCausalLM.h
1#include "Int8OPTDecoder.h"
2
4 Matrix3D<float> logits;
5 std::vector<Matrix3D<int8_t>> past_keys, past_values;
6};
8 Matrix3D<int> input_ids;
9 std::vector<Matrix3D<int8_t>> past_keys, past_values;
10 bool has_past_keys_values;
11
12 OPTForCausalLM_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
13 OPTForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<int8_t>> past_keys_,
14 std::vector<Matrix3D<int8_t>> past_values_)
15 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
16 has_past_keys_values = true;
17 }
18};
19
21 public:
22 OPTForCausalLM(std::string param_path, const struct model_config config);
23 struct OPTForCausalLM_output forward(const struct OPTForCausalLM_input& input);
24
25 private:
26 Int8OPTDecoder decoder;
27 Linear_FP lm_head;
28 std::string profile_name = "OPTForCausalLM";
29 float* logits_output;
30 float* lm_head_weight;
31};
Definition Int8OPTDecoder.h:26
Definition linear.h:6
Definition common.h:34
Definition OPTForCausalLM.h:20
Definition OPTForCausalLM.h:7
Definition OPTForCausalLM.h:3
Definition model.h:5