1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
import warnings
from typing import List, Optional
import torch
from torch._utils import _get_device_index
from torch.autograd import Function
from torch.nn.parallel import comm
class Broadcast(Function):
@staticmethod
def forward(ctx, target_gpus, *inputs):
assert all(
i.device.type != "cpu" for i in inputs
), "Broadcast function not implemented for CPU tensors"
target_gpus = [_get_device_index(x, True) for x in target_gpus]
ctx.target_gpus = target_gpus
if len(inputs) == 0:
return ()
ctx.num_inputs = len(inputs)
ctx.input_device = inputs[0].get_device()
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
non_differentiables = []
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
if not input_requires_grad:
non_differentiables.extend(output[idx] for output in outputs)
ctx.mark_non_differentiable(*non_differentiables)
return tuple([t for tensors in outputs for t in tensors])
@staticmethod
def backward(ctx, *grad_outputs):
return (None,) + ReduceAddCoalesced.apply(
ctx.input_device, ctx.num_inputs, *grad_outputs
)
class ReduceAddCoalesced(Function):
@staticmethod
def forward(ctx, destination, num_inputs, *grads):
ctx.target_gpus = [
grads[i].get_device() for i in range(0, len(grads), num_inputs)
]
grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)]
return comm.reduce_add_coalesced(grads_, destination)
@staticmethod
def backward(ctx, *grad_outputs):
return (
None,
None,
) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
class Gather(Function):
@staticmethod
def forward(ctx, target_device, dim, *inputs):
assert all(
i.device.type != "cpu" for i in inputs
), "Gather function not implemented for CPU tensors"
if target_device == "cpu":
ctx.target_device = "cpu"
else:
target_device = _get_device_index(target_device, True)
ctx.target_device = target_device
ctx.dim = dim
ctx.input_gpus = tuple(i.get_device() for i in inputs)
if all(t.dim() == 0 for t in inputs) and dim == 0:
inputs = tuple(t.view(1) for t in inputs)
warnings.warn(
"Was asked to gather along dimension 0, but all "
"input tensors were scalars; will instead unsqueeze "
"and return a vector."
)
ctx.unsqueezed_scalar = True
else:
ctx.unsqueezed_scalar = False
ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
return comm.gather(inputs, ctx.dim, ctx.target_device)
@staticmethod
def backward(ctx, grad_output):
scattered_grads = Scatter.apply(
ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output
)
if ctx.unsqueezed_scalar:
scattered_grads = tuple(g[0] for g in scattered_grads)
return (None, None) + scattered_grads
class Scatter(Function):
@staticmethod
def forward(ctx, target_gpus, chunk_sizes, dim, input):
target_gpus = [_get_device_index(x, True) for x in target_gpus]
ctx.dim = dim
ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
streams = None
if torch.cuda.is_available() and ctx.input_device == -1:
# Perform CPU to GPU copies in a background stream
streams = [
_get_stream(torch.device("cuda", device)) for device in target_gpus
]
outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
# Synchronize with the copy stream
if streams is not None:
for i, output in enumerate(outputs):
with torch.cuda.device(target_gpus[i]):
main_stream = torch.cuda.current_stream()
main_stream.wait_stream(streams[i])
output.record_stream(main_stream)
return outputs
@staticmethod
def backward(ctx, *grad_output):
return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
# background streams used for copying
_streams: Optional[List[Optional[torch.Stream]]] = None
def _get_stream(device: torch.device):
"""Get a background stream for copying between CPU and target device."""
global _streams
if device.type == "cpu":
return None
device_mod = getattr(torch, device.type, None)
if device_mod is None:
return None
if _streams is None:
_streams = [None] * device_mod.device_count()
if _streams[device.index] is None:
_streams[device.index] = device_mod.Stream(device.index)
return _streams[device.index]
|