1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Contains the base Layer class, from which all layers inherit.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import copy 21 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.keras.engine import base_layer 26from tensorflow.python.keras.engine import base_layer_utils 27from tensorflow.python.ops import variable_scope as vs 28from tensorflow.python.ops import variables as tf_variables 29from tensorflow.python.training.tracking import base as trackable 30from tensorflow.python.util import function_utils 31from tensorflow.python.util import nest 32from tensorflow.python.util import tf_contextlib 33from tensorflow.python.util.tf_export import tf_export 34 35# Avoid breaking users who directly import this symbol from this file. 36# TODO(fchollet): remove this. 37InputSpec = base_layer.InputSpec # pylint: disable=invalid-name 38 39_KERAS_STYLE_SCOPE = False 40 41 42@tf_export(v1=['layers.experimental.keras_style_scope']) 43@tf_contextlib.contextmanager 44def keras_style_scope(): 45 """Use Keras-style variable management. 46 47 All tf.layers and tf RNN cells created in this scope use Keras-style 48 variable management. Creating such layers with a scope= argument is 49 disallowed, and reuse=True is disallowed. 50 51 The purpose of this scope is to allow users of existing layers to 52 slowly transition to a Keras layers API without breaking existing 53 functionality. 54 55 One example of this is when using TensorFlow's RNN classes with Keras 56 Models or Networks. Because Keras models do not properly set variable 57 scopes, users of RNNs may either accidentally share scopes between two 58 different models, or get errors about variables that already exist. 59 60 Example: 61 62 ```python 63 class RNNModel(tf.keras.Model): 64 65 def __init__(self, name): 66 super(RNNModel, self.).__init__(name=name) 67 self.rnn = tf.nn.rnn_cell.MultiRNNCell( 68 [tf.nn.rnn_cell.LSTMCell(64) for _ in range(2)]) 69 70 def call(self, input, state): 71 return self.rnn(input, state) 72 73 model_1 = RNNModel("model_1") 74 model_2 = RNNModel("model_2") 75 76 # OK 77 output_1, next_state_1 = model_1(input, state) 78 # Raises an error about trying to create an already existing variable. 79 output_2, next_state_2 = model_2(input, state) 80 ``` 81 82 The solution is to wrap the model construction and execution in a keras-style 83 scope: 84 85 ```python 86 with keras_style_scope(): 87 model_1 = RNNModel("model_1") 88 model_2 = RNNModel("model_2") 89 90 # model_1 and model_2 are guaranteed to create their own variables. 91 output_1, next_state_1 = model_1(input, state) 92 output_2, next_state_2 = model_2(input, state) 93 94 assert len(model_1.weights) > 0 95 assert len(model_2.weights) > 0 96 assert(model_1.weights != model_2.weights) 97 ``` 98 99 Yields: 100 A keras layer style scope. 101 """ 102 global _KERAS_STYLE_SCOPE 103 stack = _KERAS_STYLE_SCOPE 104 _KERAS_STYLE_SCOPE = True 105 try: 106 yield 107 finally: 108 _KERAS_STYLE_SCOPE = stack 109 110 111@tf_export(v1=['layers.experimental.set_keras_style']) 112def set_keras_style(): 113 """Use Keras-style variable management. 114 115 All tf.layers and tf RNN cells created after keras style ha been enabled 116 use Keras-style variable management. Creating such layers with a 117 scope= argument is disallowed, and reuse=True is disallowed. 118 119 The purpose of this function is to allow users of existing layers to 120 slowly transition to Keras layers API without breaking existing 121 functionality. 122 123 For more details, see the documentation for `keras_style_scope`. 124 125 Note, once keras style has been set, it is set globally for the entire 126 program and cannot be unset. 127 128 Example: 129 130 ```python 131 set_keras_style() 132 133 model_1 = RNNModel(name="model_1") 134 model_2 = RNNModel(name="model_2") 135 136 # model_1 and model_2 are guaranteed to create their own variables. 137 output_1, next_state_1 = model_1(input, state) 138 output_2, next_state_2 = model_2(input, state) 139 140 assert len(model_1.weights) > 0 141 assert len(model_2.weights) > 0 142 assert(model_1.weights != model_2.weights) 143 ``` 144 """ 145 global _KERAS_STYLE_SCOPE 146 _KERAS_STYLE_SCOPE = True 147 148 149def _is_in_keras_style_scope(): 150 global _KERAS_STYLE_SCOPE 151 return _KERAS_STYLE_SCOPE 152 153 154@tf_export(v1=['layers.Layer']) 155class Layer(base_layer.Layer): 156 """Base layer class. 157 158 It is considered legacy, and we recommend the use of `tf.keras.layers.Layer` 159 instead. 160 161 Arguments: 162 trainable: Boolean, whether the layer's variables should be trainable. 163 name: String name of the layer. 164 dtype: Default dtype of the layer's weights (default of `None` means use the 165 type of the first input). 166 167 Read-only properties: 168 name: The name of the layer (string). 169 dtype: Default dtype of the layer's weights (default of `None` means use the 170 type of the first input). 171 trainable_variables: List of trainable variables. 172 non_trainable_variables: List of non-trainable variables. 173 variables: List of all variables of this layer, trainable and 174 non-trainable. 175 updates: List of update ops of this layer. 176 losses: List of losses added by this layer. 177 trainable_weights: List of variables to be included in backprop. 178 non_trainable_weights: List of variables that should not be 179 included in backprop. 180 weights: The concatenation of the lists trainable_weights and 181 non_trainable_weights (in this order). 182 183 Mutable properties: 184 trainable: Whether the layer should be trained (boolean). 185 input_spec: Optional (list of) `InputSpec` object(s) specifying the 186 constraints on inputs that can be accepted by the layer. 187 """ 188 189 def __init__(self, trainable=True, name=None, dtype=None, 190 **kwargs): 191 # For backwards compatibility, legacy layers do not use `ResourceVariable` 192 # by default. 193 self._use_resource_variables = False 194 scope = kwargs.pop('_scope', None) 195 self._reuse = kwargs.pop('_reuse', None) 196 197 # Avoid an incorrect lint error 198 self._trainable_weights = [] 199 self.built = False 200 201 super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype, 202 **kwargs) 203 204 if _is_in_keras_style_scope(): 205 if scope is not None: 206 raise ValueError( 207 'scope argument not allowed when keras style layers are enabled, ' 208 'but saw: {}'.format(scope)) 209 if self._reuse is not None: 210 raise ValueError( 211 'reuse argument not allowed when keras style layers are enabled, ' 212 'but saw: {}'.format(self._reuse)) 213 self._keras_style = True 214 else: 215 self._keras_style = False 216 217 self._graph = None 218 self._call_has_scope_arg = 'scope' in self._call_fn_args 219 if scope: 220 with vs.variable_scope(scope) as captured_scope: 221 self._scope = captured_scope 222 else: 223 self._scope = None 224 self._current_scope = None 225 226 @property 227 def graph(self): 228 if context.executing_eagerly(): 229 raise RuntimeError('Layer.graph not supported when executing eagerly.') 230 return self._graph 231 232 def _init_set_name(self, name): 233 # Determine layer name (non-unique). 234 if isinstance(name, vs.VariableScope): 235 base_name = name.name 236 self._name, _ = self._make_unique_name() 237 else: 238 base_name = name 239 self._name = name 240 if not name: 241 self._name, base_name = self._make_unique_name() 242 self._base_name = base_name 243 244 def _make_unique_name(self, name_uid_map=None, avoid_names=None, 245 namespace='', zero_based=False): 246 base_name = base_layer.to_snake_case(self.__class__.__name__) 247 name = base_layer_utils.unique_layer_name(base_name, 248 name_uid_map=name_uid_map, 249 avoid_names=avoid_names, 250 namespace=namespace, 251 zero_based=zero_based) 252 return (name, base_name) 253 254 @property 255 def scope_name(self): 256 if not self._scope: 257 raise ValueError('No name available for layer scope because the layer "' + 258 self._name + '" has not been used yet. The scope name ' + 259 ' is determined the first time the layer instance is ' + 260 'called. You must therefore call the layer before ' + 261 'querying `scope_name`.') 262 return self._scope.name 263 264 def add_loss(self, losses, inputs=None): 265 previous_losses_length = len(self._losses) 266 previous_callable_losses_length = len(self._callable_losses) 267 super(Layer, self).add_loss(losses, inputs=inputs) 268 if not context.executing_eagerly(): 269 # TODO(fchollet): deprecate collection below. 270 new_losses = self._losses[previous_losses_length:] 271 new_callable_losses = self._callable_losses[ 272 previous_callable_losses_length:] 273 for regularizer in new_callable_losses: 274 loss_tensor = regularizer() 275 if loss_tensor is not None: 276 new_losses.append(loss_tensor) 277 _add_elements_to_collection( 278 new_losses, 279 ops.GraphKeys.REGULARIZATION_LOSSES) 280 281 def _name_scope(self): 282 """Determines op naming for the Layer.""" 283 if self._keras_style: 284 return super(Layer, self)._name_scope() 285 return self._current_scope.original_name_scope 286 287 def _set_scope(self, scope=None): 288 if self._scope is None: 289 # If constructed with _scope=None, lazy setting of scope. 290 if self._reuse: 291 with vs.variable_scope( 292 scope if scope is not None else self._base_name) as captured_scope: 293 self._scope = captured_scope 294 else: 295 with vs.variable_scope( 296 scope, default_name=self._base_name) as captured_scope: 297 self._scope = captured_scope 298 299 def add_weight(self, 300 name, 301 shape, 302 dtype=None, 303 initializer=None, 304 regularizer=None, 305 trainable=None, 306 constraint=None, 307 use_resource=None, 308 synchronization=vs.VariableSynchronization.AUTO, 309 aggregation=vs.VariableAggregation.NONE, 310 partitioner=None, 311 **kwargs): 312 """Adds a new variable to the layer, or gets an existing one; returns it. 313 314 Arguments: 315 name: variable name. 316 shape: variable shape. 317 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 318 initializer: initializer instance (callable). 319 regularizer: regularizer instance (callable). 320 trainable: whether the variable should be part of the layer's 321 "trainable_variables" (e.g. variables, biases) 322 or "non_trainable_variables" (e.g. BatchNorm mean, stddev). 323 Note, if the current variable scope is marked as non-trainable 324 then this parameter is ignored and any added variables are also 325 marked as non-trainable. `trainable` defaults to `True` unless 326 `synchronization` is set to `ON_READ`. 327 constraint: constraint instance (callable). 328 use_resource: Whether to use `ResourceVariable`. 329 synchronization: Indicates when a distributed a variable will be 330 aggregated. Accepted values are constants defined in the class 331 `tf.VariableSynchronization`. By default the synchronization is set to 332 `AUTO` and the current `DistributionStrategy` chooses 333 when to synchronize. If `synchronization` is set to `ON_READ`, 334 `trainable` must not be set to `True`. 335 aggregation: Indicates how a distributed variable will be aggregated. 336 Accepted values are constants defined in the class 337 `tf.VariableAggregation`. 338 partitioner: (optional) partitioner instance (callable). If 339 provided, when the requested variable is created it will be split 340 into multiple partitions according to `partitioner`. In this case, 341 an instance of `PartitionedVariable` is returned. Available 342 partitioners include `tf.fixed_size_partitioner` and 343 `tf.variable_axis_size_partitioner`. For more details, see the 344 documentation of `tf.get_variable` and the "Variable Partitioners 345 and Sharding" section of the API guide. 346 **kwargs: Additional keyword arguments. 347 348 Returns: 349 The created variable. Usually either a `Variable` or `ResourceVariable` 350 instance. If `partitioner` is not `None`, a `PartitionedVariable` 351 instance is returned. 352 353 Raises: 354 RuntimeError: If called with partioned variable regularization and 355 eager execution is enabled. 356 ValueError: When trainable has been set to True with synchronization 357 set as `ON_READ`. 358 """ 359 for kwarg in kwargs: 360 if kwarg != 'experimental_autocast': 361 raise TypeError('Unknown keyword argument:', kwarg) 362 if self._keras_style: 363 return super(Layer, self).add_weight( 364 name=name, 365 shape=shape, 366 dtype=dtype, 367 initializer=initializer, 368 regularizer=regularizer, 369 trainable=trainable, 370 constraint=constraint, 371 use_resource=use_resource, 372 synchronization=vs.VariableSynchronization.AUTO, 373 aggregation=vs.VariableAggregation.NONE, 374 partitioner=partitioner, 375 **kwargs) 376 377 if synchronization == vs.VariableSynchronization.ON_READ: 378 if trainable: 379 raise ValueError( 380 'Synchronization value can be set to ' 381 'VariableSynchronization.ON_READ only for non-trainable variables. ' 382 'You have specified trainable=True and ' 383 'synchronization=VariableSynchronization.ON_READ.') 384 else: 385 # Set trainable to be false when variable is to be synced on read. 386 trainable = False 387 elif trainable is None: 388 trainable = True 389 390 def _should_add_regularizer(variable, existing_variable_set): 391 if isinstance(variable, tf_variables.PartitionedVariable): 392 for var in variable: 393 if var in existing_variable_set: 394 return False 395 return True 396 else: 397 return variable not in existing_variable_set 398 399 init_graph = None 400 if not context.executing_eagerly(): 401 default_graph = ops.get_default_graph() 402 if default_graph.building_function: 403 with ops.init_scope(): 404 # Retrieve the variables from the graph into which variables 405 # will be lifted; if initialization ops will be lifted into 406 # the eager context, then there is nothing to retrieve, since variable 407 # collections are not supported when eager execution is enabled. 408 if not context.executing_eagerly(): 409 init_graph = ops.get_default_graph() 410 existing_variables = set(tf_variables.global_variables()) 411 else: 412 # Initialization ops will not be lifted out of the default graph. 413 init_graph = default_graph 414 existing_variables = set(tf_variables.global_variables()) 415 416 if dtype is None: 417 dtype = self.dtype or dtypes.float32 418 419 self._set_scope(None) 420 reuse = self.built or self._reuse 421 prev_len_trainable = len(self._trainable_weights) 422 with vs.variable_scope( 423 self._scope, reuse=reuse, auxiliary_name_scope=False) as scope: 424 self._current_scope = scope 425 with ops.name_scope(self._name_scope()): 426 use_resource = (use_resource or 427 self._use_resource_variables or 428 scope.use_resource) 429 if initializer is None: 430 initializer = scope.initializer 431 variable = super(Layer, self).add_weight( 432 name, 433 shape, 434 dtype=dtypes.as_dtype(dtype), 435 initializer=initializer, 436 trainable=trainable, 437 constraint=constraint, 438 partitioner=partitioner, 439 use_resource=use_resource, 440 synchronization=synchronization, 441 aggregation=aggregation, 442 getter=vs.get_variable, 443 **kwargs) 444 445 if regularizer: 446 if (ops.executing_eagerly_outside_functions() 447 or _should_add_regularizer(variable, existing_variables)): 448 self._handle_weight_regularization(name, variable, regularizer) 449 450 if init_graph is not None: 451 # Handle edge case where a custom getter has overridden `trainable`. 452 # There is one known occurrence of this, in unit test 453 # testBasicRNNCellNotTrainable in 454 # contrib.rnn.python.kernel_tests.core_rnn_cell_test 455 with init_graph.as_default(): 456 trainable_variables = tf_variables.trainable_variables() 457 if (trainable and self.trainable and 458 variable not in trainable_variables): 459 # A custom getter / variable scope overrode the trainable flag. 460 extra_trainable_vars = self._trainable_weights[prev_len_trainable:] 461 self._trainable_weights = self._trainable_weights[ 462 :prev_len_trainable] 463 self._non_trainable_weights += extra_trainable_vars 464 return variable 465 466 def __call__(self, inputs, *args, **kwargs): 467 """Wraps `call`, applying pre- and post-processing steps. 468 469 Arguments: 470 inputs: input tensor(s). 471 *args: additional positional arguments to be passed to `self.call`. 472 **kwargs: additional keyword arguments to be passed to `self.call`. 473 **Note**: kwarg `scope` is reserved for use by the layer. 474 475 Returns: 476 Output tensor(s). 477 478 Note: 479 - If the layer's `call` method takes a `scope` keyword argument, 480 this argument will be automatically set to the current variable scope. 481 - If the layer's `call` method takes a `mask` argument (as some Keras 482 layers do), its default value will be set to the mask generated 483 for `inputs` by the previous layer (if `input` did come from 484 a layer that generated a corresponding mask, i.e. if it came from 485 a Keras layer with masking support. 486 487 Raises: 488 ValueError: if the layer's `call` method returns None (an invalid value). 489 """ 490 scope = kwargs.pop('scope', None) 491 492 if self._keras_style: 493 if scope is not None: 494 raise ValueError( 495 'scope argument not allowed when keras style layers are enabled, ' 496 'but saw: {}'.format(scope)) 497 return super(Layer, self).__call__(inputs, *args, **kwargs) 498 499 self._set_scope(scope) 500 501 if not context.executing_eagerly(): 502 try: 503 # Set layer's "graph" at build time 504 self._graph = ops._get_graph_from_inputs(nest.flatten(inputs), # pylint: disable=protected-access 505 graph=self._graph) 506 except ValueError as e: 507 raise ValueError('Input graph and Layer graph are not the same: %s' % e) 508 509 if self.built: 510 try: 511 # Some classes which inherit from Layer do not use its constructor, so 512 # rather than initializing to None we check for an AttributeError. 513 scope_context_manager = self._always_reuse_variable_scope 514 except AttributeError: 515 # From this point we will always set reuse=True, so create a "final" 516 # variable scope with this setting. We avoid re-creating variable scopes 517 # after this point as an optimization. 518 self._always_reuse_variable_scope = vs.variable_scope( 519 self._scope, reuse=True, auxiliary_name_scope=False) 520 scope_context_manager = self._always_reuse_variable_scope 521 else: 522 scope_context_manager = vs.variable_scope( 523 self._scope, reuse=self._reuse, auxiliary_name_scope=False) 524 525 with scope_context_manager as scope: 526 self._current_scope = scope 527 528 try: 529 call_has_scope_arg = self._call_has_scope_arg 530 except AttributeError: 531 self._call_fn_args = function_utils.fn_args(self.call) 532 self._call_has_scope_arg = 'scope' in self._call_fn_args 533 call_has_scope_arg = self._call_has_scope_arg 534 if call_has_scope_arg: 535 kwargs['scope'] = scope 536 537 # Actually call layer 538 outputs = super(Layer, self).__call__(inputs, *args, **kwargs) 539 540 if not context.executing_eagerly(): 541 # Update global default collections. 542 _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) 543 return outputs 544 545 def __deepcopy__(self, memo): 546 no_copy = set(['_graph']) 547 shallow_copy = set(['_scope', '_always_reuse_variable_scope']) 548 cls = self.__class__ 549 result = cls.__new__(cls) 550 memo[id(self)] = result 551 for k, v in self.__dict__.items(): 552 if k in no_copy: 553 setattr(result, k, v) 554 elif k in shallow_copy: 555 setattr(result, k, copy.copy(v)) 556 elif base_layer.is_tensor_or_tensor_list(v): 557 setattr(result, k, v) 558 else: 559 setattr(result, k, copy.deepcopy(v, memo)) 560 return result 561 562 def __setattr__(self, value, name): 563 # By-pass the automatic dependency tracking performed by the parent Layer. 564 super(trackable.Trackable, self).__setattr__(value, name) 565 566 567def _add_elements_to_collection(elements, collection_list): 568 if context.executing_eagerly(): 569 raise RuntimeError('Using collections from Layers not supported in Eager ' 570 'mode. Tried to add %s to %s' % (elements, 571 collection_list)) 572 elements = nest.flatten(elements) 573 collection_list = nest.flatten(collection_list) 574 for name in collection_list: 575 collection = ops.get_collection_ref(name) 576 collection_set = set(collection) 577 for element in elements: 578 if element not in collection_set: 579 collection.append(element) 580