1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
|
# Owner(s): ["oncall: package/deploy"]
from io import BytesIO
import torch
from torch.fx import Graph, GraphModule, symbolic_trace
from torch.package import (
ObjMismatchError,
PackageExporter,
PackageImporter,
sys_importer,
)
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
class TestPackageFX(PackageTestCase):
"""Tests for compatibility with FX."""
def test_package_fx_simple(self):
class SimpleTest(torch.nn.Module):
def forward(self, x):
return torch.relu(x + 3.0)
st = SimpleTest()
traced = symbolic_trace(st)
f = BytesIO()
with PackageExporter(f) as pe:
pe.save_pickle("model", "model.pkl", traced)
f.seek(0)
pi = PackageImporter(f)
loaded_traced = pi.load_pickle("model", "model.pkl")
input = torch.rand(2, 3)
self.assertEqual(loaded_traced(input), traced(input))
def test_package_then_fx(self):
from package_a.test_module import SimpleTest
model = SimpleTest()
f = BytesIO()
with PackageExporter(f) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", model)
f.seek(0)
pi = PackageImporter(f)
loaded = pi.load_pickle("model", "model.pkl")
traced = symbolic_trace(loaded)
input = torch.rand(2, 3)
self.assertEqual(loaded(input), traced(input))
def test_package_fx_package(self):
from package_a.test_module import SimpleTest
model = SimpleTest()
f = BytesIO()
with PackageExporter(f) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", model)
f.seek(0)
pi = PackageImporter(f)
loaded = pi.load_pickle("model", "model.pkl")
traced = symbolic_trace(loaded)
# re-save the package exporter
f2 = BytesIO()
# This should fail, because we are referencing some globals that are
# only in the package.
with self.assertRaises(ObjMismatchError):
with PackageExporter(f2) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", traced)
f2.seek(0)
with PackageExporter(f2, importer=(pi, sys_importer)) as pe:
# Make the package available to the exporter's environment.
pe.intern("**")
pe.save_pickle("model", "model.pkl", traced)
f2.seek(0)
pi2 = PackageImporter(f2)
loaded2 = pi2.load_pickle("model", "model.pkl")
input = torch.rand(2, 3)
self.assertEqual(loaded(input), loaded2(input))
def test_package_fx_with_imports(self):
import package_a.subpackage
# Manually construct a graph that invokes a leaf function
graph = Graph()
a = graph.placeholder("x")
b = graph.placeholder("y")
c = graph.call_function(package_a.subpackage.leaf_function, (a, b))
d = graph.call_function(torch.sin, (c,))
graph.output(d)
gm = GraphModule(torch.nn.Module(), graph)
f = BytesIO()
with PackageExporter(f) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", gm)
f.seek(0)
pi = PackageImporter(f)
loaded_gm = pi.load_pickle("model", "model.pkl")
input_x = torch.rand(2, 3)
input_y = torch.rand(2, 3)
self.assertTrue(
torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y))
)
# Check that the packaged version of the leaf_function dependency is
# not the same as in the outer env.
packaged_dependency = pi.import_module("package_a.subpackage")
self.assertTrue(packaged_dependency is not package_a.subpackage)
def test_package_fx_custom_tracer(self):
from package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer
from package_a.test_module import ModWithTwoSubmodsAndTensor, SimpleTest
class SpecialGraphModule(torch.fx.GraphModule):
def __init__(self, root, graph, info):
super().__init__(root, graph)
self.info = info
sub_module = SimpleTest()
module = ModWithTwoSubmodsAndTensor(
torch.ones(3),
sub_module,
sub_module,
)
tracer = TestAllLeafModulesTracer()
graph = tracer.trace(module)
self.assertEqual(graph._tracer_cls, TestAllLeafModulesTracer)
gm = SpecialGraphModule(module, graph, "secret")
self.assertEqual(gm._tracer_cls, TestAllLeafModulesTracer)
f = BytesIO()
with PackageExporter(f) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", gm)
f.seek(0)
pi = PackageImporter(f)
loaded_gm = pi.load_pickle("model", "model.pkl")
self.assertEqual(
type(loaded_gm).__class__.__name__, SpecialGraphModule.__class__.__name__
)
self.assertEqual(loaded_gm.info, "secret")
input_x = torch.randn(3)
self.assertEqual(loaded_gm(input_x), gm(input_x))
if __name__ == "__main__":
run_tests()
|