1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
|
import glob
import logging
import os
import os.path as osp
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
import psutil
import torch
from torch_geometric.data import HeteroData
def get_numa_nodes_cores() -> Dict[str, Any]:
"""Parses numa nodes information into a dictionary.
..code-block::
{<node_id>: [(<core_id>, [<sibling_thread_id_0>, <sibling_thread_id_1>
...]), ...], ...}
# For example:
{0: [(0, [0, 4]), (1, [1, 5])], 1: [(2, [2, 6]), (3, [3, 7])]}
If not available, returns an empty dictionary.
"""
numa_node_paths = glob.glob('/sys/devices/system/node/node[0-9]*')
if not numa_node_paths:
return {}
nodes = {}
try:
for node_path in numa_node_paths:
numa_node_id = int(osp.basename(node_path)[4:])
thread_siblings = {}
for cpu_dir in glob.glob(osp.join(node_path, 'cpu[0-9]*')):
cpu_id = int(osp.basename(cpu_dir)[3:])
if cpu_id > 0:
with open(osp.join(cpu_dir, 'online')) as core_online_file:
core_online = int(
core_online_file.read().splitlines()[0])
else:
core_online = 1 # cpu0 is always online (special case)
if core_online == 1:
with open(osp.join(cpu_dir, 'topology',
'core_id')) as core_id_file:
core_id = int(core_id_file.read().strip())
if core_id in thread_siblings:
thread_siblings[core_id].append(cpu_id)
else:
thread_siblings[core_id] = [cpu_id]
nodes[numa_node_id] = sorted([(k, sorted(v))
for k, v in thread_siblings.items()])
except (OSError, ValueError, IndexError):
Warning('Failed to read NUMA info')
return {}
return nodes
class WorkerInitWrapper:
r"""Wraps the :attr:`worker_init_fn` argument for
:class:`torch.utils.data.DataLoader` workers.
"""
def __init__(self, func: Callable) -> None:
self.func = func
def __call__(self, worker_id: int) -> None:
if self.func is not None:
self.func(worker_id)
class LogMemoryMixin:
r"""A context manager to enable logging of memory consumption in
:class:`~torch.utils.data.DataLoader` workers.
"""
def _mem_init_fn(self, worker_id: int) -> None:
proc = psutil.Process(os.getpid())
memory = proc.memory_info().rss / (1024 * 1024)
logging.debug(f"Worker {worker_id} @ PID {proc.pid}: {memory:.2f} MB")
# Chain worker init functions:
self._old_worker_init_fn(worker_id)
@contextmanager
def enable_memory_log(self) -> None:
self._old_worker_init_fn = WorkerInitWrapper(self.worker_init_fn)
try:
self.worker_init_fn = self._mem_init_fn
yield
finally:
self.worker_init_fn = self._old_worker_init_fn
class MultithreadingMixin:
r"""A context manager to enable multi-threading in
:class:`~torch.utils.data.DataLoader` workers.
It changes the default value of threads used in the loader from :obj:`1`
to :obj:`worker_threads`.
"""
def _mt_init_fn(self, worker_id: int) -> None:
try:
torch.set_num_threads(int(self._worker_threads))
except IndexError:
raise ValueError(f"Cannot set {self.worker_threads} threads "
f"in worker {worker_id}")
# Chain worker init functions:
self._old_worker_init_fn(worker_id)
@contextmanager
def enable_multithreading(
self,
worker_threads: Optional[int] = None,
) -> None:
r"""Enables multithreading in worker subprocesses.
This option requires to change the start method from :obj:`"fork"` to
:obj:`"spawn"`.
.. code-block:: python
def run():
loader = NeigborLoader(data, num_workers=3)
with loader.enable_multithreading(10):
for batch in loader:
pass
if __name__ == '__main__':
torch.set_start_method('spawn')
run()
Args:
worker_threads (int, optional): The number of threads to use in
each worker process.
By default, it uses half of all available CPU cores.
(default: :obj:`torch.get_num_threads() // num_workers`)
"""
if worker_threads is None:
worker_threads = torch.get_num_threads() // self.num_workers
self._worker_threads = worker_threads
if not self.num_workers > 0:
raise ValueError(f"'enable_multithreading' needs to be performed "
f"with at least one worker "
f"(got {self.num_workers})")
if worker_threads > torch.get_num_threads():
raise ValueError(f"'worker_threads' should be smaller than the "
f"total available number of threads "
f"{torch.get_num_threads()} "
f"(got {worker_threads})")
context = torch.multiprocessing.get_context()._name
if context != 'spawn':
raise ValueError(f"'enable_multithreading' can only be used with "
f"the 'spawn' multiprocessing context "
f"(got {context})")
self._old_worker_init_fn = WorkerInitWrapper(self.worker_init_fn)
try:
logging.debug(f"Using {worker_threads} threads in each worker")
self.worker_init_fn = self._mt_init_fn
yield
finally:
self.worker_init_fn = self._old_worker_init_fn
class AffinityMixin:
r"""A context manager to enable CPU affinity for data loader workers
(only used when running on CPU devices).
Affinitization places data loader workers threads on specific CPU cores.
In effect, it allows for more efficient local memory allocation and reduces
remote memory calls.
Every time a process or thread moves from one core to another, registers
and caches need to be flushed and reloaded.
This can become very costly if it happens often, and our threads may also
no longer be close to their data, or be able to share data in a cache.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/advanced/
cpu_affinity.html>`__ for the accompanying tutorial.
.. warning::
To correctly affinitize compute threads (*i.e.* with
:obj:`KMP_AFFINITY`), please make sure that you exclude
:obj:`loader_cores` from the list of cores available for the main
process.
This will cause core oversubsription and exacerbate performance.
.. code-block:: python
loader = NeigborLoader(data, num_workers=3)
with loader.enable_cpu_affinity(loader_cores=[0, 1, 2]):
for batch in loader:
pass
"""
def _aff_init_fn(self, worker_id: int) -> None:
try:
worker_cores = self.loader_cores[worker_id]
if not isinstance(worker_cores, List):
worker_cores = [worker_cores]
if torch.multiprocessing.get_context()._name == 'spawn':
torch.set_num_threads(len(worker_cores))
psutil.Process().cpu_affinity(worker_cores)
except IndexError:
raise ValueError(f"Cannot use CPU affinity for worker ID "
f"{worker_id} on CPU {self.loader_cores}")
# Chain worker init functions:
self._old_worker_init_fn(worker_id)
@contextmanager
def enable_cpu_affinity(
self,
loader_cores: Optional[Union[List[List[int]], List[int]]] = None,
) -> None:
r"""Enables CPU affinity.
Args:
loader_cores ([int], optional): List of CPU cores to which data
loader workers should affinitize to.
By default, it will affinitize to :obj:`numa0` cores.
If used with :obj:`"spawn"` multiprocessing context, it will
automatically enable multithreading and use multiple cores
per each worker.
"""
if not self.num_workers > 0:
raise ValueError(
f"'enable_cpu_affinity' should be used with at least one "
f"worker (got {self.num_workers})")
if loader_cores and len(loader_cores) != self.num_workers:
raise ValueError(
f"The number of loader cores (got {len(loader_cores)}) "
f"in 'enable_cpu_affinity' should match with the number "
f"of workers (got {self.num_workers})")
if isinstance(self.data, HeteroData):
warnings.warn(
"Due to conflicting parallelization methods it is not advised "
"to use affinitization with 'HeteroData' datasets. "
"Use `enable_multithreading` for better performance.")
self.loader_cores = loader_cores[:] if loader_cores else None
if self.loader_cores is None:
numa_info = get_numa_nodes_cores()
if numa_info and len(numa_info[0]) > self.num_workers:
# Take one thread per each node 0 core:
node0_cores = [cpus[0] for core_id, cpus in numa_info[0]]
node0_cores.sort()
else:
node0_cores = list(range(psutil.cpu_count(logical=False)))
if len(node0_cores) < self.num_workers:
raise ValueError(
f"More workers (got {self.num_workers}) than available "
f"cores (got {len(node0_cores)})")
# Set default loader core IDs:
if torch.multiprocessing.get_context()._name == 'spawn':
work_thread_pool = int(len(node0_cores) / self.num_workers)
self.loader_cores = [
list(
range(
work_thread_pool * i,
work_thread_pool * (i + 1),
)) for i in range(self.num_workers)
]
else:
self.loader_cores = node0_cores[:self.num_workers]
self._old_worker_init_fn = WorkerInitWrapper(self.worker_init_fn)
try:
self.worker_init_fn = self._aff_init_fn
logging.debug(f"{self.num_workers} data loader workers are "
f"assigned to CPUs {self.loader_cores}")
yield
finally:
self.worker_init_fn = self._old_worker_init_fn
|