File: _replicated_tensor_ddp_interop.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (46 lines) | stat: -rw-r--r-- 1,836 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch.distributed._shard.replicated_tensor import ReplicatedTensor

class ReplicatedTensorFunction(torch.autograd.Function):
    """
    Autograd function to ensure gradients are replicated between the
    replicated tensor and the original one.
    """
    @staticmethod
    def forward(ctx, inp, process_group=None):
        # set_materialize_grads(False) will ensure that None gradients stay as
        # None and are not filled with zeros.
        ctx.set_materialize_grads(False)
        return ReplicatedTensor(inp, process_group)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

def _make_replicated_tensor(tensor, process_group):
    replicated_tensor = ReplicatedTensorFunction.apply(tensor, process_group)
    replicated_tensor.grad = tensor.grad
    return replicated_tensor

def _replicate_module_recurse(module, process_group):
    replica = module._replicate_for_data_parallel()
    for param_name, param in module._parameters.items():
        if param is not None:
            setattr(replica, param_name, _make_replicated_tensor(param, process_group))
        else:
            setattr(replica, param_name, param)

    for buffer_name, buffer in module._buffers.items():
        setattr(replica, buffer_name, buffer)

    for module_name, child in module._modules.items():
        setattr(replica, module_name, _replicate_module_recurse(child, process_group))
    return replica

def _replicate_module(network, process_group):
    from torch.nn.parallel.replicate import _replicatable_module  # type: ignore[attr-defined]
    if not _replicatable_module(network):
        raise RuntimeError("Cannot replicate network where python modules are "
                           "childrens of ScriptModule")

    return _replicate_module_recurse(network, process_group)