TinyChatEngine
Loading...
Searching...
No Matches
Fp32OPTForCausalLM.h
1#include "Fp32OPTDecoder.h"
2
4 Matrix3D<float> logits;
5 std::vector<Matrix3D<float>> past_keys, past_values;
6};
8 Matrix3D<int> input_ids;
9 std::vector<Matrix3D<float>> past_keys, past_values;
10 bool has_past_keys_values;
11
12 Fp32OPTForCausalLM_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
13 Fp32OPTForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
14 std::vector<Matrix3D<float>> past_values_)
15 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
16 has_past_keys_values = true;
17 }
18};
19
21 public:
22 Fp32OPTForCausalLM(std::string param_path, const struct model_config config);
23 struct Fp32OPTForCausalLM_output forward(const struct Fp32OPTForCausalLM_input& input);
24
25 private:
26 Fp32OPTDecoder decoder;
27 Linear_FP lm_head;
28 std::string profile_name = "Fp32OPTForCausalLM";
29 float* logits_output;
30 float* lm_head_weight;
31};
Definition Fp32OPTDecoder.h:26
Definition Fp32OPTForCausalLM.h:20
Definition linear.h:6
Definition common.h:34
Definition Fp32OPTForCausalLM.h:7
Definition Fp32OPTForCausalLM.h:3
Definition model.h:5