File: hooks.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (100 lines) | stat: -rw-r--r-- 3,315 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from utils import process_bucket_with_remote_server

import torch
import torch.distributed as c10d


def allreduce_hook(state, bucket):
    r"""
    A ddp communication hook that uses the process_group allreduce implementation.
    Args:
        state (object): maintains state during the training process
        bucket (GradBucket): gradient bucket
    """
    cref = state.cref
    tensor = bucket.buffer()
    tensors = [tensor / state.process_group.size()]
    key = state.get_key(bucket.get_index())
    if tensor.is_sparse:
        tensor = tensor.coalesce()
    tensor_type = "sparse" if tensor.is_sparse else "dense"
    cref.record_start(
        "hook_future_metric", key, f"{cref.backend}_{tensor_type}_allreduce"
    )
    fut = state.process_group.allreduce(tensors).get_future()

    def callback(fut):
        cref.record_end("hook_future_metric", key)
        return fut.wait()

    return fut.then(callback)


def hybrid_hook(state, bucket):
    r"""
    A ddp communication hook that uses Gloo default process
    group for sparse gradients and NCCL non-default process
    group for dense gradients.
    Args:
        state (object): maintains state during the training process
        bucket (GradBucket): gradient bucket
    """
    cref = state.cref
    tensor = bucket.buffer()
    key = state.get_key(bucket.get_index())

    if tensor.is_sparse:
        cref.record_start("hook_c10d_metric", key, "gloo_sparse_allreduce")
        tensor = tensor.coalesce()
        tensor = tensor / state.process_group.size()
        c10d.all_reduce(tensor, op=c10d.ReduceOp.SUM)
        cref.record_end("hook_c10d_metric", key)
        fut = torch.futures.Future()
        fut.set_result([tensor])
    else:
        cref.record_start("hook_future_metric", key, "nccl_dense_allreduce")
        tensors = [bucket.buffer() / state.process_group.size()]
        fut = state.process_group.allreduce(tensors).get_future()

        def callback(fut):
            cref.record_end("hook_future_metric", key)
            return fut.wait()

        fut = fut.then(callback)
    return fut


def rpc_hook(state, bucket):
    r"""
    A ddp communication hook that averages sparse and dense tensors using
    process_bucket_with_remote_server method.
    Args:
        state (object): maintains state during the training process
        bucket (GradBucket): gradient bucket
    """
    return process_bucket_with_remote_server(state, bucket)


def sparse_rpc_hook(state, bucket):
    r"""
    A ddp communication hook that uses the current backend allreduce
    implementation for dense tensors and a server for sparse tensors.
    Args:
        state (object): maintains state during the training process
        bucket (GradBucket): gradient bucket
    """
    tensor = bucket.buffer()
    if tensor.is_sparse:
        return process_bucket_with_remote_server(state, bucket)
    else:
        cref = state.cref
        tensor = [tensor / state.process_group.size()]
        key = state.get_key(bucket.get_index())
        cref.record_start("hook_future_metric", key, f"{cref.backend}_dense_allreduce")
        fut = state.process_group.allreduce(tensor).get_future()

        def callback(fut):
            cref.record_end("hook_future_metric", key)
            return fut.wait()

        return fut.then(callback)