File: rnn.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (58 lines) | stat: -rw-r--r-- 1,952 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch.cuda

try:
    from torch._C import _cudnn
except ImportError:
    # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
    # so it's safe to not emit any checks here.
    _cudnn = None  # type: ignore[assignment]


def get_cudnn_mode(mode):
    if mode == 'RNN_RELU':
        return int(_cudnn.RNNMode.rnn_relu)
    elif mode == 'RNN_TANH':
        return int(_cudnn.RNNMode.rnn_tanh)
    elif mode == 'LSTM':
        return int(_cudnn.RNNMode.lstm)
    elif mode == 'GRU':
        return int(_cudnn.RNNMode.gru)
    else:
        raise Exception("Unknown mode: {}".format(mode))


# NB: We don't actually need this class anymore (in fact, we could serialize the
# dropout state for even better reproducibility), but it is kept for backwards
# compatibility for old models.
class Unserializable(object):

    def __init__(self, inner):
        self.inner = inner

    def get(self):
        return self.inner

    def __getstate__(self):
        # Note: can't return {}, because python2 won't call __setstate__
        # if the value evaluates to False
        return "<unserializable>"

    def __setstate__(self, state):
        self.inner = None


def init_dropout_state(dropout, train, dropout_seed, dropout_state):
    dropout_desc_name = 'desc_' + str(torch.cuda.current_device())
    dropout_p = dropout if train else 0
    if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None):
        if dropout_p == 0:
            dropout_state[dropout_desc_name] = Unserializable(None)
        else:
            dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state(  # type: ignore[call-arg]
                dropout_p,
                train,
                dropout_seed,
                self_ty=torch.uint8,
                device=torch.device('cuda')))
    dropout_ts = dropout_state[dropout_desc_name].get()
    return dropout_ts