22 model_config() :
model_config(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6, 0, 0, 0, 0) {}
23 model_config(
int batch,
int num_heads,
int num_layers,
int max_sqlen,
int embed_dim,
int hidden_dim,
int vocsize,
24 int padding_idx,
float rms_norm_eps)
27 num_layers(num_layers),
30 hidden_dim(hidden_dim),
32 padding_idx(padding_idx),
33 rms_norm_eps(rms_norm_eps) {}
35 model_config(
int batch,
int num_heads,
int num_kv_heads,
int num_layers,
int max_sqlen,
int embed_dim,
int hidden_dim,
int vocsize,
36 int padding_idx,
float rms_norm_eps)
39 num_kv_heads(num_kv_heads),
40 num_layers(num_layers),
43 hidden_dim(hidden_dim),
45 padding_idx(padding_idx),
46 rms_norm_eps(rms_norm_eps) {}
48 model_config(
int batch,
int num_heads,
int num_layers,
int max_sqlen,
int embed_dim,
int hidden_dim,
int vocsize,
49 int padding_idx,
float rms_norm_eps,
int image_size,
int patch_size,
int projection_dim,
int mmproj_dim)
52 num_layers(num_layers),
55 hidden_dim(hidden_dim),
57 padding_idx(padding_idx),
58 rms_norm_eps(rms_norm_eps),
59 image_size(image_size),
60 patch_size(patch_size),
61 projection_dim(projection_dim),
62 mmproj_dim(mmproj_dim) {}
65enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B, StarCoder_15_5B, LLaVA_7B, LLaVA_13B, VILA_2_7B, VILA_7B, VILA_13B, Clip_ViT_Large, Mistral_7B, LLaMA_3_8B, VILA1_5_8B };
66enum { FP32, QINT8, INT4 };
68const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1, 0);
69const struct model_config opt_1_3B(1, 32, 24, 2048, 2048, 8192, 50272, 1, 0);
70const struct model_config opt_125m(1, 12, 12, 2048, 768, 3072, 50272, 1, 0);
71const struct model_config llama_7B(1, 32, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6);
72const struct model_config llama_13B(1, 40, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-6);
73const struct model_config codellama_7B(1, 32, 32, 32, 2048, 4096, 11008, 32016, 1, 1e-5);
74const struct model_config codellama_13B(1, 40, 40, 40, 2048, 5120, 13824, 32016, 1, 1e-5);
75const struct model_config starcoder_15_5B(1, 48, 40, 2048, 6144, 24576, 49152, 1, 0);
76const struct model_config llava_7B(1, 32, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
77const struct model_config llava_13B(1, 40, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
78const struct model_config vila_2_7B(1, 20, 20, 32, 2048, 2560, 6912, 32000, 1, 1e-5);
79const struct model_config vila_7B(1, 32, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
80const struct model_config vila_13B(1, 40, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
81const struct model_config clip_vit_large(1, 16, 23, 2048, 1024, 4096, 0, 1, 0, 336, 14, 768, 4096);
82const struct model_config mistral_7B(1, 32, 8, 32, 2048, 4096, 14336, 32000, 1, 1e-5);
83const struct model_config llama_3_8B(1, 32, 8, 32, 2048, 4096, 14336, 128256, 1, 1e-5);
85static struct model_config get_opt_model_config(int choise) {
109 case StarCoder_15_5B:
110 ret = starcoder_15_5B;
128 ret = clip_vit_large;
140 throw(
"Unsupported model choice.");