TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Fp32GPTBigCodeForCausalLM.h
1
#include "Fp32GPTBigCodeDecoder.h"
2
3
struct
Fp32GPTBigCodeForCausalLM_output
{
4
Matrix3D<float>
logits;
5
std::vector<Matrix3D<float>> past_keys, past_values;
6
};
7
struct
Fp32GPTBigCodeForCausalLM_input
{
8
Matrix3D<int>
input_ids;
9
std::vector<Matrix3D<float>> past_keys, past_values;
10
bool
has_past_keys_values;
11
12
Fp32GPTBigCodeForCausalLM_input
() {}
13
Fp32GPTBigCodeForCausalLM_input
(
Matrix3D<int>
input_ids_) : input_ids(input_ids_) { has_past_keys_values =
false
; }
14
Fp32GPTBigCodeForCausalLM_input
(
Matrix3D<int>
input_ids_, std::vector<
Matrix3D<float>
> past_keys_,
15
std::vector<
Matrix3D<float>
> past_values_)
16
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
17
has_past_keys_values =
true
;
18
}
19
};
20
21
class
Fp32GPTBigCodeForCausalLM
{
22
public
:
23
Fp32GPTBigCodeForCausalLM
(std::string param_path,
const
struct
model_config
config);
24
struct
Fp32GPTBigCodeForCausalLM_output
forward(const struct
Fp32GPTBigCodeForCausalLM_input
& input);
25
26
private
:
27
Fp32GPTBigCodeDecoder
decoder;
28
Linear_FP
lm_head;
29
std::string profile_name =
"Fp32GPTBigCodeForCausalLM"
;
30
float
* logits_output;
31
float
* lm_head_weight;
32
};
Fp32GPTBigCodeDecoder
Definition
Fp32GPTBigCodeDecoder.h:26
Fp32GPTBigCodeForCausalLM
Definition
Fp32GPTBigCodeForCausalLM.h:21
Linear_FP
Definition
linear.h:6
Matrix3D
Definition
common.h:34
Fp32GPTBigCodeForCausalLM_input
Definition
Fp32GPTBigCodeForCausalLM.h:7
Fp32GPTBigCodeForCausalLM_output
Definition
Fp32GPTBigCodeForCausalLM.h:3
model_config
Definition
model.h:5
Generated by
1.11.0