File: base.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (43 lines) | stat: -rw-r--r-- 1,615 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import Any, Callable

from torch.utils.data.dataloader import (
    _BaseDataLoaderIter,
    _MultiProcessingDataLoaderIter,
)


class DataLoaderIterator:
    r"""A data loader iterator extended by a simple post transformation
    function :meth:`transform_fn`. While the iterator may request items from
    different sub-processes, :meth:`transform_fn` will always be executed in
    the main process.

    This iterator is used in PyG's sampler classes, and is responsible for
    feature fetching and filtering data objects after sampling has taken place
    in a sub-process. This has the following advantages:

    * We do not need to share feature matrices across processes which may
      prevent any errors due to too many open file handles.
    * We can execute any expensive post-processing commands on the main thread
      with full parallelization power (which usually executes faster).
    * It lets us naturally support data already being present on the GPU.
    """
    def __init__(self, iterator: _BaseDataLoaderIter, transform_fn: Callable):
        self.iterator = iterator
        self.transform_fn = transform_fn

    def __iter__(self) -> 'DataLoaderIterator':
        return self

    def _reset(self, loader: Any, first_iter: bool = False):
        self.iterator._reset(loader, first_iter)

    def __len__(self) -> int:
        return len(self.iterator)

    def __next__(self) -> Any:
        return self.transform_fn(next(self.iterator))

    def __del__(self) -> Any:
        if isinstance(self.iterator, _MultiProcessingDataLoaderIter):
            self.iterator.__del__()