1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include "torch/csrc/Device.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/autograd/python_nn_functions.h"
#include "torch/csrc/autograd/python_return_types.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
#include "torch/csrc/utils/pycfunction_helpers.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/structseq.h"
#include "torch/csrc/utils/tensor_memoryformats.h"
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif
using at::Tensor;
using at::Scalar;
using at::MemoryFormat;
using at::Generator;
using at::IntArrayRef;
using at::ArrayRef;
using namespace torch::autograd::utils;
namespace torch { namespace autograd {
static PyObject* THPNNVariableFunctionsModule = NULL;
static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
"to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
"to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
});
ParsedArgs<5> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(r, args, kwargs, THPNNVariableFunctionsModule, "torch.nn", "_parse_to");
}
auto parsed = parse_to_conversion(r, /*allow_copy*/ false); // we don't want copy for nn.Module.to
auto& device = std::get<0>(parsed);
auto& scalarType = std::get<1>(parsed);
auto non_blocking = std::get<2>(parsed);
auto opt_memory_format = std::get<4>(parsed);
auto tuple = THPObjectPtr{PyTuple_New(4)};
if (!tuple) throw python_error();
if (device) {
PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device));
} else {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(tuple.get(), 0, Py_None);
}
if (scalarType) {
PyTuple_SET_ITEM(tuple.get(), 1, torch::autograd::utils::wrap(torch::getTHPDtype(*scalarType)));
} else {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(tuple.get(), 1, Py_None);
}
PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
if (opt_memory_format.has_value()) {
PyTuple_SET_ITEM(tuple.get(), 3, torch::utils::getTHPMemoryFormat(opt_memory_format.value()));
} else {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(tuple.get(), 3, Py_None);
}
return tuple.release();
END_HANDLE_TH_ERRORS
}
// generated forward declarations start here
${py_forwards}
static PyMethodDef nn_functions[] = {
{"_parse_to", castPyCFunctionWithKeywords(THPVariable__parse_to),
METH_VARARGS | METH_KEYWORDS, nullptr},
${py_method_defs}
{NULL}
};
void initNNFunctions(PyObject* module) {
static struct PyModuleDef def = {
PyModuleDef_HEAD_INIT,
"torch._C._nn",
NULL,
-1,
nn_functions
};
PyObject* nn = PyModule_Create(&def);
THPNNVariableFunctionsModule = nn;
if (!nn) {
throw python_error();
}
// steals a reference to nn
if (PyModule_AddObject(module, "_nn", nn) != 0) {
throw python_error();
}
}
// generated methods start here
${py_methods}
}} // namespace torch::autograd
|