1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15# pylint: disable=g-classes-have-attributes 16"""Contains the base Layer class, from which all layers inherit.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import warnings 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.keras import backend 28from tensorflow.python.keras.engine import base_layer 29from tensorflow.python.keras.engine import base_layer_utils 30from tensorflow.python.keras.legacy_tf_layers import variable_scope_shim 31from tensorflow.python.keras.mixed_precision import policy 32from tensorflow.python.keras.utils import tf_contextlib 33from tensorflow.python.ops import variable_scope as vs 34from tensorflow.python.ops import variables as tf_variables 35from tensorflow.python.training.tracking import base as trackable 36from tensorflow.python.util import nest 37from tensorflow.python.util.tf_export import keras_export 38from tensorflow.python.util.tf_export import tf_export 39 40# Avoid breaking users who directly import this symbol from this file. 41# TODO(fchollet): remove this. 42InputSpec = base_layer.InputSpec # pylint: disable=invalid-name 43 44_KERAS_STYLE_SCOPE = False 45 46 47@keras_export( 48 v1=['keras.__internal__.legacy.layers.experimental.keras_style_scope']) 49@tf_export(v1=['layers.experimental.keras_style_scope']) 50@tf_contextlib.contextmanager 51def keras_style_scope(): 52 """Use Keras-style variable management. 53 54 All tf.layers and tf RNN cells created in this scope use Keras-style 55 variable management. Creating such layers with a scope= argument is 56 disallowed, and reuse=True is disallowed. 57 58 The purpose of this scope is to allow users of existing layers to 59 slowly transition to a Keras layers API without breaking existing 60 functionality. 61 62 One example of this is when using TensorFlow's RNN classes with Keras 63 Models or Networks. Because Keras models do not properly set variable 64 scopes, users of RNNs may either accidentally share scopes between two 65 different models, or get errors about variables that already exist. 66 67 Example: 68 69 ```python 70 class RNNModel(tf.keras.Model): 71 72 def __init__(self, name): 73 super(RNNModel, self).__init__(name=name) 74 self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell( 75 [tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)]) 76 77 def call(self, input, state): 78 return self.rnn(input, state) 79 80 model_1 = RNNModel("model_1") 81 model_2 = RNNModel("model_2") 82 83 # OK 84 output_1, next_state_1 = model_1(input, state) 85 # Raises an error about trying to create an already existing variable. 86 output_2, next_state_2 = model_2(input, state) 87 ``` 88 89 The solution is to wrap the model construction and execution in a keras-style 90 scope: 91 92 ```python 93 with keras_style_scope(): 94 model_1 = RNNModel("model_1") 95 model_2 = RNNModel("model_2") 96 97 # model_1 and model_2 are guaranteed to create their own variables. 98 output_1, next_state_1 = model_1(input, state) 99 output_2, next_state_2 = model_2(input, state) 100 101 assert len(model_1.weights) > 0 102 assert len(model_2.weights) > 0 103 assert(model_1.weights != model_2.weights) 104 ``` 105 106 Yields: 107 A keras layer style scope. 108 """ 109 global _KERAS_STYLE_SCOPE 110 stack = _KERAS_STYLE_SCOPE 111 _KERAS_STYLE_SCOPE = True 112 try: 113 yield 114 finally: 115 _KERAS_STYLE_SCOPE = stack 116 117 118@keras_export( 119 v1=['keras.__internal__.legacy.layers.experimental.set_keras_style']) 120@tf_export(v1=['layers.experimental.set_keras_style']) 121def set_keras_style(): 122 """Use Keras-style variable management. 123 124 All tf.layers and tf RNN cells created after keras style ha been enabled 125 use Keras-style variable management. Creating such layers with a 126 scope= argument is disallowed, and reuse=True is disallowed. 127 128 The purpose of this function is to allow users of existing layers to 129 slowly transition to Keras layers API without breaking existing 130 functionality. 131 132 For more details, see the documentation for `keras_style_scope`. 133 134 Note, once keras style has been set, it is set globally for the entire 135 program and cannot be unset. 136 137 Example: 138 139 ```python 140 set_keras_style() 141 142 model_1 = RNNModel(name="model_1") 143 model_2 = RNNModel(name="model_2") 144 145 # model_1 and model_2 are guaranteed to create their own variables. 146 output_1, next_state_1 = model_1(input, state) 147 output_2, next_state_2 = model_2(input, state) 148 149 assert len(model_1.weights) > 0 150 assert len(model_2.weights) > 0 151 assert(model_1.weights != model_2.weights) 152 ``` 153 """ 154 global _KERAS_STYLE_SCOPE 155 _KERAS_STYLE_SCOPE = True 156 157 158def _is_in_keras_style_scope(): 159 global _KERAS_STYLE_SCOPE 160 return _KERAS_STYLE_SCOPE 161 162 163@keras_export(v1=['keras.__internal__.legacy.layers.Layer']) 164@tf_export(v1=['layers.Layer']) 165class Layer(base_layer.Layer): 166 """Base layer class. 167 168 It is considered legacy, and we recommend the use of `tf.keras.layers.Layer` 169 instead. 170 171 Args: 172 trainable: Boolean, whether the layer's variables should be trainable. 173 name: String name of the layer. 174 dtype: Default dtype of the layer's weights (default of `None` means use the 175 type of the first input). 176 177 Read-only properties: 178 name: The name of the layer (string). 179 dtype: Default dtype of the layer's weights (default of `None` means use the 180 type of the first input). 181 trainable_variables: List of trainable variables. 182 non_trainable_variables: List of non-trainable variables. 183 variables: List of all variables of this layer, trainable and 184 non-trainable. 185 updates: List of update ops of this layer. 186 losses: List of losses added by this layer. 187 trainable_weights: List of variables to be included in backprop. 188 non_trainable_weights: List of variables that should not be 189 included in backprop. 190 weights: The concatenation of the lists trainable_weights and 191 non_trainable_weights (in this order). 192 193 Mutable properties: 194 trainable: Whether the layer should be trained (boolean). 195 input_spec: Optional (list of) `InputSpec` object(s) specifying the 196 constraints on inputs that can be accepted by the layer. 197 """ 198 199 def __init__(self, trainable=True, name=None, dtype=None, 200 **kwargs): 201 # For backwards compatibility, legacy layers do not use `ResourceVariable` 202 # by default. 203 self._use_resource_variables = False 204 scope = kwargs.pop('_scope', None) 205 self._reuse = kwargs.pop('_reuse', None) 206 207 # Avoid an incorrect lint error 208 self._trainable_weights = [] 209 self.built = False 210 211 if dtype is None: 212 # Indicates to infer dtype from inputs. When the V2 dtype behavior is 213 # enabled, Keras layers default their dtype to floatx instead, so we pass 214 # an "_infer" policy to keep the old V1 behavior. 215 dtype = policy.Policy('_infer') 216 217 if 'autocast' not in kwargs: 218 kwargs['autocast'] = False 219 220 # Mark that legacy layers should not be instrumented as Keras usage 221 self._disable_keras_instrumentation = True 222 223 super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype, 224 **kwargs) 225 226 if _is_in_keras_style_scope(): 227 if scope is not None: 228 raise ValueError( 229 'scope argument not allowed when keras style layers are enabled, ' 230 'but saw: {}'.format(scope)) 231 if self._reuse is not None: 232 raise ValueError( 233 'reuse argument not allowed when keras style layers are enabled, ' 234 'but saw: {}'.format(self._reuse)) 235 self._keras_style = True 236 else: 237 self._keras_style = False 238 239 self._call_has_scope_arg = 'scope' in self._call_fn_args 240 if scope: 241 with vs.variable_scope(scope) as captured_scope: 242 self._scope = captured_scope 243 else: 244 self._scope = None 245 self._current_scope = None 246 247 # We no longer track graph in tf.layers layers. This property is only kept to 248 # maintain API backward compatibility. 249 @property 250 def graph(self): 251 warnings.warn('`Layer.graph` is deprecated and ' 252 'will be removed in a future version. ' 253 'Please stop using this property because tf.layers layers no ' 254 'longer track their graph.') 255 if context.executing_eagerly(): 256 raise RuntimeError('Layer.graph not supported when executing eagerly.') 257 return None 258 259 def _init_set_name(self, name): 260 # Determine layer name (non-unique). 261 if isinstance(name, vs.VariableScope): 262 base_name = name.name 263 self._name, _ = self._make_unique_name() 264 else: 265 base_name = name 266 self._name = name 267 if not name: 268 self._name, base_name = self._make_unique_name() 269 self._base_name = base_name 270 271 def _make_unique_name(self, name_uid_map=None, avoid_names=None, 272 namespace='', zero_based=False): 273 base_name = base_layer.to_snake_case(self.__class__.__name__) 274 name = backend.unique_object_name( 275 base_name, 276 name_uid_map=name_uid_map, 277 avoid_names=avoid_names, 278 namespace=namespace, 279 zero_based=zero_based) 280 return (name, base_name) 281 282 @property 283 def scope_name(self): 284 if not self._scope: 285 raise ValueError('No name available for layer scope because the layer "' + 286 self._name + '" has not been used yet. The scope name ' + 287 ' is determined the first time the layer instance is ' + 288 'called. You must therefore call the layer before ' + 289 'querying `scope_name`.') 290 return self._scope.name 291 292 def add_loss(self, losses, inputs=None): 293 previous_losses_length = len(self._losses) 294 previous_callable_losses_length = len(self._callable_losses) 295 super(Layer, self).add_loss(losses, inputs=inputs) 296 if not context.executing_eagerly(): 297 # TODO(fchollet): deprecate collection below. 298 new_losses = self._losses[previous_losses_length:] 299 new_callable_losses = self._callable_losses[ 300 previous_callable_losses_length:] 301 for regularizer in new_callable_losses: 302 loss_tensor = regularizer() 303 if loss_tensor is not None: 304 new_losses.append(loss_tensor) 305 _add_elements_to_collection( 306 new_losses, 307 ops.GraphKeys.REGULARIZATION_LOSSES) 308 309 def _name_scope(self): # pylint: disable=method-hidden 310 """Determines op naming for the Layer.""" 311 if self._keras_style: 312 return super(Layer, self)._name_scope() 313 return self._current_scope.original_name_scope 314 315 def _set_scope(self, scope=None): 316 if self._scope is None: 317 # If constructed with _scope=None, lazy setting of scope. 318 if self._reuse: 319 with vs.variable_scope( 320 scope if scope is not None else self._base_name) as captured_scope: 321 self._scope = captured_scope 322 else: 323 with vs.variable_scope( 324 scope, default_name=self._base_name) as captured_scope: 325 self._scope = captured_scope 326 327 def add_weight(self, 328 name, 329 shape, 330 dtype=None, 331 initializer=None, 332 regularizer=None, 333 trainable=None, 334 constraint=None, 335 use_resource=None, 336 synchronization=vs.VariableSynchronization.AUTO, 337 aggregation=vs.VariableAggregation.NONE, 338 partitioner=None, 339 **kwargs): 340 """Adds a new variable to the layer, or gets an existing one; returns it. 341 342 Args: 343 name: variable name. 344 shape: variable shape. 345 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 346 initializer: initializer instance (callable). 347 regularizer: regularizer instance (callable). 348 trainable: whether the variable should be part of the layer's 349 "trainable_variables" (e.g. variables, biases) 350 or "non_trainable_variables" (e.g. BatchNorm mean, stddev). 351 Note, if the current variable scope is marked as non-trainable 352 then this parameter is ignored and any added variables are also 353 marked as non-trainable. `trainable` defaults to `True` unless 354 `synchronization` is set to `ON_READ`. 355 constraint: constraint instance (callable). 356 use_resource: Whether to use `ResourceVariable`. 357 synchronization: Indicates when a distributed a variable will be 358 aggregated. Accepted values are constants defined in the class 359 `tf.VariableSynchronization`. By default the synchronization is set to 360 `AUTO` and the current `DistributionStrategy` chooses 361 when to synchronize. If `synchronization` is set to `ON_READ`, 362 `trainable` must not be set to `True`. 363 aggregation: Indicates how a distributed variable will be aggregated. 364 Accepted values are constants defined in the class 365 `tf.VariableAggregation`. 366 partitioner: (optional) partitioner instance (callable). If 367 provided, when the requested variable is created it will be split 368 into multiple partitions according to `partitioner`. In this case, 369 an instance of `PartitionedVariable` is returned. Available 370 partitioners include `tf.compat.v1.fixed_size_partitioner` and 371 `tf.compat.v1.variable_axis_size_partitioner`. For more details, see 372 the documentation of `tf.compat.v1.get_variable` and the "Variable 373 Partitioners and Sharding" section of the API guide. 374 **kwargs: Additional keyword arguments. 375 376 Returns: 377 The created variable. Usually either a `Variable` or `ResourceVariable` 378 instance. If `partitioner` is not `None`, a `PartitionedVariable` 379 instance is returned. 380 381 Raises: 382 RuntimeError: If called with partitioned variable regularization and 383 eager execution is enabled. 384 ValueError: When trainable has been set to True with synchronization 385 set as `ON_READ`. 386 """ 387 for kwarg in kwargs: 388 if kwarg != 'experimental_autocast': 389 raise TypeError('Unknown keyword argument:', kwarg) 390 if self._keras_style: 391 return super(Layer, self).add_weight( 392 name=name, 393 shape=shape, 394 dtype=dtype, 395 initializer=initializer, 396 regularizer=regularizer, 397 trainable=trainable and self.trainable, 398 constraint=constraint, 399 use_resource=use_resource, 400 synchronization=vs.VariableSynchronization.AUTO, 401 aggregation=vs.VariableAggregation.NONE, 402 partitioner=partitioner, 403 **kwargs) 404 405 if synchronization == vs.VariableSynchronization.ON_READ: 406 if trainable: 407 raise ValueError( 408 'Synchronization value can be set to ' 409 'VariableSynchronization.ON_READ only for non-trainable variables. ' 410 'You have specified trainable=True and ' 411 'synchronization=VariableSynchronization.ON_READ.') 412 else: 413 # Set trainable to be false when variable is to be synced on read. 414 trainable = False 415 elif trainable is None: 416 trainable = True 417 418 def _should_add_regularizer(variable, existing_variable_set): 419 if base_layer_utils.is_split_variable(variable): 420 for var in variable: 421 if var in existing_variable_set: 422 return False 423 return True 424 else: 425 return variable not in existing_variable_set 426 427 init_graph = None 428 if not context.executing_eagerly(): 429 default_graph = ops.get_default_graph() 430 if default_graph.building_function: 431 with ops.init_scope(): 432 # Retrieve the variables from the graph into which variables 433 # will be lifted; if initialization ops will be lifted into 434 # the eager context, then there is nothing to retrieve, since variable 435 # collections are not supported when eager execution is enabled. 436 if not context.executing_eagerly(): 437 init_graph = ops.get_default_graph() 438 existing_variables = set(tf_variables.global_variables()) 439 else: 440 # Initialization ops will not be lifted out of the default graph. 441 init_graph = default_graph 442 existing_variables = set(tf_variables.global_variables()) 443 444 if dtype is None: 445 dtype = self.dtype or dtypes.float32 446 447 self._set_scope(None) 448 reuse = self.built or self._reuse 449 prev_len_trainable = len(self._trainable_weights) 450 with vs.variable_scope( 451 self._scope, reuse=reuse, auxiliary_name_scope=False) as scope: 452 self._current_scope = scope 453 with backend.name_scope(self._name_scope()): # pylint: disable=not-callable 454 use_resource = (use_resource or 455 self._use_resource_variables or 456 scope.use_resource) 457 if initializer is None: 458 initializer = scope.initializer 459 variable = super(Layer, self).add_weight( 460 name, 461 shape, 462 dtype=dtypes.as_dtype(dtype), 463 initializer=initializer, 464 trainable=trainable and self.trainable, 465 constraint=constraint, 466 partitioner=partitioner, 467 use_resource=use_resource, 468 synchronization=synchronization, 469 aggregation=aggregation, 470 getter=vs.get_variable, 471 **kwargs) 472 473 if regularizer: 474 if (ops.executing_eagerly_outside_functions() 475 or _should_add_regularizer(variable, existing_variables)): 476 self._handle_weight_regularization(name, variable, regularizer) 477 var_store = vs._get_default_variable_store() # pylint: disable=protected-access 478 # When the shim to get variable scope working in TF2 is used, 479 # We need to explicitly make the shim track the regularization 480 # losses as the collections will not be accessible. 481 if hasattr(var_store, 'add_regularizer'): 482 var_store.add_regularizer(variable, regularizer) 483 484 if init_graph is not None: 485 # Handle edge case where a custom getter has overridden `trainable`. 486 # There is one known occurrence of this, in unit test 487 # testBasicRNNCellNotTrainable in 488 # contrib.rnn.python.kernel_tests.core_rnn_cell_test 489 with init_graph.as_default(): 490 trainable_variables = tf_variables.trainable_variables() 491 if (trainable and self.trainable and 492 variable not in trainable_variables): 493 # A custom getter / variable scope overrode the trainable flag. 494 extra_trainable_vars = self._trainable_weights[prev_len_trainable:] 495 self._trainable_weights = self._trainable_weights[ 496 :prev_len_trainable] 497 self._non_trainable_weights += extra_trainable_vars 498 return variable 499 500 def __call__(self, inputs, *args, **kwargs): 501 """Wraps `call`, applying pre- and post-processing steps. 502 503 Args: 504 inputs: input tensor(s). 505 *args: additional positional arguments to be passed to `self.call`. 506 **kwargs: additional keyword arguments to be passed to `self.call`. 507 **Note**: kwarg `scope` is reserved for use by the layer. 508 509 Returns: 510 Output tensor(s). 511 512 Note: 513 - If the layer's `call` method takes a `scope` keyword argument, 514 this argument will be automatically set to the current variable scope. 515 - If the layer's `call` method takes a `mask` argument (as some Keras 516 layers do), its default value will be set to the mask generated 517 for `inputs` by the previous layer (if `input` did come from 518 a layer that generated a corresponding mask, i.e. if it came from 519 a Keras layer with masking support. 520 521 Raises: 522 ValueError: if the layer's `call` method returns None (an invalid value). 523 """ 524 scope = kwargs.pop('scope', None) 525 526 if self._keras_style: 527 if scope is not None: 528 raise ValueError( 529 'scope argument not allowed when keras style layers are enabled, ' 530 'but saw: {}'.format(scope)) 531 return super(Layer, self).__call__(inputs, *args, **kwargs) 532 533 self._set_scope(scope) 534 535 if self.built: 536 try: 537 # Some classes which inherit from Layer do not use its constructor, so 538 # rather than initializing to None we check for an AttributeError. 539 scope_context_manager = self._always_reuse_variable_scope # pylint: disable=access-member-before-definition 540 except AttributeError: 541 scope_context_manager = None 542 543 if scope_context_manager is None: 544 # From this point we will always set reuse=True, so create a "final" 545 # variable scope with this setting. We avoid re-creating variable scopes 546 # after this point as an optimization. 547 scope_context_manager = vs.variable_scope( 548 self._scope, reuse=True, auxiliary_name_scope=False) 549 550 # Do not cache variable scopes if Eager mode is enabled. If Eager mode 551 # is enabled then we don't want to reuse scopes because the cached scope 552 # might be from a FuncGraph or Eager scope we are no longer in. 553 if not ops.executing_eagerly_outside_functions(): 554 self._always_reuse_variable_scope = scope_context_manager 555 else: 556 scope_context_manager = vs.variable_scope( 557 self._scope, reuse=self._reuse, auxiliary_name_scope=False) 558 559 with scope_context_manager as scope: 560 self._current_scope = scope 561 562 try: 563 call_has_scope_arg = self._call_has_scope_arg 564 except AttributeError: 565 self._call_fn_args = variable_scope_shim.fn_args(self.call) 566 self._call_has_scope_arg = 'scope' in self._call_fn_args 567 call_has_scope_arg = self._call_has_scope_arg 568 if call_has_scope_arg: 569 kwargs['scope'] = scope 570 571 # Actually call layer 572 outputs = super(Layer, self).__call__(inputs, *args, **kwargs) 573 574 if not context.executing_eagerly(): 575 # Update global default collections. 576 _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) 577 return outputs 578 579 def __deepcopy__(self, memo): 580 no_copy = set(['_graph', '_thread_local', '_metrics_lock']) 581 shallow_copy = set(['_scope', '_always_reuse_variable_scope']) 582 cls = self.__class__ 583 result = cls.__new__(cls) 584 memo[id(self)] = result 585 for k, v in self.__dict__.items(): 586 if k in no_copy: 587 setattr(result, k, v) 588 elif k in shallow_copy: 589 setattr(result, k, copy.copy(v)) 590 elif base_layer.is_tensor_or_tensor_list(v): 591 setattr(result, k, v) 592 else: 593 setattr(result, k, copy.deepcopy(v, memo)) 594 return result 595 596 def __setattr__(self, value, name): 597 # By-pass the automatic dependency tracking performed by the parent Layer. 598 super(trackable.Trackable, self).__setattr__(value, name) # pylint: disable=bad-super-call 599 600 @property 601 def _is_legacy_layer(self): 602 """Used by keras to check compatibility. This should not be overridden.""" 603 return True 604 605 606def _add_elements_to_collection(elements, collection_list): 607 if context.executing_eagerly(): 608 raise RuntimeError('Using collections from Layers not supported in Eager ' 609 'mode. Tried to add %s to %s' % (elements, 610 collection_list)) 611 elements = nest.flatten(elements) 612 collection_list = nest.flatten(collection_list) 613 for name in collection_list: 614 collection = ops.get_collection_ref(name) 615 collection_set = {id(e) for e in collection} 616 for element in elements: 617 if id(element) not in collection_set: 618 collection.append(element) 619