1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
|
import json
import jax
import jax.tree_util
import awkward as ak
import numpy as np
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
class AuxData(object):
def __init__(self, form, length, indexes, datakeys):
self.form = form
self.length = length
self.indexes = indexes
self.datakeys = datakeys
def __eq__(self, other):
# AuxData is an object so that JAX can naively call __eq__ on it
return (
self.form == other.form
and self.length == other.length
and self.indexes.keys() == other.indexes.keys()
and all(
# normally, array equality would be a problem for __eq__ (in an if-statement)
np.array_equal(self.indexes[k], other.indexes[k])
for k in self.indexes.keys()
)
and self.datakeys == other.datakeys
)
class DifferentiableArray(ak.Array):
def __init__(self, aux_data, tracers):
self.aux_data = aux_data
self.tracers = tracers
@property
def layout(self):
buffers = dict(self.aux_data.indexes)
for key, tracer in zip(self.aux_data.datakeys, self.tracers):
if hasattr(tracer, "primal"):
buffers[key] = tracer.primal
return ak.from_buffers(
self.aux_data.form, self.aux_data.length, buffers, highlevel=False
)
@layout.setter
def layout(self, layout):
raise ValueError(
"this operation cannot be performed in a JAX-compiled or JAX-differentiated function"
)
def __getitem__(self, where):
out = self.layout[where]
if isinstance(out, ak.layout.Content):
form, length, indexes = ak.to_buffers(
out, form_key="getitem_node{id}", virtual="pass"
)
aux_data = AuxData(form, length, indexes, self.aux_data.datakeys)
return DifferentiableArray(aux_data, self.tracers)
else:
return out
def __setitem__(self, where, what):
raise ValueError(
"this operation cannot be performed in a JAX-compiled or JAX-differentiated function"
)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# optional sanity-check (i.e. sanity is optional)
for x in inputs:
if isinstance(x, DifferentiableArray):
assert x.aux_data == self.aux_data
assert len(x.tracers) == len(self.tracers)
# ak.Array __add__, etc. map to the NumPy functions, switch to JAX
for name, np_ufunc in np.core.umath.__dict__.items():
if ufunc is np_ufunc:
ufunc = getattr(jax.numpy, name)
# need to apply the ufunc to the same argument list for each tracer separately
nexttracers = []
for i in range(len(self.tracers)):
nextinputs = [
x.tracers[i] if isinstance(x, DifferentiableArray) else x
for x in inputs
]
nexttracers.append(getattr(ufunc, method)(*nextinputs, **kwargs))
# and return a new DifferentiableArray (keep it wrapped!)
return DifferentiableArray(self.aux_data, nexttracers)
def find_datanode(formjson, form_key):
if isinstance(formjson, dict):
if formjson.get("form_key") == form_key:
return formjson
for k, v in formjson.items():
out = find_datanode(v, form_key)
if out is not None:
if out == formjson[k]:
formjson[k] = {
"class": "VirtualArray",
"form": out,
"has_length": True,
"has_identities": False,
"parameters": {},
"form_key": None,
}
return out
else:
return None
elif isinstance(formjson, list):
for i, v in enumerate(formjson):
out = find_datanode(v, form_key)
if out is not None:
if out == formjson[i]:
formjson[i] = {
"class": "VirtualArray",
"form": out,
"has_length": True,
"has_identities": False,
"parameters": {},
"form_key": None,
}
return out
else:
return None
else:
return None
def special_flatten(array):
if isinstance(array, DifferentiableArray):
aux_data, children = array.aux_data, array.tracers
else:
form, length, buffers = ak.to_buffers(array)
formjson = json.loads(form.tojson())
indexes = {k: v for k, v in buffers.items() if not k.endswith("-data")}
datakeys = []
for key in buffers:
partition, form_key, role = key.split("-")
if role == "data":
nodejson = find_datanode(formjson, form_key)
assert nodejson is not None
node = ak.forms.Form.fromjson(json.dumps(nodejson))
datakeys.append(key)
nextform = ak.forms.Form.fromjson(json.dumps(formjson))
aux_data = AuxData(nextform, length, indexes, datakeys)
children = [jax.numpy.asarray(buffers[x], buffers[x].dtype) for x in datakeys]
return children, aux_data
def special_unflatten(aux_data, children):
if any(isinstance(x, jax.core.Tracer) for x in children):
return DifferentiableArray(aux_data, children)
else:
buffers = dict(aux_data.indexes)
buffers.update(zip(aux_data.datakeys, children))
return ak.from_buffers(aux_data.form, aux_data.length, buffers)
jax.tree_util.register_pytree_node(ak.Array, special_flatten, special_unflatten)
jax.tree_util.register_pytree_node(DifferentiableArray, special_flatten, special_unflatten)
###############################################################################
# TESTING
###############################################################################
def func(array):
return 2*array.y[0] + 10
primal = ak.Array([
[{"x": 1.1, "y": [1.0]}, {"x": 2.2, "y": [1.0, 2.2]}],
[],
[{"x": 3.3, "y": [1.0, 2.0, 3.0]}]
])
tangent = ak.Array([
[{"x": 0.0, "y": [1.0]}, {"x": 2.0, "y": [1.5, 0.0]}],
[],
[{"x": 1.5, "y": [2.0, 0.5, 1.0]}]
])
primal_result, tangent_result = jax.jvp(func, (primal,), (tangent,))
print("resulting types", type(primal_result), type(tangent_result))
print(primal_result)
print(tangent_result)
jit_result = jax.jit(func)(primal)
print("resulting type", type(jit_result))
print(jit_result)
val, func = jax.vjp(func, primal)
print(func(val))
|