File: test_common.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (40 lines) | stat: -rw-r--r-- 1,192 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn

from torch.distributed._shard.sharded_tensor import ShardedTensor


class SimpleMegatronLM(nn.Module):
    def __init__(self, linear_size, rank=None, dtype=torch.float32):
        super().__init__()
        self.fc1 = nn.Linear(*linear_size[0], dtype=dtype)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(*linear_size[1], dtype=dtype)
        if rank is not None:
            self.fc1.cuda(rank)
            self.fc2.cuda(rank)

    def forward(self, inp):
        return self.fc2(self.gelu(self.fc1(inp)))

    def get_weights(self):
        if isinstance(self.fc1.weight, ShardedTensor):
            weight1 = self.fc1.weight.local_tensor()
        else:
            weight1 = self.fc1.weight

        if isinstance(self.fc2.weight, ShardedTensor):
            weight2 = self.fc2.weight.local_tensor()
        else:
            weight2 = self.fc2.weight

        return (weight1, weight2)

    def get_biases(self):
        return (self.fc1.bias, self.fc2.bias)

    def get_weight_grads(self):
        return (self.fc1.weight.grad, self.fc2.weight.grad)

    def get_bias_grads(self):
        return (self.fc1.bias.grad, self.fc2.bias.grad)