1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
|
import numpy as np
from . import tracer, tensor
import einx
import inspect
# Define classes for different types of inputs that act as cache keys and will
# be converted into the corresponding tracer objects when a graph is constructed
class CacheKey:
pass
class Scalar(CacheKey):
def __eq__(self, other):
return isinstance(other, Scalar)
def __hash__(self):
return 1
def to_tracer(self, backend, virtual_arg):
x = tensor.Scalar()
return x, x
class Tensor(CacheKey):
def __init__(self, shape, type):
self.shape = shape
self.type = type
def __eq__(self, other):
return isinstance(other, Tensor) and other.shape == self.shape and other.type == self.type
def __hash__(self):
return 2 + hash(self.shape) + hash(self.type)
def to_tracer(self, backend, virtual_arg):
if any(issubclass(self.type, type) for type in backend.tensor_types):
x = tensor.Tensor(self.shape)
else:
x = tensor.TensorRequiringConversion(self.shape)
return x, x
class TensorFactory(CacheKey):
def __init__(self, params):
self.params = tuple(params)
def __eq__(self, other):
return isinstance(other, TensorFactory) and other.params == self.params
def __hash__(self):
return 3 + hash(self.params)
def to_tracer(self, backend, virtual_arg):
x = tensor.TensorFactory(self.params)
return x, x
class Input:
pass
tensor_factories = []
def register_tensor_factory(factory):
tensor_factories.append(factory)
return factory
def apply_registered_tensor_factory(x):
for factory in tensor_factories:
x2 = factory(x)
if x2 is not None:
return x2
return None
def concrete_to_value_and_key(x):
if isinstance(x, (float, int, np.floating, np.integer, bool, np.bool_)):
# Scalar
return x, Scalar()
elif isinstance(x, (tuple, list)):
# Nested list/ tuple of scalars
shape = einx.tracer.get_shape(x)
if shape is None:
raise ValueError("Failed to determine shape of input tensor")
return x, Tensor(shape, type(x))
elif isinstance(x, Input):
# Custom input
return x.to_value_and_key()
elif not (x2 := apply_registered_tensor_factory(x)) is None:
# Registered tensor factory
return x2
elif callable(x):
# Simple callable tensor factory
params = []
try:
for name, param in inspect.signature(x).parameters.items():
if param.kind == inspect.Parameter.VAR_KEYWORD:
name = f"**{name}"
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
name = f"*{name}"
params.append(name)
except:
pass
return x, TensorFactory(params)
else:
# Tensor
return x, Tensor(tuple(int(i) for i in x.shape), type(x))
def key_to_tracer(x, backend, virtual_arg):
args = []
def map(x):
if isinstance(x, CacheKey):
arg, x = x.to_tracer(backend, virtual_arg)
if not arg is None:
args.append(arg)
return x
else:
return x
x = einx.tree_util.tree_map(map, x)
return args, x
|