1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
# Owner(s): ["oncall: package/deploy"]
from torch.package._digraph import DiGraph
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
class TestDiGraph(PackageTestCase):
"""Test the DiGraph structure we use to represent dependencies in PackageExporter"""
def test_successors(self):
g = DiGraph()
g.add_edge("foo", "bar")
g.add_edge("foo", "baz")
g.add_node("qux")
self.assertIn("bar", list(g.successors("foo")))
self.assertIn("baz", list(g.successors("foo")))
self.assertEqual(len(list(g.successors("qux"))), 0)
def test_predecessors(self):
g = DiGraph()
g.add_edge("foo", "bar")
g.add_edge("foo", "baz")
g.add_node("qux")
self.assertIn("foo", list(g.predecessors("bar")))
self.assertIn("foo", list(g.predecessors("baz")))
self.assertEqual(len(list(g.predecessors("qux"))), 0)
def test_successor_not_in_graph(self):
g = DiGraph()
with self.assertRaises(ValueError):
g.successors("not in graph")
def test_predecessor_not_in_graph(self):
g = DiGraph()
with self.assertRaises(ValueError):
g.predecessors("not in graph")
def test_node_attrs(self):
g = DiGraph()
g.add_node("foo", my_attr=1, other_attr=2)
self.assertEqual(g.nodes["foo"]["my_attr"], 1)
self.assertEqual(g.nodes["foo"]["other_attr"], 2)
def test_node_attr_update(self):
g = DiGraph()
g.add_node("foo", my_attr=1)
self.assertEqual(g.nodes["foo"]["my_attr"], 1)
g.add_node("foo", my_attr="different")
self.assertEqual(g.nodes["foo"]["my_attr"], "different")
def test_edges(self):
g = DiGraph()
g.add_edge(1, 2)
g.add_edge(2, 3)
g.add_edge(1, 3)
g.add_edge(4, 5)
edge_list = list(g.edges)
self.assertEqual(len(edge_list), 4)
self.assertIn((1, 2), edge_list)
self.assertIn((2, 3), edge_list)
self.assertIn((1, 3), edge_list)
self.assertIn((4, 5), edge_list)
def test_iter(self):
g = DiGraph()
g.add_node(1)
g.add_node(2)
g.add_node(3)
nodes = set()
for n in g:
nodes.add(n)
self.assertEqual(nodes, set([1, 2, 3]))
def test_contains(self):
g = DiGraph()
g.add_node("yup")
self.assertTrue("yup" in g)
self.assertFalse("nup" in g)
def test_contains_non_hashable(self):
g = DiGraph()
self.assertFalse([1, 2, 3] in g)
def test_forward_closure(self):
g = DiGraph()
g.add_edge("1", "2")
g.add_edge("2", "3")
g.add_edge("5", "4")
g.add_edge("4", "3")
self.assertTrue(g.forward_transitive_closure("1") == set(["1", "2", "3"]))
self.assertTrue(g.forward_transitive_closure("4") == set(["4", "3"]))
def test_all_paths(self):
g = DiGraph()
g.add_edge("1", "2")
g.add_edge("1", "7")
g.add_edge("7", "8")
g.add_edge("8", "3")
g.add_edge("2", "3")
g.add_edge("5", "4")
g.add_edge("4", "3")
result = g.all_paths("1", "3")
# to get rid of indeterminism
actual = set([i.strip("\n") for i in result.split(";")[2:-1]])
expected = {
'"2" -> "3"',
'"1" -> "7"',
'"7" -> "8"',
'"1" -> "2"',
'"8" -> "3"',
}
self.assertEqual(actual, expected)
if __name__ == "__main__":
run_tests()
|