File: benchmark_ddp_rpc.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (365 lines) | stat: -rw-r--r-- 11,803 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import argparse
import io
import os
import random
import shlex
import subprocess
import time

import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.optim as optim
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions
from torch.distributed.rpc.backend_registry import BackendType
from torch.nn.parallel import DistributedDataParallel as DDP


# Config
NUM_TRAINERS = 8
NUM_PS = 8

NUM_EMBEDDINGS = 300
EMBEDDING_DIM = 64

WARMUP_CYCLES = 5


class HybridModel(torch.nn.Module):
    r"""
   The model consists of a sparse part and a dense part. The dense part is an
   nn.Linear module that is replicated across all trainers using
   DistributedDataParallel. The sparse part has nn.EmbeddingBags stored on multiple
   parameter servers.

   The model holds a Remote Reference to the embedding tables on the parameter
   servers.
   """

    def __init__(self, emb_rref_list, device):
        super(HybridModel, self).__init__()
        self.emb_rref_list = emb_rref_list
        fc1 = torch.nn.Linear(512, 256)
        fc2 = torch.nn.Linear(256, 128)
        relu = torch.nn.ReLU()
        fc3 = torch.nn.Linear(128, 64)
        fc4 = torch.nn.Linear(64, 32)
        fc5 = torch.nn.Linear(32, 8)
        sec = nn.Sequential(fc1, fc2, relu, fc3, fc4, fc5)
        self.ddp = DDP(sec.to(device), device_ids=[device])
        self.device = device

    def forward(self, indices, offsets):
        emb_lookups = []

        for emb_rref in self.emb_rref_list:
            emb_lookups.append(
                emb_rref.rpc_sync().forward(
                    indices, offsets
                )  # embedding_sum(input, offsets)
            )
            emb_lookups_cat = torch.cat(emb_lookups, dim=1)

        # Make sure combined PS dimension is always bigger or equal than the FC input
        assert NUM_PS * EMBEDDING_DIM >= 512
        dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512)
        emb_lookups_reshaped = emb_lookups_cat.reshape(
            [emb_lookups_cat.shape[0] * dim_normalizer, 512]
        )

        return self.ddp(emb_lookups_reshaped)


def _retrieve_embedding_parameters(emb_rref):
    return [RRef(p) for p in emb_rref.local_value().parameters()]


def _print_header():
    _print_cont("\n")
    _print_cont("%10s" % "")
    for p in [50, 75, 90, 95]:
        _print_cont("%14s%10s" % ("sec/epoch", "epoch/sec"))
    _print_cont("\n")


def _print_benchmark(prefix, nelem, measurements):
    measurements = sorted(measurements)
    _print_cont("%8s:" % prefix)
    for p in [50, 75, 90, 95]:
        v = np.percentile(measurements, p)
        _print_cont("  p%02d:  %1.3fs  %6d/s" % (p, v, nelem / v))
    _print_cont("\n")


def _print_cont(msg):
    print(msg, end="", flush=True)


def _run_printable(cmd):
    proc = subprocess.run(shlex.split(cmd), capture_output=True)  # type: ignore[call-overload]
    assert proc.returncode == 0

    buffer = io.BytesIO()
    torch.save(proc.stdout.decode("utf-8"), buffer)
    input_tensor = torch.ByteTensor(list(buffer.getvalue()))
    input_length = torch.IntTensor([input_tensor.size(0)])

    output = []
    buffer = io.BytesIO(np.asarray(input_tensor).tobytes())
    output.append(torch.load(buffer))
    return output


def _run_trainer(emb_rref_list, rank):
    r"""
   Each trainer runs a forward pass which involves an embedding lookup on the
   8 parameter servers and running nn.Linear locally. During the backward pass,
   DDP is responsible for aggregating the gradients for the dense part
   (nn.Linear) and distributed autograd ensures gradients updates are
   propagated to the parameter servers.
   """

    # Setup the model.
    model = HybridModel(emb_rref_list, rank)

    # Retrieve all model parameters as rrefs for DistributedOptimizer.

    # Retrieve parameters from all embedding tables for the current trainer.
    model_parameter_rrefs = []
    for ind, emb_rref in enumerate(emb_rref_list):
        ps_name = "ps{}".format(ind)
        model_parameter_rrefs.extend(
            rpc.rpc_sync(ps_name, _retrieve_embedding_parameters, args=(emb_rref,))
        )

    # model.parameters() only includes local parameters.
    for param in model.parameters():
        model_parameter_rrefs.append(RRef(param))

    # Setup distributed optimizer
    opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05)

    criterion = torch.nn.CrossEntropyLoss()

    def get_next_batch(rank):
        for _ in range(10):
            num_indices = random.randint(20, 50)
            indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)

            # Generate offsets.
            offsets = []
            start = 0
            batch_size = 0

            while start < num_indices:
                offsets.append(start)
                start += random.randint(1, 10)
                batch_size += 1

            offsets_tensor = torch.LongTensor(offsets)
            target = torch.LongTensor(batch_size).random_(8).cuda(rank)

            yield indices, offsets_tensor, target

    measurements = []
    # Include warm-up cycles during training
    for epoch in range(100 + WARMUP_CYCLES):
        start = time.time()
        batch_size = 0

        # create distributed autograd context
        for indices, offsets, target in get_next_batch(rank):
            batch_size += len(target)

            with dist_autograd.context() as context_id:
                output = model(indices, offsets)
                loss = criterion(output, target)

                # Run distributed backward pass
                dist_autograd.backward(context_id, [loss])

                # Run distributed optimizer. Gradients propagated all the way to the parameter servers
                opt.step(context_id)

                # Not necessary to zero grads as each iteration creates a different
                # distributed autograd context which hosts different grads

        measurements.append(time.time() - start)
        # print("Training done for epoch {}".format(epoch))

    # Throw away warm-up measurements
    measurements = measurements[WARMUP_CYCLES:]
    return rank, measurements, batch_size


