File: layer_normalization.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (121 lines) | stat: -rw-r--r-- 4,291 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121





from caffe2.python import schema
from caffe2.python.layers.layers import ModelLayer

import numpy as np


class LayerNormalization(ModelLayer):
    def __init__(
        self,
        model,
        input_record,
        name='layer_normalization',
        scale_optim=None,
        bias_optim=None,
        epsilon=1e-4,
        axis=1,
        use_layer_norm_op=True,
        scale_init_value=1.0,
        **kwargs
    ):
        super(LayerNormalization, self).__init__(
            model, name, input_record, **kwargs)

        assert isinstance(input_record, schema.Scalar), (
            "Incorrect input type: {}".format(input_record))

        self.input_shape = input_record.field_type().shape
        self.axis = axis

        assert len(self.input_shape) >= 1, (
            "This layer supports only >= 2D tensors")
        input_dims = self.input_shape[0]

        self.output_schema = schema.Scalar(
            (np.float32, self.input_shape),
            self.get_next_blob_reference('output')
        )

        self.scale = self.create_param(param_name='scale',
                                       shape=[input_dims],
                                       initializer=('ConstantFill', {'value': scale_init_value}),
                                       optimizer=scale_optim)
        self.bias = self.create_param(param_name='bias',
                                       shape=[input_dims],
                                       initializer=('ConstantFill', {'value': 0.0}),
                                       optimizer=bias_optim)
        self.use_layer_norm_op = use_layer_norm_op

        if self.use_layer_norm_op:
            self.epsilon = epsilon
        else:
            assert len(self.input_shape) == 1, (
                "When using alternative implementation, "
                "input data can only be 2D"
            )
            self.epsilon = model.maybe_add_global_constant(
                "%s_epsilon" % self.name, float(epsilon)
            )

    def add_ops_with_layer_norm_op(self, net):
        input_blob = self.input_record.field_blobs()
        ln_output = self.output_schema.field_blobs()

        output_blobs = [net.NextScopedBlob('ln_output'), net.NextScopedBlob('ln_mean'),
                        net.NextScopedBlob('ln_stdev')]

        normalized, mean, stdev = net.LayerNorm(input_blob,
            output_blobs,
            axis=self.axis,
            epsilon=self.epsilon)

        scaled = net.Mul(
            [normalized, self.scale],
            [net.NextScopedBlob('ln_scaled')],
            broadcast=1,
            axis=self.axis,
        )

        net.Add(
            [scaled, self.bias],
            ln_output,
            broadcast=1,
            axis=self.axis,
        )

    def add_ops_without_layer_norm_op(self, net):
        # two issues here:
        #  1. use multiple ops to replace the function of LayerNorm
        #  2. do not use legacy broadcast
        ln_output = net.NextScopedBlob("ln_output")
        ln_mean = net.NextScopedBlob("ln_mean")
        ln_stdev = net.NextScopedBlob("ln_stdev")
        ln_mean_arr = net.NextScopedBlob("ln_mean_arr")
        net.ReduceBackMean(self.input_record.field_blobs(), [ln_mean_arr])
        net.ExpandDims([ln_mean_arr], [ln_mean], dims=[1])
        ln_centered = net.NextScopedBlob("ln_centered")
        net.Sub(self.input_record.field_blobs() + [ln_mean], [ln_centered])
        ln_sqr = net.NextScopedBlob("ln_sqr")
        net.Sqr([ln_centered], [ln_sqr])
        ln_sqr_mean = net.NextScopedBlob("ln_sqr_mean")
        net.ReduceBackMean([ln_sqr], [ln_sqr_mean])
        ln_var = net.NextScopedBlob("ln_var")
        net.Add([ln_sqr_mean, self.epsilon], ln_var)
        ln_std_arr = net.NextScopedBlob("ln_std_arr")
        net.Pow([ln_var], [ln_std_arr], exponent=0.5)
        net.ExpandDims([ln_std_arr], [ln_stdev], dims=[1])
        net.Div([ln_centered, ln_stdev], [ln_output])
        ln_scaled = net.NextScopedBlob("ln_scaled")
        net.Mul([ln_output, self.scale], [ln_scaled])
        net.Add([ln_scaled, self.bias], self.output_schema.field_blobs())

    def add_ops(self, net):
        if self.use_layer_norm_op:
            self.add_ops_with_layer_norm_op(net)
        else:
            self.add_ops_without_layer_norm_op(net)