TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
OPTForCausalLM.h
1
#include "Int8OPTDecoder.h"
2
3
struct
OPTForCausalLM_output
{
4
Matrix3D<float>
logits;
5
std::vector<Matrix3D<int8_t>> past_keys, past_values;
6
};
7
struct
OPTForCausalLM_input
{
8
Matrix3D<int>
input_ids;
9
std::vector<Matrix3D<int8_t>> past_keys, past_values;
10
bool
has_past_keys_values;
11
12
OPTForCausalLM_input
(
Matrix3D<int>
input_ids_) : input_ids(input_ids_) { has_past_keys_values =
false
; }
13
OPTForCausalLM_input
(
Matrix3D<int>
input_ids_, std::vector<
Matrix3D<int8_t>
> past_keys_,
14
std::vector<
Matrix3D<int8_t>
> past_values_)
15
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
16
has_past_keys_values =
true
;
17
}
18
};
19
20
class
OPTForCausalLM
{
21
public
:
22
OPTForCausalLM
(std::string param_path,
const
struct
model_config
config);
23
struct
OPTForCausalLM_output
forward(const struct
OPTForCausalLM_input
& input);
24
25
private
:
26
Int8OPTDecoder
decoder;
27
Linear_FP
lm_head;
28
std::string profile_name =
"OPTForCausalLM"
;
29
float
* logits_output;
30
float
* lm_head_weight;
31
};
Int8OPTDecoder
Definition
Int8OPTDecoder.h:26
Linear_FP
Definition
linear.h:6
Matrix3D
Definition
common.h:34
OPTForCausalLM
Definition
OPTForCausalLM.h:20
OPTForCausalLM_input
Definition
OPTForCausalLM.h:7
OPTForCausalLM_output
Definition
OPTForCausalLM.h:3
model_config
Definition
model.h:5
Generated by
1.11.0