1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/cast_op.h"
#include "caffe2/utils/conversions.h"
namespace caffe2 {
template <typename DstType, typename SrcType>
__global__ void CastKernel(const int N, const SrcType* X, DstType* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
// Y[i] = static_cast<DstType>(X[i]);
Y[i] = convert::To<SrcType, DstType>(X[i]);
}
}
template <>
template <typename DstType, typename SrcType>
bool CastOp<CUDAContext>::DoRunWithType() {
auto& input = Input(0);
auto* output = Output(0, input.sizes(), at::dtype<DstType>());
const auto* data = input.template data<SrcType>();
auto* out = output->template mutable_data<DstType>();
DCHECK(input.numel() < INT_MAX);
int N = input.numel();
if (N == 0) {
// skip the rest of the computation if input is empty
return true;
}
CastKernel<DstType, SrcType>
<<<CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(N, data, out);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
template <>
template <typename DstType>
bool CastOp<CUDAContext>::DoRunWithDstType() {
return DispatchHelper<
TensorTypes<
float,
int32_t,
bool,
uint8_t,
int8_t,
uint16_t,
int16_t,
int64_t,
double>,
DstType>::call(this, Input(0));
}
// specific version that allows for casting to fp16
template <>
template <>
bool CastOp<CUDAContext>::DoRunWithDstType<float>() {
return DispatchHelper<
TensorTypes<
float,
at::Half,
int32_t,
bool,
uint8_t,
int8_t,
uint16_t,
int16_t,
int64_t,
double>,
float /* DstType */>::call(this, Input(0));
}
// specific version for casting _from_ fp16
template <>
template <>
bool CastOp<CUDAContext>::DoRunWithDstType<at::Half>() {
return DispatchHelper<
TensorTypes<
float,
at::Half>,
at::Half /* DstType */>::call(this, Input(0));
}
template <>
void CastOp<CUDAContext>::SetBody(TensorProto_DataType to) {
switch (to) {
case TensorProto_DataType_FLOAT:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<float>;
break;
case TensorProto_DataType_INT32:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<int>;
break;
case TensorProto_DataType_BYTE:
LOG(FATAL) << "BYTE is deprecated";
break;
case TensorProto_DataType_STRING:
CAFFE_THROW("Casting to and from strings is not supported yet");
// break;
case TensorProto_DataType_BOOL:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<bool>;
break;
case TensorProto_DataType_UINT8:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<uint8_t>;
break;
case TensorProto_DataType_INT8:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<int8_t>;
break;
case TensorProto_DataType_UINT16:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<uint16_t>;
break;
case TensorProto_DataType_INT16:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<int16_t>;
break;
case TensorProto_DataType_INT64:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<int64_t>;
break;
case TensorProto_DataType_FLOAT16:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<at::Half>;
break;
case TensorProto_DataType_DOUBLE:
body_ = &CastOp<CUDAContext>::DoRunWithDstType<double>;
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW("Cast op must have 'to' argument of type DataType");
// break;
default:
CAFFE_THROW("Unexpected 'to' argument value: ", to);
}
}
REGISTER_CUDA_OPERATOR(Cast, CastOp<CUDAContext>);
} // namespace caffe2
|