# mypy: allow-untyped-defs
import collections
import contextlib
import copy
import functools
import itertools
import logging
import operator
import re
import sys
import traceback
import weakref
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    cast,
    Dict,
    List,
    Optional,
    Set,
    Tuple,
    TYPE_CHECKING,
    Union,
)

import sympy

import torch._guards
import torch._logging
import torch.distributed as dist
import torch.nn
import torch.utils._pytree as pytree
from torch import fx
from torch._dynamo.exc import TensorifyScalarRestartAnalysis
from torch._guards import (
    CompileContext,
    CompileId,
    GlobalContextCheckpointState,
    Source,
    TracingContext,
)
from torch._utils_internal import signpost_event
from torch.fx._lazy_graph_module import _make_graph_module  # type: ignore[attr-defined]
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

from . import config, exc, logging as torchdynamo_logging, variables
from .backends.registry import CompiledFn, CompilerFn
from .bytecode_transformation import (
    create_call_function,
    create_instruction,
    create_load_const,
    Instruction,
    unique_id,
)
from .code_context import code_context
from .codegen import PyCodegen
from .current_scope_id import enter_new_scope
from .exc import (
    BackendCompilerFailed,
    exceptions_allowed_to_be_fallback,
    SkipFrame,
    unimplemented,
    unimplemented_with_warning,
)
from .graph_deduplication import apply_graph_deduplication
from .graph_region_tracker import GraphRegionTracker
from .guards import GuardBuilder, install_guard
from .mutation_guard import is_dynamic_nn_module
from .side_effects import AttributeMutationExisting, SideEffects
from .source import (
    AttrSource,
    BackwardStateSource,
    ConstantSource,
    GetItemSource,
    GlobalStateSource,
    is_constant_source,
    is_from_local_source,
    LocalSource,
    ParamBufferSource,
    ShapeEnvSource,
    SyntheticLocalSource,
    TensorProperty,
    TensorPropertySource,
)
from .utils import (
    _extract_tensor_dict,
    checkpoint_params,
    CleanupHook,
    clone_inputs,
    count_calls,
    counters,
    dynamo_timed,
    get_instruction_source_311,
    get_locals_to_steal,
    get_static_address_type,
    graph_break_reasons,
    increment_op_count,
    lazy_format_graph_code,
    LazyString,
    nn_module_proxy,
    same,
    set_example_value,
)
from .variables.base import VariableTracker
from .variables.builder import (
    BackwardStateGraphArg,
    GraphArg,
    TrackedFake,
    wrap_fx_proxy,
)
from .variables.lists import BaseListVariable
from .variables.misc import CellVariable, NullVariable
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
    NumpyNdarrayVariable,
    SymNodeVariable,
    TensorVariable,
    UnspecializedPythonVariable,
)
from .variables.torch_function import TensorWithTFOverrideVariable


if TYPE_CHECKING:
    from torch._dynamo.symbolic_convert import InstructionTranslatorBase


log = logging.getLogger(__name__)
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")


@dataclass(frozen=True)
class VariableTrackerCacheKey:
    vt_id: int
    # Two different source can point to the same object. However, Dynamo handles
    # globals and local source differently when it comes to guards and possibly
    # some other parts as well. So, cache also relies on the source.
    source: Source


class VariableTrackerCache:
    def __init__(self):
        self.cache = {}

    def lookup(self, value, source):
        key = VariableTrackerCacheKey(id(value), source)
        if key not in self.cache:
            return None
        return self.cache[key]

    def add(self, value, source, vt):
        key = VariableTrackerCacheKey(id(value), source)
        self.cache[key] = vt

    def clone(self):
        # Needed for copy and restore graph state
        new_cache = VariableTrackerCache()
        new_cache.cache.update(self.cache)
        return new_cache

    def clear(self):
        self.cache.clear()


@functools.lru_cache(None)
def _step_logger():
    return torchdynamo_logging.get_step_logger(log)


@dataclass
class GraphCompileReason:
    """Stores why a given output graph was compiled; i.e. what caused the graph break."""

    reason: str
    user_stack: List[traceback.FrameSummary]

    # Indicates if this was a graph compile reason due to graph break.
    graph_break: bool = True

    def __post_init__(self):
        if self.graph_break:
            graph_break_reasons.append(self)


def _get_gen_rand_values_fn(random_calls):
    def _gen_rand_values():
        return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]

    return _gen_rand_values


class FakeRootModule(torch.nn.Module):
    """Trick the constructor of fx.GraphModule"""

    def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
        super().__init__()
        for k, v in nn_modules.items():
            setattr(self, k, v)

    def __repr__(self) -> str:
        return "FakeRootModule(...)"


class WrapperBackend:
    def __init__(self, backend: CompilerFn):
        self.backend: CompilerFn = backend

    def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
        self.restore = checkpoint_params(gm)
        self.gm = gm
        copy_gm = copy.deepcopy(self.gm)
        self.candidate = self.backend(copy_gm, example_inputs)

        if self.candidate is None or self.candidate is self.gm.forward:
            return self.gm.forward

        if not config.verify_correctness:
            return self.candidate

        # if verify_correctness=True
        try:
            correct = self.gm.forward(*clone_inputs(example_inputs))
            result = self.candidate(*clone_inputs(example_inputs))

            # TODO: replace `same` function with the one in testing
            if same(correct, result):
                return self.candidate

            raise RuntimeError(f"incorrect results of backend {self}")
            return self.gm.forward

        except Exception:
            log.exception("error in verify_correctness")
            raise
        finally:
            self.restore()


Scope = Dict[str, object]


