1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
import warnings
from contextlib import nullcontext
from functools import partial
from typing import Any, Optional
import torch
from torch.utils.data import DataLoader
from torch_geometric.typing import WITH_IPEX
class DeviceHelper:
def __init__(self, device: Optional[torch.device] = None):
with_cuda = torch.cuda.is_available()
with_xpu = torch.xpu.is_available() if WITH_IPEX else False
if device is None:
if with_cuda:
device = 'cuda'
elif with_xpu:
device = 'xpu'
else:
device = 'cpu'
self.device = torch.device(device)
self.is_gpu = self.device.type in ['cuda', 'xpu']
if ((self.device.type == 'cuda' and not with_cuda)
or (self.device.type == 'xpu' and not with_xpu)):
warnings.warn(
f"Requested device '{self.device.type}' is not "
f"available, falling back to CPU", stacklevel=2)
self.device = torch.device('cpu')
self.stream = None
self.stream_context = nullcontext
self.module = getattr(torch, self.device.type) if self.is_gpu else None
def maybe_init_stream(self) -> None:
if self.is_gpu:
self.stream = self.module.Stream()
self.stream_context = partial(
self.module.stream,
stream=self.stream,
)
def maybe_wait_stream(self) -> None:
if self.stream is not None:
self.module.current_stream().wait_stream(self.stream)
class PrefetchLoader:
r"""A GPU prefetcher class for asynchronously transferring data of a
:class:`torch.utils.data.DataLoader` from host memory to device memory.
Args:
loader (torch.utils.data.DataLoader): The data loader.
device (torch.device, optional): The device to load the data to.
(default: :obj:`None`)
"""
def __init__(
self,
loader: DataLoader,
device: Optional[torch.device] = None,
):
self.loader = loader
self.device_helper = DeviceHelper(device)
def non_blocking_transfer(self, batch: Any) -> Any:
if not self.device_helper.is_gpu:
return batch
if isinstance(batch, (list, tuple)):
return [self.non_blocking_transfer(v) for v in batch]
if isinstance(batch, dict):
return {k: self.non_blocking_transfer(v) for k, v in batch.items()}
batch = batch.pin_memory()
return batch.to(self.device_helper.device, non_blocking=True)
def __iter__(self) -> Any:
first = True
self.device_helper.maybe_init_stream()
batch = None
for next_batch in self.loader:
with self.device_helper.stream_context():
next_batch = self.non_blocking_transfer(next_batch)
if not first:
yield batch
else:
first = False
self.device_helper.maybe_wait_stream()
batch = next_batch
yield batch
def __len__(self) -> int:
return len(self.loader)
def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.loader})'
|