TinyChatEngine
Loading...
Searching...
No Matches
model.h
1#ifndef MODEL_H
2#define MODEL_H
3#include <cstring>
4
6 int batch;
7 int num_heads;
8 int num_kv_heads;
9 int num_layers;
10 int max_sqlen;
11 int embed_dim;
12 int hidden_dim;
13 int vocsize;
14 int padding_idx;
15 float rms_norm_eps; // RMSNorm epsilon (only for LLaMA models)
16 // Below are for Clip models
17 int image_size;
18 int patch_size;
19 int projection_dim;
20 int mmproj_dim;
21
22 model_config() : model_config(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6, 0, 0, 0, 0) {}
23 model_config(int batch, int num_heads, int num_layers, int max_sqlen, int embed_dim, int hidden_dim, int vocsize,
24 int padding_idx, float rms_norm_eps)
25 : batch(batch),
26 num_heads(num_heads),
27 num_layers(num_layers),
28 max_sqlen(max_sqlen),
29 embed_dim(embed_dim),
30 hidden_dim(hidden_dim),
31 vocsize(vocsize),
32 padding_idx(padding_idx),
33 rms_norm_eps(rms_norm_eps) {}
34 // GQA/MQA models
35 model_config(int batch, int num_heads, int num_kv_heads, int num_layers, int max_sqlen, int embed_dim, int hidden_dim, int vocsize,
36 int padding_idx, float rms_norm_eps)
37 : batch(batch),
38 num_heads(num_heads),
39 num_kv_heads(num_kv_heads),
40 num_layers(num_layers),
41 max_sqlen(max_sqlen),
42 embed_dim(embed_dim),
43 hidden_dim(hidden_dim),
44 vocsize(vocsize),
45 padding_idx(padding_idx),
46 rms_norm_eps(rms_norm_eps) {}
47 // Clip models
48 model_config(int batch, int num_heads, int num_layers, int max_sqlen, int embed_dim, int hidden_dim, int vocsize,
49 int padding_idx, float rms_norm_eps, int image_size, int patch_size, int projection_dim, int mmproj_dim)
50 : batch(batch),
51 num_heads(num_heads),
52 num_layers(num_layers),
53 max_sqlen(max_sqlen),
54 embed_dim(embed_dim),
55 hidden_dim(hidden_dim),
56 vocsize(vocsize),
57 padding_idx(padding_idx),
58 rms_norm_eps(rms_norm_eps),
59 image_size(image_size),
60 patch_size(patch_size),
61 projection_dim(projection_dim),
62 mmproj_dim(mmproj_dim) {}
63};
64
65enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B, StarCoder_15_5B, LLaVA_7B, LLaVA_13B, VILA_2_7B, VILA_7B, VILA_13B, Clip_ViT_Large, Mistral_7B, LLaMA_3_8B, VILA1_5_8B };
66enum { FP32, QINT8, INT4 };
67
68const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1, 0);
69const struct model_config opt_1_3B(1, 32, 24, 2048, 2048, 8192, 50272, 1, 0);
70const struct model_config opt_125m(1, 12, 12, 2048, 768, 3072, 50272, 1, 0);
71const struct model_config llama_7B(1, 32, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6);
72const struct model_config llama_13B(1, 40, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-6);
73const struct model_config codellama_7B(1, 32, 32, 32, 2048, 4096, 11008, 32016, 1, 1e-5);
74const struct model_config codellama_13B(1, 40, 40, 40, 2048, 5120, 13824, 32016, 1, 1e-5);
75const struct model_config starcoder_15_5B(1, 48, 40, 2048, 6144, 24576, 49152, 1, 0);
76const struct model_config llava_7B(1, 32, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
77const struct model_config llava_13B(1, 40, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
78const struct model_config vila_2_7B(1, 20, 20, 32, 2048, 2560, 6912, 32000, 1, 1e-5);
79const struct model_config vila_7B(1, 32, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
80const struct model_config vila_13B(1, 40, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
81const struct model_config clip_vit_large(1, 16, 23, 2048, 1024, 4096, 0, 1, 0, 336, 14, 768, 4096); // llava's and vila's clip model uses only 23 layers out of 24
82const struct model_config mistral_7B(1, 32, 8, 32, 2048, 4096, 14336, 32000, 1, 1e-5);
83const struct model_config llama_3_8B(1, 32, 8, 32, 2048, 4096, 14336, 128256, 1, 1e-5);
84
85static struct model_config get_opt_model_config(int choise) {
86 struct model_config ret;
87 switch (choise) {
88 case OPT_125M:
89 ret = opt_125m;
90 break;
91 case OPT_1_3B:
92 ret = opt_1_3B;
93 break;
94 case OPT_6_7B:
95 ret = opt_6_7B;
96 break;
97 case LLaMA_7B:
98 ret = llama_7B;
99 break;
100 case LLaMA_13B:
101 ret = llama_13B;
102 break;
103 case CodeLLaMA_7B:
104 ret = codellama_7B;
105 break;
106 case CodeLLaMA_13B:
107 ret = codellama_13B;
108 break;
109 case StarCoder_15_5B:
110 ret = starcoder_15_5B;
111 break;
112 case LLaVA_7B:
113 ret = llava_7B;
114 break;
115 case LLaVA_13B:
116 ret = llava_13B;
117 break;
118 case VILA_2_7B:
119 ret = vila_2_7B;
120 break;
121 case VILA_7B:
122 ret = vila_7B;
123 break;
124 case VILA_13B:
125 ret = vila_13B;
126 break;
127 case Clip_ViT_Large:
128 ret = clip_vit_large;
129 break;
130 case Mistral_7B:
131 ret = mistral_7B;
132 break;
133 case LLaMA_3_8B:
134 ret = llama_3_8B;
135 break;
136 case VILA1_5_8B:
137 ret = vila_7B;
138 break;
139 default:
140 throw("Unsupported model choice.");
141 break;
142 }
143 return ret;
144}
145
146#endif
Definition model.h:5