1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
|
import os
from typing import Any, Tuple
import intel_extension_for_pytorch as ipex
import oneccl_bindings_for_pytorch # noqa
import torch.distributed as dist
from benchmark.multi_gpu.training.common import (
get_predefined_args,
run,
supported_sets,
)
from benchmark.utils import get_dataset
def get_dist_params() -> Tuple[int, int, str]:
master_addr = "127.0.0.1"
master_port = "29500"
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
mpi_rank = int(os.environ.get("PMI_RANK", -1))
mpi_world_size = int(os.environ.get("PMI_SIZE", -1))
rank = mpi_rank if mpi_world_size > 0 else os.environ.get("RANK", 0)
world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get(
"WORLD_SIZE", 1))
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
init_method = f"tcp://{master_addr}:{master_port}"
return rank, world_size, init_method
def custom_optimizer(model: Any, optimizer: Any) -> Tuple[Any, Any]:
return ipex.optimize(model, optimizer=optimizer)
if __name__ == '__main__':
rank, world_size, init_method = get_dist_params()
dist.init_process_group(backend="ccl", init_method=init_method,
world_size=world_size, rank=rank)
argparser = get_predefined_args()
args = argparser.parse_args()
setattr(args, 'device', 'xpu')
assert args.dataset in supported_sets.keys(), \
f"Dataset {args.dataset} isn't supported."
# if the dataset is not present, it will be downloaded
# only by process with rank=0,
# other process will use the dataset cache from rank=0,
# and will not re-download and process it
if rank == 0:
data, num_classes = get_dataset(args.dataset, args.root)
dist.barrier()
if rank != 0:
data, num_classes = get_dataset(args.dataset, args.root)
run(rank, world_size, args, num_classes, data, custom_optimizer)
|