class OutputGraph:
    """
    Wrapper class to hold outputs of InstructionTranslator.  Mainly the
    generated fx.Graph.

    OutputGraph is 1:1 with a frame being processed. Each frame is associated
    with some root InstructionTranslator. When user code calls a function,
    we construct a InliningInstructionTranslator that continues to write into
    the root InstructionTranslator's OutputGraph.
    """

    def __init__(
        self,
        code_options: Dict[str, Any],
        compiler_fn: Optional[CompilerFn],
        root_tx,
        export: bool,
        export_constraints,
        frame_state,
        local_scope: Scope,
        global_scope: Scope,
        f_code,
        torch_function_mode_stack,
    ):
        super().__init__()
        self.tracers = [SubgraphTracer(self, is_export=export)]
        # Map from graph input's `Source` to its `VariableTracker` to
        # de-duplicate graph inputs by source and reuse the tracker
        self.input_source_to_var: Dict[Source, VariableTracker] = {}
        self.export = export
        self.export_constraints = export_constraints
        self.frame_state = frame_state
        # Map from graph input's `Source` to sizes / strides metadata
        self.input_source_to_sizes_strides: Dict[Source, Dict[str, Any]] = {}
        self.cleanup_hooks: List[Callable[[], Any]] = []
        # compile_id is an id number for the current torch.compile
        self.compile_id: int = next(_compile_id_counter)
        # Set of globals installed via install_global* APIs
        self.installed_globals: Set[str] = set()

        # TODO: maybe should just pass the entire f_code in here?  Not
        # sure...
        self.co_fields = {
            "co_name": f_code.co_name,
            "co_filename": f_code.co_filename,
            "co_firstlineno": f_code.co_firstlineno,
        }

        self.region_tracker = GraphRegionTracker()

        # tracked_fakes says where any tensor that was wrapped to fake came
        # from.  It is similar to GraphArg, in that all GraphArgs will get
        # will get added to TrackedFakes, but TrackedFakes also contains
        # GraphArgs that got pruned, and things like Tensor attributes which
        # aren't explicit graph inputs.  Used by shape guard
        self.tracked_fakes: List[TrackedFake] = []

        shape_env = ShapeEnv(
            # Reference Cycle!
            # Share a reference to the list of TrackedFake.
            #
            # ShapeEnv needs this in order to be able to reproduce the call
            # to produce_guards at an arbitrary time point. That is because
            # TrackedFake instances may have its metadata changed throughout
            # the program execution.
            tracked_fakes=self.tracked_fakes,
            allow_scalar_outputs=config.capture_scalar_outputs,
            allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
            prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
            allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts,
            co_fields=self.co_fields,
        )

        # In export mode, we force the shape_env to strictly disallow any constraining
        # of the user marked dynamic dims
        import torch._functorch.config as _config

        with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
            fake_mode = torch._subclasses.FakeTensorMode(
                shape_env=shape_env,
                # TODO (tmanlaibaatar) Remove this once we always lift params and buffers
                allow_non_fake_inputs=True if self.export else False,
                export=self.export,
            )
        self.tracing_context: TracingContext = TracingContext(fake_mode)
        self.dynamo_compile_id: Optional[
            CompileId
        ] = CompileContext.current_compile_id()
        self.init_ambient_guards()

        # Map each tensor id to a list of sources. This is necessary because
        # tensor ids cannot be recovered from tracked fakes (in general).
        # We use this map to interpret (i.e., check for violations of) constraints,
        # specifically equality constraints, which have shared tensor ids in them.
        # This map should also be generally useful, e.g., for (de)serialization.
        self.tracked_fakes_id_to_source: Dict[
            int, List[Source]
        ] = collections.defaultdict(list)
        # Stores the full fqn of a param or buffer to the relevant source.
        self.param_name_to_source: Optional[Dict[str, Source]] = {}
        self.side_effects = SideEffects(self)
        # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL
        # and LOAD_ATTR for same python objects free.
        self.variable_tracker_cache = VariableTrackerCache()
        self.unique_var_id = itertools.count()
        self.code_options = dict(code_options)
        self.output_instructions: List[Instruction] = []
        # used to track nodes that are added between calls of copy_graphstate
        # and restore_graphstate
        self.timestamp = 0

        # A list of register_finalizer_fns to apply to the output graph module
        self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = []

        # Not checkpointed
        self.compiler_fn: Optional[CompilerFn] = compiler_fn
        self.global_scope = global_scope
        self.local_scope = local_scope
        self.root_tx = root_tx

        # Given a source, what are the user stacks of all locations that
        # accessed it?
        #
        # For efficiency, we only populate this:
        #   - During export, and
        #   - If the source could potentially lead to a spurious export input
        #
        # Feel free to populate this more frequently if other use-cases arise,
        # but be aware that we have to generate full stacks for each
        # recording!
        self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {}

        self._current_tx: List[InstructionTranslatorBase] = []
        self.cleanups: List[CleanupHook] = []
        self.should_exit = False
        self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}

        # Note this returns true iff TF Mode and TF Subclasses are enabled
        self.torch_function_enabled = torch._C._is_torch_function_enabled()
        # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
        self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
        # This records the initial torch function mode stack for guarding
        self.torch_function_mode_stack = torch_function_mode_stack

        # Tracks if the output graph has a user defined allowed function in the
        # graph. This is used later to determine if we should fallback to eager
        # for certain exceptions. THe idea is that if the user has applied
        # allow_in_graph, they would like to see the error instead of falling
        # back for backend errors.
        self.has_user_defined_allowed_in_graph = False

        # Tracks a list of called ops that were not tagged with "pt2_compliant_tag".
        # This information is useful for logging.
        self.non_compliant_ops: Set[torch._ops.OpOverload] = set({})

        # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag".
        # This information is useful for logging.
        self.compliant_custom_ops: Set[torch._ops.OpOverload] = set({})

        # We save the global torch state here to be restored in case of graph
        # breaks. The relevant issue is seen here
        # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086
        # where inlining of a function changes the global state (because of the
        # presence of torch.no_grad) and there is a graph break.
        self.save_global_state()

        # Tracks the original FQNs of the constant tensors from the original graph,
        # i.e. buffers and parameters.
        self.dynamo_flat_name_to_original_fqn: Dict[str, str] = {}

        # All calls to random() are replaced with a single call to __gen_rand_values
        # functions that returns a tuple of random values for each original call.
        # random_calls tracks calls to random() and random_values_var stores the name of
        # the variable that stores __gen_rand_values results.
        self.random_calls: List[
            Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
        ] = []
        self.random_values_var = None

        # Bytecode to insert right before we call the graph
        self.pregraph_bytecode: List[Instruction] = []

        # Use to pass values to backward hooks when using compiled autograd
        self.backward_state: Dict[str, VariableTracker] = {}
        self.backward_state_proxy: Optional[torch.fx.Proxy] = None
        self.backward_state_var: Optional[str] = None

        self.name_of_builtins_dict_key_in_fglobals: str = (
            self.install_builtins_dict_in_fglobals()
        )

        self.guard_on_key_order: Set[str] = set()

    def install_builtins_dict_in_fglobals(self):
        # f_globals["__builtins__"] can be a dict or a module. This is an
        # implemenation detail -
        # https://docs.python.org/3/library/builtins.html.

        # This makes guarding on any builtin messy because the guard check_fn
        # has to check if the __builtins__ is a module or dict, and then access
        # by either using getattr or getitem respectively.

        # To solve this problem, we insert a new entry in f_globals which points
        # to the builtins __dict__ and then we guard any builtin on this dict.
        # To avoid any collision with the pre-existing keys, we use the
        # install_global to give us a unique dict key.

        f_builtins = self.global_scope["__builtins__"]
        if not isinstance(f_builtins, dict):
            f_builtins = f_builtins.__dict__
        return self.install_global("__builtins_dict__", f_builtins)

    def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"):
        name = f"{prefix}{len(self.backward_state)}"
        assert name not in self.backward_state
        self.backward_state[name] = hook
        return name, self.get_backward_state_proxy()

    def get_backward_state_proxy(self):
        if self.backward_state_proxy is None:
            if self.export:
                unimplemented("backward_state does not support export")
            example_value = BackwardState()
            self.backward_state_proxy = self.root_tracer.create_graph_input(
                "dynamo_backward_state",
                type(example_value),
                example_value,
                source=BackwardStateSource(),
            )
            self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
            self.backward_state_var = self.new_var()
        return self.backward_state_proxy

    # This gets its own helper function so guards DEBUG logs are more informative
    def init_ambient_guards(self):
        # Register a SHAPE_ENV guard to make sure we setup shape guards
        # that show up in ShapeEnv
        self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))

        self.guards.add(
            GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
        )

        self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))

        self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))

        self.guards.add(
            GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
        )

        ci = torch._C._functorch.peek_interpreter_stack()
        if ci is not None:
            self.guards.add(
                GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
            )

    def synthetic_graph_input(self, fn, args):
        """
        call fn(*args) before the graph runs and turn the result into a fake input.
        """
        example_value = fn(*args)
        varname = self.new_var()
        cg = PyCodegen(self.root_tx)
        cg.add_push_null(
            lambda: cg.load_import_from(
                fn.__module__,
                fn.__name__,
            )
        )
        cg.foreach(map(variables.ConstantVariable.create, args))
        cg.call_function(len(args), False)
        cg.store(varname)
        self.pregraph_bytecode.extend(cg.get_instructions())
        source = SyntheticLocalSource(varname)
        result = VariableTracker.build(self.root_tx, example_value, source)
        TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
            source
        )
        return result

    def add_cleanup_hook(self, fn: Callable[[], Any]):
        self.cleanup_hooks.append(fn)

    def call_cleanup_hooks(self):
        for hook in reversed(self.cleanup_hooks):
            hook()
        self.cleanup_hooks.clear()

    @property
    def root_tracer(self):
        return self.tracers[0]

    @property
    def current_tracer(self):
        return self.tracers[-1]

    def is_root_tracer(self):
        # Helper to tell if we are inside the higher order operator tracing.
        return len(self.tracers) == 1

    @property
    def graph(self):
        return self.current_tracer.graph

    # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer.
    @graph.setter
    def graph(self, value):
        self.current_tracer.graph = value

    @property
    def input_name_to_proxy(self):
        return self.current_tracer.input_name_to_proxy

    @property
    def real_value_cache(self):
        return self.current_tracer.real_value_cache

    @property
    def bound_symbols(self):
        return self.current_tracer.bound_symbols

    # If you are here, and you're looking for create_graph_input,
    # to avoid ambiguity, please call one of the following:
    # - self.current_tracer.create_graph_input
    # - self.root_tracer.create_graph_input
    # See NOTE [HigherOrderOperator tracing design] for more context.

    def create_proxy(self, *args, **kwargs):
        return self.current_tracer.create_proxy(*args, **kwargs)

    def create_node(self, *args, **kwargs):
        return self.current_tracer.create_node(*args, **kwargs)

    def remove_node(self, *args, **kwargs):
        return self.current_tracer.remove_node(*args, **kwargs)

    @contextlib.contextmanager
    def subtracer(self, source_target, prior_tracer):
        new_scope_ctx = enter_new_scope()
        try:
            if prior_tracer:
                # Lineage MUST stay preserved
                assert prior_tracer.parent is self.current_tracer
            new_scope_ctx.__enter__()
            tracer = (
                prior_tracer
                if prior_tracer
                else SubgraphTracer(
                    self,
                    parent=self.current_tracer,
                    source_target=source_target,
                    is_export=self.current_tracer.is_export,
                )
            )
            self.tracers.append(tracer)
            yield tracer
        finally:
            new_scope_ctx.__exit__(None, None, None)
            self.tracers.pop()

    @property
    def output(self):
        return self

    @property
    def fake_mode(self):
        return self.tracing_context.fake_mode

    @property
    def shape_env(self):
        return self.tracing_context.fake_mode.shape_env

    @property
    def guards(self) -> torch._guards.GuardsSet:
        return self.tracing_context.guards_context.dynamo_guards

    @property
    def nn_modules(self) -> Dict[str, Any]:
        return self.tracing_context.module_context.nn_modules

    def save_global_state(self, out=None):
        """
        Saves to out if it is provided. Else saves to the tracing context's global_state.
        """
        global_state = cast(
            Dict[str, Tuple[Callable[..., Any], bool]],
            out
            if out is not None
            else self.tracing_context.global_context.global_state,
        )

        # TODO - Consider having a torch level API for torch_function_state. As
        # of now, we create a ref cycle by passing the
        # output.set_torch_function_state to
        # output.tracing_context.global_context.global_state. In the interim,
        # the problem can be solved by manually set
        # output.tracing_context.global_context.global_state to None at cleanup.
        global_state["torch_function_enabled"] = (
            self.set_torch_function_state,
            self.torch_function_enabled,
        )
        global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())

        global_state["autocast_enabled"] = (
            functools.partial(torch.set_autocast_enabled, "cuda"),
            torch.is_autocast_enabled("cuda"),
        )
        global_state["autocast_cpu_enabled"] = (
            functools.partial(torch.set_autocast_enabled, "cpu"),
            torch.is_autocast_enabled("cpu"),
        )
        global_state["autocast_gpu_dtype"] = (  # type:ignore[assignment]
            functools.partial(torch.set_autocast_dtype, "cuda"),
            torch.get_autocast_dtype("cuda"),
        )
        global_state["autocast_cpu_dtype"] = (  # type:ignore[assignment]
            functools.partial(torch.set_autocast_dtype, "cpu"),
            torch.get_autocast_dtype("cpu"),
        )
        global_state["autocast_cache_enabled"] = (
            torch.set_autocast_cache_enabled,
            torch.is_autocast_cache_enabled(),
        )

    def push_tx(self, tx):
        self._current_tx.append(tx)

    def pop_tx(self):
        return self._current_tx.pop()

    @property
    def current_tx(self):
        return self.root_tx if not self._current_tx else self._current_tx[-1]

    def count_calls(self):
        return count_calls(self.graph)

    def is_empty_graph(self):
        return len(list(self.graph.nodes)) == 0

    def get_submodule(self, keys):
        assert keys
        obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules
        for k in keys.split("."):
            if isinstance(obj, dict):
                obj = obj[k]
            else:
                obj = getattr(obj, k)
        return obj

    def new_var(self, name="tmp"):
        existing = set(self.code_options["co_varnames"])
        # In common case, this will be O(1)
        while True:
            var = f"{name}_{next(self.unique_var_id)}"
            if var not in existing:
                self.code_options["co_varnames"] += (var,)
                return var

    def update_co_names(self, name):
        """Ensure self.code_options.co_names contains name"""
        if name not in self.code_options["co_names"]:
            self.code_options["co_names"] += (name,)

    @staticmethod
    def module_key_name(*names):
        # create a new unique name
        name = "_".join(map(str, names))
        # Strip the guard lookup L/G access
        name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name)
        # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv
        name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
        # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
        name = re.sub(r"[^a-zA-Z0-9]", "_", name)

        if not name or not name[0].isalpha():
            name = "sub" + name

        return name

    def register_attr_or_module(
        self,
        target: Union[torch.nn.Module, torch.Tensor, Any],
        *names,
        **options,
    ):
        if is_dynamic_nn_module(target, self.root_tx.export):
            # Instead of returning UnspecializedNNModuleVariable, call
            # VariableTracker.build so that it is tracked for mutation.
            return VariableTracker.build(self.current_tx, target, **options)

        options = dict(options)
        assert "source" in options
        source = options["source"]
        assert not isinstance(source, ParamBufferSource)

        if isinstance(target, torch.Tensor):
            tracer = self.current_tracer
            if not self.is_root_tracer():
                # For higher order ops, we don't want to insert the get_attr in
                # innermost graph. Instead, we want to raise the params/buffers
                # as inputs to the higher-order graph, and register them as
                # get_attrs in the root tracer.

                # Note that Dynamo will still call lift_tracked_freevar_to_input
                # when these inputs are encountered for the inner graph. The
                # only difference is what happens at the root tracer for
                # nn.Parameters vs free inputs. The free inputs are registered
                # as placeholders in the root graph, whereas the nn.Parameters
                # are registered as get_attr nodes in the root graph.
                tracer = self.root_tracer

            def wrap_name(module_key):
                assert self.param_name_to_source is not None
                self.param_name_to_source[module_key] = source

                # Check if the attr has already been registered. This can happen
                # when two different sources point to the same tensor.
                if target in self.root_tx.output.side_effects:
                    return self.root_tx.output.side_effects[target]

                if get_static_address_type(target) == "guarded":
                    install_guard(source.make_guard(GuardBuilder.ID_MATCH))
                elif not is_constant_source(source):
                    install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))

                vt = wrap_fx_proxy(
                    self.root_tx,
                    tracer.create_proxy("get_attr", module_key, (), {}),
                    example_value=target,
                    **options,
                )

                # Track the object so to avoid duplicate registration in case of
                # different sources pointing to the same tensor object.
                vt = self.root_tx.output.side_effects.track_object_existing(target, vt)

                assert "tensor_dict" not in vt.proxy.node.meta
                vt.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(target)

                return vt

        elif isinstance(target, torch.nn.Module):
            assert isinstance(target, torch.nn.Module)

            if source:
                install_guard(source.make_guard(GuardBuilder.NN_MODULE))

                def wrap_name(module_key):
                    return NNModuleVariable(type(target), module_key, target, **options)

            else:
                # This is Dynamo created graph module, e.g., graph module coming
                # from higher order ops. NNModuleVariable tracker can't be
                # sourceless, so let's return a unspecializedNNModule variable
                # tracker.
                def wrap_name(module_key):
                    return variables.UnspecializedNNModuleVariable(target, **options)

        elif isinstance(target, (torch.SymInt, torch.SymFloat)):
            # HACKY CODE REGION BEGIN
            # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
            # This ultimately gets written to self.nn_modules, which is unfortunate
            # Attrs that are tenors and symints and such need to be migrated to have their
            # own storage
            # alas, this is like this for now

            def wrap_name(module_key):
                return SymNodeVariable.create(
                    self,
                    self.create_proxy("get_attr", module_key, (), {}),
                    sym_num=target,
                    **options,
                )

            # HACKY CODE REGION END
        else:

            def wrap_name(module_key):
                self.output.update_co_names(module_key)
                self.global_scope[module_key] = target
                return VariableTracker.build(
                    self, target, ConstantSource(source_name=module_key)
                )

        for k, v in self.nn_modules.items():
            if v is target:
                # it already exists
                return wrap_name(k)

        name = OutputGraph.module_key_name(*names)

        base = name
        for i in itertools.count():
            if name not in self.nn_modules and name not in self.global_scope:
                self.nn_modules[name] = target
                if isinstance(target, torch.nn.Module):

                    def register_leaf_name(leaf_name):
                        assert self.param_name_to_source is not None
                        new_source = ParamBufferSource(source, leaf_name)
                        new_name = f"{name}.{leaf_name}"
                        self.param_name_to_source[new_name] = new_source
                        if isinstance(source, LocalSource):
                            self.dynamo_flat_name_to_original_fqn[
                                OutputGraph.module_key_name(new_source.name())
                            ] = leaf_name

                    # annoying, but there are cases when we do not have parameters
                    # see test_nn_moduledict_contains
                    if hasattr(target, "_parameters"):
                        for leaf_name, _ in target.named_parameters():
                            register_leaf_name(leaf_name)
                    if hasattr(target, "_buffers"):
                        for leaf_name, _ in target.named_buffers():
                            register_leaf_name(leaf_name)

                return wrap_name(name)
            name = f"{base}_{i}"

        raise AssertionError("unreachable")

    def handle_aliases_for_stolen_lists(self, tx):
        # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive
        maybe_gm = self.local_scope.get("self")
        stolen_list_names = get_locals_to_steal(maybe_gm)
        if not stolen_list_names:
            return [], {}

        alias_insts = []
        needs_alias: Dict[str, List[VariableTracker]] = {}

        queue = [
            *tx.stack,
            *tx.symbolic_locals.values(),
            *self.side_effects.store_attr_mutations.keys(),
        ]

        while queue:
            x = queue.pop()
            if isinstance(x, BaseListVariable):
                assert isinstance(x.items, List)
                queue += x.items
                continue

            if not (
                (
                    x not in self.side_effects.store_attr_mutations
                    or isinstance(x.mutation_type, AttributeMutationExisting)
                )
                and isinstance(x.source, GetItemSource)
                and isinstance(x.source.base, LocalSource)
                and x.source.base.local_name in stolen_list_names
            ):
                continue

            stolen_name = x.source.base.local_name
            if stolen_name not in needs_alias:
                needs_alias[stolen_name] = []
            needs_alias[stolen_name].append(x)

        visited = {}
        overridden_sources: Dict[Source, Source] = {}
        for arg in self.graphargs:
            if not (
                isinstance(arg._example, list)
                and isinstance(arg.source, LocalSource)
                and arg.source.local_name in needs_alias
            ):
                continue

            # arg is a list that will be cleared by the compiled function
            list_name = arg.source.local_name
            assert list_name in self.code_options["co_varnames"]
            for x in needs_alias[list_name]:
                # Skip if already handled.
                if x.source in overridden_sources:
                    continue

                # A small codegen optimization because we might have different
                # VariableTrackers that share the same source.
                list_idx = x.source.index
                if list_idx not in visited:
                    alias_name = self.new_var(
                        f"{list_name}_ref"
                    )  # self.new_var already adds unique id suffix

                    visited[list_idx] = alias_name
                    # bytecode of `alias_name = list_name[list_idx]`
                    alias_insts.extend(
                        [
                            create_instruction("LOAD_FAST", argval=list_name),
                            create_load_const(list_idx),
                            create_instruction("BINARY_SUBSCR"),
                            create_instruction("STORE_FAST", argval=alias_name),
                        ]
                    )

                # operate on alias, handled by suffix codegen
                old_source = x.source
                overridden_sources[old_source] = LocalSource(visited[list_idx])

        # NOTE: we need `overridden_sources` because (1) we want to codegen for
        # these list items to use the new local source, but (2) we want to avoid
        # updating `source` in place because that might break invariants in
        # other parts of Dynamo like guards.
        return alias_insts, overridden_sources

    def compile_subgraph(
        self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
    ):
        """
        Generate a subgraph to continue execution on user code.
        Automatically restore live variables.
        """
        assert reason is not None

        from .decorators import disable

        self.partial_convert = partial_convert
        self.compile_subgraph_reason = reason
        self.should_exit = True

        log.debug("COMPILING GRAPH due to %s", reason)

        if not all(block.can_restore() for block in tx.block_stack):
            unimplemented("compile_subgraph with block_depth != 0")

        prefix_insts: List[Instruction] = []
        if sys.version_info >= (3, 11):
            # prefix instructions (Python 3.11+)
            for inst in tx.prefix_insts:
                if inst.opname == "MAKE_CELL":
                    prefix_insts.append(
                        create_instruction("MAKE_CELL", argval=inst.argval)
                    )
                elif inst.opname == "COPY_FREE_VARS":
                    prefix_insts.append(
                        create_instruction(
                            "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
                        )
                    )
                else:
                    prefix_insts.append(copy.copy(inst))
        assert not (
            self.pregraph_bytecode and self.export
        ), "export does not support pregraph_bytecode"
        prefix_insts.extend(self.pregraph_bytecode)
        alias_insts, overridden_sources = self.handle_aliases_for_stolen_lists(tx)
        prefix_insts.extend(alias_insts)

        def append_prefix_insts():
            self.add_output_instructions(prefix_insts)
            prefix_insts.clear()

        for block in reversed(tx.block_stack):
            block.exit(tx, is_graph_break=reason.graph_break)

        self.cleanup_graph()
        tx.prune_dead_locals()
        stack_values = list(tx.stack)

        # realize any unrealized tensor VTs in case they
        # need to be added to self.nn_modules as attributes
        for value in stack_values:
            value.realize()

        output_replacements = self.dedup_pass()

        # Use nn.Module "proxies" in the constructed GraphModule so that
        # the resulting GM does not hold additional strong references to the original modules.
        # This prevents a strong ref cycle where Dynamo created code holds on to references
        # to modules that also have Dynamo code cache invalidation checks.
        # When cache invalidation runs, the generated GM will be invalidated, which also deletes
        # the proxies.
        nn_modules_proxies = {
            name: nn_module_proxy(mod) for name, mod in self.nn_modules.items()
        }
        root = FakeRootModule(nn_modules_proxies)
        # Add all the local vars to the "stack" so restore at the end
        restore_vars: List[str] = []
        val_to_names: Dict[VariableTracker, List[str]] = {}
        # NB: Typically (i.e., for graph compile from RETURN_VALUE),
        # symbolic_locals will be empty at this point, as prune_dead_locals
        # will clear out all of symbolic_locals because RETURN_VALUE is the
        # last instruction and no more locals are used.  The fanciness here
        # is only needed for partial graphs.
        for k, v in tx.symbolic_locals.items():
            # Note! this explicitly uses .local_name for matching
            # Failure to do so will cause spurious registrations in val_to_names.
            # This will in turn result in spurious variables showing up in the graph.
            # This was very tricky to debug. For an example, dump the graph at call_user_compiler
            # while running test_subgraphs.py
            if isinstance(v.source, LocalSource) and v.source.local_name == k:
                continue  # no need to restore initial state
            if isinstance(v, CellVariable) and v.local_name == k:
                continue  # no need to restore initial state
            # Do not load variable if it is NULL.
            if sys.version_info >= (3, 12):
                # Continuation function will load the NULL for v.
                if type.__instancecheck__(NullVariable, v):
                    continue
            else:
                # A variable should never be NULL in < 3.12
                assert not type.__instancecheck__(NullVariable, v)
            if v not in val_to_names:
                val_to_names[v] = []
            val_to_names[v].append(k)
        for v in val_to_names.keys():
            restore_vars.extend(val_to_names[v])
            stack_values.extend([v] * len(val_to_names[v]))

        # to handle random calls
        if len(self.random_calls) > 0:
            append_prefix_insts()
            random_calls_instructions = []
            self.random_values_var = self.new_var("random_values")
            rand_fn = disable(_get_gen_rand_values_fn(self.random_calls))
            rand_fn_name = self.install_global("__gen_rand_values", rand_fn)
            codegen = PyCodegen(tx, root, overridden_sources=overridden_sources)
            random_calls_instructions.extend(
                codegen.load_function_name(rand_fn_name, True)
            )
            random_calls_instructions.extend(create_call_function(0, False))
            random_calls_instructions.append(
                codegen.create_store(tx.output.random_values_var),
            )
            self.add_output_instructions(random_calls_instructions)

        if (
            stack_values
            and all(
                not isinstance(
                    v,
                    (
                        UnspecializedPythonVariable,
                        NumpyNdarrayVariable,
                        TensorWithTFOverrideVariable,
                    ),
                )
                and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
                for v in stack_values
            )
            and all(isinstance(x, TensorVariable) for x in stack_values)
            and len(set(stack_values)) == len(stack_values)
            and self.side_effects.is_empty()
            and not len(tx.debug_locals) != 0
            and not self.backward_state
        ):
            append_prefix_insts()
            # optimization to generate better code in a common case
            self.add_output_instructions(
                self.compile_and_call_fx_graph(
                    tx, list(reversed(stack_values)), root, output_replacements
                )
                + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
            )
            # restore all the live local vars
            self.add_output_instructions(
                [
                    PyCodegen(tx, overridden_sources=overridden_sources).create_store(
                        var
                    )
                    for var in reversed(restore_vars)
                ]
            )
        else:
            graph_output_var = self.new_var("graph_out")
            pass1 = PyCodegen(
                tx, root, graph_output_var, overridden_sources=overridden_sources
            )
            self.codegen_suffix(tx, stack_values, pass1)

            # one more time now that we have established tempvars
            pass2 = PyCodegen(
                tx,
                root,
                graph_output_var,
                tempvars={val: None for val, count in pass1.uses.items() if count > 1},
                overridden_sources=overridden_sources,
            )
            self.codegen_suffix(tx, stack_values, pass2)

            stored_graph_output_var = False
            output = []
            if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
                output.extend(
                    self.compile_and_call_fx_graph(
                        tx, pass2.graph_output_vars(), root, output_replacements
                    )
                )

                if len(pass2.graph_outputs) != 0:
                    output.append(pass2.create_store(graph_output_var))
                    stored_graph_output_var = True
                else:
                    output.append(create_instruction("POP_TOP"))
            else:
                # NB: Important to run compiler collective even when there is
                # a graph break
                self.run_compiler_collective(tx)
            append_prefix_insts()
            self.add_output_instructions(output + pass2.get_instructions())

            # restore all the live local vars
            self.add_output_instructions(
                [
                    PyCodegen(tx, overridden_sources=overridden_sources).create_store(
                        var
                    )
                    for var in reversed(restore_vars)
                ]
            )

            if stored_graph_output_var:
                self.add_output_instructions(
                    [
                        PyCodegen(
                            tx, overridden_sources=overridden_sources
                        ).create_delete(graph_output_var)
                    ]
                )

    def codegen_suffix(self, tx, stack_values, cg):
        # NOTE: `codegen_save_tempvars` must run first to update `source` fields
        # for variables with `AttributeMutationNew`, as they don't implement
        # `reconstruct` themselves.
        self.side_effects.codegen_save_tempvars(cg)
        if self.backward_state:
            assert not self.export
            for name, val in self.backward_state.items():
                cg(val)
                cg.append_output(cg.create_load(self.backward_state_var))
                cg.store_attr(name)
        self.side_effects.codegen_hooks(cg)

        # Return variables used for logging at the end
        for debug_var, args in tx.debug_locals:
            cg.add_push_null(lambda: cg(debug_var))
            for arg in args:
                cg(arg)
            cg.extend_output(create_call_function(len(args), False))
            cg.extend_output([create_instruction("POP_TOP")])

        cg.restore_stack(stack_values, value_from_source=not tx.export)
        self.side_effects.codegen_update_mutated(cg)

    def cleanup_graph(self):
        """
        Remove "creation_timestamp" from node meta

        Remove this pattern from the graph:
            torch._C._set_grad_enabled(False)
            torch._C._set_grad_enabled(True)
        """
        assert self.should_exit
        nodes = list(self.graph.nodes)
        for node in nodes:
            node.meta.pop("creation_timestamp", None)

        grad_enabled = torch.is_grad_enabled()
        for node1, node2 in zip(nodes, nodes[1:]):
            if (
                node1.target is torch._C._set_grad_enabled
                and tuple(node1.args) == (not grad_enabled,)
                and not node1._erased
            ):
                grad_enabled = node1.args[0]
                if (
                    node2.target is torch._C._set_grad_enabled
                    and tuple(node2.args) == (not grad_enabled,)
                    and not node2._erased
                ):
                    grad_enabled = node2.args[0]
                    self.graph.erase_node(node1)
                    self.graph.erase_node(node2)

    def get_graph_sizes_structured(self):
        ret = {}
        for node in self.graph.nodes:
            example_value = node.meta.get("example_value", None)
            if isinstance(example_value, torch._subclasses.FakeTensor):
                size = example_value.size()
                ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size]
        return ret

    def get_graph_sizes(self, name: str):
        graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
        graph_sizes_str += f"===== {name} =====\n"
        for node in self.graph.nodes:
            example_value = node.meta.get("example_value", None)
            if isinstance(example_value, torch._subclasses.FakeTensor):
                size = example_value.size()
                graph_sizes_str += f"{node.name}: {tuple(size)}\n"
                concrete_size = []
                has_symint = False
                for sz in size:
                    if isinstance(sz, int):
                        concrete_size.append(sz)
                    elif isinstance(sz, torch.SymInt):
                        has_symint = True
                        concrete_size.append(sz.node.hint)
                    else:
                        break
                else:
                    if has_symint:
                        graph_sizes_str += (
                            f"{node.name} (concrete): {tuple(concrete_size)}\n"
                        )
        return graph_sizes_str

    @contextlib.contextmanager
    def restore_global_state(self):
        """
        Momentarily restores the global state to what it was prior to tracing the current output
        """
        prior_global_state = self.tracing_context.global_context.copy_graphstate()
        current_global_state: Dict[str, Tuple[Any, bool]] = {}
        self.save_global_state(out=current_global_state)
        try:
            # Set to state prior to tracing the graph
            self.tracing_context.global_context.restore_graphstate(prior_global_state)
            yield
        finally:
            # Reset to state at the current time (e.g. before calling the user compiler)
            self.tracing_context.global_context.restore_graphstate(
                GlobalContextCheckpointState(current_global_state)
            )

    def run_compiler_collective(self, tx):
        if (ds := tx.distributed_state) is not None and ds.all_states is None:
            compile_pg = ds.compile_pg
            log.info("compiler_collective %s", ds.local_state)
            torch._logging.trace_structured(
                "artifact",
                metadata_fn=lambda: {
                    "name": "compiler_collective",
                    "encoding": "string",
                },
                payload_fn=lambda: ds.local_state.render(),
            )
            with torch.cuda.device(compile_pg.rank() % torch.cuda.device_count()):
                all_states = [None] * compile_pg.size()
                dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
                ds.all_states = all_states
            # Clear speculation log, because are tracing may diverge due to
            # this information from the compiler collective
            tx.speculation_log.clear()
            raise exc.CompileCollectiveRestartAnalysis

    def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs):
        """
        Generate code from self.graph and return the Instruction()s to
        call that generated code.
        """
        with torch._guards.TracingContext.clear_frame():
            from .decorators import disable

            assert self.should_exit

            self.run_compiler_collective(tx)

            name = unique_id("__compiled_fn")

            assert isinstance(rv, list)
            assert isinstance(root, FakeRootModule)

            output_node = self.create_node(
                "output",
                "output",
                (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
                {},
            )

            for old_node, new_node in replaced_outputs.items():
                old_node.replace_all_uses_with(new_node)

            tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node)
            if not config.do_not_emit_runtime_asserts:
                insert_deferred_runtime_asserts(
                    fx.GraphModule(root, self.graph),
                    self.shape_env,
                    name,
                    export=self.export,
                )
            # NB: deferred runtime asserts can keep graphargs live, so make sure
            # those are inserted before pruning
            self.remove_unused_graphargs()
            ncalls = count_calls(self.graph)
            counters["stats"]["calls_captured"] += ncalls

            # free a bit of memory
            self.real_value_cache.clear()

            gm = _make_graph_module(root, self.graph)
            for register_finalizer in self.register_finalizer_fns:
                register_finalizer(gm)

            gm.compile_subgraph_reason = self.compile_subgraph_reason
            gm.meta[
                "dynamo_flat_name_to_original_fqn"
            ] = self.dynamo_flat_name_to_original_fqn.copy()
            gm.meta["dynamo_compile_id"] = self.dynamo_compile_id

            graph_code_log.debug(
                "%s",
                lazy_format_graph_code(
                    name, gm, include_stride=True, include_device=True, colored=True
                ),
            )
            torch._logging.trace_structured(
                "dynamo_output_graph",
                lambda: {"sizes": self.get_graph_sizes_structured()},
                payload_fn=lambda: gm.print_readable(
                    print_output=False, include_stride=True, include_device=True
                ),
            )
            self.call_cleanup_hooks()
            old_fake_mode = self.tracing_context.fake_mode
            if not self.export:
                import torch._functorch.config as _config

                with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
                    # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
                    backend_fake_mode = torch._subclasses.FakeTensorMode(
                        shape_env=old_fake_mode.shape_env,
                    )
                # TODO(voz): Ostensibily, this should be scoped and
                # restore back to old_fake_mode, but doing so currently violates
                # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
                self.tracing_context.fake_mode = backend_fake_mode

            with self.restore_global_state():
                compiled_fn = self.call_user_compiler(gm)

            from torch.fx._lazy_graph_module import _LazyGraphModule

            if isinstance(compiled_fn, _LazyGraphModule) or (
                isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
                and compiled_fn.__name__ == "_lazy_forward"  # type: ignore[attr-defined]
            ):
                # Since dynamo will run the forward method for the GraphModule shortly
                # anyways, it does not hurt to do the real recompilation here if
                # this is a _LazyGraphModule. This makes it easier for dynamo to
                # optimize a _LazyGraphModule.

                lazy_gm = (
                    compiled_fn
                    if isinstance(compiled_fn, _LazyGraphModule)
                    else compiled_fn.__self__  # type: ignore[attr-defined]
                )

                _LazyGraphModule.force_recompile(lazy_gm)

                if not isinstance(compiled_fn, _LazyGraphModule):
                    # replace compiled_fn with the real forward method
                    compiled_fn = lazy_gm.forward

            compiled_fn = disable(compiled_fn)

            counters["stats"]["unique_graphs"] += 1
            # This is safe because we pre-process name to be unique
            self.install_global_unsafe(name, compiled_fn)

            cg = PyCodegen(tx)
            cg.make_call_generated_code(name)
            return cg.get_instructions()

    @property
    def placeholders(self) -> List[fx.Node]:
        return self.graph.find_nodes(op="placeholder")

    @property
    def graphargs(self) -> List[GraphArg]:
        return [node.meta["grapharg"] for node in self.placeholders]

    def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        with dynamo_timed(
            "OutputGraph.call_user_compiler",
            phase_name="backend_compile",
            log_pt2_compile_event=True,
            dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
        ):
            return self._call_user_compiler(gm)

    def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        assert self.compiler_fn is not None
        tot = 0
        placeholders = []
        for node in gm.graph.nodes:
            if node.op in ("call_function", "call_method", "call_module"):
                tot += 1
            if node.op == "placeholder":
                placeholders.append(node)
        increment_op_count(tot)
        for pl in placeholders:
            arg = pl.meta["grapharg"]
            # TODO: Why isn't this stored in meta :think:
            pl._dynamo_source = arg.source

        gm._param_name_to_source = self.param_name_to_source  # type: ignore[assignment]
        gm._source_to_user_stacks = self.source_to_user_stacks  # type: ignore[assignment]

        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiler_fn = self.compiler_fn
            if config.verify_correctness:
                compiler_fn = WrapperBackend(compiler_fn)
            compiled_fn = compiler_fn(gm, self.example_inputs())
            _step_logger()(logging.INFO, f"done compiler function {name}")
            assert callable(compiled_fn), "compiler_fn did not return callable"
        except TensorifyScalarRestartAnalysis:
            raise
        except exceptions_allowed_to_be_fallback as e:
            if self.has_user_defined_allowed_in_graph:
                raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
                    e.__traceback__
                ) from None
            msg = (
                "Backend compiler failed with a fake tensor exception at \n"
                f"{self.root_tx.format_frame_summary()}"
                "Adding a graph break."
            )
            unimplemented_with_warning(e, self.root_tx.f_code, msg)
        except SkipFrame as e:
            # The backend compiler has requested that we skip the frame, instead of
            # aborting execution.
            raise e
        except Exception as e:
            raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
                e.__traceback__
            ) from None

        signpost_event(
            "dynamo",
            "OutputGraph.call_user_compiler",
            {
                **self.co_fields,
                "op_count": tot,
                "node_count": len(gm.graph.nodes),
                "input_count": len(placeholders),
            },
        )

        return compiled_fn

    def dedup_pass(self):
        if torch._dynamo.config.use_graph_deduplication:
            return apply_graph_deduplication(self)
        else:
            return dict()

    def install_subgraph(self, name, sub_gm):
        next_name = None
        i = 0
        while not next_name:
            candidate = f"{name}_{i}"
            if candidate in self.nn_modules:
                i += 1
            else:
                next_name = candidate

        sub_gm.__name__ = next_name
        sub_gm.torchdynamo_force_dynamic = False
        # This graph module is not present in the user space, so it can't be
        # accessed by a source. Set source=None.
        self.register_attr_or_module(sub_gm, next_name, source=None)
        return next_name

    def example_inputs(self) -> List[torch.Tensor]:
        result = [arg.example for arg in self.graphargs]
        return result

    def remove_unused_graphargs(self) -> None:
        # NB: It's always OK to drop GraphArg for symbols that ended up being
        # specialized.  You don't even have to make a guard for it, because
        # ShapeEnv produce_guards operates on tracked_fakes, which never gets
        # pruned.  That being said, you'll get marginally better generated
        # guard code if you promote the guard into a Dynamo guard (since that
        # allows for the guard to be done using C++ guards.)  If we get
        # ShapeEnv guards to go into C++ guards, this will stop being a thing
        # though!

        assert self.should_exit

        # Miniature DCE pass, but only for obviously trivial operations
        def is_static_true(b_node: fx.node.Argument):
            if b_node is True:
                return True
            if not isinstance(b_node, fx.Node):
                return False
            b = b_node.meta.get("example_value")
            if b is None:
                return False
            if b is True:
                return True
            if (
                isinstance(b, torch.SymBool)
                and (r := b.node.maybe_as_bool()) is not None
            ):
                return r
            # TODO: We can also technically remove all cases when the input
            # doesn't have unbacked inputs, since it's all in the ShapeEnv
            return False

        def is_symnode_arg(a: fx.node.Argument):
            from torch.fx.experimental.sym_node import SymTypes

            if isinstance(a, (int, float, bool)):
                return True
            if isinstance(a, fx.Node):
                return isinstance(a.meta.get("example_value"), SymTypes)
            return False

        # NB: We assume that you cannot do mutations on int/float/bool,
        # because they are immutable types, and therefore is always safe to
        # DCE.
        def is_symnode_compute_node(node):
            from torch.fx.experimental.sym_node import SymTypes

            if node.op != "call_function":
                return False
            # TODO: I don't think it's possible to have a bare int/float here?
            if not isinstance(node.meta.get("example_value"), SymTypes):
                return False
            # TODO: This will bail here if you ever end up with a more complicated
            # computation function, like sum(list_of_ints), even though it
            # should be DCE'able
            if not all(is_symnode_arg(a) for a in node.args):
                return False
            if not all(is_symnode_arg(a) for a in node.kwargs.values()):
                return False
            return True

        from torch.fx.experimental.symbolic_shapes import is_accessor_node

        for node in reversed(list(self.graph.nodes)):
            if len(list(node.users)) == 0:
                if (
                    node.op == "get_attr"
                    or (node.op == "call_function" and node.target is operator.getitem)
                    or (
                        node.op == "call_function"
                        and node.target is torch._check
                        and is_static_true(node.args[0])
                    )
                    or is_symnode_compute_node(node)
                    or is_accessor_node(node)
                ):
                    self.remove_node(node)

        def placeholder_binds_symbol(node):
            arg = node.meta["grapharg"]
            example = arg.example
            if isinstance(example, torch.SymInt) and isinstance(
                example.node.expr, sympy.Symbol
            ):
                return example.node.expr
            return None

        def remove_unused(node):
            log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name())
            # I'm not really sure why you need to delete these from the
            # node since the node is going to get removed
            del node.meta["grapharg"]
            self.remove_node(node)
            self.real_value_cache.pop(node, None)

        used_symbols: Set[sympy.Symbol] = set()

        def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
            used_symbols |= free_symbols(fake)

        recheck_placeholders = []
        for node in self.placeholders:
            binds_symbol = placeholder_binds_symbol(node) is not None
            # Don't delete symbol bindings yet
            if binds_symbol:
                if not node.users:
                    recheck_placeholders.append(node)
            else:
                if not node.users and not isinstance(
                    node.meta["grapharg"], BackwardStateGraphArg
                ):
                    remove_unused(node)
                else:
                    # Register the free symbols as uses
                    arg = node.meta["grapharg"]
                    if isinstance(arg, BackwardStateGraphArg):
                        continue
                    if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
                        real_script_obj = node.meta["grapharg"].example
                        fake_script_obj = node.meta["grapharg"].example_strong_ref
                        if not torch._library.fake_class_registry.tracing_with_real(
                            real_script_obj
                        ):
                            flat_dict = dict(real_script_obj.__obj_flatten__())  # type: ignore[attr-defined]
                            for attr in flat_dict.keys():
                                fake_attr_val = getattr(
                                    fake_script_obj.wrapped_obj, attr
                                )
                                pytree.tree_map_only(
                                    (torch.SymInt, torch.Tensor),
                                    lambda t: update_used_symbols(used_symbols, t),
                                    fake_attr_val,
                                )
                        continue
                    fake = (
                        arg.fake_tensor if arg.fake_tensor is not None else arg.example
                    )
                    update_used_symbols(used_symbols, fake)

        # After removing unused graphargs, prune unused binds_symbol
        for node in recheck_placeholders:
            symbol = placeholder_binds_symbol(node)
            if symbol is not None:
                if symbol not in used_symbols:
                    remove_unused(node)
                else:
                    # Make sure we delete later occurrences of the same symbol
                    used_symbols.remove(symbol)

    def add_output_instructions(self, prefix: List[Instruction]) -> None:
        """
        We call this on the creation of a new compiled subgraph that is inserted
        before user code.
        """
        self.output_instructions.extend(prefix)
        self.should_exit = True

    def install_global_unsafe(self, name, value) -> None:
        """
        WARNING: prefer the safer `install_global_by_id/install_global`.
        torch.compile instances should be independent of each other;
        one footgun is to have one instance depend on the existence of
        a global installed by another instance. This can happen if we mangle
        a global the same way across both instances.
        """
        assert name not in self.installed_globals
        self.installed_globals.add(name)
        self.cleanups.append(CleanupHook.create(self.global_scope, name, value))

    def install_global_by_id(self, prefix, value) -> str:
        """
        Installs a global if it hasn't been installed already.
        This is determined by (prefix, id(value)) pair.

        Returns the name of the newly installed global.
        """
        # NB: need self.compile_id to distinguish this global
        # from another global created in a different torch.compile instance
        name = f"{prefix}_{id(value)}_c{self.compile_id}"
        if name in self.installed_globals:
            return name
        self.install_global_unsafe(name, value)
        return name

    def install_global(self, prefix, value) -> str:
        """
        Installs a global, generating a unique name for it.

        Returns the name of the newly installed global.
        """
        # NB: unique_id is unique, even across torch.compile instances
        name = unique_id(prefix)
        self.install_global_unsafe(name, value)
        return name

    def cleanup(self) -> None:
        # There is a reference cycle between tracer and OutputGraph, causing
        # some of the tensor objects to be held alive for longer than necessary.
        self.root_tx = None
        self.nn_modules.clear()
        self.param_name_to_source = None

        for node in self.graph.nodes:
            if "grapharg" in node.meta:
                del node.meta["grapharg"]
        self.real_value_cache.clear()
        self.input_name_to_proxy.clear()
        self.side_effects.clear()
        self.variable_tracker_cache.clear()
        self.register_finalizer_fns.clear()
        self.dynamo_flat_name_to_original_fqn.clear()
        self.tracing_context.clear()
        self.input_source_to_var.clear()
        self.unspec_variable_map.clear()
        self.backward_state.clear()

    def set_torch_function_state(self, enabled: bool) -> None:
        self.torch_function_enabled = enabled

    def add_graph_finalizer(
        self, register_finalizer: Callable[[fx.GraphModule], None]
    ) -> None:
        self.register_finalizer_fns.append(register_finalizer)

    def example_value_from_input_node(self, node: torch.fx.Node):
        """Extract the non-fake example tensor"""
        if node.op == "placeholder":
            return node.meta["grapharg"].example
        assert node.op == "get_attr"
        return self.nn_modules[node.target]  # type: ignore[index]


