TinyChatEngine
Loading...
Searching...
No Matches
Fp32CLIPVisionTransformer.h
1#include <cstdlib>
2#include <string>
3#include <vector>
4
5#include "Fp32CLIPEncoder.h"
6#include "common.h"
7#include "operators.h"
8
10 Matrix3D<float> last_hidden_state;
11 std::vector<Matrix3D<float>> past_keys, past_values;
12};
14 Matrix3D<float> input_image;
15 std::vector<Matrix3D<float>> past_keys, past_values;
16 bool has_past_keys_values;
17
19 Fp32CLIPVisionTransformer_input(Matrix3D<float> input_image_) : input_image(input_image_) { has_past_keys_values = false; }
20 Fp32CLIPVisionTransformer_input(Matrix3D<float> input_image_, std::vector<Matrix3D<float>> past_keys_,
21 std::vector<Matrix3D<float>> past_values_)
22 : input_image(input_image_), past_keys(past_keys_), past_values(past_values_) {
23 has_past_keys_values = true;
24 }
25};
26
28 public:
29 Fp32CLIPVisionTransformer(std::string param_path, const struct model_config config, bool is_vila);
31 struct Fp32CLIPVisionTransformer_output forward(const struct Fp32CLIPVisionTransformer_input& input, bool is_vila);
32 Embedding embed_positions;
33 Conv2D embed_patch;
34 LayerNorm pre_layernorm;
35 Linear_FP mm_proj_0, mm_proj_2;
36 int voc_size, embed_dim, padding_idx, hidden_dim, num_heads, image_size, patch_size, num_patches, num_positions,
37 projection_dim, mmproj_dim;
38 std::vector<Fp32CLIPEncoderLayer> layers;
39 std::string profile_name = "Fp32CLIPVisionTransformer";
40
41 private:
42 Fp32CLIPEncoder encoder;
43 float* patch_embeds_buf;
44 float* class_embeds_buf;
45 float* pos_embeds_buf;
46 float* last_hidden_states_buf;
47 float* hidden_states_buf;
48 float* embeddings_buf;
49 float* mm_proj_0_arr;
50 float* mm_proj_2_arr;
51};
Definition Conv2D.h:17
Definition Embedding.h:5
Definition Fp32CLIPEncoder.h:30
Definition Fp32CLIPVisionTransformer.h:27
Definition LayerNorm.h:8
Definition linear.h:6
Definition common.h:34
Definition Fp32CLIPVisionTransformer.h:13
Definition Fp32CLIPVisionTransformer.h:9
Definition model.h:5