TinyChatEngine
Loading...
Searching...
No Matches
Fp32llamaAttention.h
1#include <utility>
2
3#include "common.h"
4#include "operators.h"
5
7 Matrix3D<float> attn_output;
8 Matrix3D<float> attn_probs_reshaped;
9 std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value;
10};
12 Matrix3D<float> hidden_states;
13 Matrix3D<float> attention_mask;
14 Matrix3D<float> past_key, past_value;
15 bool has_past_key_value = false;
16 int layer_idx;
17
18 Fp32llamaAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, int layer_idx_)
19 : hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {}
20
21 Fp32llamaAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, Matrix3D<float> past_key_,
22 Matrix3D<float> past_value_, bool has_past_key_value_, int layer_idx_)
23 : hidden_states(hidden_states_),
24 attention_mask(attention_mask_),
25 past_key(past_key_),
26 past_value(past_value_),
27 has_past_key_value(has_past_key_value_),
28 layer_idx(layer_idx_) {}
29};
30
32 public:
33 Fp32llamaAttention(std::string param_path, const struct model_config config);
35 static void initialized_memory(const struct model_config config);
36 struct Fp32llamaAttention_output forward(const struct Fp32llamaAttention_input &input);
37
38 private:
39 void unshape(Matrix3D<float> shaped, Matrix3D<float> unshape, int sqlen);
40 void shape(Matrix3D<float> unshape, Matrix3D<float> shaped, int sqlen);
41 int embed_dim, num_heads, head_dim;
42 Linear_FP k_proj, v_proj, q_proj, o_proj;
43 RotaryPosEmb rotary_pos_emb;
44 BMM_F32T qk_bmm, pv_bmm;
45 std::string profile_name = "Fp32llamaAttention";
46};
Definition BMM_F32T.h:3
Definition Fp32llamaAttention.h:31
Definition linear.h:6
Definition common.h:34
Definition RotaryPosEmb.h:6
Definition Fp32llamaAttention.h:11
Definition Fp32llamaAttention.h:6
Definition model.h:5