File: saved_variable.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (116 lines) | stat: -rw-r--r-- 4,451 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#pragma once

#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>

#include <ATen/core/Tensor.h>

#include <cstdint>
#include <memory>

namespace torch {
namespace autograd {

using Variable = at::Tensor;
struct Node;

TORCH_API extern const char* ERR_BACKWARD_TWICE;

/// A snapshot of a variable at a certain version. A `SavedVariable` stores
/// enough information to reconstruct a variable from a certain point in time.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_API SavedVariable {
 public:
  SavedVariable() = default;
  SavedVariable(
      const Variable& variable,
      bool is_output,
      bool is_inplace_on_view = false);
  SavedVariable(
      const c10::optional<Variable>& variable,
      bool is_output,
      bool is_inplace_on_view = false);
  SavedVariable(SavedVariable&&) = default;
  SavedVariable& operator=(SavedVariable&&) = default;
  ~SavedVariable() {
    if (fw_grad_) {
      // See note [ Using ForwardGrad ]
      fw_grad_->clear();
    }
  }

  /// Reconstructs the saved variable. Pass `saved_for` as the gradient
  /// function if constructing the `SavedVariable` with it would have caused a
  /// circular reference.
  Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const;

  void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks);

  void reset_data();

 private:
  // This field contains either:
  // 1. the variable to save
  // 2. or its tensor_data.
  // If storing the variable itself would create a circular reference,
  // we fall into the second case and its metadata is also saved separately.
  // In that case, the grad_fn must be passed in to the unpack function when
  // reconstructing the Variable (except when we are doing an inplace operation
  // on a view, see below). The field saved_orignal_ below reflects the two
  // cases: its value is true in the first case and false in the second case.
  // The value data_.defined() can be false in three cases:
  // 1. SavedVariable was constructed without a Tensor (the value to save is
  // None), in that case was_default_constructed_ will be kept at true
  // 2. The saved variable has been released by calling
  // SavedVariable::reset_data(), typically during the backward pass
  // 3. Hooks have been registered. In that case, hooks_ will be defined
  // instead. Note that the value of saved_original_ only reflects what happened
  // during the construction of the SavedVariable. If saved_original_ is true,
  // we saved the original tensor in data_, but if the user registers hooks, we
  // will no longer have it (despite the saved_original_ still being true)
  at::Tensor data_;

  // This field is used to store the forward AD gradients associated with
  // the saved Tensor. Note that this shared_ptr must never be shared with
  // either the saved Tensor or the unpacked Tensor. See note [ Using
  // ForwardGrad ]
  std::shared_ptr<ForwardGrad> fw_grad_;

  // Weak version of grad_fn_ that prevents leaks in rebase_history() for
  // inplace views.
  // This variable is used when the user chooses to create a SavedVariable with
  // is_inplace_on_view = true.
  // In that case, the grad_fn passed in to the unpack function at unwrapping
  // time is unused.
  std::weak_ptr<Node> weak_grad_fn_;
  c10::VariableVersion version_counter_;

  uint32_t saved_version_ = 0;
  uint32_t output_nr_ = 0;
  bool was_default_constructed_ = true;
  bool is_inplace_on_view_ = false;
  bool saved_original_ = false;
  bool is_leaf_ = false;
  bool is_output_ = false;

  // Hooks are a pair of functions pack_hook/unpack_hook that provides
  // fine-grained control over how the SavedVariable should save its data.
  // pack_hook is called upon registration, while unpack_hook is called when
  // unpacking.
  std::unique_ptr<SavedVariableHooks> hooks_;
  // Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if
  // hooks are defined. They are set before pack_hook is called and used after
  // unpack_hook is called.
  std::shared_ptr<Node> grad_fn_;
  std::weak_ptr<Node> grad_accumulator_;
  bool requires_grad_ = false;

  void save_metadata(const Variable& data);
  static std::unique_ptr<SavedVariableHooks> get_default_hooks();
  void set_hooks_and_pack_data(
      std::unique_ptr<SavedVariableHooks>&& hooks,
      const Variable& data);
};
} // namespace autograd
} // namespace torch