File: DifferentiableArray-for-JAX.py

package info (click to toggle)
python-awkward 2.6.5-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 23,088 kB
  • sloc: python: 148,689; cpp: 33,562; sh: 432; makefile: 21; javascript: 8
file content (193 lines) | stat: -rw-r--r-- 6,803 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import json

import jax
import jax.tree_util

import awkward as ak
import numpy as np

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

class AuxData(object):
    def __init__(self, form, length, indexes, datakeys):
        self.form = form
        self.length = length
        self.indexes = indexes
        self.datakeys = datakeys

    def __eq__(self, other):
        # AuxData is an object so that JAX can naively call __eq__ on it
        return (
            self.form == other.form
            and self.length == other.length
            and self.indexes.keys() == other.indexes.keys()
            and all(
                # normally, array equality would be a problem for __eq__ (in an if-statement)
                np.array_equal(self.indexes[k], other.indexes[k])
                for k in self.indexes.keys()
            )
            and self.datakeys == other.datakeys
        )

class DifferentiableArray(ak.Array):
    def __init__(self, aux_data, tracers):
        self.aux_data = aux_data
        self.tracers = tracers

    @property
    def layout(self):
        buffers = dict(self.aux_data.indexes)
        for key, tracer in zip(self.aux_data.datakeys, self.tracers):
            if hasattr(tracer, "primal"):
                buffers[key] = tracer.primal
        return ak.from_buffers(
            self.aux_data.form, self.aux_data.length, buffers, highlevel=False
        )

    @layout.setter
    def layout(self, layout):
        raise ValueError(
            "this operation cannot be performed in a JAX-compiled or JAX-differentiated function"
        )

    def __getitem__(self, where):
        out = self.layout[where]
        if isinstance(out, ak.layout.Content):
            form, length, indexes = ak.to_buffers(
                out, form_key="getitem_node{id}", virtual="pass"
            )
            aux_data = AuxData(form, length, indexes, self.aux_data.datakeys)
            return DifferentiableArray(aux_data, self.tracers)
        else:
            return out

    def __setitem__(self, where, what):
        raise ValueError(
            "this operation cannot be performed in a JAX-compiled or JAX-differentiated function"
        )

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        # optional sanity-check (i.e. sanity is optional)
        for x in inputs:
            if isinstance(x, DifferentiableArray):
                assert x.aux_data == self.aux_data
                assert len(x.tracers) == len(self.tracers)

        # ak.Array __add__, etc. map to the NumPy functions, switch to JAX
        for name, np_ufunc in np.core.umath.__dict__.items():
            if ufunc is np_ufunc:
                ufunc = getattr(jax.numpy, name)

        # need to apply the ufunc to the same argument list for each tracer separately
        nexttracers = []
        for i in range(len(self.tracers)):
            nextinputs = [
                x.tracers[i] if isinstance(x, DifferentiableArray) else x
                for x in inputs
            ]
            nexttracers.append(getattr(ufunc, method)(*nextinputs, **kwargs))

        # and return a new DifferentiableArray (keep it wrapped!)
        return DifferentiableArray(self.aux_data, nexttracers)

def find_datanode(formjson, form_key):
    if isinstance(formjson, dict):
        if formjson.get("form_key") == form_key:
            return formjson
        for k, v in formjson.items():
            out = find_datanode(v, form_key)
            if out is not None:
                if out == formjson[k]:
                    formjson[k] = {
                        "class": "VirtualArray",
                        "form": out,
                        "has_length": True,
                        "has_identities": False,
                        "parameters": {},
                        "form_key": None,
                    }
                return out
        else:
            return None
    elif isinstance(formjson, list):
        for i, v in enumerate(formjson):
            out = find_datanode(v, form_key)
            if out is not None:
                if out == formjson[i]:
                    formjson[i] = {
                        "class": "VirtualArray",
                        "form": out,
                        "has_length": True,
                        "has_identities": False,
                        "parameters": {},
                        "form_key": None,
                    }
                return out
        else:
            return None
    else:
        return None

def special_flatten(array):
    if isinstance(array, DifferentiableArray):
        aux_data, children = array.aux_data, array.tracers
    else:
        form, length, buffers = ak.to_buffers(array)
        formjson = json.loads(form.tojson())
        indexes = {k: v for k, v in buffers.items() if not k.endswith("-data")}
        datakeys = []
        for key in buffers:
            partition, form_key, role = key.split("-")
            if role == "data":
                nodejson = find_datanode(formjson, form_key)
                assert nodejson is not None
                node = ak.forms.Form.fromjson(json.dumps(nodejson))
                datakeys.append(key)
        nextform = ak.forms.Form.fromjson(json.dumps(formjson))
        aux_data = AuxData(nextform, length, indexes, datakeys)
        children = [jax.numpy.asarray(buffers[x], buffers[x].dtype) for x in datakeys]
    return children, aux_data

def special_unflatten(aux_data, children):
    if any(isinstance(x, jax.core.Tracer) for x in children):
        return DifferentiableArray(aux_data, children)
    else:
        buffers = dict(aux_data.indexes)
        buffers.update(zip(aux_data.datakeys, children))
        return ak.from_buffers(aux_data.form, aux_data.length, buffers)

jax.tree_util.register_pytree_node(ak.Array, special_flatten, special_unflatten)
jax.tree_util.register_pytree_node(DifferentiableArray, special_flatten, special_unflatten)

###############################################################################
#  TESTING
###############################################################################

def func(array):
    return 2*array.y[0] + 10

primal = ak.Array([
    [{"x": 1.1, "y": [1.0]}, {"x": 2.2, "y": [1.0, 2.2]}],
    [],
    [{"x": 3.3, "y": [1.0, 2.0, 3.0]}]
])

tangent = ak.Array([
    [{"x": 0.0, "y": [1.0]}, {"x": 2.0, "y": [1.5, 0.0]}],
    [],
    [{"x": 1.5, "y": [2.0, 0.5, 1.0]}]
])

primal_result, tangent_result = jax.jvp(func, (primal,), (tangent,))
print("resulting types", type(primal_result), type(tangent_result))
print(primal_result)
print(tangent_result)

jit_result = jax.jit(func)(primal)
print("resulting type", type(jit_result))
print(jit_result)

val, func = jax.vjp(func, primal)

print(func(val))