File: prefetch.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (105 lines) | stat: -rw-r--r-- 3,232 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import warnings
from contextlib import nullcontext
from functools import partial
from typing import Any, Optional

import torch
from torch.utils.data import DataLoader

from torch_geometric.typing import WITH_IPEX


class DeviceHelper:
    def __init__(self, device: Optional[torch.device] = None):
        with_cuda = torch.cuda.is_available()
        with_xpu = torch.xpu.is_available() if WITH_IPEX else False

        if device is None:
            if with_cuda:
                device = 'cuda'
            elif with_xpu:
                device = 'xpu'
            else:
                device = 'cpu'

        self.device = torch.device(device)
        self.is_gpu = self.device.type in ['cuda', 'xpu']

        if ((self.device.type == 'cuda' and not with_cuda)
                or (self.device.type == 'xpu' and not with_xpu)):
            warnings.warn(
                f"Requested device '{self.device.type}' is not "
                f"available, falling back to CPU", stacklevel=2)
            self.device = torch.device('cpu')

        self.stream = None
        self.stream_context = nullcontext
        self.module = getattr(torch, self.device.type) if self.is_gpu else None

    def maybe_init_stream(self) -> None:
        if self.is_gpu:
            self.stream = self.module.Stream()
            self.stream_context = partial(
                self.module.stream,
                stream=self.stream,
            )

    def maybe_wait_stream(self) -> None:
        if self.stream is not None:
            self.module.current_stream().wait_stream(self.stream)


class PrefetchLoader:
    r"""A GPU prefetcher class for asynchronously transferring data of a
    :class:`torch.utils.data.DataLoader` from host memory to device memory.

    Args:
        loader (torch.utils.data.DataLoader): The data loader.
        device (torch.device, optional): The device to load the data to.
            (default: :obj:`None`)
    """
    def __init__(
        self,
        loader: DataLoader,
        device: Optional[torch.device] = None,
    ):
        self.loader = loader
        self.device_helper = DeviceHelper(device)

    def non_blocking_transfer(self, batch: Any) -> Any:
        if not self.device_helper.is_gpu:
            return batch
        if isinstance(batch, (list, tuple)):
            return [self.non_blocking_transfer(v) for v in batch]
        if isinstance(batch, dict):
            return {k: self.non_blocking_transfer(v) for k, v in batch.items()}

        batch = batch.pin_memory()
        return batch.to(self.device_helper.device, non_blocking=True)

    def __iter__(self) -> Any:
        first = True
        self.device_helper.maybe_init_stream()

        batch = None
        for next_batch in self.loader:

            with self.device_helper.stream_context():
                next_batch = self.non_blocking_transfer(next_batch)

            if not first:
                yield batch
            else:
                first = False

            self.device_helper.maybe_wait_stream()

            batch = next_batch

        yield batch

    def __len__(self) -> int:
        return len(self.loader)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.loader})'