TinyChatEngine
Loading...
Searching...
No Matches
Generate.h
1/*
2
3Adapted from llama.cpp:
4https://github.com/ggerganov/llama.cpp
5
6*/
7
8#ifndef GENERATE_H
9#define GENERATE_H
10
11#include <algorithm>
12#include <cassert>
13#include <cstdio>
14#include <iostream>
15#include <queue>
16#include <random>
17#include <string>
18#include <unordered_map>
19#include <vector>
20
21#include "Fp32OPTForCausalLM.h"
22#include "Fp32llamaForCausalLM.h"
23#include "Fp32GPTBigCodeForCausalLM.h"
24#include "Int4OPTForCausalLM.h"
25#include "Int4llamaForCausalLM.h"
26#include "Int4GPTBigCodeForCausalLM.h"
27#include "Fp32CLIPVisionTransformer.h"
28#include "OPTForCausalLM.h"
29#include "OPTTokenizer.h"
30#include "operators.h"
31#include "utils.h"
32
33// inline std::mt19937 OPT_rng; // inline variables are only available with ‘-std=c++17’ or ‘-std=gnu++17’
34static std::mt19937 OPT_rng;
35
36typedef struct OPT_token_data {
37 int id; // token id
38 float logit; // log-odds of the token
39 float p; // probability of the token
41
42typedef struct OPT_token_data_array {
43 OPT_token_data* data;
44 size_t size;
45 bool sorted;
47
48struct opt_params {
49 int32_t seed = -1; // RNG seed
50 int32_t n_threads = 1; // TODO: fix this
51 int32_t n_predict = 128; // new tokens to predict
52 int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
53 int32_t n_ctx = 512; // context size
54 int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
55 int32_t n_keep = 0; // number of tokens to keep from initial prompt
56 int32_t n_vocab = 50272; // vocabulary size
57
58 // sampling parameters
59 std::unordered_map<int, float> logit_bias; // logit bias for specific tokens
60 int32_t top_k = 40; // <= 0 to use vocab size
61 float top_p = 0.95f; // 1.0 = disabled
62 float tfs_z = 1.00f; // 1.0 = disabled
63 float typical_p = 1.00f; // 1.0 = disabled
64 float temp = 0.80f; // 1.0 = disabled
65 float repeat_penalty = 1.10f; // 1.0 = disabled
66 int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
67 float frequency_penalty = 0.00f; // 0.0 = disabled
68 float presence_penalty = 0.00f; // 0.0 = disabled
69 int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
70 float mirostat_tau = 5.00f; // target entropy
71 float mirostat_eta = 0.10f; // learning rate
72};
73
74void sample_repetition_penalty(OPT_token_data_array* candidates, const int* last_tokens, size_t last_tokens_size,
75 float penalty);
76
77void sample_frequency_and_presence_penalties(OPT_token_data_array* candidates, const int* last_tokens_p,
78 size_t last_tokens_size, float alpha_frequency, float alpha_presence);
79
80int sample_token_greedy(OPT_token_data_array* candidates);
81
82void sample_temperature(OPT_token_data_array* candidates_p, float temp);
83
84void sample_softmax(OPT_token_data_array* candidates);
85
86int sample_token(OPT_token_data_array* candidates);
87
88void sample_top_k(OPT_token_data_array* candidates, int k, size_t min_keep);
89
90int sample_token_mirostat(const int n_vocab, OPT_token_data_array* candidates, float tau, float eta, int m, float* mu);
91
92int sample_token_mirostat_v2(OPT_token_data_array* candidates, float tau, float eta, float* mu);
93
94void sample_tail_free(OPT_token_data_array* candidates, float z, size_t min_keep);
95
96void sample_typical(OPT_token_data_array* candidates, float p, size_t min_keep);
97
98void sample_top_p(OPT_token_data_array* candidates, float p, size_t min_keep);
99
100std::vector<int> OPTGenerate(void* model, int model_type, std::vector<int> input_ids,
101 const struct opt_params generation_config, Encoder* encoder = NULL,
102 bool interactive = false, bool voicechat = false);
103
104enum { OPT_INT8, LLaMA_FP32, LLaMA_INT4, OPT_FP32, OPT_INT4, StarCoder_FP32, StarCoder_INT4, LLaVA_FP32, LLaVA_INT4, VILA_FP32, VILA_INT4};
105std::string LLaMAGenerate(std::string param_path, void* model, int model_type, std::string text, const struct opt_params generation_config,
106 std::string voc_path, bool interactive, bool voicechat);
107
108std::string GPTBigCodeGenerate(std::string param_path, void *model_ptr, int model_type, std::string text, const struct opt_params generation_config,
109 std::string voc_path, bool interactive);
110
111std::string LLaVAGenerate(std::string llama_param_path, void* llama_model_ptr, std::string clip_param_path, void* clip_model_ptr, int model_type,
112 std::string text, std::string img_path, const struct opt_params generation_config, std::string voc_path, bool interactive,
113 bool voicechat, bool is_vila);
114
115std::string MistralGenerate(std::string param_path, void* model, int model_type, std::string text, const struct opt_params generation_config,
116 std::string voc_path, bool interactive, bool voicechat);
117
118std::string LLaMA3Generate(std::string param_path, void* model, int model_type, std::string text, const struct opt_params generation_config,
119 std::string voc_path, bool interactive, bool voicechat);
120
121#endif // GENERATE_H
Definition OPTTokenizer.h:35
Definition Generate.h:42
Definition Generate.h:36
Definition Generate.h:48