TinyChatEngine
Loading...
Searching...
No Matches
llm
include
nn_modules
Fp32CLIPVisionTransformer.h
1
#include <cstdlib>
2
#include <string>
3
#include <vector>
4
5
#include "Fp32CLIPEncoder.h"
6
#include "common.h"
7
#include "operators.h"
8
9
struct
Fp32CLIPVisionTransformer_output
{
10
Matrix3D<float>
last_hidden_state;
11
std::vector<Matrix3D<float>> past_keys, past_values;
12
};
13
struct
Fp32CLIPVisionTransformer_input
{
14
Matrix3D<float>
input_image;
15
std::vector<Matrix3D<float>> past_keys, past_values;
16
bool
has_past_keys_values;
17
18
Fp32CLIPVisionTransformer_input
() {}
19
Fp32CLIPVisionTransformer_input
(
Matrix3D<float>
input_image_) : input_image(input_image_) { has_past_keys_values =
false
; }
20
Fp32CLIPVisionTransformer_input
(
Matrix3D<float>
input_image_, std::vector<
Matrix3D<float>
> past_keys_,
21
std::vector<
Matrix3D<float>
> past_values_)
22
: input_image(input_image_), past_keys(past_keys_), past_values(past_values_) {
23
has_past_keys_values =
true
;
24
}
25
};
26
27
class
Fp32CLIPVisionTransformer
{
28
public
:
29
Fp32CLIPVisionTransformer
(std::string param_path,
const
struct
model_config
config,
bool
is_vila);
30
Fp32CLIPVisionTransformer
(){};
31
struct
Fp32CLIPVisionTransformer_output
forward(const struct
Fp32CLIPVisionTransformer_input
& input,
bool
is_vila);
32
Embedding
embed_positions;
33
Conv2D
embed_patch;
34
LayerNorm
pre_layernorm;
35
Linear_FP
mm_proj_0, mm_proj_2;
36
int
voc_size, embed_dim, padding_idx, hidden_dim, num_heads, image_size, patch_size, num_patches, num_positions,
37
projection_dim, mmproj_dim;
38
std::vector<Fp32CLIPEncoderLayer> layers;
39
std::string profile_name =
"Fp32CLIPVisionTransformer"
;
40
41
private
:
42
Fp32CLIPEncoder
encoder;
43
float
* patch_embeds_buf;
44
float
* class_embeds_buf;
45
float
* pos_embeds_buf;
46
float
* last_hidden_states_buf;
47
float
* hidden_states_buf;
48
float
* embeddings_buf;
49
float
* mm_proj_0_arr;
50
float
* mm_proj_2_arr;
51
};
Conv2D
Definition
Conv2D.h:17
Embedding
Definition
Embedding.h:5
Fp32CLIPEncoder
Definition
Fp32CLIPEncoder.h:30
Fp32CLIPVisionTransformer
Definition
Fp32CLIPVisionTransformer.h:27
LayerNorm
Definition
LayerNorm.h:8
Linear_FP
Definition
linear.h:6
Matrix3D
Definition
common.h:34
Fp32CLIPVisionTransformer_input
Definition
Fp32CLIPVisionTransformer.h:13
Fp32CLIPVisionTransformer_output
Definition
Fp32CLIPVisionTransformer.h:9
model_config
Definition
model.h:5
Generated by
1.11.0