1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
|
import os
import csv
import types
import logging
import torch
import torch.distributed as dist
def _info_on_master(self, *args, **kwargs):
if dist.get_rank() == 0:
self.info(*args, **kwargs)
def getLogger(name):
"""Get logging.Logger module with additional ``info_on_master`` method."""
logger = logging.getLogger(name)
logger.info_on_master = types.MethodType(_info_on_master, logger)
return logger
_LG = getLogger(__name__)
def setup_distributed(
world_size, rank, local_rank, backend="nccl", init_method="env://"
):
"""Perform env setup and initialization for distributed training"""
if init_method == "env://":
_set_env_vars(world_size, rank, local_rank)
if world_size > 1 and "OMP_NUM_THREADS" not in os.environ:
_LG.info("Setting OMP_NUM_THREADS == 1")
os.environ["OMP_NUM_THREADS"] = "1"
params = {
"backend": backend,
"init_method": init_method,
"world_size": world_size,
"rank": rank,
}
_LG.info("Initializing distributed process group with %s", params)
dist.init_process_group(**params)
_LG.info("Initialized distributed process group.")
def _set_env_vars(world_size, rank, local_rank):
for key, default in [("MASTER_ADDR", "127.0.0.1"), ("MASTER_PORT", "29500")]:
if key not in os.environ:
os.environ[key] = default
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(local_rank)
def save_on_master(path, obj):
if dist.get_rank() == 0:
_LG.info("Saving %s", path)
torch.save(obj, path)
def write_csv_on_master(path, *rows):
if dist.get_rank() == 0:
with open(path, "a", newline="") as fileobj:
writer = csv.writer(fileobj)
for row in rows:
writer.writerow(row)
def synchronize_params(path, device, *modules):
if dist.get_world_size() < 2:
return
rank = dist.get_rank()
if rank == 0:
_LG.info("[Parameter Sync]: Saving parameters to a temp file...")
torch.save({f"{i}": m.state_dict() for i, m in enumerate(modules)}, path)
dist.barrier()
if rank != 0:
_LG.info("[Parameter Sync]: Loading parameters...")
data = torch.load(path, map_location=device)
for i, m in enumerate(modules):
m.load_state_dict(data[f"{i}"])
dist.barrier()
if rank == 0:
_LG.info("[Parameter Sync]: Removing the temp file...")
os.remove(path)
_LG.info_on_master("[Parameter Sync]: Complete.")
|