1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
|
---
jupytext:
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Differentiation using JAX
=========================
JAX, amongst other things, is a powerful tool for computing derivatives of native Python and NumPy code. Awkward Array implements support for the {func}`jax.jvp` and {func}`jax.vjp` JAX functions for computing forward/reverse-mode Jacobian-vector/vector-Jacobian products of functions that operate upon Awkard Arrays. Only a subset of Awkward Array operations can be differentiated through, including:
- ufunc operations like `x + y`
- reducers like {func}`ak.sum`
- slices like `x[1:]`
+++
How to differentiate Awkward Arrays?
------------------------------------
For this notebook (which is evaluated on a CPU), we need to configure JAX to use only the CPU.
```{code-cell}
import jax
jax.config.update("jax_platform_name", "cpu")
```
Next, we must call {func}`ak.jax.register_and_check()` to register Awkward's JAX integration.
```{code-cell}
import awkward as ak
ak.jax.register_and_check()
```
Let's define a simple function that accepts an Awkward Array.
```{code-cell}
def reverse_sum(array):
return ak.sum(array[::-1], axis=0)
```
We can then create an array with which to evaluate `reverse_sum`. The `backend` argument ensures that we build an Awkward Array that is backed by {class}`jaxlib.xla_extension.DeviceArray` buffers, which power JAX's automatic differentiation and JIT compiling features.
```{code-cell}
array = ak.Array([[1.0, 2.0, 3.0], [], [4.0, 5.0]], backend="jax")
```
```{code-cell}
reverse_sum(array)
```
To compute the JVP of `reverse_sum` requires a _tangent_ vector, which can also be defined as an Awkward Array:
```{code-cell}
tangent = ak.Array([[0.0, 0.0, 0.0], [], [0.0, 1.0]], backend="jax")
```
```{code-cell}
value_jvp, jvp_grad = jax.jvp(reverse_sum, (array,), (tangent,))
```
{func}`jax.jvp` returns both the value of `reverse_sum` evaluated at `array`:
```{code-cell}
value_jvp
```
```{code-cell}
assert value_jvp.to_list() == reverse_sum(array).to_list()
```
and the JVP evaluted at `array` for the given `tangent`:
```{code-cell}
jvp_grad
```
JAX's own documentation encourages the user to use {mod}`jax.numpy` instead of the canonical {mod}`numpy` module when operating upon JAX arrays. However, {mod}`jax.numpy` does not understand Awkward Arrays, so for {class}`ak.Array`s you should use the normal {mod}`ak` and {mod}`numpy` functions instead.
|