File: input.py

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (128 lines) | stat: -rw-r--r-- 3,339 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
from . import tracer, tensor
import einx
import inspect

# Define classes for different types of inputs that act as cache keys and will
# be converted into the corresponding tracer objects when a graph is constructed


class CacheKey:
    pass


class Scalar(CacheKey):
    def __eq__(self, other):
        return isinstance(other, Scalar)

    def __hash__(self):
        return 1

    def to_tracer(self, backend, virtual_arg):
        x = tensor.Scalar()
        return x, x


class Tensor(CacheKey):
    def __init__(self, shape, type):
        self.shape = shape
        self.type = type

    def __eq__(self, other):
        return isinstance(other, Tensor) and other.shape == self.shape and other.type == self.type

    def __hash__(self):
        return 2 + hash(self.shape) + hash(self.type)

    def to_tracer(self, backend, virtual_arg):
        if any(issubclass(self.type, type) for type in backend.tensor_types):
            x = tensor.Tensor(self.shape)
        else:
            x = tensor.TensorRequiringConversion(self.shape)
        return x, x


class TensorFactory(CacheKey):
    def __init__(self, params):
        self.params = tuple(params)

    def __eq__(self, other):
        return isinstance(other, TensorFactory) and other.params == self.params

    def __hash__(self):
        return 3 + hash(self.params)

    def to_tracer(self, backend, virtual_arg):
        x = tensor.TensorFactory(self.params)
        return x, x


class Input:
    pass


tensor_factories = []


def register_tensor_factory(factory):
    tensor_factories.append(factory)
    return factory


def apply_registered_tensor_factory(x):
    for factory in tensor_factories:
        x2 = factory(x)
        if x2 is not None:
            return x2
    return None


def concrete_to_value_and_key(x):
    if isinstance(x, (float, int, np.floating, np.integer, bool, np.bool_)):
        # Scalar
        return x, Scalar()
    elif isinstance(x, (tuple, list)):
        # Nested list/ tuple of scalars
        shape = einx.tracer.get_shape(x)
        if shape is None:
            raise ValueError("Failed to determine shape of input tensor")
        return x, Tensor(shape, type(x))
    elif isinstance(x, Input):
        # Custom input
        return x.to_value_and_key()
    elif not (x2 := apply_registered_tensor_factory(x)) is None:
        # Registered tensor factory
        return x2
    elif callable(x):
        # Simple callable tensor factory
        params = []
        try:
            for name, param in inspect.signature(x).parameters.items():
                if param.kind == inspect.Parameter.VAR_KEYWORD:
                    name = f"**{name}"
                elif param.kind == inspect.Parameter.VAR_POSITIONAL:
                    name = f"*{name}"
                params.append(name)
        except:
            pass
        return x, TensorFactory(params)
    else:
        # Tensor
        return x, Tensor(tuple(int(i) for i in x.shape), type(x))


def key_to_tracer(x, backend, virtual_arg):
    args = []

    def map(x):
        if isinstance(x, CacheKey):
            arg, x = x.to_tracer(backend, virtual_arg)
            if not arg is None:
                args.append(arg)
            return x
        else:
            return x

    x = einx.tree_util.tree_map(map, x)

    return args, x