TinyChatEngine
Loading...
Searching...
No Matches
operators.h
1#ifndef OPERATORS_H
2#define OPERATORS_H
3#include <cassert>
4
5#include "common.h"
6#include "matmul.h"
7
8#define BLK_SIZE 16
9// #define NUM_THREAD 8
10extern int NUM_THREAD;
11
12// include all ops
13#include "ops/BMM_F32T.h"
14#include "ops/BMM_S8T_S8N_F32T.h"
15#include "ops/BMM_S8T_S8N_S8T.h"
16#include "ops/Embedding.h"
17#include "ops/LayerNorm.h"
18#include "ops/LayerNormQ.h"
19#include "ops/LlamaRMSNorm.h"
20#include "ops/RotaryPosEmb.h"
21#include "ops/W8A8B8O8Linear.h"
22#include "ops/W8A8B8O8LinearReLU.h"
23#include "ops/W8A8BFP32OFP32Linear.h"
24#include "ops/arg_max.h"
25#include "ops/linear.h"
26#include "ops/Conv2D.h"
27#include "ops/Gelu.h"
28
29void softmax(const Matrix3D<float> &input, Matrix3D<float> &output, int dim);
30void batch_Add(const Matrix3D<float> &input, const Matrix3D<float> &input2, Matrix3D<float> &output);
31template <typename T>
32void linear(Matrix3D<T> &a, Matrix3D<T> &b, Matrix3D<T> &c);
33
34
35#ifdef QM_CUDA
36#include "ops/cuda/BMM_F16T.cuh"
37#include "ops/cuda/Embedding.cuh"
38#include "ops/cuda/LlamaRMSNorm.cuh"
39#include "ops/cuda/RotaryPosEmb.cuh"
40
41__global__ void batch_Add_float(const Matrix3D<float> input, const Matrix3D<float> input2, Matrix3D<float> output);
42__global__ void batch_Add_cuda(const Matrix3D<float16_t> input, const Matrix3D<float16_t> input2,
43 Matrix3D<float16_t> output);
44__global__ void batch_Add_cuda_half2(Matrix3D<float16_t> input, Matrix3D<float16_t> input2, Matrix3D<float16_t> output);
45__global__ void softmax_float(Matrix3D<float> input, Matrix3D<float> output);
46__global__ void softmax_cuda(Matrix3D<float16_t> input, Matrix3D<float16_t> output);
47#endif
48
49#endif // OPERATORS_H
Definition common.h:34