TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Int4OPTForCausalLM.h
1
#include "Int4OPTDecoder.h"
2
3
struct
Int4OPTForCausalLM_output
{
4
Matrix3D<float>
logits;
5
std::vector<Matrix3D<float>> past_keys, past_values;
6
};
7
struct
Int4OPTForCausalLM_input
{
8
Matrix3D<int>
input_ids;
9
std::vector<Matrix3D<float>> past_keys, past_values;
10
bool
has_past_keys_values;
11
12
Int4OPTForCausalLM_input
(
Matrix3D<int>
input_ids_) : input_ids(input_ids_) { has_past_keys_values =
false
; }
13
Int4OPTForCausalLM_input
(
Matrix3D<int>
input_ids_, std::vector<
Matrix3D<float>
> past_keys_,
14
std::vector<
Matrix3D<float>
> past_values_)
15
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
16
has_past_keys_values =
true
;
17
}
18
};
19
20
class
Int4OPTForCausalLM
{
21
public
:
22
Int4OPTForCausalLM
(std::string param_path,
const
struct
model_config
config);
23
struct
Int4OPTForCausalLM_output
forward(const struct
Int4OPTForCausalLM_input
& input);
24
25
private
:
26
Int4OPTDecoder
decoder;
27
Linear_FP_int4
lm_head;
28
std::string profile_name =
"Int4OPTForCausalLM"
;
29
float
* logits_output;
30
uint8_t* lm_head_weight;
31
};
Int4OPTDecoder
Definition
Int4OPTDecoder.h:26
Int4OPTForCausalLM
Definition
Int4OPTForCausalLM.h:20
Linear_FP_int4
Definition
linear.h:27
Matrix3D
Definition
common.h:34
Int4OPTForCausalLM_input
Definition
Int4OPTForCausalLM.h:7
Int4OPTForCausalLM_output
Definition
Int4OPTForCausalLM.h:3
model_config
Definition
model.h:5
Generated by
1.11.0