File: input_metadata.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (113 lines) | stat: -rw-r--r-- 2,982 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#pragma once

#include <ATen/ExpandUtils.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/DimVector.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/zeros.h>
#endif

namespace torch::autograd {

using SymIntSmallVec = c10::SmallVector<c10::SymInt, c10::kDimVectorStaticSize>;
using MetadataShape = std::variant<SymIntSmallVec, at::Tensor>;

/**
 * Records TensorOptions, shape of the tensor, whether or not the Python
 * dispatch key is set (tensor subclass), and, where applicable, the stream the
 * corresponding operation took place on.
 *
 * If is_valid() is false, then the corresponding input is not used and may be
 * an undefined tensor.
 */
struct TORCH_API InputMetadata {
  InputMetadata() = default;
  InputMetadata(
      const at::TensorOptions& options,
      MetadataShape input_shape,
      bool is_tensor_subclass,
      bool is_nested);
  InputMetadata(const at::Tensor& t);

  const at::TensorOptions& options() const {
    return options_;
  }

  caffe2::TypeMeta dtype() const {
    return options_.dtype();
  }

  at::Device device() const {
    return options_.device();
  }

  at::Layout layout() const {
    return options_.layout();
  }

  c10::Stream stream() const {
    return stream_;
  }

  bool is_tensor_subclass() const {
    return is_tensor_subclass_;
  }

  at::Tensor zeros_like() const;

  bool is_same_shape(const at::Tensor& grad) const;

  bool is_expandable_to_shape(const at::Tensor& grad) const;

  at::Tensor reduce_grad(at::Tensor& grad) const;

  at::Tensor maybe_reduce(
      const size_t index,
      at::Tensor grad,
      const std::function<std::string(const std::string&)>& format_error) const;

  std::stringstream incompatible_shape_error_message(
      const size_t index,
      const at::Tensor& grad) const;

  bool was_default_constructed() const {
    return was_default_constructed_;
  }

  bool is_cpp_nested_tensor() const;

  bool is_nested_tensor() const {
    return is_nested_;
  }

  c10::SymIntArrayRef shape_as_dim_vector() const;

  // Danger: not thread safe, caller must protect with lock
  SymIntSmallVec& mutable_shape_as_dim_vector();

 private:
  at::Tensor shape_as_tensor() const;
  bool is_nestedness_same(const at::Tensor& grad) const;
  bool maybe_expandable_to(const at::Tensor& grad) const;

  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  const at::TensorOptions options_;
  MetadataShape shape_;
  c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
  bool is_tensor_subclass_ = false;
  bool is_nested_ = false;
  bool was_default_constructed_ = true;
};
} // namespace torch::autograd