1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
## @package model_helper_api
# Module caffe2.python.model_helper_api
import sys
import copy
import inspect
from past.builtins import basestring
from caffe2.python.model_helper import ModelHelper
# flake8: noqa
from caffe2.python.helpers.algebra import *
from caffe2.python.helpers.arg_scope import *
from caffe2.python.helpers.array_helpers import *
from caffe2.python.helpers.control_ops import *
from caffe2.python.helpers.conv import *
from caffe2.python.helpers.db_input import *
from caffe2.python.helpers.dropout import *
from caffe2.python.helpers.elementwise_linear import *
from caffe2.python.helpers.fc import *
from caffe2.python.helpers.nonlinearity import *
from caffe2.python.helpers.normalization import *
from caffe2.python.helpers.pooling import *
from caffe2.python.helpers.quantization import *
from caffe2.python.helpers.tools import *
from caffe2.python.helpers.train import *
class HelperWrapper(object):
_registry = {
'arg_scope': arg_scope,
'fc': fc,
'packed_fc': packed_fc,
'fc_decomp': fc_decomp,
'fc_sparse': fc_sparse,
'fc_prune': fc_prune,
'dropout': dropout,
'max_pool': max_pool,
'average_pool': average_pool,
'max_pool_with_index' : max_pool_with_index,
'lrn': lrn,
'softmax': softmax,
'instance_norm': instance_norm,
'spatial_bn': spatial_bn,
'spatial_gn': spatial_gn,
'moments_with_running_stats': moments_with_running_stats,
'relu': relu,
'prelu': prelu,
'tanh': tanh,
'concat': concat,
'depth_concat': depth_concat,
'sum': sum,
'reduce_sum': reduce_sum,
'sub': sub,
'arg_min': arg_min,
'transpose': transpose,
'iter': iter,
'accuracy': accuracy,
'conv': conv,
'conv_nd': conv_nd,
'conv_transpose': conv_transpose,
'group_conv': group_conv,
'group_conv_deprecated': group_conv_deprecated,
'image_input': image_input,
'video_input': video_input,
'add_weight_decay': add_weight_decay,
'elementwise_linear': elementwise_linear,
'layer_norm': layer_norm,
'mat_mul' : mat_mul,
'batch_mat_mul' : batch_mat_mul,
'cond' : cond,
'loop' : loop,
'db_input' : db_input,
'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float,
'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse,
}
def __init__(self, wrapped):
self.wrapped = wrapped
def __getattr__(self, helper_name):
if helper_name not in self._registry:
raise AttributeError(
"Helper function {} not "
"registered.".format(helper_name)
)
def scope_wrapper(*args, **kwargs):
new_kwargs = {}
if helper_name != 'arg_scope':
if len(args) > 0 and isinstance(args[0], ModelHelper):
model = args[0]
elif 'model' in kwargs:
model = kwargs['model']
else:
raise RuntimeError(
"The first input of helper function should be model. " \
"Or you can provide it in kwargs as model=<your_model>.")
new_kwargs = copy.deepcopy(model.arg_scope)
func = self._registry[helper_name]
var_names, _, varkw, _= inspect.getargspec(func)
if varkw is None:
# this helper function does not take in random **kwargs
new_kwargs = {
var_name: new_kwargs[var_name]
for var_name in var_names if var_name in new_kwargs
}
cur_scope = get_current_scope()
new_kwargs.update(cur_scope.get(helper_name, {}))
new_kwargs.update(kwargs)
return func(*args, **new_kwargs)
scope_wrapper.__name__ = helper_name
return scope_wrapper
def Register(self, helper):
name = helper.__name__
if name in self._registry:
raise AttributeError(
"Helper {} already exists. Please change your "
"helper name.".format(name)
)
self._registry[name] = helper
def has_helper(self, helper_or_helper_name):
helper_name = (
helper_or_helper_name
if isinstance(helper_or_helper_name, basestring) else
helper_or_helper_name.__name__
)
return helper_name in self._registry
# pyre-fixme[6]: incompatible parameter type: expected ModuleType, got HelperWrapper
sys.modules[__name__] = HelperWrapper(sys.modules[__name__])
|