TinyChatEngine
Loading...
Searching...
No Matches
Embedding.h
1#include <cassert>
2
3#include "common.h"
4
5class Embedding {
6 public:
7 Embedding(int embed_dim_, int voc_size_, int padding_idx_, Matrix3D<float> lookup_)
8 : embed_dim(embed_dim_), voc_size(voc_size_), padding_idx(padding_idx_), lookup(lookup_) {
9 assert(lookup_.m_dim_y == voc_size_);
10 assert(lookup_.m_dim_z == embed_dim_);
11 }
12 Embedding(){};
13 void forward(Matrix3D<int> input_id, Matrix3D<float> output);
14 int embed_dim, voc_size, padding_idx;
15 Matrix3D<float> lookup;
16
17 private:
18 std::string profile_name = "Embedding";
19};
20
21void load_Embedding_params(Embedding &op, std::string prefix);
Definition Embedding.h:5
Definition common.h:34