1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
|
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import os
import shutil
import tempfile
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple
import torch.distributed as dist
def with_temp_dir(
func: Optional[Callable] = None,
) -> Optional[Callable]:
"""
Wrapper to initialize temp directory for distributed checkpoint.
"""
assert func is not None
@wraps(func)
def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None:
if dist.is_initialized():
# Only create temp_dir when rank is 0
if dist.get_rank() == 0:
temp_dir = tempfile.mkdtemp()
print(f"Using temp directory: {temp_dir}")
else:
temp_dir = ""
object_list = [temp_dir]
# Broadcast temp_dir to all the other ranks
os.sync()
dist.broadcast_object_list(object_list)
self.temp_dir = object_list[0]
os.sync()
else:
temp_dir = tempfile.mkdtemp()
print(f"No process group initialized, using temp directory: {temp_dir}")
self.temp_dir = temp_dir
try:
func(self, *args, **kwargs)
finally:
if dist.is_initialized() and dist.get_rank() == 0:
shutil.rmtree(self.temp_dir, ignore_errors=True)
else:
shutil.rmtree(self.temp_dir, ignore_errors=True)
return wrapper
|