1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
|
# mypy: allow-untyped-defs
# pyre-strict
from typing import Union
import torch
class ProxyValue:
# pyre-ignore
def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]):
# pyre-ignore
self.data = data
self.proxy_or_node = proxy
@property
def node(self) -> torch.fx.Node:
if isinstance(self.proxy_or_node, torch.fx.Node):
return self.proxy_or_node
assert isinstance(self.proxy_or_node, torch.fx.Proxy)
return self.proxy_or_node.node
@property
def proxy(self) -> torch.fx.Proxy:
if not isinstance(self.proxy_or_node, torch.fx.Proxy):
raise RuntimeError(
f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}"
)
return self.proxy_or_node
def to_tensor(self) -> torch.Tensor:
assert isinstance(self.data, torch.Tensor)
return self.data
def is_tensor(self) -> bool:
return isinstance(self.data, torch.Tensor)
# pyre-ignore
def __iter__(self):
yield from self.data
def __bool__(self) -> bool:
return bool(self.data)
|