1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
|
import ssl
ssl._create_default_https_context = (
ssl._create_unverified_context
) # Fixed problem with downloading CIFAR10 dataset
import torch
import einx
import os
import torchvision
import time
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import einx.nn.torch as einn
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
batch_size = 256
cifar10_path = os.path.join(os.path.dirname(__file__), "cifar10")
trainset = torchvision.datasets.CIFAR10(
root=cifar10_path, train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root=cifar10_path, train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
class Net(nn.Module):
def __init__(self):
super().__init__()
blocks = []
for c in [1024, 512, 256]:
blocks.append(einn.Linear("b [...->c]", c=c))
blocks.append(einn.Norm("[b] c", decay_rate=0.99))
blocks.append(nn.GELU())
blocks.append(einn.Dropout("[...]", drop_rate=0.2))
blocks.append(einn.Linear("b [...->c]", c=10))
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
return self.blocks(x)
net = Net()
# Call on dummy batch to initialize parameters (before torch.compile!)
inputs, _ = next(iter(trainloader))
net(inputs)
net = net.cuda()
net = torch.compile(net)
optimizer = optim.Adam(net.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
@torch.compile
def test_step(inputs, labels):
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
return predicted == labels
print("Starting training")
for epoch in range(100):
t0 = time.time()
# Train
net.train()
for data in trainloader:
inputs, labels = data
inputs, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Test
net.eval()
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
inputs, labels = data
inputs, labels = inputs.cuda(), labels.cuda()
accurate = test_step(inputs, labels)
total += accurate.size(0)
correct += int(torch.count_nonzero(accurate))
print(
f"Test accuracy after {epoch + 1:5d} epochs {float(correct) / total} "
f"({time.time() - t0:.2f}sec)"
)
|