File: train_torch.py

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (109 lines) | stat: -rw-r--r-- 2,763 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import ssl

ssl._create_default_https_context = (
    ssl._create_unverified_context
)  # Fixed problem with downloading CIFAR10 dataset

import torch
import einx
import os
import torchvision
import time
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import einx.nn.torch as einn

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

batch_size = 256

cifar10_path = os.path.join(os.path.dirname(__file__), "cifar10")
trainset = torchvision.datasets.CIFAR10(
    root=cifar10_path, train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root=cifar10_path, train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2
)


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        blocks = []
        for c in [1024, 512, 256]:
            blocks.append(einn.Linear("b [...->c]", c=c))
            blocks.append(einn.Norm("[b] c", decay_rate=0.99))
            blocks.append(nn.GELU())
            blocks.append(einn.Dropout("[...]", drop_rate=0.2))
        blocks.append(einn.Linear("b [...->c]", c=10))
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)


net = Net()

# Call on dummy batch to initialize parameters (before torch.compile!)
inputs, _ = next(iter(trainloader))
net(inputs)

net = net.cuda()
net = torch.compile(net)

optimizer = optim.Adam(net.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()


@torch.compile
def test_step(inputs, labels):
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)
    return predicted == labels


print("Starting training")
for epoch in range(100):
    t0 = time.time()

    # Train
    net.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Test
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()

            accurate = test_step(inputs, labels)
            total += accurate.size(0)
            correct += int(torch.count_nonzero(accurate))

    print(
        f"Test accuracy after {epoch + 1:5d} epochs {float(correct) / total} "
        f"({time.time() - t0:.2f}sec)"
    )