58 static void initialized_memory(
const struct model_config config);
62 int *q_weight = nullptr, *k_weight = nullptr, *v_weight = nullptr, *o_weight = nullptr, *qkv_weight = nullptr;
66 void free_cuda_memory();
67 half *cos_buf = nullptr, *sin_buf = nullptr;
69 float *cos_buf = nullptr, *sin_buf = nullptr;
74 int embed_dim, num_heads, num_kv_heads, head_dim;
76 Linear_half_int4 o_proj, qkv_proj;
77 RotaryPosEmb_cuda rotary_pos_emb;
78 BMM_F16T qk_bmm, pv_bmm;
84 void unshape(Matrix3D<float> shaped, Matrix3D<float> unshape, int num_heads, int head_dim, int sqlen);
85 void shape(Matrix3D<float> unshape, Matrix3D<float> shaped, int num_heads, int head_dim, int sqlen);
86 void shape_qkv(Matrix3D<float> unshape, Matrix3D<float> shaped_q, Matrix3D<float> shaped_k,
87 Matrix3D<float> shaped_v, int sqlen);
88 void repeat(Matrix3D<float> input, Matrix3D<float> output, int num_heads, int num_kv_heads, int sqlen, int head_dim);