1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
|
#pragma once
#include <c10/macros/Export.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
//
// The operations defined in this header is intended as user facing functions.
// The user will provide the necessary input TensorViews and the function will
// create the correct intermediate nodes and return the output TensorViews.
//
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
struct ForwardNormResult {
TensorView* output = nullptr;
TensorView* mean = nullptr;
TensorView* invstd = nullptr;
};
struct BackwardNormResult {
TensorView* grad_input = nullptr;
TensorView* grad_weight = nullptr;
TensorView* grad_bias = nullptr;
};
struct ForwardRMSNormResult {
TensorView* output = nullptr;
TensorView* invstd = nullptr;
};
struct BackwardRMSNormResult {
TensorView* grad_input = nullptr;
TensorView* grad_weight = nullptr;
};
struct VarMeanResult {
TensorView* var = nullptr;
TensorView* mean = nullptr;
};
TORCH_CUDA_CU_API TensorView* mean(
TensorView* x,
const std::vector<int>& dims,
bool keepdim);
TORCH_CUDA_CU_API TensorView* variance(
TensorView* x,
const std::vector<int>& dims,
bool unbiased,
bool keepdim);
TORCH_CUDA_CU_API TensorView* variance(
TensorView* x,
const std::vector<int>& dims,
int64_t correction,
bool keepdim);
TORCH_CUDA_CU_API VarMeanResult variance_mean(
TensorView* x,
const std::vector<int>& dims,
int64_t correction,
bool keepdim);
TORCH_CUDA_CU_API TensorView* standard_deviation(
TensorView* x,
const std::vector<int>& dims,
bool unbiased,
bool keepdim);
TORCH_CUDA_CU_API TensorView* softmax(TensorView* x, int dim);
TORCH_CUDA_CU_API TensorView* softmax_backward(
TensorView* dy,
TensorView* y,
const int dim);
TORCH_CUDA_CU_API TensorView* log_softmax(TensorView* x, int dim);
TORCH_CUDA_CU_API TensorView* log_softmax_backward(
TensorView* dy,
TensorView* y,
const int dim);
TORCH_CUDA_CU_API ForwardNormResult layer_norm(
TensorView* x,
const std::vector<int64_t>& norm_shape,
TensorView* weight,
TensorView* bias,
Val* eps);
TORCH_CUDA_CU_API ForwardNormResult layer_norm(
TensorView* x,
const size_t kNormShapeNumDims,
TensorView* weight,
TensorView* bias,
Val* eps);
TORCH_CUDA_CU_API ForwardRMSNormResult rms_norm(
TensorView* x,
const std::vector<int64_t>& norm_shape,
TensorView* weight,
Val* eps);
TORCH_CUDA_CU_API ForwardRMSNormResult rms_norm(
TensorView* x,
const size_t kNormShapeNumDims,
TensorView* weight,
Val* eps);
TORCH_CUDA_CU_API BackwardNormResult layer_norm_backward(
TensorView* dy,
TensorView* x,
const std::vector<int64_t>& norm_shape,
TensorView* mean,
TensorView* rstd,
TensorView* weight,
TensorView* bias,
const std::vector<bool>& output_mask);
TORCH_CUDA_CU_API BackwardRMSNormResult rms_norm_backward(
TensorView* dy,
TensorView* x,
const std::vector<int64_t>& norm_shape,
TensorView* rstd,
TensorView* weight,
const std::vector<bool>& output_mask);
TORCH_CUDA_CU_API ForwardNormResult batch_norm(
TensorView* x,
TensorView* weight,
TensorView* bias,
TensorView* running_mean,
TensorView* running_var,
const bool kTraining,
Val* momentum,
Val* eps,
bool channels_last = false);
TORCH_CUDA_CU_API BackwardNormResult batch_norm_backward(
TensorView* x,
TensorView* dy,
TensorView* weight,
TensorView* running_mean,
TensorView* running_var,
TensorView* save_mean,
TensorView* save_invstd,
const bool kTraining,
Val* eps,
const std::vector<bool>& output_mask,
bool channels_last = false);
TORCH_CUDA_CU_API ForwardNormResult instance_norm(
TensorView* x,
TensorView* weight,
TensorView* bias,
TensorView* running_mean,
TensorView* running_var,
const bool kUseInputStats, // kTraining?
Val* momentum,
Val* eps,
bool channels_last = false);
TORCH_CUDA_CU_API BackwardNormResult instance_norm_backward(
TensorView* x,
TensorView* dy,
TensorView* weight,
TensorView* running_mean,
TensorView* running_var,
TensorView* save_mean,
TensorView* save_invstd,
const bool kTraining,
Val* eps,
const std::vector<bool>& output_mask,
bool channels_last = false);
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|