err_epilogue = (
    "With the current config, we will graph break "
    "(and fall back to eager-mode PyTorch) on all ops "
    "that have do not have the 'pt2_compliant_tag'. "
    "Please see the following doc for how to mark this op as PT2 compliant "
    "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html"
)


def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
    if kind != "call_function":
        return

    def encountered_compliant_op(target):
        if target.namespace in {"prim", "prims", "aten"}:
            return
        output_graph.compliant_custom_ops.add(target)

    def encountered_non_compliant_op(target, msg):
        output_graph.non_compliant_ops.add(target)
        if config.only_allow_pt2_compliant_ops:
            unimplemented(msg + " " + err_epilogue)

    if isinstance(target, torch._ops.OpOverload):
        if torch.Tag.pt2_compliant_tag in target.tags:
            encountered_compliant_op(target)
            return
        encountered_non_compliant_op(
            target,
            f"Encountered the torch.ops.OpOverload {target} "
            f"that is not PT2 compliant.",
        )
        return

    if isinstance(target, torch._ops.OpOverloadPacket):
        overloads = tuple(target.overloads())
        # Optimization: Overload resolution is expensive.
        # If there's only one overload, we know what it will resolve to.
        if len(overloads) == 1:
            op = getattr(target, overloads[0])
            if torch.Tag.pt2_compliant_tag in op.tags:
                encountered_compliant_op(op)
                return
            encountered_non_compliant_op(
                op,
                f"Encountered the non-overloaded "
                f"torch.ops.OpOverloadPacket {target} "
                f"that is not PT2 compliant. ",
            )
            return

        args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes(
            output_graph.current_tx, (args, kwargs), False
        )
        try:
            overload = torch._C._jit_resolve_packet(
                target._qualified_op_name, *args, **kwargs
            )
        except RuntimeError as e:
            unimplemented(str(e))

        op = getattr(target, overload)
        if torch.Tag.pt2_compliant_tag in op.tags:
            encountered_compliant_op(op)
        else:
            encountered_non_compliant_op(
                op,
                f"Encountered the torch.ops.OpOverloadPacket {target} "
                f"which resolves to the overload ({overload}) that is "
                f"not PT2 compliant.",
            )


