1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the arg_scope used for scoping layers arguments. 16 17 Allows one to define models much more compactly by eliminating boilerplate 18 code. This is accomplished through the use of argument scoping (arg_scope). 19 20 Example of how to use tf.contrib.framework.arg_scope: 21 22 ``` 23 from third_party.tensorflow.contrib.layers.python import layers 24 25 arg_scope = tf.contrib.framework.arg_scope 26 27 with arg_scope([layers.conv2d], padding='SAME', 28 initializer=layers.variance_scaling_initializer(), 29 regularizer=layers.l2_regularizer(0.05)): 30 net = layers.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1') 31 net = layers.conv2d(net, 256, [5, 5], scope='conv2') 32 ``` 33 The first call to conv2d will behave as follows: 34 layers.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 35 initializer=layers.variance_scaling_initializer(), 36 regularizer=layers.l2_regularizer(0.05), scope='conv1') 37 38 The second call to conv2d will also use the arg_scope's default for padding: 39 layers.conv2d(inputs, 256, [5, 5], padding='SAME', 40 initializer=layers.variance_scaling_initializer(), 41 regularizer=layers.l2_regularizer(0.05), scope='conv2') 42 43 Example of how to reuse an arg_scope: 44 45 ``` 46 with arg_scope([layers.conv2d], padding='SAME', 47 initializer=layers.variance_scaling_initializer(), 48 regularizer=layers.l2_regularizer(0.05)) as sc: 49 net = layers.conv2d(net, 256, [5, 5], scope='conv1') 50 .... 51 52 with arg_scope(sc): 53 net = layers.conv2d(net, 256, [5, 5], scope='conv2') 54 ``` 55 56 Example of how to use tf.contrib.framework.add_arg_scope to enable your 57 function to be called within an arg_scope later: 58 59 @tf.contrib.framework.add_arg_scope 60 def conv2d(*args, **kwargs) 61""" 62from __future__ import absolute_import 63from __future__ import division 64from __future__ import print_function 65 66from tensorflow.python.util import tf_contextlib 67from tensorflow.python.util import tf_decorator 68 69__all__ = [ 70 'arg_scope', 'add_arg_scope', 'current_arg_scope', 'has_arg_scope', 71 'arg_scoped_arguments', 'arg_scope_func_key' 72] 73 74_ARGSTACK = [{}] 75 76_DECORATED_OPS = {} 77 78 79def _get_arg_stack(): 80 if _ARGSTACK: 81 return _ARGSTACK 82 else: 83 _ARGSTACK.append({}) 84 return _ARGSTACK 85 86 87def current_arg_scope(): 88 stack = _get_arg_stack() 89 return stack[-1] 90 91 92def arg_scope_func_key(op): 93 return getattr(op, '_key_op', str(op)) 94 95 96def _name_op(op): 97 return (op.__module__, op.__name__) 98 99 100def _kwarg_names(func): 101 kwargs_length = len(func.__defaults__) if func.__defaults__ else 0 102 return func.__code__.co_varnames[-kwargs_length:func.__code__.co_argcount] 103 104 105def _add_op(op): 106 key_op = arg_scope_func_key(op) 107 _DECORATED_OPS[key_op] = _kwarg_names(op) 108 109 110@tf_contextlib.contextmanager 111def arg_scope(list_ops_or_scope, **kwargs): 112 """Stores the default arguments for the given set of list_ops. 113 114 For usage, please see examples at top of the file. 115 116 Args: 117 list_ops_or_scope: List or tuple of operations to set argument scope for or 118 a dictionary containing the current scope. When list_ops_or_scope is a 119 dict, kwargs must be empty. When list_ops_or_scope is a list or tuple, 120 then every op in it need to be decorated with @add_arg_scope to work. 121 **kwargs: keyword=value that will define the defaults for each op in 122 list_ops. All the ops need to accept the given set of arguments. 123 124 Yields: 125 the current_scope, which is a dictionary of {op: {arg: value}} 126 Raises: 127 TypeError: if list_ops is not a list or a tuple. 128 ValueError: if any op in list_ops has not be decorated with @add_arg_scope. 129 """ 130 if isinstance(list_ops_or_scope, dict): 131 # Assumes that list_ops_or_scope is a scope that is being reused. 132 if kwargs: 133 raise ValueError('When attempting to re-use a scope by suppling a' 134 'dictionary, kwargs must be empty.') 135 current_scope = list_ops_or_scope.copy() 136 try: 137 _get_arg_stack().append(current_scope) 138 yield current_scope 139 finally: 140 _get_arg_stack().pop() 141 else: 142 # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs. 143 if not isinstance(list_ops_or_scope, (list, tuple)): 144 raise TypeError('list_ops_or_scope must either be a list/tuple or reused ' 145 'scope (i.e. dict)') 146 try: 147 current_scope = current_arg_scope().copy() 148 for op in list_ops_or_scope: 149 key = arg_scope_func_key(op) 150 if not has_arg_scope(op): 151 raise ValueError('%s is not decorated with @add_arg_scope', 152 _name_op(op)) 153 if key in current_scope: 154 current_kwargs = current_scope[key].copy() 155 current_kwargs.update(kwargs) 156 current_scope[key] = current_kwargs 157 else: 158 current_scope[key] = kwargs.copy() 159 _get_arg_stack().append(current_scope) 160 yield current_scope 161 finally: 162 _get_arg_stack().pop() 163 164 165def add_arg_scope(func): 166 """Decorates a function with args so it can be used within an arg_scope. 167 168 Args: 169 func: function to decorate. 170 171 Returns: 172 A tuple with the decorated function func_with_args(). 173 """ 174 175 def func_with_args(*args, **kwargs): 176 current_scope = current_arg_scope() 177 current_args = kwargs 178 key_func = arg_scope_func_key(func) 179 if key_func in current_scope: 180 current_args = current_scope[key_func].copy() 181 current_args.update(kwargs) 182 return func(*args, **current_args) 183 184 _add_op(func) 185 setattr(func_with_args, '_key_op', arg_scope_func_key(func)) 186 return tf_decorator.make_decorator(func, func_with_args) 187 188 189def has_arg_scope(func): 190 """Checks whether a func has been decorated with @add_arg_scope or not. 191 192 Args: 193 func: function to check. 194 195 Returns: 196 a boolean. 197 """ 198 return arg_scope_func_key(func) in _DECORATED_OPS 199 200 201def arg_scoped_arguments(func): 202 """Returns the list kwargs that arg_scope can set for a func. 203 204 Args: 205 func: function which has been decorated with @add_arg_scope. 206 207 Returns: 208 a list of kwargs names. 209 """ 210 assert has_arg_scope(func) 211 return _DECORATED_OPS[arg_scope_func_key(func)] 212