File: linear_train.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (90 lines) | stat: -rw-r--r-- 2,147 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functorch import make_functional
from functorch.compile import nnc_jit
import torch
import torch.nn as nn
import time
torch._C._jit_override_can_fuse_on_cpu(True)


def bench(f, iters=100, warmup=10):
    for _ in range(warmup):
        f()
    begin = time.time()
    for _ in range(iters):
        f()
    print((time.time() - begin))


class Foo(nn.Module):
    def __init__(self, num_layers=3, features=100):
        super().__init__()
        mods = []
        for _ in range(num_layers):
            mods.append(nn.Linear(features, features, bias=False))
        self.mod = nn.Sequential(*mods)

    def forward(self, x):
        return (self.mod(x)**2).sum()


batch_size = 16
features = 64
num_layers = 8
inp = torch.randn((batch_size, features))

mod = Foo(num_layers, features)

jit_mod = torch.jit.script(mod)

func_model, weights = make_functional(mod)
lr = 1.0


def functional_step(x, weights):
    weights = [weight.detach().requires_grad_() for weight in weights]
    out = func_model(weights, x)
    out.backward()
    new_weights = [weight - lr * weight.grad for weight in weights]
    return out, new_weights


optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0)


def jit_step(x, weights):
    optim.zero_grad()
    loss = jit_mod(x)
    loss.backward()
    optim.step()
    return loss, None


def train(train_step, weights):
    torch.manual_seed(16)
    train_step(inp, weights)
    begin = time.time()
    for itr in range(1000):
        loss, weights = train_step(torch.randn(batch_size, features), weights)
        if itr % 200 == 0:
            print(f"Loss at {itr}: {loss}")
    print("Time taken: ", time.time() - begin)
    print()


grad_pt = functional_step
grad_nnc = nnc_jit(functional_step)

print("Starting PT training")
train(grad_pt, weights)

print("Starting NNC training")
train(grad_nnc, weights)

print("Starting JIT training")
train(jit_step, None)