File: utils.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (60 lines) | stat: -rw-r--r-- 2,026 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch


RPC_SPARSE = "rpc_sparse"
RPC_DENSE = "rpc_dense"


def sparse_tensor_to_rpc_format(sparse_tensor):
    r"""
    A helper function creates a list containing the indices, values, and size
    of a coalesced sparse tensor.
    Args:
        sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
    """
    sparse_tensor = sparse_tensor.coalesce()
    return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()]


def sparse_rpc_format_to_tensor(sparse_rpc_format):
    r"""
    A helper function creates a sparse_coo_tensor from indices, values, and size.
    Args:
        sparse_rpc_format (list): sparse_coo_tensor represented as a list
    """
    return torch.sparse_coo_tensor(
        sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2]
    ).coalesce()


def process_bucket_with_remote_server(state, bucket):
    r"""
    Processes a gradient bucket passed by a DDP communication hook
    during .backward(). The method supports processing sparse and dense
    tensors. It records RPC future completion time metric for the trainer.
    Args:
        state (object): maintains state during the training process
        bucket (GradBucket): gradient bucket
    """
    cref = state.cref
    tensor = bucket.buffer()
    if not cref.use_cuda_rpc:
        tensor = tensor.cpu()
    sparse = tensor.is_sparse
    if sparse:
        tensor = sparse_tensor_to_rpc_format(tensor)
    b_index = bucket.get_index()
    server_args = [cref.server_rref, state.batch_number, b_index, tensor]
    key = state.get_key(b_index)
    cref.record_start("hook_future_metric", key, RPC_SPARSE if sparse else RPC_DENSE)
    fut = cref.server_rref.rpc_async().average_gradient(*server_args)

    def callback(fut):
        cref.record_end("hook_future_metric", key)
        tensor = fut.wait()
        if type(tensor) is list:
            tensor = sparse_rpc_format_to_tensor(tensor)
        tensor = tensor.cuda(cref.rank)
        return [tensor]

    return fut.then(callback)