File: container.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (167 lines) | stat: -rw-r--r-- 6,426 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#pragma once

#include <mutex>
#include <unordered_map>

#include <torch/csrc/distributed/autograd/context/context.h>

namespace torch {
namespace distributed {
namespace autograd {

// Singleton class per worker which is responsible for storing the distributed
// autograd context for each autograd pass and also cleans up data for an
// autograd pass once its done.
//
// Each autograd pass is assigned a unique autograd_context_id and all data for
// that pass (DistAutogradContext) is stored in this container indexed by the
// autograd_context_id. The autograd_context_id itself is a 64 bit globally
// unique id. The first 16 bits is the worker_id and the next 48 bits is an
// auto-incrementing id for each worker.
//
// This container is also responsible for maintaining a globally unique message
// id, which is used to associate send/recv autograd function pairs. The format
// is similar to the autograd_context_id where we have a 64 bit integer with
// first 16 bits being the worker id and next 48 bits are auto-incrementing.
class TORCH_API DistAutogradContainer {
 public:
  explicit DistAutogradContainer(uint32_t num_shards);

  // One time initialization of the container.
  static DistAutogradContainer& init(int64_t worker_id);

  // Retrieve the singleton instance of the container, ensures we have
  // initialized the container.
  static DistAutogradContainer& getInstance();

  // Create a new context for a distributed autograd pass.
  const ContextPtr newContext();

  // Clean up resources for a given context_id once the autograd pass is done.
  // Sends RPC to other workers this worker knows about, telling them to clean
  // up their context as well. Throws an exception if the context_id does not
  // exist.
  void releaseContext(int64_t context_id);

  // Releases an autograd context if it is present on this node. Also sends RPC
  // to other workers this worker knows about, telling them to clean up their
  // context. Does nothing if it is not present.
  void releaseContextIfPresent(int64_t context_id);

  // Checks if the passed in context_id is valid.
  void isValidContext(int64_t context_id);

  // Retrieve the autograd context for a given context_id.
  ContextPtr retrieveContext(int64_t context_id);

  // Retrieves the currently active autograd context for the current thread.
  ContextPtr currentContext();

  // Checks whether or not the current thread has a valid autograd context.
  bool hasValidContext() const;

  // Generate a new autograd_message_id for send/recv autograd functions.
  int64_t newAutogradMessageId();

  // Creates a new autograd context with the provided context_id. If a context
  // already exists with the provided context_id, we just return it.
  // This does not set the current context for the current thread.
  ContextPtr getOrCreateContext(int64_t context_id);

  // Retrieves the maximum possible autograd_context_id/autograd_message_id that
  // can be generated by this worker.
  int64_t getMaxId();

  // Retrieves the worker ID for this node
  rpc::worker_id_t getWorkerId() const;

  // Can set current context id if there is no valid context yet
  static void setCurrentContextId(int64_t contextId);

  // Forcibly sets the thread local current context id. Should only be used in
  // cases where you know what you're doing and need to override the thread
  // local. Otherwise, use setCurrentContextId instead.
  static void forceCurrentContextId(int64_t contextId);

  // Clear current context id
  void clearCurrentContext();

  // Returns the number of autograd contexts in the container.
  size_t numAutogradContexts() const;

  // Returns the current thread local context id for this thread.
  static int64_t currentContextId();

  DistAutogradContainer(const DistAutogradContainer&) = delete;
  DistAutogradContainer& operator=(const DistAutogradContainer&) = delete;
  DistAutogradContainer(DistAutogradContainer&&) = delete;
  DistAutogradContainer& operator=(DistAutogradContainer&&) = delete;

 private:
  // Number of shards for the map storing autograd contexts. We'd like this
  // to be a power of 2 and we don't expect a value much higher than the
  // number of cores would provide much benefit.
  static constexpr uint32_t kNumDefaultShards = 128;

  // Use cache line size for alignment.
  static constexpr int kCacheLineSize = 64;

  // Structure holding one shard of the sharded autograd context map with its
  // associated lock. Align to cache line size to avoid contention between
  // adjacent entries.
  struct alignas(kCacheLineSize) ContextsShard {
    // Lock for this shard.
    mutable std::mutex lock;

    // Map storing autograd contexts for this shard.
    std::unordered_map<int64_t, ContextPtr> contexts;
  };

  DistAutogradContainer();
  ~DistAutogradContainer() = default;

  static DistAutogradContainer& getInstanceInternal();

  // Retrieve the shard for given context_id.
  ContextsShard& getShard(int64_t context_id);

  // Sends an RPC to the workers that have a context corresponding to passed in
  // context_id. This function should be called with the lock.
  void sendReleaseContextRpc(
      const std::unordered_set<rpc::worker_id_t>& workerIds,
      int64_t context_id);

  // Erase context_id from the autograd context map, and reset the thread local
  // current context id if it corresponds to the passed in context id. This
  // function should be called with the lock.
  void eraseContextIdAndReset(ContextsShard& shard, int64_t context_id);

  // Compute the number of shards for the autograd_contexts_ map.
  static uint32_t computeNumShards();

  // Auto incrementing context id used to identify unique autograd passes.
  // Initialized with the first 16 bits being the worker_id.
  std::atomic<int64_t> next_context_id_;

  // Unique id to identify a worker in the distributed setting.
  int16_t worker_id_;

  // Whether or not the container has been initialized appropriately.
  bool initialized_;

  // Sharded autograd context map.
  std::vector<ContextsShard> autograd_contexts_;

  // Number of shards for the sharded autograd_contexts_ map.
  uint32_t num_shards_;

  // Autograd message id to identify unique send/recv autograd function pairs.
  std::atomic<int64_t> next_autograd_message_id_;

  // Maximum allowed value for autograd_context_id or autograd_message_id.
  int64_t max_id_;
};

} // namespace autograd
} // namespace distributed
} // namespace torch