File: __init__.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (70 lines) | stat: -rw-r--r-- 2,268 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Dict, Union

import torch
import torch.utils._pytree as pytree
from torch.export.exported_program import ExportedProgram


__all__ = ["move_to_device_pass"]


def move_to_device_pass(
    ep: ExportedProgram, location: Union[torch.device, str, Dict[str, str]]
) -> ExportedProgram:
    """
    Move the exported program to the given device.

    Args:
        ep (ExportedProgram): The exported program to move.
        location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to.
            If a string, it is interpreted as a device name.
            If a dict, it is interpreted as a mapping from
            the existing device to the intended one

    Returns:
        ExportedProgram: The moved exported program.
    """

    def _get_new_device(
        curr_device: torch.device,
        location: Union[torch.device, str, Dict[str, str]],
    ) -> str:
        if isinstance(location, dict):
            if str(curr_device) in location.keys():
                return location[str(curr_device)]
            else:
                return str(curr_device)
        else:
            return str(location)

    # move all the state_dict
    for k, v in ep.state_dict.items():
        if isinstance(v, torch.nn.Parameter):
            ep._state_dict[k] = torch.nn.Parameter(
                v.to(_get_new_device(v.device, location)),
                v.requires_grad,
            )
        else:
            ep._state_dict[k] = v.to(_get_new_device(v.device, location))

    # move all the constants
    for k, v in ep.constants.items():
        if isinstance(v, torch.Tensor):
            ep._constants[k] = v.to(_get_new_device(v.device, location))

    for node in ep.graph.nodes:
        # move all the nodes kwargs with burnt-in device
        if "device" in node.kwargs:
            kwargs = node.kwargs.copy()
            kwargs["device"] = _get_new_device(kwargs["device"], location)
            node.kwargs = kwargs
        # move all the tensor metadata
        node.meta["val"] = pytree.tree_map(
            lambda v: v.to(_get_new_device(v.device, location))
            if isinstance(v, torch.Tensor)
            else v,
            node.meta.get("val"),
        )

    ep.validate()
    return ep