TinyChatEngine
Loading...
Searching...
No Matches
Int4GPTBigCodeForCausalLM.h
1#include "Int4GPTBigCodeDecoder.h"
2
4 Matrix3D<float> logits;
5 std::vector<Matrix3D<float>> past_keys, past_values;
6};
8 Matrix3D<int> input_ids;
9 std::vector<Matrix3D<float>> past_keys, past_values;
10 bool has_past_keys_values;
11
13 Int4GPTBigCodeForCausalLM_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
14 Int4GPTBigCodeForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
15 std::vector<Matrix3D<float>> past_values_)
16 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
17 has_past_keys_values = true;
18 }
19};
20
22 public:
23 Int4GPTBigCodeForCausalLM(std::string param_path, const struct model_config config);
24 struct Int4GPTBigCodeForCausalLM_output forward(std::string param_path, const struct Int4GPTBigCodeForCausalLM_input& input);
25
26 private:
28 Linear_FP_int4 lm_head;
29 std::string profile_name = "Int4GPTBigCodeForCausalLM";
30 float* logits_output;
31 uint8_t* lm_head_weight;
32};
Definition Int4GPTBigCodeDecoder.h:26
Definition Int4GPTBigCodeForCausalLM.h:21
Definition linear.h:27
Definition common.h:34
Definition Int4GPTBigCodeForCausalLM.h:7
Definition Int4GPTBigCodeForCausalLM.h:3
Definition model.h:5