1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
# Owner(s): ["oncall: jit"]
import os
import sys
from typing import Any, Dict, List
import torch
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModuleAPIs(JitTestCase):
def test_default_state_dict_methods(self):
"""Tests that default state dict methods are automatically available"""
class DefaultStateDictModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(6, 16, 5)
self.fc = torch.nn.Linear(16 * 5 * 5, 120)
def forward(self, x):
x = self.conv(x)
x = self.fc(x)
return x
m1 = torch.jit.script(DefaultStateDictModule())
m2 = torch.jit.script(DefaultStateDictModule())
state_dict = m1.state_dict()
m2.load_state_dict(state_dict)
def test_customized_state_dict_methods(self):
"""Tests that customized state dict methods are in effect"""
class CustomStateDictModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(6, 16, 5)
self.fc = torch.nn.Linear(16 * 5 * 5, 120)
self.customized_save_state_dict_called: bool = False
self.customized_load_state_dict_called: bool = False
def forward(self, x):
x = self.conv(x)
x = self.fc(x)
return x
@torch.jit.export
def _save_to_state_dict(
self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
):
self.customized_save_state_dict_called = True
return {"dummy": torch.ones(1)}
@torch.jit.export
def _load_from_state_dict(
self,
state_dict: Dict[str, torch.Tensor],
prefix: str,
local_metadata: Any,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
self.customized_load_state_dict_called = True
return
m1 = torch.jit.script(CustomStateDictModule())
self.assertFalse(m1.customized_save_state_dict_called)
state_dict = m1.state_dict()
self.assertTrue(m1.customized_save_state_dict_called)
m2 = torch.jit.script(CustomStateDictModule())
self.assertFalse(m2.customized_load_state_dict_called)
m2.load_state_dict(state_dict)
self.assertTrue(m2.customized_load_state_dict_called)
def test_submodule_customized_state_dict_methods(self):
"""Tests that customized state dict methods on submodules are in effect"""
class CustomStateDictModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(6, 16, 5)
self.fc = torch.nn.Linear(16 * 5 * 5, 120)
self.customized_save_state_dict_called: bool = False
self.customized_load_state_dict_called: bool = False
def forward(self, x):
x = self.conv(x)
x = self.fc(x)
return x
@torch.jit.export
def _save_to_state_dict(
self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
):
self.customized_save_state_dict_called = True
return {"dummy": torch.ones(1)}
@torch.jit.export
def _load_from_state_dict(
self,
state_dict: Dict[str, torch.Tensor],
prefix: str,
local_metadata: Any,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
self.customized_load_state_dict_called = True
return
class ParentModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.sub = CustomStateDictModule()
def forward(self, x):
return self.sub(x)
m1 = torch.jit.script(ParentModule())
self.assertFalse(m1.sub.customized_save_state_dict_called)
state_dict = m1.state_dict()
self.assertTrue(m1.sub.customized_save_state_dict_called)
m2 = torch.jit.script(ParentModule())
self.assertFalse(m2.sub.customized_load_state_dict_called)
m2.load_state_dict(state_dict)
self.assertTrue(m2.sub.customized_load_state_dict_called)
|