_compile_id_counter = itertools.count()


class LazyProxy:
    def __init__(self, tracer, fn, *args, **kwargs):
        self.tracer = tracer
        self.fn = fn
        self.args = args
        self.kwargs = kwargs

    def __call__(self):
        return self.fn(*self.args, **self.kwargs)


class SubgraphTracer(fx.Tracer):
    """
    Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
    and the separation of responsibilities is that SubgraphTracer is
    responsible for building the graph while OutputGraph is responsible for
    compiling and executing the graph.
    """

    def __init__(self, output_graph, parent=None, is_export=False, source_target=None):
        super().__init__()
        self.output_graph = weakref.proxy(output_graph)
        self.graph = torch.fx.Graph()

        # See note [Export inputs must be explicitly passed in]
        self.is_export = is_export
        # Map from graph input name to its placeholder proxy object, where the
        # map's keys give all current placeholder node names and can be used to
        # create unique node names
        self.input_name_to_proxy: Dict[str, fx.Proxy] = {}
        # Node => computed real value (see utils.get_real_value)
        self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}

        # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
        self.parent = parent
        # A dict mapping previously free variables (Proxy objects)
        # to new Proxy objects that wrap inputs to this subgraph.
        #
        # This dict maps proxies in outer graphs to placeholders in current graph.
        # It serves two purposes:
        # - Proxies are associated with VariableTrackers. If we see
        # the same VariableTracker twice (and it is a free variable),
        # then we want to use the same Proxy in the current subgraph to
        # record the tracing.
        # - If we are tracing a HigherOrderOperator's body_fn, then we
        # need to keep track of what free variables were lifted so we can
        # rewrite the HigherOrderOperator call using the traced body_fn.
        # Dicts maintain the order of args for the HigherOrderOperator call.
        self.lifted_freevars = {}

        # map basic symbols (unbacked and unbacked) to their bound proxies.
        # There are only two cases where bound_symbols will be recorded:
        # 1. when we create_graph_input for a backed SymInt that's basic symbol
        # 2. when we track_unbacked_symbols for intermediate results that contain unbacked symints.
        self.bound_symbols: Dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {}

        self.prev_inst = None
        # True if this tracer is currently tracing into torch.utils.checkpoint
        # as part of speculate_subgraph.
        self.under_activation_checkpoint = False
        # True if we want to allow side-effects (doesn't throw error on their existence)
        # during this tracer's tracing of torch.utils.checkpoint (via speculate_subgraph).
        # Only safe if we know for sure that *NOT* replaying these side-effects during
        # backward recomputation of the checkpoint region doesn't affect its correctness.
        self.allow_side_effects_under_checkpoint = False

        self.debug_level: int = parent.debug_level + 1 if parent is not None else 0

        self._cur_code = None
        self._orig_gm_meta = None
        self._orig_gm_lineno_map = None
        self._orig_gm_firstlineno = None
        # Each SubgraphTracer is associated with a source target, which indicates
        # which operator this subgraph is attached to. We compute a source_fn_stack
        # based on the source target. For the root tracer, it's set to [].
        # This is useful for debugging and transforming the exported graph.
        if self.parent is None:
            self.source_fn_stack = []
        else:
            self.source_fn_stack = self.parent.source_fn_stack + [
                (self.graph._target_to_str(source_target), source_target)
            ]

    # preserve original meta if it is available
    def _maybe_preserve_original_meta(self, tx, node):
        if (
            self._orig_gm_meta
            and self._orig_gm_lineno_map
            and self._orig_gm_firstlineno
        ):
            lineno = tx.current_instruction.starts_line
            node_idx = None
            if lineno is not None:
                node_idx = self._orig_gm_lineno_map.get(
                    lineno - self._orig_gm_firstlineno, None
                )
            if node_idx is not None:
                meta = self._orig_gm_meta[node_idx]
                for field in fx.proxy._COPY_META_FIELDS:
                    if field in meta:
                        node.meta[field] = meta[field]
                if "stack_trace" in meta:
                    node.meta["stack_trace"] = meta["stack_trace"]

    def create_proxy(
        self,
        kind,
        target,
        args,
        kwargs,
        name=None,
        type_expr=None,
        proxy_factory_fn=None,
    ):
        # NOTE: [Nested SubgraphTracer and free_variable handling]
        # --------------------------------------------------------
        # Read NOTE [HigherOrderOperator tracing design] first.
        #
        # Let's say we're in the middle of introspecting the body of a possibly
        # nested HigherOrderOperator, and we see a free variable.
        #
        # There are two cases:
        # 1. We see a free variable that is already tracked by Dynamo.
        # 2. We see a free variable that has not been tracked by Dynamo
        #
        # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below)
        # which will lift the freevar to be an input of this subgraph
        # and also recursively lift it to be an input on the parent(s).
        #
        # In case 2, before the call to `create_proxy`, the InstructionTranslator
        # will see the freevar when it gets loaded by Python bytecode.
        # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or
        # LOAD_GLOBAL.
        # There, the InstructionTranslator asks Dynamo to begin tracking the
        # freevar by building a new Variable.
        # Building a new Variable automatically lifts the freevar to be an
        # input of the root SubgraphTracer.
        #
        # The implications for the code below are:
        # - We will always be in Case 1 when we get to this code.
        # - Any "free variable" we encounter here is guaranteed to already be
        #   bound, that is, it is either a graph input of the root graph, or
        #   some local variable of the root graph or a subgraph.
        # - The additional work we need to do here is *only* that we need to
        #   lift this free variable into inputs (recursively) of each nested
        #   higher-order-op subgraph until we hit the subgraph where the free
        #   variable is bound
        if self.parent is not None:
            flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
            new_flat_args = []
            for arg in flat_args:
                maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg)
                new_flat_args.append(maybe_new_arg)

            args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)

        rv = super().create_proxy(
            kind, target, args, kwargs, name, type_expr, proxy_factory_fn
        )

        # append stack trace to fx node
        tx = self.output_graph.current_tx

        # log detailed location of line of code in 3.11
        if sys.version_info >= (3, 11) and kind in (
            "call_function",
            "call_method",
            "call_module",
        ):
            cur_inst = tx.current_instruction
            if (
                cur_inst is not self.prev_inst
                and cur_inst.positions is not None
                and cur_inst.positions.lineno is not None
            ):
                tx_code = tx.f_code
                header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)

                def get_trace_call_log_str():
                    line = get_instruction_source_311(tx_code, cur_inst).rstrip()
                    return f"TRACE FX call {rv.node.name} from {header}\n{line}"

                trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
                self.prev_inst = cur_inst

        # update reference to original meta if we're tracing a new code object
        is_retracing = False
        if tx.f_code is not self._cur_code:
            orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
                "orig_graphmodule", lambda: None
            )()
            if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
                is_retracing = True
                self._orig_gm_meta = [
                    nd.meta for nd in orig_graphmodule_maybe.graph.nodes
                ]
                self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map
                self._orig_gm_firstlineno = (
                    orig_graphmodule_maybe.forward.__code__.co_firstlineno
                )
            else:
                self._orig_gm_meta = None
                self._orig_gm_lineno_map = None
                self._orig_gm_firstlineno = None
        nn_module_stack = tx.nn_module_stack
        if nn_module_stack:
            rv.node.meta["nn_module_stack"] = nn_module_stack.copy()

        if kind in {"call_function", "call_method"}:
            rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
                (rv.node.name, target)
            ]
        elif kind == "call_module":
            if self.parent is not None:
                unimplemented("Invoking an nn.Module inside HigherOrderOperator")
            # For modules we store the class
            rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
                (
                    rv.node.name,
                    next(
                        ty
                        for k, (_, ty) in rv.node.meta["nn_module_stack"].items()
                        if k.split("@")[0] == target
                    ),
                )
            ]

        self._maybe_preserve_original_meta(tx, rv.node)

        if not is_retracing:
            if "nn_module_stack" not in rv.node.meta:
                nn_module_stack = tx.nn_module_stack
                if nn_module_stack:
                    rv.node.meta["nn_module_stack"] = nn_module_stack.copy()

            if "source_fn_stack" not in rv.node.meta:
                if kind in {"call_function", "call_method"}:
                    rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
                        (rv.node.name, target)
                    ]
                elif kind == "call_module":
                    if self.parent is not None:
                        unimplemented(
                            "Invoking an nn.Module inside HigherOrderOperator"
                        )
                    # For modules we store the class
                    rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
                        (
                            rv.node.name,
                            rv.node.meta["nn_module_stack"][target][1],
                        )
                    ]

        if "stack_trace" not in rv.node.meta:
            frame_summaries: List[traceback.FrameSummary] = []
            while tx:
                # Avoid frame summaries from inside the torch/nn/modules. This ensures that we keep the stack trace of
                # the user code.
                if not tx.is_co_filename_from_nn_modules():
                    frame_summaries.append(tx.frame_summary())
                tx = getattr(tx, "parent", None)
            # Reverse the frame_summaries, such that the innermost frame is at the last
            frame_summaries.reverse()

            # official from_list stub doesn't have new-style type
            msgs = traceback.StackSummary.from_list(frame_summaries).format()
            rv.node.stack_trace = "".join(msgs)

        if (
            torch._dynamo.config.use_graph_deduplication
            or torch._dynamo.config.track_nodes_for_deduplication
        ):
            self.output_graph.region_tracker.track_node(
                self.output_graph.current_tx, rv.node
            )
        return rv

    def create_node(
        self, op, target, args=None, kwargs=None, name=None, type_expr=None
    ):
        check_pt2_compliant_op(self.output_graph, op, target, args, kwargs)
        if self.parent is not None:
            flat_args = pytree.arg_tree_leaves(*args, **kwargs)
            for arg in flat_args:
                if not isinstance(arg, torch.fx.Node):
                    continue
                assert (
                    arg.graph == self.graph
                ), "create_node using arg not from this SubgraphTracer"

        node = super().create_node(op, target, args, kwargs, name, type_expr)
        node.meta["creation_timestamp"] = self.output_graph.timestamp
        return node

    # Note: we did not override erase_node since
    # we call self.graph.erase_node elsewhere
    def remove_node(self, node):
        if len(node.users) > 0:
            user_graph_nodes: List[torch.fx.Node] = []
            for user in node.users.keys():
                # For the case where user.graph == self.graph, that is a real bug and will raise
                # properly.
                if user.graph != self.graph:
                    # This is a nested graph, which needs to be deleted.
                    # If we do not do this, we will raise on attempting to remove this.
                    # As we only get here during restoration cleanup, this is sound.
                    user_graph_nodes.extend(reversed(list(user.graph.nodes)))
            for other_graph_node in user_graph_nodes:
                other_graph_node.graph.erase_node(other_graph_node)
        self.graph.erase_node(node)
        self.input_name_to_proxy.pop(node.name, None)

    # when before=True, we will insert this input before the most recent
    # inserted proxy.  This is a hack to get around an ordering problem,
    # where we first insert a tensor argument, and then insert bindings
    # for SymInts that may occur in the tensor argument.
    # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets
    # fixed.
    def create_graph_input(
        self, name, type_expr, example_value, before=False, source=None
    ):
        log.debug(
            "create_graph_input %s %s %s at debug_level %s before=%s",
            name,
            source.name() if source is not None else "(none)",
            example_value,
            self.debug_level,
            before,
        )
        if source is None:
            assert (
                self.parent is not None
            ), f"you are required to provide a source for inputs {name} example_val {example_value} on the root tracer"

        # Note [Export inputs must be explicitly passed in]
        # In eager, we are generally OK with adding graph inputs whenever we
        # want, because we take care of writing the bytecode that knows how
        # to source all the inputs.
        #
        # In export, this is bad, because you want a self-contained export
        # object which only depends on the inputs you explicitly passed to it.
        # So we are a bit more strict about what sources can become inputs
        # in export
        if self.is_export and self.parent is None:
            if not is_from_local_source(source, only_allow_input=True):
                self.output_graph.source_to_user_stacks.setdefault(source, []).append(
                    TracingContext.extract_stack()
                )

        # unique
        if name in self.input_name_to_proxy:
            for i in itertools.count():
                candidate_name = f"{name}_{i}"
                if candidate_name not in self.input_name_to_proxy:
                    name = candidate_name
                    break

        if self.input_name_to_proxy:
            prev_name = next(reversed(self.input_name_to_proxy))
            node = self.input_name_to_proxy[prev_name].node
            if before:
                ctx = self.graph.inserting_before(node)
            else:
                ctx = self.graph.inserting_after(node)
        else:
            ctx = self.graph.inserting_before(None)
        with ctx:
            proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
            set_example_value(proxy.node, example_value)
            if self.input_name_to_proxy and before:
                k, v = self.input_name_to_proxy.popitem()
                self.input_name_to_proxy[name] = proxy
                self.input_name_to_proxy[k] = v
            else:
                self.input_name_to_proxy[name] = proxy

            # NOTE: [Auto lift basic free symbols when create_graph_input]
            # Whenever we call create_graph_input, we try to also lift the basic symbols in example values
            # as graph input.
            # This applies to both top-level graph and subgraphs in higher order ops.
            # It has several cases:
            #  1. When create_graph_input for a tensor that has symbolic shapes,
            #     we look for basic symbols in its size and stride, we check if the symbol is bound
            #     in current graph (i.e. bound_symbols), it it's not bound, we'll create a placeholder
            #     for it then recursively check its parent, creates ph if not bound.
            #     Every tracer maintains a mapping (i.e. lifted_freevars)
            #     that maps from parent proxy to proxy in current tracer for the symbol.
            #  2. When create_graph_input for a tensor with unbacked symbolic shapes,
            #     Backed symbols all come from inputs's symbolic shape. But unbacked symbols
            #     can be created while tracing. So we use track_unbacked_symbols will intercept
            #     at wrap_fx_proxy, and try to bind the unbacked symbols immediately after they're
            #     created.
            #  3. subgraph will also lifted basic symbols in compound exprs of tensor shape.
            #     For example, if an input to subgraph takes size [s1+s2//8], we'll look for the
            #     the free symbols in the sizes and lift as inputs similar to 1 in _lift_symbols_in_symint)
            #  4. When create_graph_input for a SymInt, if the symint is a basic symbol, we'll track it
            #     in bound_symbols so that we don't lift the same basic symbol twice. When the symint is a
            #     compound expr, we'll just create the proxy for the compouned expr but not lift its basic symbols.
            # Also see NOTE: [Export inputs must be explicitly passed in]
            is_strict_export = self.is_export
            is_non_strict_export = torch.compiler.is_compiling()
            if (
                not is_strict_export
                and not is_non_strict_export
                and isinstance(example_value, torch.Tensor)
            ):
                self._lift_basic_symbols(example_value, source)

            # Bound the symbol to ph if example_value is a SymInt with basic symbol.
            if isinstance(example_value, torch.SymInt) and isinstance(
                example_value.node.expr, sympy.Symbol
            ):
                self.bound_symbols[example_value.node.expr] = proxy
            return proxy

    # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
    def lift_tracked_freevar_to_input(self, proxy):
        # You're doing something wrong if we are the root SubgraphTracer because
        # Dynamo adds tensors to graph inputs before creating a proxy for them.
        assert (
            self.parent is not None
        ), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"

        example_value = proxy.node.meta["example_value"]

        # To avoid lifting the same symbol twice, we check whether basic symbols has been tracked.
        # For example, the basic symbols may have already been lifted for current subgraph when
        # we automatically lift basic symbols in the sizes/strides of a tensor t.
        # Suppose parent graph calls sz = t.size()[0], it creates
        # a proxy in parent and the subgraph accesses sz via closure. sz's proxy is not tracked
        # in current sub-tracer so we may lift the same symbol twice.
        if (
            isinstance(example_value, torch.SymInt)
            and example_value.node.expr in self.bound_symbols
        ):
            return self.bound_symbols[example_value.node.expr]

        # Proxys are associated with VariableTracker.
        # It is possible that we've already lifted the Proxy to be an input.
        # If that is the case, just return the already lifted Proxy.
        if proxy in self.lifted_freevars:
            return self.lifted_freevars[proxy]

        # We first lift proxy to parent's graph then lift to current grpah's input
        # so that when we bind symints of the sizes in current graph, those symints
        # would already be lifted as inputs to parent graph.
        if proxy.tracer != self.parent:
            self.parent.lift_tracked_freevar_to_input(proxy)

        example_value = proxy.node.meta["example_value"]
        new_proxy = self.create_graph_input(
            proxy.node.name, type(example_value), example_value
        )
        self.lifted_freevars[proxy] = new_proxy
        return new_proxy

    def maybe_lift_tracked_freevar_to_input(self, arg):
        """
        If arg is a free variable, then lift it to be an input.
        Returns the new lifted arg (if arg was a freevar), else the
        original arg.
        """
        if not isinstance(arg, torch.fx.Proxy):
            # Note: arg can be a python built-in slice type e.g.
            # x[:max_seq] is represented as get_item(t, (slice(None, max_seq, None)))
            # we need to also look into the slice variable itself to lift the
            # proxies there.
            if isinstance(arg, slice):
                return slice(
                    *(
                        self.maybe_lift_tracked_freevar_to_input(sub_arg)
                        for sub_arg in (arg.start, arg.stop, arg.step)
                    )
                )
            else:
                return arg
        elif arg.tracer == self:
            return arg
        return self.lift_tracked_freevar_to_input(arg)

    # See NOTE: [Auto lift basic free symbols when create_graph_input] for overall design
    # You MUST call this API every time when creating a proxy in wrap_fx_proxy for a call
    # that produced unbacked symints or tensors with unbacked symint shapes.
    # This function is used to track the unbacked symints with its proxies created during
    # dynamo tracing so that subgraph knows how to bind a symbol input with parent's proxy.
    # LazyProxy are created for tensor shapes that're unbacked so that we don't create proxies
    # for symbols that're not going to be used.
    def track_unbacked_symbols(
        self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy]
    ):
        # When binding the symbols in an exmaple_value, we bind the symbols
        # to the proxy's associatied Tracer instead of current tracer.
        # This is because:
        # 1. We may be calling wrap_tensors during speculate_subgraph because
        # the variables are lazily realized. The proxy are top-level phs but
        # current tracer is a subtracer.
        # 2. For autograd.Function, we trace the backward graph with a new tracer
        # whose parent is the forward tracer, but we're using all the proxies created
        # in forward tracer to trace the backward.
        # For example, forward calls save_for_backward for a input tensor t.
        # Backward calls t.tolist(). In this case, all the proxies that backward tracer
        # sees are from parent tracer (i.e. the forward tracer). (e.g. t[0].item())
        # See test_validate_outputs_unbacked for repro on 2.
        tracer = e_proxy.tracer
        assert isinstance(tracer, SubgraphTracer)

        def need_bind(s) -> bool:
            from torch.fx.experimental.symbolic_shapes import is_symbolic

            return (
                is_symbolic(s)
                and isinstance(s.node.expr, sympy.Symbol)
                and s.node.shape_env.is_unbacked_symint(s.node.expr)
                and s.node.expr not in self.bound_symbols
            )

        def _proxy_with_example_value(example_value, *args, **kwargs):
            proxy = tracer.create_proxy(*args, **kwargs)
            set_example_value(proxy.node, example_value)
            return proxy

        if isinstance(example_value, torch.Tensor):
            for i, s in enumerate(example_value.size()):
                if need_bind(s):
                    log.debug(
                        "_track_unbacked_symbols %s for %s.size()[%s] at debug_level %s",
                        s,
                        e_proxy,
                        i,
                        tracer.debug_level,
                    )
                    lazy_proxy = LazyProxy(
                        tracer,
                        _proxy_with_example_value,
                        s,
                        "call_function",
                        torch.ops.aten.sym_size.int,
                        (e_proxy, i),
                        {},
                        type_expr=type(s),
                    )
                    self.track_unbacked_symbols(s, lazy_proxy)

            if example_value.layout is torch.strided:
                for i, s in enumerate(example_value.stride()):
                    if need_bind(s):
                        log.debug(
                            "_track_unbacked_symbols %s for %s.stride()[%s] at debug_level %s",
                            s,
                            e_proxy,
                            i,
                            tracer.debug_level,
                        )
                        lazy_proxy = LazyProxy(
                            tracer,
                            _proxy_with_example_value,
                            s,
                            "call_function",
                            torch.ops.aten.sym_stride.int,
                            (e_proxy, i),
                            {},
                            type_expr=type(s),
                        )
                        self.track_unbacked_symbols(s, lazy_proxy)

            elif example_value.layout is torch.sparse_coo:
                self.track_unbacked_symbols(example_value._indices(), e_proxy)
                self.track_unbacked_symbols(example_value._values(), e_proxy)
            elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
                self.track_unbacked_symbols(example_value.crow_indices(), e_proxy)
                self.track_unbacked_symbols(example_value.col_indices(), e_proxy)
            elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
                self.track_unbacked_symbols(example_value.ccol_indices(), e_proxy)
                self.track_unbacked_symbols(example_value.row_indices(), e_proxy)
            if is_traceable_wrapper_subclass(example_value):
                attrs, ctx = example_value.__tensor_flatten__()
                for attr in attrs:
                    inner_t = getattr(example_value, attr)
                    self.track_unbacked_symbols(inner_t, getattr(e_proxy, attr))
        elif isinstance(example_value, torch.SymInt):
            # Only bind unbacked symbols. backed symbols are lifted as inputs.
            if need_bind(example_value):
                expr = example_value.node.expr
                tracer.bound_symbols[expr] = e_proxy

    # See Note [Auto lift basic free symbols when create_graph_input]
    def _lift_basic_symbols(
        self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source]
    ):
        # The before arg is for inserting symints in the sizes/strides of a tensor
        # before the tensor. This odering ensures that when we look at the tensor's
        # symbols, they're already lifted/tracked. E.g. this assumption is used
        # in insert_deferred_runtime_asserts.
        def _lift_symbols_in_symint(
            s: Union[int, torch.SymInt],
            source: Optional[Source],
            before: bool = False,
        ) -> None:
            if not is_symbolic(s):
                return

            assert isinstance(s, torch.SymInt)
            self_to_be_bound = self.lookup_unbound_symbols(s)
            if len(self_to_be_bound) == 0:
                return

            # For subgraph
            if self.parent is not None:
                # Recursively lift symbols in symint until top-level.
                self.parent._lift_basic_symbols(s, source)
                for s0 in self_to_be_bound:
                    parent_proxy = self.parent.bound_symbols[s0]
                    example_val = parent_proxy.node.meta["example_value"]
                    assert isinstance(example_val, torch.SymInt)
                    ph = self.create_graph_input(
                        str(s0),
                        type(example_val),
                        example_val,
                        before=before,
                        source=source,
                    )
                    log.debug(
                        "_lift_symbols_in_symint %s from %s at debug_level %s",
                        s0,
                        source.name() if source is not None else "subgraph inputs",
                        self.debug_level,
                    )
                    self.lifted_freevars[parent_proxy] = ph
            # For root_tracer:
            else:
                assert len(self_to_be_bound) == 1, (
                    f"For root tracer, we only expect to bind basic symbols (compound symbols "
                    f"should be cached before) but got unbound symbols {self_to_be_bound} in {s}"
                )
                assert source is not None, (
                    f"Source of '{s}' is None when lifting it to input of top-level. If it's an unbacked symbol, "
                    "this could be because it's not tracked with lazy_bind_unbacked_symbols. "
                    f"Otherwise, should provide a source when create_graph_input for `{s}` at root tracer."
                )
                s0 = next(iter(self_to_be_bound))
                ph = self.create_graph_input(
                    str(s0),
                    type(s),
                    s,
                    before=before,
                    source=source,
                )
                log.debug(
                    "_lift_symbols_in_symint %s from %s at debug_level %s",
                    s,
                    source.name() if source is not None else "subgraph inputs",
                    self.debug_level,
                )
                ph.node.meta["grapharg"] = GraphArg(
                    source,
                    s,
                    pass_arg_as_tensor=False,
                    fake_tensor=None,
                    is_tensor=False,
                )

        if isinstance(example_value, torch.Tensor):
            for i, s in enumerate(example_value.size()):
                _lift_symbols_in_symint(
                    s,
                    (
                        TensorPropertySource(src, TensorProperty.SIZE, i)
                        if src is not None
                        else None
                    ),
                    before=True,
                )
            if example_value.layout is torch.strided:
                for i, s in enumerate(example_value.stride()):
                    _lift_symbols_in_symint(
                        s,
                        (
                            TensorPropertySource(src, TensorProperty.STRIDE, i)
                            if src is not None
                            else None
                        ),
                        before=True,
                    )
                _lift_symbols_in_symint(
                    example_value.storage_offset(),
                    (
                        TensorPropertySource(src, TensorProperty.STORAGE_OFFSET)
                        if src is not None
                        else None
                    ),
                    before=True,
                )
            elif example_value.layout is torch.sparse_coo:
                self._lift_basic_symbols(example_value._indices(), src)
                self._lift_basic_symbols(example_value._values(), src)
            elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
                self._lift_basic_symbols(example_value.crow_indices(), src)
                self._lift_basic_symbols(example_value.col_indices(), src)
            elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
                self._lift_basic_symbols(example_value.ccol_indices(), src)
                self._lift_basic_symbols(example_value.row_indices(), src)
            if is_traceable_wrapper_subclass(example_value):
                attrs, ctx = example_value.__tensor_flatten__()
                for attr in attrs:
                    inner_t = getattr(example_value, attr)
                    self._lift_basic_symbols(
                        inner_t, AttrSource(src, attr) if src is not None else None
                    )
        elif isinstance(example_value, torch.SymInt):
            _lift_symbols_in_symint(
                example_value,
                src,
            )

    # Lookup the proxy in current tracer for each symbol in expressions of s,
    # See Note [Auto lift basic free symbols when create_graph_input]
    def lookup_unbound_symbols(self, s: torch.SymInt) -> List[sympy.Symbol]:
        free_symbols = s.node.expr.free_symbols
        if len(free_symbols) == 0:
            return []

        to_be_bound = []
        for s0 in free_symbols:
            if s0 not in self.bound_symbols:
                to_be_bound.append(s0)
                continue

            proxy = self.bound_symbols[s0]
            if isinstance(proxy, LazyProxy):
                proxy = proxy()
                self.bound_symbols[s0] = proxy
            assert (
                isinstance(proxy, torch.fx.Proxy) and proxy.tracer is self
            ), f"The proxy of symbol {s0} doesn't belong to current tracer."
        # Sort the symbols so that we can have a deterministic lifting order
        return sorted(to_be_bound, key=lambda s: s.name)


