File: utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (34 lines) | stat: -rw-r--r-- 1,147 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from torch import nn
from typing import List

def partition_model(
        module: nn.Sequential,
        balance: List[int],
        devices: List[int] = None):
    """
    Given an :class:`nn.Sequential <torch.nn.Sequential>` module, partitions
    the model across multiple GPU devices according the provided ``balance``
    and ``devices``.

    Args:
        module (:class:`nn.Sequential <torch.nn.Sequential>`):
            Sequential model representing the pipe.
        balance (List[int]):
            List indicating the number of layers in each partition.
        devices (List[int], optional):
            List indicating the device to use for each partition. Defaults to
            ``range(len(balance))``
    """
    device_idx = 0
    pipe_idx = 0
    balanced_pipe = []
    for num_layers in balance:
        layers = []
        for i in range(num_layers):
            layers.append(module[pipe_idx])
            pipe_idx += 1
        device = device_idx if devices is None else devices[device_idx]
        balanced_pipe.append(nn.Sequential(*layers).to(device))
        device_idx += 1

    return nn.Sequential(*balanced_pipe)