File: parameter_sharing.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (118 lines) | stat: -rw-r--r-- 3,848 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118





from caffe2.python import scope

import contextlib
import logging

logger = logging.getLogger(__name__)


class ParameterSharingContext(object):
    """
    This class manages scope driven way of parameter sharing across different
    NameScopes.
    """

    def __init__(self):
        self._scope_overrides = {}
        self._contexts = []

    def _resolve_scope_overrides(self, candidate_scope):
        """
        Recursively resolves all scope overrides, i.e multiple steps of
        override can be used.

        For example, if one provides following scope overrides:
        {'scope_b': 'scope_a'} and within 'scope_b' - {'shared_child': ''},
        then name 'w' will get resolved to the following blobs depending on the
        namescope:
          a. 'scope_a' -> 'scope_a/w'
          b. 'scope_b' -> 'scope_a/w'
          c. 'scope_c' -> 'scope_c/w'
          d. 'scope_b/shared_child' -> 'scope_a/w'
          d. 'scope_b/unshared_child' -> 'scope_a/unshared_child/w'
        """
        best_scope = candidate_scope
        best_scope_idx = 0
        sub_scopes = candidate_scope.split(scope._NAMESCOPE_SEPARATOR)

        cur_scope = ''
        for idx, sub_scope in enumerate(sub_scopes):
            cur_scope = cur_scope + sub_scope + scope._NAMESCOPE_SEPARATOR
            if cur_scope in self._scope_overrides:
                best_scope = self._scope_overrides[cur_scope]
                best_scope_idx = idx
        if best_scope == candidate_scope:
            return candidate_scope
        else:
            return (self._resolve_scope_overrides(best_scope) +
                    scope._NAMESCOPE_SEPARATOR.join(
                        sub_scopes[best_scope_idx + 1:]))

    def get_parameter_name(self, name):
        candidate_scope = scope.CurrentNameScope()
        best_scope = self._resolve_scope_overrides(candidate_scope)
        if best_scope != candidate_scope:
            logger.info("Overwriting scope {0} with scope {1}".format(
                candidate_scope, best_scope))

        return best_scope + name

    def add_scope_overrides(self, shared_scopes):
        self._contexts.append(shared_scopes)
        self._scope_overrides.update(shared_scopes)

    def pop(self):
        assert len(self._contexts) > 0
        self._contexts.pop()
        self._scope_overrides = {}
        for x in self._contexts:
            self._scope_overrides.update(x)


parameter_sharing_context = ParameterSharingContext()


def _normalize_namescope(namescope):
    if namescope and namescope[-1] != scope._NAMESCOPE_SEPARATOR:
        return namescope + scope._NAMESCOPE_SEPARATOR
    else:
        return namescope


@contextlib.contextmanager
def ParameterSharing(shared_scopes):
    """
    Helper function for sharing scopes.
    All the parameters within the shared_scopes, will be remapped with the
    respect of CurrentNamescope()

    I.e. if one calls ParameterSharing with {'scope_b': 'scope_'a'}, from the
    scope 'some_global_scope', it'll effectively mean, that all parameters from
    'some_global_scope/scope_b' will shared with the parameters from
    'some_global_scope/scope_a'
    """
    assert isinstance(shared_scopes, dict)

    shared_scope_overrides = {}
    current_scope = scope.CurrentNameScope()
    for k, v in shared_scopes.items():
        assert not v.startswith(k), (
            "Illegal override for parameter sharing. {} is prefix of {}".
            format(k, v))
        k = current_scope + k
        v = current_scope + v
        # Normalize all the scopes, so scope_a and scope_a/ are equivalent
        k = _normalize_namescope(k)
        v = _normalize_namescope(v)
        shared_scope_overrides[k] = v

    try:
        parameter_sharing_context.add_scope_overrides(shared_scope_overrides)
        yield
    finally:
        parameter_sharing_context.pop()