1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
#!/usr/bin/python3
import importlib
import logging
import os
import sys
import tempfile
from typing import Optional
import torch
from torch.distributed.nn.jit.templates.remote_module_template import (
get_remote_module_template,
)
logger = logging.getLogger(__name__)
_FILE_PREFIX = "_remote_module_"
_TEMP_DIR = tempfile.TemporaryDirectory()
INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name
logger.info(f"Created a temporary directory at {INSTANTIATED_TEMPLATE_DIR_PATH}")
sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
def get_arg_return_types_from_interface(module_interface):
assert getattr(
module_interface, "__torch_script_interface__", False
), "Expect a TorchScript class interface decorated by @torch.jit.interface."
qualified_name = torch._jit_internal._qualified_name(module_interface)
cu = torch.jit._state._python_cu
module_interface_c = cu.get_interface(qualified_name)
assert (
"forward" in module_interface_c.getMethodNames()
), "Expect forward in interface methods, while it has {}".format(
module_interface_c.getMethodNames()
)
method_schema = module_interface_c.getMethod("forward")
arg_str_list = []
arg_type_str_list = []
assert method_schema is not None
for argument in method_schema.arguments:
arg_str_list.append(argument.name)
if argument.has_default_value():
default_value_str = " = {}".format(argument.default_value)
else:
default_value_str = ""
arg_type_str = "{name}: {type}{default_value}".format(
name=argument.name, type=argument.type, default_value=default_value_str
)
arg_type_str_list.append(arg_type_str)
arg_str_list = arg_str_list[1:] # Remove "self".
args_str = ", ".join(arg_str_list)
arg_type_str_list = arg_type_str_list[1:] # Remove "self".
arg_types_str = ", ".join(arg_type_str_list)
assert len(method_schema.returns) == 1
argument = method_schema.returns[0]
return_type_str = str(argument.type)
return args_str, arg_types_str, return_type_str
def _write(out_path, text):
old_text: Optional[str]
try:
with open(out_path, "r") as f:
old_text = f.read()
except IOError:
old_text = None
if old_text != text:
with open(out_path, "w") as f:
logger.info("Writing {}".format(out_path))
f.write(text)
else:
logger.info("Skipped writing {}".format(out_path))
def _do_instantiate_remote_module_template(
generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
):
generated_code_text = get_remote_module_template(
enable_moving_cpu_tensors_to_cuda
).format(**str_dict)
out_path = os.path.join(
INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py"
)
_write(out_path, generated_code_text)
# From importlib doc,
# > If you are dynamically importing a module that was created since
# the interpreter began execution (e.g., created a Python source file),
# you may need to call invalidate_caches() in order for the new module
# to be noticed by the import system.
importlib.invalidate_caches()
generated_module = importlib.import_module(f"{generated_module_name}")
return generated_module
def instantiate_scriptable_remote_module_template(
module_interface_cls, enable_moving_cpu_tensors_to_cuda=True
):
if not getattr(module_interface_cls, "__torch_script_interface__", False):
raise ValueError(
f"module_interface_cls {module_interface_cls} must be a type object decorated by "
"@torch.jit.interface"
)
# Generate the template instance name.
module_interface_cls_name = torch._jit_internal._qualified_name(
module_interface_cls
).replace(".", "_")
generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}"
# Generate type annotation strs.
assign_module_interface_cls_str = (
f"from {module_interface_cls.__module__} import "
f"{module_interface_cls.__name__} as module_interface_cls"
)
args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface(
module_interface_cls
)
kwargs_str = ""
arrow_and_return_type_str = f" -> {return_type_str}"
arrow_and_future_return_type_str = f" -> Future[{return_type_str}]"
str_dict = dict(
assign_module_interface_cls=assign_module_interface_cls_str,
arg_types=arg_types_str,
arrow_and_return_type=arrow_and_return_type_str,
arrow_and_future_return_type=arrow_and_future_return_type_str,
args=args_str,
kwargs=kwargs_str,
jit_script_decorator="@torch.jit.script",
)
return _do_instantiate_remote_module_template(
generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
)
def instantiate_non_scriptable_remote_module_template():
generated_module_name = f"{_FILE_PREFIX}non_scriptable"
str_dict = dict(
assign_module_interface_cls="module_interface_cls = None",
args="*args",
kwargs="**kwargs",
arg_types="*args, **kwargs",
arrow_and_return_type="",
arrow_and_future_return_type="",
jit_script_decorator="",
)
# For a non-scriptable template, always enable moving CPU tensors to a cuda device,
# because there is no syntax limitation on the extra handling caused by the script.
return _do_instantiate_remote_module_template(generated_module_name, str_dict, True)
|