File: dist_utils.py

package info (click to toggle)
pytorch-audio 0.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 5,512 kB
  • sloc: python: 15,606; cpp: 1,352; sh: 257; makefile: 21
file content (86 lines) | stat: -rw-r--r-- 2,585 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import csv
import types
import logging

import torch
import torch.distributed as dist


def _info_on_master(self, *args, **kwargs):
    if dist.get_rank() == 0:
        self.info(*args, **kwargs)


def getLogger(name):
    """Get logging.Logger module with additional ``info_on_master`` method."""
    logger = logging.getLogger(name)
    logger.info_on_master = types.MethodType(_info_on_master, logger)
    return logger


_LG = getLogger(__name__)


def setup_distributed(
    world_size, rank, local_rank, backend="nccl", init_method="env://"
):
    """Perform env setup and initialization for distributed training"""
    if init_method == "env://":
        _set_env_vars(world_size, rank, local_rank)
    if world_size > 1 and "OMP_NUM_THREADS" not in os.environ:
        _LG.info("Setting OMP_NUM_THREADS == 1")
        os.environ["OMP_NUM_THREADS"] = "1"
    params = {
        "backend": backend,
        "init_method": init_method,
        "world_size": world_size,
        "rank": rank,
    }
    _LG.info("Initializing distributed process group with %s", params)
    dist.init_process_group(**params)
    _LG.info("Initialized distributed process group.")


def _set_env_vars(world_size, rank, local_rank):
    for key, default in [("MASTER_ADDR", "127.0.0.1"), ("MASTER_PORT", "29500")]:
        if key not in os.environ:
            os.environ[key] = default

    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(local_rank)


def save_on_master(path, obj):
    if dist.get_rank() == 0:
        _LG.info("Saving %s", path)
        torch.save(obj, path)


def write_csv_on_master(path, *rows):
    if dist.get_rank() == 0:
        with open(path, "a", newline="") as fileobj:
            writer = csv.writer(fileobj)
            for row in rows:
                writer.writerow(row)


def synchronize_params(path, device, *modules):
    if dist.get_world_size() < 2:
        return
    rank = dist.get_rank()
    if rank == 0:
        _LG.info("[Parameter Sync]: Saving parameters to a temp file...")
        torch.save({f"{i}": m.state_dict() for i, m in enumerate(modules)}, path)
    dist.barrier()
    if rank != 0:
        _LG.info("[Parameter Sync]:    Loading parameters...")
        data = torch.load(path, map_location=device)
        for i, m in enumerate(modules):
            m.load_state_dict(data[f"{i}"])
    dist.barrier()
    if rank == 0:
        _LG.info("[Parameter Sync]: Removing the temp file...")
        os.remove(path)
    _LG.info_on_master("[Parameter Sync]: Complete.")