File: comm.hpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (129 lines) | stat: -rw-r--r-- 3,880 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#pragma once

#include <ATen/ATen.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/Export.h>

namespace c10d {

// Broadcast many tensors to all processes in the process group.
TORCH_API void broadcast_coalesced(
    c10::intrusive_ptr<c10d::ProcessGroup> process_group,
    at::TensorList tensors,
    size_t buffer_size,
    int rank = 0);

// This class passes bucket contents tensor to DDP communication hook.
class TORCH_API GradBucket {
 public:
  explicit GradBucket(
      size_t index,
      size_t bucket_count,
      const at::Tensor& tensor,
      const std::vector<size_t>& offsets,
      const std::vector<size_t>& lengths,
      const std::vector<c10::IntArrayRef>& sizes_vec,
      const std::vector<at::Tensor>& parameters)
      : index_(index),
        bucket_count_(bucket_count),
        buffer_(tensor),
        offsets_(offsets),
        lengths_(lengths),
        sizes_vec_(sizes_vec),
        parameters_(parameters) {}

  // Returns the index of the bucket, which is unique across all the buckets.
  size_t getIndex() const {
    return index_;
  }

  const at::Tensor& getBuffer() const {
    return buffer_;
  }

  // Returns a mutable buffer compared with the above method.
  at::Tensor& getBufferRef() {
    return buffer_;
  }

  // Overwrites the buffer at a specific index.
  void setBuffer(at::Tensor& buffer) {
    buffer_ = buffer;
  }

  // Each tensor in the list that getGradients corresponds to a
  // parameter.
  std::vector<at::Tensor> getGradients() const;

  // Returns model parameters belonging to this bucket. They are returned in the
  // same order as gradient tensors via getGradients(). For example,
  // getParameters[i] will have its gradient stored in
  // getGradients[i]
  const std::vector<at::Tensor> getParameters() const {
    return parameters_;
  }

  // Returns whther this bucket is the last bucket to allreduce in an iteration.
  bool isLast() const {
    return index_ == bucket_count_ - 1;
  }

 private:
  size_t index_;
  size_t bucket_count_;
  at::Tensor buffer_;

  // Per-variable info in buffer_.
  std::vector<size_t> offsets_;
  std::vector<size_t> lengths_;
  std::vector<c10::IntArrayRef> sizes_vec_;
  // Model parameters for this bucket.
  const std::vector<at::Tensor> parameters_;
};

// Base class of both `PythonCommHook` and `CppCommHook`.
// Requires implementing 1) `runHook` method that communicates gradients
// asynchronously, and 2) `parseHookResult` method that converts the hook
// result into a tensor.
class TORCH_API CommHookInterface {
 public:
  virtual ~CommHookInterface() = default;

  // Passes the input grad bucket to the registered communication hook.
  // Once the tensor in the bucket are ready, kicks off the hook asynchronously
  // and returns a future that holds the communication results.
  virtual c10::intrusive_ptr<c10::ivalue::Future> runHook(
      GradBucket& bucket) = 0;

  // Returns the resulting tensor once the communication hook result is
  // ready. The resulting tensor will then be copied to the grads of
  // individual parameters.
  virtual at::Tensor parseHookResult(
      const c10::IValue& result) = 0;
};

namespace detail {
// This helper function is called both by CppCommHookInterface below and inside
// reducer.
 at::Tensor parseCppCommHookResult(const c10::IValue& result);
} // namespace detail

// This CppCommHook interface only requires implementing runHook method that
// potentially uses a state.
template <typename T>
class CppCommHookInterface : public CommHookInterface {
 public:
  explicit CppCommHookInterface(const T& state) : state_(state) {}

  ~CppCommHookInterface() override = default;

  at::Tensor parseHookResult(const c10::IValue& result) override {
    return detail::parseCppCommHookResult(result);
  }

 protected:
  T state_;
};

} // namespace c10d