1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
from caffe2.python import brew, model_helper, scope
from caffe2.python.modeling.parameter_sharing import (
ParameterSharing,
parameter_sharing_context,
)
from caffe2.python.modeling.initializers import (
Initializer
)
import unittest
class ParameterSharingTest(unittest.TestCase):
def test_parameter_sharing_default_scopes(self):
# Test no sharing default scopes
param_1 = parameter_sharing_context.get_parameter_name('w')
self.assertEquals(param_1, 'w')
with scope.NameScope('scope'):
param_2 = parameter_sharing_context.get_parameter_name('w')
self.assertEquals(param_2, 'scope/w')
with scope.NameScope('scope_2'):
param_3 = parameter_sharing_context.get_parameter_name('w')
self.assertEquals(param_3, 'scope/scope_2/w')
def test_parameter_sharing_nested_scopes(self):
# Test parameter sharing
with scope.NameScope('global_scope'):
with ParameterSharing({'model_b': 'model_a'}):
param_global = parameter_sharing_context.get_parameter_name('w')
self.assertEquals(param_global, 'global_scope/w')
# This scope is overridden to match 'model_a'
with scope.NameScope('model_b'):
with ParameterSharing({'shared_scope': ''}):
param_4 = parameter_sharing_context.get_parameter_name(
'w')
self.assertEquals(param_4, 'global_scope/model_a/w')
with scope.NameScope('shared_scope'):
param_5 = parameter_sharing_context.\
get_parameter_name('w')
self.assertEquals(param_5, 'global_scope/model_a/w')
# This scope is supposed to have not sharing
with scope.NameScope('model_c'):
with ParameterSharing({'shared_scope': ''}):
param_4 = parameter_sharing_context.get_parameter_name(
'w')
self.assertEquals(param_4, 'global_scope/model_c/w')
with scope.NameScope('shared_scope'):
param_5 = parameter_sharing_context.\
get_parameter_name('w')
self.assertEquals(param_5, 'global_scope/model_c/w')
def test_parameter_sharing_subscopes(self):
# Sharing only one of the subscopes
with ParameterSharing({'global_scope/b': 'global_scope/a'}):
with scope.NameScope('global_scope'):
param_6 = parameter_sharing_context.get_parameter_name('w')
self.assertEquals(param_6, 'global_scope/w')
with scope.NameScope('a'):
param_7 = parameter_sharing_context.get_parameter_name('w')
self.assertEquals(param_7, 'global_scope/a/w')
with scope.NameScope('b'):
param_8 = parameter_sharing_context.get_parameter_name('w')
self.assertEquals(param_8, 'global_scope/a/w')
with scope.NameScope('c'):
param_9 = parameter_sharing_context.get_parameter_name('w')
self.assertEquals(param_9, 'global_scope/c/w')
def test_create_param(self):
model = model_helper.ModelHelper(name="test")
# Test no sharing default scopes
p1 = model.create_param(
'w',
shape=[2],
initializer=Initializer("ConstantFill")
)
with scope.NameScope('some_global_scope'):
p2 = model.create_param(
'w',
shape=[2],
initializer=Initializer("ConstantFill")
)
self.assertNotEqual(model.get_param_info(p1), None)
self.assertNotEqual(model.get_param_info(p2), None)
self.assertNotEqual(model.get_param_info(p1), model.get_param_info(p2))
model.Validate()
def test_deep_hierarchy(self):
model = model_helper.ModelHelper(name="test")
with ParameterSharing({'a': 'b'}):
with scope.NameScope('a'):
with ParameterSharing({'c': 'd'}):
with scope.NameScope('c'):
with ParameterSharing({'e': 'f'}):
with scope.NameScope('e'):
p = model.create_param(
'w',
shape=[2],
initializer=Initializer("ConstantFill")
)
self.assertNotEqual(model.get_param_info(p), None)
def test_parameter_sharing_brew(self):
# Test no sharing default scopes
model = model_helper.ModelHelper(name="test")
data = model.net.AddExternalInput("data")
fc1 = brew.fc(model, data, "fc1", dim_in=16, dim_out=16)
# Shared params are expected to share the same shape and fail if it's
# not true
with self.assertRaises(AssertionError):
_ = brew.fc(model, data, "fc1", dim_in=2, dim_out=2) # noqa
output_blobs = set()
with scope.NameScope('some_global_scope'):
with scope.NameScope('model_a'):
output_blobs.add(str(brew.fc(model, fc1, 'output', 16, 16)))
with ParameterSharing({'model_b': 'model_a'}),\
scope.NameScope('model_b'):
with ParameterSharing({'shared_1': '', 'shared_2': ''}):
# All params in DenseLayers from shared_1, shared_2 and
# model_a are shared and will be pointing to:
# [some_global_scope/model_a/output_W,
# some_global_scope/model_a/output_b]
with scope.NameScope('shared_1'):
output_blobs.add(
str(brew.fc(model, fc1, 'output', 16, 16)))
with scope.NameScope('shared_2'):
output_blobs.add(
str(brew.fc(model, fc1, 'output', 16, 16)))
# Params of this layer are not shared with anyone unless
# there is some explicit sharing with model_a/unshared (not
# in this example).
# Names of the blobs are
# [some_global_scope/model_a/unshared/output_W,
# some_global_scope/model_a/unshared/output_b]
with scope.NameScope('unshared'):
output_blobs.add(
str(brew.fc(model, fc1, 'output', 16, 16)))
self.assertEqual(len(model._parameters_info), 6)
self.assertEqual(len(output_blobs), 4)
self.assertEqual(sorted(model._parameters_info.keys()), [
'fc1_b',
'fc1_w',
'some_global_scope/model_a/output_b',
'some_global_scope/model_a/output_w',
'some_global_scope/model_a/unshared/output_b',
'some_global_scope/model_a/unshared/output_w',
])
model.Validate()
|