# NOTE: [HigherOrderOperator tracing design]
# Ignoring HigherOrderOperators for a moment,
# OutputGraph represents the graph being built by Dynamo that may be compiled
# and executed. It holds a root SubgraphTracer where the FX graph is built.
#
# HigherOrderOperators are operators that take functions as their arguments.
# When Dynamo encounters a HigherOrderOperator, then it attempts to introspect
# the function passed to it (call this the "body function"), capture it into a
# GraphModule, and rewrite the call to the HigherOrderOperator to use the
# GraphModule.
#
# The way we handle the capture of body functions is through having
# (possibly nested) SubgraphTracers, one per body function.
#
# Mechanically, we do the introspection by:
# - Creating a new SubgraphTracer via OutputGraph.subtracer
# - Executing the body function.
# This constructs the graph of the body function in the new SubgraphTracer
# while modifying the state of the OutputGraph. For example:
# - the OutputGraph can receive new GraphArgs (if we discover any new
#   untracked Tensors)
# - side effects from the body function get accumulated into
#   OutputGraph.side_effects
# - guards produced by the body function get accumulated into OutputGraph.guards
#
# The traced function has some special properties that make it easier for us
# to transform later down the line:
# - we lift all free variables to being inputs.
#
# If the introspection fails (due to the existence of graph breaks), then
# we roll back the current OutputGraph state and graph break on the
# HigherOrderOperator.
