File: nested.h

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 161,668 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (93 lines) | stat: -rw-r--r-- 2,773 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#pragma once

#include <ATen/ATen.h>
#include <ATen/core/ATen_fwd.h>
#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
#include <algorithm>

namespace torch::nested {

/// Nested tensor
///
/// See
/// https://pytorch.org/docs/main/nested.html#torch.nested.nested_tensor
///
/// ```
// implemented on python object to allow torch.nested.nested_tensor to be
// constructed with arbitrarily nested python objects - for now, only arbitrary
// python lists and lists of Tensors
// See torch/csrc/autograd/python_nested_functions_manual.cpp for Python
// implementation
// See here for C++ implementation
inline at::Tensor nested_tensor(
    at::TensorList nested_tensor_data,
    const at::TensorOptions& options = {}) {
  auto out = at::_nested_tensor_from_tensor_list(
      nested_tensor_data,
      c10::typeMetaToScalarType(options.dtype()),
      std::nullopt,
      options.device(),
      options.pinned_memory());
  if (options.has_requires_grad() && options.requires_grad()) {
    out.requires_grad_(true);
  }
  return out;
}

inline at::Tensor nested_tensor(
    at::ArrayRef<detail::TensorDataContainer> nested_tensor_data,
    const at::TensorOptions& options = {}) {
  for (const auto& tdc : nested_tensor_data) {
    TORCH_CHECK(
        tdc.is_init_list(),
        "nested_tensor() not implemented for these parameters");
  }
  // Construct a TensorList using nested_tensor_data
  std::vector<at::Tensor> tensor_list(nested_tensor_data.size());
  std::transform(
      nested_tensor_data.begin(),
      nested_tensor_data.end(),
      tensor_list.begin(),
      [&](const detail::TensorDataContainer& tdc) {
        return tdc.convert_to_tensor(options);
      });
  auto out = at::_nested_tensor_from_tensor_list(
      tensor_list,
      c10::typeMetaToScalarType(options.dtype()),
      std::nullopt,
      options.device(),
      options.pinned_memory());
  if (options.has_requires_grad() && options.requires_grad()) {
    out.requires_grad_(true);
  }
  return out;
}

/// As Nested Tensor
///
/// See
/// https://pytorch.org/docs/main/nested.html#torch.nested.as_nested_tensor
///
/// ```
inline at::Tensor as_nested_tensor(
    at::TensorList list,
    std::optional<at::ScalarType> dtype = std::nullopt,
    std::optional<at::Device> device = std::nullopt) {
  return at::_nested_tensor_from_tensor_list(
      list, dtype, std::nullopt, device, std::nullopt);
}

/// Nested to padded tensor
///
/// See
/// https://pytorch.org/docs/main/nested.html#torch.nested.to_padded_tensor
///
/// ```
inline at::Tensor to_padded_tensor(
    const at::Tensor& self,
    double padding,
    at::OptionalIntArrayRef output_size = std::nullopt) {
  return at::nested_to_padded_tensor(self, padding, output_size);
}

} // namespace torch::nested