1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
|
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/ATen_fwd.h>
#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
#include <algorithm>
namespace torch::nested {
/// Nested tensor
///
/// See
/// https://pytorch.org/docs/main/nested.html#torch.nested.nested_tensor
///
/// ```
// implemented on python object to allow torch.nested.nested_tensor to be
// constructed with arbitrarily nested python objects - for now, only arbitrary
// python lists and lists of Tensors
// See torch/csrc/autograd/python_nested_functions_manual.cpp for Python
// implementation
// See here for C++ implementation
inline at::Tensor nested_tensor(
at::TensorList nested_tensor_data,
const at::TensorOptions& options = {}) {
auto out = at::_nested_tensor_from_tensor_list(
nested_tensor_data,
c10::typeMetaToScalarType(options.dtype()),
std::nullopt,
options.device(),
options.pinned_memory());
if (options.has_requires_grad() && options.requires_grad()) {
out.requires_grad_(true);
}
return out;
}
inline at::Tensor nested_tensor(
at::ArrayRef<detail::TensorDataContainer> nested_tensor_data,
const at::TensorOptions& options = {}) {
for (const auto& tdc : nested_tensor_data) {
TORCH_CHECK(
tdc.is_init_list(),
"nested_tensor() not implemented for these parameters");
}
// Construct a TensorList using nested_tensor_data
std::vector<at::Tensor> tensor_list(nested_tensor_data.size());
std::transform(
nested_tensor_data.begin(),
nested_tensor_data.end(),
tensor_list.begin(),
[&](const detail::TensorDataContainer& tdc) {
return tdc.convert_to_tensor(options);
});
auto out = at::_nested_tensor_from_tensor_list(
tensor_list,
c10::typeMetaToScalarType(options.dtype()),
std::nullopt,
options.device(),
options.pinned_memory());
if (options.has_requires_grad() && options.requires_grad()) {
out.requires_grad_(true);
}
return out;
}
/// As Nested Tensor
///
/// See
/// https://pytorch.org/docs/main/nested.html#torch.nested.as_nested_tensor
///
/// ```
inline at::Tensor as_nested_tensor(
at::TensorList list,
std::optional<at::ScalarType> dtype = std::nullopt,
std::optional<at::Device> device = std::nullopt) {
return at::_nested_tensor_from_tensor_list(
list, dtype, std::nullopt, device, std::nullopt);
}
/// Nested to padded tensor
///
/// See
/// https://pytorch.org/docs/main/nested.html#torch.nested.to_padded_tensor
///
/// ```
inline at::Tensor to_padded_tensor(
const at::Tensor& self,
double padding,
at::OptionalIntArrayRef output_size = std::nullopt) {
return at::nested_to_padded_tensor(self, padding, output_size);
}
} // namespace torch::nested
|