1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
|
#include "caffe2/operators/batch_matmul_op.h"
#include "caffe2/core/context_gpu.h"
namespace caffe2 {
template <>
bool BatchMatMulOp<CUDAContext, DefaultEngine>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
REGISTER_CUDA_OPERATOR(BatchMatMul, BatchMatMulOp<CUDAContext>);
#if !defined(USE_ROCM)
template <>
bool BatchMatMulOp<CUDAContext, TensorCoreEngine>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
REGISTER_CUDA_OPERATOR_WITH_ENGINE(
BatchMatMul,
TENSORCORE,
BatchMatMulOp<CUDAContext, TensorCoreEngine>);
#endif
} // namespace caffe2
|