1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798
|
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/python_torch_functions.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/utils/cuda_lazy_init.h>
#include <torch/csrc/utils/out_types.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/structseq.h>
#include <torch/csrc/utils/tensor_layouts.h>
#include <torch/csrc/utils/tensor_new.h>
#include <torch/csrc/utils/tensor_numpy.h>
#include <ATen/ATen.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <Python.h>
#include <fmt/format.h>
#include <pybind11/pybind11.h>
#include <vector>
using at::ArrayRef;
using at::Backend;
using at::Device;
using at::DeviceGuard;
using at::Dimname;
using at::DimnameList;
using at::Generator;
using at::IntArrayRef;
using at::Layout;
using at::OptionalDeviceGuard;
using at::Scalar;
using at::ScalarType;
using at::Tensor;
using at::TensorList;
using at::TensorOptions;
using torch::utils::check_out_type_matches;
using namespace torch::autograd::utils;
namespace torch {
namespace autograd {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PyObject* THPVariableFunctionsModule = nullptr;
inline Tensor dispatch_range(
const Scalar& start,
const Scalar& end,
const Scalar& step,
Tensor result) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(result));
return at::range_out(result, start, end, step);
}
inline Tensor dispatch_range(
const Scalar& start,
const Scalar& end,
const Scalar& step,
const TensorOptions& options) {
torch::utils::maybe_initialize_cuda(options);
pybind11::gil_scoped_release no_gil;
DeviceGuard device_guard(options.device());
return torch::range(start, end, step, options);
}
static PyObject* THPVariable_range(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)",
});
ParsedArgs<8> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
auto ret = PyErr_WarnEx(
PyExc_UserWarning,
"torch.range is deprecated and will be removed in a future release "
"because its behavior is inconsistent with Python's range builtin. "
"Instead, use torch.arange, which produces values in [start, end).",
1);
if (ret != 0)
throw python_error();
if (r.isNone(3)) {
const auto options = TensorOptions()
.dtype(r.scalartype(4))
.device(r.device(6))
.layout(r.layout(5))
.requires_grad(r.toBool(7));
return wrap(
dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options));
} else {
check_out_type_matches(
r.tensor(3),
r.scalartype(4),
r.isNone(4),
r.layout(5),
r.device(6),
r.isNone(6));
return wrap(
dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3))
.set_requires_grad(r.toBool(7)));
}
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// implemented on python object to allow torch.as_tensor to be constructed with
// arbitrarily nested python objects - list, tuple, np array, scalar, etc.
static PyObject* THPVariable_as_tensor(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)",
});
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}
jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::as_tensor(
torch::tensors::get_default_dispatch_key(),
torch::tensors::get_default_scalar_type(),
r));
END_HANDLE_TH_ERRORS
}
// implemented on python object here because PyObject currently not natively
// declarable See: ATen/native/README.md for more context
static PyObject* THPVariable_from_numpy(PyObject* module, PyObject* arg) {
HANDLE_TH_ERRORS
jit::tracer::warn("torch.from_numpy", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::tensor_from_numpy(arg));
END_HANDLE_TH_ERRORS
}
static Tensor dispatch_nonzero(const Tensor& self) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return self.nonzero();
}
static Tensor dispatch_nonzero(const Tensor& self, Tensor out) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return at::nonzero_out(out, self);
}
static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor& self) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return self.nonzero_numpy();
}
static PyObject* THPVariable_nonzero(
PyObject* self,
PyObject* args,
PyObject* kwargs);
#define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \
static PyObject* THPVariable_##NAME( \
PyObject* self, PyObject* args, PyObject* kwargs) { \
HANDLE_TH_ERRORS \
static PythonArgParser parser SIGNATURES; \
ParsedArgs<NARGS> parsed_args; \
auto r = parser.parse(args, kwargs, parsed_args); \
if (r.has_torch_function()) { \
return handle_torch_function( \
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \
} \
jit::tracer::warn("torch." #NAME, jit::tracer::WARN_CONSTRUCTOR); \
return THPVariable_Wrap(torch::utils::NAME##_ctor( \
torch::tensors::get_default_dispatch_key(), \
torch::tensors::get_default_scalar_type(), \
r)); \
END_HANDLE_TH_ERRORS \
}
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_compressed_tensor,
9,
({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_csr_tensor,
9,
({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_csc_tensor,
9,
({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_bsr_tensor,
9,
({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_bsc_tensor,
9,
({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
_sparse_compressed_tensor_unsafe,
8,
({"_sparse_compressed_tensor_unsafe(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
_sparse_csr_tensor_unsafe,
7,
({"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
_sparse_csc_tensor_unsafe,
7,
({"_sparse_csc_tensor_unsafe(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
_sparse_bsr_tensor_unsafe,
7,
({"_sparse_bsr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
_sparse_bsc_tensor_unsafe,
7,
({"_sparse_bsc_tensor_unsafe(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"}))
static PyObject* THPVariable_sparse_coo_tensor(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
"sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
"sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<6> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}
jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(
torch::tensors::get_default_dispatch_key(),
torch::tensors::get_default_scalar_type(),
r));
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable__sparse_coo_tensor_unsafe(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"_sparse_coo_tensor_unsafe(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<6> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}
jit::tracer::warn(
"torch._sparse_coo_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::_sparse_coo_tensor_unsafe_ctor(
torch::tensors::get_default_dispatch_key(),
torch::tensors::get_default_scalar_type(),
r));
END_HANDLE_TH_ERRORS
}
// implemented on python object to allow torch.tensor to be constructed with
// arbitrarily nested python objects - list, tuple, np array, scalar, etc.
static PyObject* THPVariable_tensor(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)",
});
constexpr int ctor_num_args = 6;
ParsedArgs<ctor_num_args> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}
jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::tensor_ctor(
torch::tensors::get_default_dispatch_key(),
torch::tensors::get_default_scalar_type(),
r));
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_get_device(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{
"get_device(Tensor input)",
},
/*traceable=*/false);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
return wrap(r.tensor(0).get_device());
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_frombuffer(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{
"frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)",
},
/*traceable=*/false);
ParsedArgs<5> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
auto buffer = r.pyobject(0);
auto dtype = r.scalartype(1);
auto count = r.toInt64(2);
auto offset = r.toInt64(3);
auto requires_grad = r.toBool(4);
TORCH_CHECK_VALUE(
PyObject_CheckBuffer(buffer) != 0,
"object does not implement Python buffer protocol.");
return wrap(torch::utils::tensor_frombuffer(
buffer, dtype, count, offset, requires_grad));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_asarray(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{
"asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool requires_grad=False)",
},
/*traceable=*/false);
ParsedArgs<5> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
auto obj = r.pyobject(0);
auto dtype = r.scalartypeOptional(1);
auto device = r.deviceOptional(2);
auto copy = r.toBoolOptional(3);
auto requires_grad = r.toBool(4);
return wrap(torch::utils::asarray(obj, dtype, device, copy, requires_grad));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_numel(
PyObject* self_,
PyObject* args,
PyObject* kwargs);
static PyObject* THPVariable__to_functional_tensor(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_to_functional_tensor(Tensor t)"}, /*traceable=*/true);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto self_ = r.tensor(0);
auto wrapped = at::functionalization::impl::to_functional_tensor(self_);
return wrap(wrapped);
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable__from_functional_tensor(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_from_functional_tensor(Tensor t)"}, /*traceable=*/true);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto self_ = r.tensor(0);
auto unwrapped = at::functionalization::impl::from_functional_tensor(self_);
return wrap(unwrapped);
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable__is_functional_tensor(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_is_functional_tensor(Tensor t)"}, /*traceable=*/true);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto self_ = r.tensor(0);
if (at::functionalization::impl::isFunctionalTensor(self_)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable__sync(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({"_sync(Tensor t)"}, /*traceable=*/true);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto self_ = r.tensor(0);
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_));
at::functionalization::impl::sync(self_);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable__enable_functionalization(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_enable_functionalization(*, bool reapply_views=False)"},
/*traceable=*/true);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
const auto reapply_views = r.toBool(0);
if (c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Functionalize)) {
TORCH_INTERNAL_ASSERT(
false,
"multiple layers of mode-style functionalization nesting is not"
" currently supported, outside of the functionalize() transform");
}
c10::impl::tls_set_dispatch_key_included(
at::DispatchKey::Functionalize, true);
if (reapply_views) {
at::functionalization::impl::setFunctionalizationReapplyViewsTLS(true);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable__disable_functionalization(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
c10::impl::tls_set_dispatch_key_included(
at::DispatchKey::Functionalize, false);
at::functionalization::impl::setFunctionalizationReapplyViewsTLS(false);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// XXX: ops that are bound here are not exposed to the C++ api nor the JIT.
// Any new ops added here should be accompanied with a comment why they are not
// being registered through native_functions.yaml, and be tagged cpp / JIT
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef torch_functions_manual[] = {
{"asarray",
castPyCFunctionWithKeywords(THPVariable_asarray),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"as_tensor",
castPyCFunctionWithKeywords(THPVariable_as_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, nullptr},
{"frombuffer",
castPyCFunctionWithKeywords(THPVariable_frombuffer),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_is_functional_tensor",
castPyCFunctionWithKeywords(THPVariable__is_functional_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_to_functional_tensor",
castPyCFunctionWithKeywords(THPVariable__to_functional_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_from_functional_tensor",
castPyCFunctionWithKeywords(THPVariable__from_functional_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_sync",
castPyCFunctionWithKeywords(THPVariable__sync),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_enable_functionalization",
castPyCFunctionWithKeywords(THPVariable__enable_functionalization),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_disable_functionalization",
castPyCFunctionWithKeywords(THPVariable__disable_functionalization),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"nonzero",
castPyCFunctionWithKeywords(THPVariable_nonzero),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"range",
castPyCFunctionWithKeywords(THPVariable_range),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"sparse_coo_tensor",
castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_sparse_coo_tensor_unsafe",
castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_sparse_compressed_tensor_unsafe",
castPyCFunctionWithKeywords(THPVariable__sparse_compressed_tensor_unsafe),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"sparse_compressed_tensor",
castPyCFunctionWithKeywords(THPVariable_sparse_compressed_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"sparse_csr_tensor",
castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"sparse_csc_tensor",
castPyCFunctionWithKeywords(THPVariable_sparse_csc_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"sparse_bsr_tensor",
castPyCFunctionWithKeywords(THPVariable_sparse_bsr_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"sparse_bsc_tensor",
castPyCFunctionWithKeywords(THPVariable_sparse_bsc_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_sparse_csr_tensor_unsafe",
castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_sparse_csc_tensor_unsafe",
castPyCFunctionWithKeywords(THPVariable__sparse_csc_tensor_unsafe),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_sparse_bsr_tensor_unsafe",
castPyCFunctionWithKeywords(THPVariable__sparse_bsr_tensor_unsafe),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_sparse_bsc_tensor_unsafe",
castPyCFunctionWithKeywords(THPVariable__sparse_bsc_tensor_unsafe),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"tensor",
castPyCFunctionWithKeywords(THPVariable_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"get_device",
castPyCFunctionWithKeywords(THPVariable_get_device),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"numel",
castPyCFunctionWithKeywords(THPVariable_numel),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
};
static PyObject* THPVariable_nonzero(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)",
});
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, args, kwargs, THPVariableFunctionsModule, "torch");
}
const auto as_tuple = r.toBool(1);
const auto has_out = !r.isNone(2);
if (as_tuple) {
TORCH_CHECK(
!has_out,
"nonzero does not support the out kwarg when as_tuple is True");
return wrap(dispatch_nonzero_numpy(r.tensor(0)));
}
if (has_out) {
return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2)));
}
return wrap(dispatch_nonzero(r.tensor(0)));
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_numel(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{
"numel(Tensor input)",
},
/*traceable=*/false);
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, args, kwargs, THPVariableFunctionsModule, "torch");
}
if (r.idx == 0) {
return wrap(r.tensor(0).numel());
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// Sharded function definitions
void gatherTorchFunctions_0(std::vector<PyMethodDef>& torch_functions);
void gatherTorchFunctions_1(std::vector<PyMethodDef>& torch_functions);
void gatherTorchFunctions_2(std::vector<PyMethodDef>& torch_functions);
void gatherTorchFunctions(std::vector<PyMethodDef>& torch_functions) {
constexpr size_t num_functions =
sizeof(torch_functions_manual) / sizeof(torch_functions_manual[0]);
torch_functions.assign(
torch_functions_manual, torch_functions_manual + num_functions);
// NOTE: Must be synced with num_shards in
// tools/autograd/gen_python_functions.py
gatherTorchFunctions_0(torch_functions);
gatherTorchFunctions_1(torch_functions);
gatherTorchFunctions_2(torch_functions);
static std::array<std::pair<const char*, const char*>, 4> aliases{
{// Canonical function, alias name
{"sspaddmm", "saddmm"},
{"mm", "spmm"},
{"mm", "dsmm"},
{"hspmm", "hsmm"}}};
for (const auto& alias : aliases) {
auto it = std::find_if(
torch_functions.begin(),
torch_functions.end(),
[&](const PyMethodDef& def) {
return strcmp(def.ml_name, alias.first) == 0;
});
TORCH_INTERNAL_ASSERT(
it != torch_functions.end(),
"Failed to create function alias from ",
alias.first,
" to ",
alias.second);
PyMethodDef alias_def = *it;
alias_def.ml_name = alias.second;
torch_functions.push_back(alias_def);
}
torch_functions.push_back({nullptr});
torch_functions.shrink_to_fit();
}
static PyTypeObject THPVariableFunctions = {
PyVarObject_HEAD_INIT(
nullptr,
0) "torch._C._VariableFunctionsClass", /* tp_name */
0, /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
nullptr /* tp_new */
};
void initTorchFunctions(PyObject* module) {
static std::vector<PyMethodDef> torch_functions;
gatherTorchFunctions(torch_functions);
THPVariableFunctions.tp_methods = torch_functions.data();
if (PyType_Ready(&THPVariableFunctions) < 0) {
throw python_error();
}
Py_INCREF(&THPVariableFunctions);
// Steals
Py_INCREF(&THPVariableFunctions);
if (PyModule_AddObject(
module,
"_VariableFunctionsClass",
reinterpret_cast<PyObject*>(&THPVariableFunctions)) < 0) {
throw python_error();
}
// PyType_GenericNew returns a new reference
THPVariableFunctionsModule =
PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
// PyModule_AddObject steals a reference
if (PyModule_AddObject(
module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
throw python_error();
}
}
} // namespace autograd
} // namespace torch
|