def run_worker(rank, world_size):
    r"""
   A wrapper function that initializes RPC, calls the function, and shuts down
   RPC.
   """

    # Using different port numbers in TCP init_method for init_rpc and
    # init_process_group to avoid port conflicts.
    rpc_backend_options = TensorPipeRpcBackendOptions()
    rpc_backend_options.init_method = "tcp://localhost:29500"

    # Rank 16. Master
    if rank == (NUM_TRAINERS + NUM_PS):

        rpc.init_rpc(
            "master", rank=rank,
            backend=BackendType.TENSORPIPE,  # type: ignore[attr-defined]
            world_size=world_size
        )


        # Build the Embedding tables on the Parameter Servers.
        emb_rref_list = []
        index = 0
        while index < NUM_PS:
            ps_name = "ps{}".format(index)
            emb_rref = rpc.remote(
                ps_name,
                torch.nn.EmbeddingBag,
                args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
                kwargs={"mode": "sum"},
            )
            emb_rref_list.append(emb_rref)
            index += 1

        # Run training loop on the trainers.
        futs = []
        for trainer_rank in range(NUM_TRAINERS):
            trainer_name = "trainer{}".format(trainer_rank)
            fut = rpc.rpc_async(
                trainer_name, _run_trainer, args=(emb_rref_list, trainer_rank)
            )
            futs.append(fut)

        _print_header()

        measurements_all_trainers = []
        batch_size_all_trainers = 0
        # Wait for all training to finish.
        for fut in futs:
            rank, measurements, batch_size = fut.wait()
            _print_benchmark("Trainer{}".format(rank), batch_size, measurements)
            batch_size_all_trainers += batch_size
            measurements_all_trainers.append(measurements)

        _print_benchmark("All", batch_size_all_trainers, measurements_all_trainers)

    # Rank 0-7. Trainers
    elif rank >= 0 and rank < NUM_PS:

        # Initialize process group for Distributed DataParallel on trainers.
        dist.init_process_group(
            backend=dist.Backend.GLOO,
            rank=rank,
            world_size=NUM_TRAINERS,
            init_method="tcp://localhost:29501",
        )

        # Initialize RPC. Trainer just waits for RPCs from master.
        trainer_name = "trainer{}".format(rank)
        rpc.init_rpc(
            trainer_name,
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

    # Rank 8-15. Parameter Servers
    elif rank >= NUM_TRAINERS and rank < NUM_TRAINERS + NUM_PS:
        ps_name = "ps{}".format(rank - NUM_TRAINERS)
        rpc.init_rpc(
            ps_name,
            rank=rank,
            world_size=world_size,
            backend=BackendType.TENSORPIPE,  # type: ignore[attr-defined]
            rpc_backend_options=rpc_backend_options,
        )
        # parameter server do nothing
        pass

    # block until all rpcs finish
    rpc.shutdown()


if __name__ == "__main__":
    """ Initializing the distributed environment. """

    output = _run_printable("nvidia-smi topo -m")
    print("-------------------------------------------")
    print("                  Info                     ")
    print("-------------------------------------------")
    print("")
    print("* PyTorch version: {}".format(torch.__version__))
    print("* CUDA version: {}".format(torch.version.cuda))
    print("")
    print("------------ nvidia-smi topo -m -----------")
    print("")
    print(output[0])
    print("-------------------------------------------")
    print("PyTorch Distributed Benchmark (DDP and RPC)")
    print("-------------------------------------------")

    # Cmd arguments to enable automated runs (e.g. Chronos, SSH, etc).
    parser = argparse.ArgumentParser(description="PyTorch DDP and RPC Benchmark")
    parser.add_argument(
        "--master-addr", type=str, default="localhost", help="Address of master node."
    )
    parser.add_argument("--master-port", type=str, default="29500", help="Master port.")

    parser.add_argument(
        "--number-trainers",
        type=int,
        default=NUM_TRAINERS,
        help="Number of Trainer Nodes.",
    )
    parser.add_argument(
        "--number-ps", type=int, default=NUM_PS, help="Number of Parameter Servers."
    )
    parser.add_argument(
        "--number-embeddings",
        type=int,
        default=NUM_EMBEDDINGS,
        help="Number of test embeddings to be generated.",
    )
    parser.add_argument(
        "--embedding-dim",
        type=int,
        default=EMBEDDING_DIM,
        help="Number of embedding dimentions.",
    )
    parser.add_argument(
        "--warmup-cycles",
        type=int,
        default=WARMUP_CYCLES,
        help="Number of cycles to warm-up each process before running the benchmark.",
    )

    args = parser.parse_args()

    os.environ["MASTER_ADDR"] = args.master_addr
    os.environ["MASTER_PORT"] = args.master_port

    NUM_TRAINERS = args.number_trainers
    NUM_PS = args.number_ps

    NUM_EMBEDDINGS = args.number_embeddings
    EMBEDDING_DIM = args.embedding_dim

    WARMUP_CYCLES = args.warmup_cycles

    # Defaults:
    #  8 trainers (rank 0-7),
    #  8 parameter servers (rank 8-15),
    #  1 master (rank 16).
    world_size = NUM_TRAINERS + NUM_PS + 1  # Trainers + PS + Master
    mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)