File: utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (68 lines) | stat: -rw-r--r-- 2,093 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch

RPC_SPARSE = "rpc_sparse"
RPC_DENSE = "rpc_dense"


def sparse_tensor_to_rpc_format(sparse_tensor):
    r"""
    A helper function creates a list containing the indices, values, and size
    of a coalesced sparse tensor.
    Args:
        sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
    """
    sparse_tensor = sparse_tensor.coalesce()
    return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()]


def sparse_rpc_format_to_tensor(sparse_rpc_format):
    r"""
    A helper function creates a sparse_coo_tensor from indices, values, and size.
    Args:
        sparse_rpc_format (list): sparse_coo_tensor represented as a list
    """
    return torch.sparse_coo_tensor(
        sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2]
    ).coalesce()


def process_bucket_with_remote_server(state, bucket):
    r"""
    Processes a gradient bucket passed by a DDP communication hook
    during .backward(). The method supports processing sparse and dense
    tensors. It records RPC future completion time metric for the trainer.
    Args:
        state (object): maintains state during the training process
        bucket (GradBucket): gradient bucket
    """
    cref = state.cref
    tensor = bucket.buffer()
    if not cref.use_cuda_rpc:
        tensor = tensor.cpu()
    sparse = tensor.is_sparse
    if sparse:
        tensor = sparse_tensor_to_rpc_format(tensor)
    b_index = bucket.get_index()
    server_args = [
        cref.server_rref,
        state.batch_number,
        b_index,
        tensor
    ]
    key = state.get_key(b_index)
    cref.record_start(
        "hook_future_metric",
        key,
        RPC_SPARSE if sparse else RPC_DENSE
    )
    fut = cref.server_rref.rpc_async().average_gradient(*server_args)

    def callback(fut):
        cref.record_end("hook_future_metric", key)
        tensor = fut.wait()
        if type(tensor) is list:
            tensor = sparse_rpc_format_to_tensor(tensor)
        tensor = tensor.cuda(cref.rank)
        return [tensor]

    return fut.then(callback)