TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Fp32OPTForCausalLM.h
1
#include "Fp32OPTDecoder.h"
2
3
struct
Fp32OPTForCausalLM_output
{
4
Matrix3D<float>
logits;
5
std::vector<Matrix3D<float>> past_keys, past_values;
6
};
7
struct
Fp32OPTForCausalLM_input
{
8
Matrix3D<int>
input_ids;
9
std::vector<Matrix3D<float>> past_keys, past_values;
10
bool
has_past_keys_values;
11
12
Fp32OPTForCausalLM_input
(
Matrix3D<int>
input_ids_) : input_ids(input_ids_) { has_past_keys_values =
false
; }
13
Fp32OPTForCausalLM_input
(
Matrix3D<int>
input_ids_, std::vector<
Matrix3D<float>
> past_keys_,
14
std::vector<
Matrix3D<float>
> past_values_)
15
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
16
has_past_keys_values =
true
;
17
}
18
};
19
20
class
Fp32OPTForCausalLM
{
21
public
:
22
Fp32OPTForCausalLM
(std::string param_path,
const
struct
model_config
config);
23
struct
Fp32OPTForCausalLM_output
forward(const struct
Fp32OPTForCausalLM_input
& input);
24
25
private
:
26
Fp32OPTDecoder
decoder;
27
Linear_FP
lm_head;
28
std::string profile_name =
"Fp32OPTForCausalLM"
;
29
float
* logits_output;
30
float
* lm_head_weight;
31
};
Fp32OPTDecoder
Definition
Fp32OPTDecoder.h:26
Fp32OPTForCausalLM
Definition
Fp32OPTForCausalLM.h:20
Linear_FP
Definition
linear.h:6
Matrix3D
Definition
common.h:34
Fp32OPTForCausalLM_input
Definition
Fp32OPTForCausalLM.h:7
Fp32OPTForCausalLM_output
Definition
Fp32OPTForCausalLM.h:3
model_config
Definition
model.h:5
Generated by
1.11.0