TinyChatEngine
Loading...
Searching...
No Matches
Fp32GPTBigCodeForCausalLM.h
1#include "Fp32GPTBigCodeDecoder.h"
2
4 Matrix3D<float> logits;
5 std::vector<Matrix3D<float>> past_keys, past_values;
6};
8 Matrix3D<int> input_ids;
9 std::vector<Matrix3D<float>> past_keys, past_values;
10 bool has_past_keys_values;
11
13 Fp32GPTBigCodeForCausalLM_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
14 Fp32GPTBigCodeForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
15 std::vector<Matrix3D<float>> past_values_)
16 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
17 has_past_keys_values = true;
18 }
19};
20
22 public:
23 Fp32GPTBigCodeForCausalLM(std::string param_path, const struct model_config config);
24 struct Fp32GPTBigCodeForCausalLM_output forward(const struct Fp32GPTBigCodeForCausalLM_input& input);
25
26 private:
28 Linear_FP lm_head;
29 std::string profile_name = "Fp32GPTBigCodeForCausalLM";
30 float* logits_output;
31 float* lm_head_weight;
32};
Definition Fp32GPTBigCodeDecoder.h:26
Definition Fp32GPTBigCodeForCausalLM.h:21
Definition linear.h:6
Definition common.h:34
Definition Fp32GPTBigCodeForCausalLM.h:7
Definition Fp32GPTBigCodeForCausalLM.h:3
Definition model.h:5