1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
|
#pragma once
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/comm.hpp>
namespace c10d {
enum class BuiltinCommHookType {
ALLREDUCE = 1,
FP16_COMPRESS = 2,
};
class AllReduceCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
public:
explicit AllReduceCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
: CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
~AllReduceCommHook() override = default;
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
};
class FP16CompressCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
public:
explicit FP16CompressCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
: CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
~FP16CompressCommHook() override = default;
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
};
// Almost same as AllReduceCommHook, but without division inside the hook.
// This enables the optimization of fusing copy and division and saves one scan
// over all the input parameters, when no communication hook is provided by the user.
// Only used internally and not released as a public built-in communication hook.
class _AllReduceBySumCommHook
: public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
public:
explicit _AllReduceBySumCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
: CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
~_AllReduceBySumCommHook() override = default;
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
};
} // namespace c10d
|