File: initializers.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (153 lines) | stat: -rw-r--r-- 5,378 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153





from caffe2.python.core import DataType, BlobReference, ScopedBlobReference
from caffe2.python.modeling.parameter_info import ParameterInfo


class Initializer(object):
    '''
    This class abstracts out parameter creation. One can come up with a new
    Initializer in order to implement more complex parameter initialization logic
    '''

    def __init__(self, operator_name=None, **kwargs):
        self.operator_name = operator_name
        self.operator_kwargs = kwargs

    def update(self, operator_name, kwargs):
        if self.operator_name is not None:
            raise Exception("Operator name overwrites are not allowed")
        self.operator_name = operator_name
        self.operator_kwargs = kwargs

    def create_param(self, param_name, init_net, shape):
        param = init_net.__getattr__(self.operator_name)(
            [], param_name, shape=shape, **self.operator_kwargs)
        return ParameterInfo(
            param_id=None,
            param=param,
            shape=shape,
        )


class ExternalInitializer(object):
    '''
    This class is used in cases when the parameter should not be initialized by
    the initializer, but rather provided in the workspace when param_init_net is
    executed.

    Current version is not doing any real sanity checks to the parameter.
    '''

    def create_param(self, param_name, init_net, shape):
        if isinstance(param_name, BlobReference):
            param = BlobReference(str(param_name), init_net)
        elif isinstance(param_name, str):
            param = ScopedBlobReference(param_name, init_net)
        else:
            raise TypeError("Unsupported type for param_name")
        # TODO(amalevich): Add operator that will check param in the workspace
        return ParameterInfo(
            param_id=None,
            param=param,
            shape=shape,
        )


class PseudoFP16Initializer(Initializer):
    '''
    Used in cases when the parameter should be used at half (16-bit) precision
    for compute purposes (i.e. on the forward and backward pass) but
    needs to be stored and optimized at single (32-bit) precision so tiny
    gradients with small learning rates don't underflow FP16 precision.
    A 32-bit copy of the 16-bit blob is stored in the ParameterInfo.
    This is helpful for mixed-precision training, see
    https://arxiv.org/abs/1710.03740 for details.
    '''
    def update(self, operator_name, kwargs):
        if self.operator_name is not None:
            raise Exception("Operator name overwrites are not allowed")
        self.operator_name = operator_name
        self.operator_kwargs = kwargs

    def create_param(self, param_name, init_net, shape):
        # create master fp32 copy
        param_fp32 = init_net.__getattr__(self.operator_name)(
            [], param_name + "_fp32", shape=shape,
            **self.operator_kwargs)
        # cast to fp16 copy
        param = init_net.FloatToHalf(
            param_fp32, param_name)

        return ParameterInfo(
            param_id=None,
            param=param,
            shape=shape,
            blob_copy={DataType.FLOAT: param_fp32}
        )


class ReversePseudoFP16Initializer(Initializer):
    '''
    Like PseudoFP16Initializer above, except the primary blob is taken to
    be the 32-bit precision parameter, and the 16-bit version of the blob
    is stored in blob_copy instead.
    '''
    def update(self, operator_name, kwargs):
        if self.operator_name is not None:
            raise Exception("Operator name overwrites are not allowed")
        self.operator_name = operator_name
        self.operator_kwargs = kwargs

    def create_param(self, param_name, init_net, shape):
        # create master fp32 copy
        param_fp32 = init_net.__getattr__(self.operator_name)(
            [], param_name, shape=shape,
            **self.operator_kwargs)
        # cast to fp16 copy
        param_fp16 = init_net.FloatToHalf(
            param_fp32, param_name + "_fp16")

        return ParameterInfo(
            param_id=None,
            param=param_fp32,
            shape=shape,
            blob_copy={DataType.FLOAT16: param_fp16}
        )

def update_initializer(initializer_class,
                       operator_name_and_kwargs,
                       default_operator_name_and_kwargs):
    '''
    A helper function to convert from operator_name_and_kwargs to new
    object of type initializer_class. This function serves two purposes:

    1. Support for custom initialization operators being passed in
    2. Allow user to specify a custom Initializer without overwriting
       default operators used for initialization

    If initializer_class is None, creates a default initializer using
    the Initializer class and operator_name_and_kwargs provided

    If operator_name_and_kwargs is None, uses default_operator_name_and_kwargs

    returns an instantiated Initializer object
    '''
    def get_initializer_args():
        return (
            operator_name_and_kwargs or
            default_operator_name_and_kwargs
        )

    if initializer_class is not None:
        init = initializer_class(get_initializer_args()[0],
                                 **get_initializer_args()[1])
    else:
        init = Initializer(
            get_initializer_args()[0],
            **get_initializer_args()[1]
        )
    return init