1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
|
# Owner(s): ["module: fx"]
import unittest
from typing import Mapping
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport
from torch.testing._internal.common_utils import TestCase
class DummyDevOperatorSupport(OperatorSupport):
def is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
return True
class DummyPartitioner(CapabilityBasedPartitioner):
def __init__(self, graph_module: torch.fx.GraphModule):
super().__init__(
graph_module,
DummyDevOperatorSupport(),
allows_single_node_partition=True,
)
class AddModule(torch.nn.Module):
def forward(self, x):
y = torch.add(x, x)
z = torch.add(y, x)
return z
class TestPartitionerOrder(TestCase):
# partitoner test to check graph node order
def test_partitioner_order(self):
m = AddModule()
traced_m = torch.fx.symbolic_trace(m)
partions = DummyPartitioner(traced_m).propose_partitions()
partion_nodes = [list(partition.nodes) for partition in partions]
node_order = [n.name for n in partion_nodes[0]]
for _ in range(10):
traced_m = torch.fx.symbolic_trace(m)
new_partion = DummyPartitioner(traced_m).propose_partitions()
new_partion_nodes = [list(partition.nodes) for partition in new_partion]
new_node_order = [n.name for n in new_partion_nodes[0]]
self.assertTrue(node_order == new_node_order)
if __name__ == "__main__":
unittest.main()
|