1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
import ssl
ssl._create_default_https_context = (
ssl._create_unverified_context
) # Fixed problem with downloading CIFAR10 dataset
import torch
import einx
import os
import torchvision
import time
import jax
import optax
import torchvision.transforms as transforms
import einx.nn.equinox as einn
import equinox as eqx
from functools import partial
import jax.numpy as jnp
from typing import List
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
batch_size = 256
rng = jax.random.PRNGKey(42)
def next_rng():
global rng
rng, x = jax.random.split(rng)
return x
cifar10_path = os.path.join(os.path.dirname(__file__), "cifar10")
trainset = torchvision.datasets.CIFAR10(
root=cifar10_path, train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root=cifar10_path, train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
class Block(eqx.Module):
linear: einn.Linear
norm: einn.Norm
dropout: einn.Dropout
def __init__(self, c):
self.linear = einn.Linear("b [...->c]", c=c)
self.norm = einn.Norm("b [c]")
self.dropout = einn.Dropout("[...]", drop_rate=0.2)
def __call__(self, x, rng):
x = self.linear(x, rng=rng)
x = self.norm(x, rng=rng)
x = jax.nn.gelu(x)
x = self.dropout(x, rng=rng)
return x
class Net(eqx.Module):
blocks: List[Block]
classifier: einn.Linear
def __init__(self):
self.blocks = [Block(c) for c in [1024, 512, 256]]
self.classifier = einn.Linear("b [...->c]", c=10)
def __call__(self, x, rng):
for block in self.blocks:
x = block(x, rng=rng)
return self.classifier(x, rng=rng)
train_net = Net()
inputs, _ = next(iter(trainloader))
train_net(jnp.asarray(inputs), rng=next_rng()) # Run on dummy batch
optimizer = optax.adam(3e-4)
opt_state = optimizer.init(eqx.filter(train_net, eqx.is_array))
@partial(eqx.filter_jit, donate="all")
def update_step(opt_state, net, images, labels, rng):
def loss_fn(net):
logits = net(images, rng=rng)
one_hot = jax.nn.one_hot(labels, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss
_loss, grads = eqx.filter_value_and_grad(loss_fn)(net)
updates, new_opt_state = optimizer.update(grads, opt_state, net)
new_net = eqx.apply_updates(net, updates)
return new_opt_state, new_net
@partial(eqx.filter_jit, donate="all")
def test_step(net, images, labels):
logits = net(images, rng=rng)
accurate = jnp.argmax(logits, axis=1) == jnp.asarray(labels)
return accurate
print("Starting training")
for epoch in range(100):
t0 = time.time()
# Train
for _i, data in enumerate(trainloader):
inputs, labels = data
opt_state, train_net = update_step(
opt_state, train_net, jnp.asarray(inputs), jnp.asarray(labels), next_rng()
)
# Test
correct = 0
total = 0
infer_net = eqx.nn.inference_mode(train_net)
for data in testloader:
images, labels = data
accurate = test_step(infer_net, jnp.asarray(images), jnp.asarray(labels))
total += accurate.shape[0]
correct += jnp.sum(accurate)
print(
f"Test accuracy after {epoch + 1:5d} epochs: {float(correct) / total} "
f"({time.time() - t0:.2f}sec)"
)
|