TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Int4GPTBigCodeForCausalLM.h
1
#include "Int4GPTBigCodeDecoder.h"
2
3
struct
Int4GPTBigCodeForCausalLM_output
{
4
Matrix3D<float>
logits;
5
std::vector<Matrix3D<float>> past_keys, past_values;
6
};
7
struct
Int4GPTBigCodeForCausalLM_input
{
8
Matrix3D<int>
input_ids;
9
std::vector<Matrix3D<float>> past_keys, past_values;
10
bool
has_past_keys_values;
11
12
Int4GPTBigCodeForCausalLM_input
() {}
13
Int4GPTBigCodeForCausalLM_input
(
Matrix3D<int>
input_ids_) : input_ids(input_ids_) { has_past_keys_values =
false
; }
14
Int4GPTBigCodeForCausalLM_input
(
Matrix3D<int>
input_ids_, std::vector<
Matrix3D<float>
> past_keys_,
15
std::vector<
Matrix3D<float>
> past_values_)
16
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
17
has_past_keys_values =
true
;
18
}
19
};
20
21
class
Int4GPTBigCodeForCausalLM
{
22
public
:
23
Int4GPTBigCodeForCausalLM
(std::string param_path,
const
struct
model_config
config);
24
struct
Int4GPTBigCodeForCausalLM_output
forward(std::string param_path, const struct
Int4GPTBigCodeForCausalLM_input
& input);
25
26
private
:
27
Int4GPTBigCodeDecoder
decoder;
28
Linear_FP_int4
lm_head;
29
std::string profile_name = "
Int4GPTBigCodeForCausalLM
";
30
float* logits_output;
31
uint8_t* lm_head_weight;
32
};
Int4GPTBigCodeDecoder
Definition
Int4GPTBigCodeDecoder.h:26
Int4GPTBigCodeForCausalLM
Definition
Int4GPTBigCodeForCausalLM.h:21
Linear_FP_int4
Definition
linear.h:27
Matrix3D
Definition
common.h:34
Int4GPTBigCodeForCausalLM_input
Definition
Int4GPTBigCodeForCausalLM.h:7
Int4GPTBigCodeForCausalLM_output
Definition
Int4GPTBigCodeForCausalLM.h:3
model_config
Definition
model.h:5
Generated by
1.11.0