22 bool has_past_keys_values;
25 std::vector<Matrix3D<float16_t>> past_keys, past_values;
27 std::vector<Matrix3D<float>> past_keys, past_values;
32 has_past_keys_values =
false;
42 : input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
43 has_past_keys_values =
true;
47 : input_ids(input_ids_), image_embed(image_embed_), second_input_ids(second_input_ids_) {
48 has_past_keys_values =
false;
52 : input_ids(input_ids_), image_embed(image_embed_) {
53 has_past_keys_values =
false;
62 Matrix3D<float> prepare_decoder_attention_mask(
int length,
int past_length);
64 int voc_size, embed_dim, padding_idx, hidden_dim, num_heads;
66 std::vector<Int4llamaDecoderLayer> layers;
69 void free_cuda_memory();
71 LlamaRMSNorm_cuda norm;
73 float16_t* attention_mask_buf = nullptr;
74 float16_t* last_hidden_states_buf = nullptr;
75 float* hidden_states_buf = nullptr;
76 float16_t* hidden_states_half_buf = nullptr;
81 float* attention_mask_buf;
82 float* pos_embeds_buf;
83 float* last_hidden_states_buf;
84 float* hidden_states_buf;
85 float* inputs_embeds_buf;
86 float* first_input_ids_buf;
87 float* image_embed_buf;
88 float* second_input_ids_buf;
90 float* norm_weight_ptr = nullptr;