• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15# pylint: disable=g-classes-have-attributes
16"""Contains the base Layer class, from which all layers inherit."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import warnings
23
24from tensorflow.python.eager import context
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.keras import backend
28from tensorflow.python.keras.engine import base_layer
29from tensorflow.python.keras.engine import base_layer_utils
30from tensorflow.python.keras.legacy_tf_layers import variable_scope_shim
31from tensorflow.python.keras.mixed_precision import policy
32from tensorflow.python.keras.utils import tf_contextlib
33from tensorflow.python.ops import variable_scope as vs
34from tensorflow.python.ops import variables as tf_variables
35from tensorflow.python.training.tracking import base as trackable
36from tensorflow.python.util import nest
37from tensorflow.python.util.tf_export import keras_export
38from tensorflow.python.util.tf_export import tf_export
39
40# Avoid breaking users who directly import this symbol from this file.
41# TODO(fchollet): remove this.
42InputSpec = base_layer.InputSpec  # pylint: disable=invalid-name
43
44_KERAS_STYLE_SCOPE = False
45
46
47@keras_export(
48    v1=['keras.__internal__.legacy.layers.experimental.keras_style_scope'])
49@tf_export(v1=['layers.experimental.keras_style_scope'])
50@tf_contextlib.contextmanager
51def keras_style_scope():
52  """Use Keras-style variable management.
53
54  All tf.layers and tf RNN cells created in this scope use Keras-style
55  variable management.  Creating such layers with a scope= argument is
56  disallowed, and reuse=True is disallowed.
57
58  The purpose of this scope is to allow users of existing layers to
59  slowly transition to a Keras layers API without breaking existing
60  functionality.
61
62  One example of this is when using TensorFlow's RNN classes with Keras
63  Models or Networks.  Because Keras models do not properly set variable
64  scopes, users of RNNs may either accidentally share scopes between two
65  different models, or get errors about variables that already exist.
66
67  Example:
68
69  ```python
70  class RNNModel(tf.keras.Model):
71
72    def __init__(self, name):
73      super(RNNModel, self).__init__(name=name)
74      self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell(
75        [tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)])
76
77    def call(self, input, state):
78      return self.rnn(input, state)
79
80  model_1 = RNNModel("model_1")
81  model_2 = RNNModel("model_2")
82
83  # OK
84  output_1, next_state_1 = model_1(input, state)
85  # Raises an error about trying to create an already existing variable.
86  output_2, next_state_2 = model_2(input, state)
87  ```
88
89  The solution is to wrap the model construction and execution in a keras-style
90  scope:
91
92  ```python
93  with keras_style_scope():
94    model_1 = RNNModel("model_1")
95    model_2 = RNNModel("model_2")
96
97    # model_1 and model_2 are guaranteed to create their own variables.
98    output_1, next_state_1 = model_1(input, state)
99    output_2, next_state_2 = model_2(input, state)
100
101    assert len(model_1.weights) > 0
102    assert len(model_2.weights) > 0
103    assert(model_1.weights != model_2.weights)
104  ```
105
106  Yields:
107    A keras layer style scope.
108  """
109  global _KERAS_STYLE_SCOPE
110  stack = _KERAS_STYLE_SCOPE
111  _KERAS_STYLE_SCOPE = True
112  try:
113    yield
114  finally:
115    _KERAS_STYLE_SCOPE = stack
116
117
118@keras_export(
119    v1=['keras.__internal__.legacy.layers.experimental.set_keras_style'])
120@tf_export(v1=['layers.experimental.set_keras_style'])
121def set_keras_style():
122  """Use Keras-style variable management.
123
124  All tf.layers and tf RNN cells created after keras style ha been enabled
125  use Keras-style variable management.  Creating such layers with a
126  scope= argument is disallowed, and reuse=True is disallowed.
127
128  The purpose of this function is to allow users of existing layers to
129  slowly transition to Keras layers API without breaking existing
130  functionality.
131
132  For more details, see the documentation for `keras_style_scope`.
133
134  Note, once keras style has been set, it is set globally for the entire
135  program and cannot be unset.
136
137  Example:
138
139  ```python
140  set_keras_style()
141
142  model_1 = RNNModel(name="model_1")
143  model_2 = RNNModel(name="model_2")
144
145  # model_1 and model_2 are guaranteed to create their own variables.
146  output_1, next_state_1 = model_1(input, state)
147  output_2, next_state_2 = model_2(input, state)
148
149  assert len(model_1.weights) > 0
150  assert len(model_2.weights) > 0
151  assert(model_1.weights != model_2.weights)
152  ```
153  """
154  global _KERAS_STYLE_SCOPE
155  _KERAS_STYLE_SCOPE = True
156
157
158def _is_in_keras_style_scope():
159  global _KERAS_STYLE_SCOPE
160  return _KERAS_STYLE_SCOPE
161
162
163@keras_export(v1=['keras.__internal__.legacy.layers.Layer'])
164@tf_export(v1=['layers.Layer'])
165class Layer(base_layer.Layer):
166  """Base layer class.
167
168  It is considered legacy, and we recommend the use of `tf.keras.layers.Layer`
169  instead.
170
171  Args:
172    trainable: Boolean, whether the layer's variables should be trainable.
173    name: String name of the layer.
174    dtype: Default dtype of the layer's weights (default of `None` means use the
175      type of the first input).
176
177  Read-only properties:
178    name: The name of the layer (string).
179    dtype: Default dtype of the layer's weights (default of `None` means use the
180      type of the first input).
181    trainable_variables: List of trainable variables.
182    non_trainable_variables: List of non-trainable variables.
183    variables: List of all variables of this layer, trainable and
184      non-trainable.
185    updates: List of update ops of this layer.
186    losses: List of losses added by this layer.
187    trainable_weights: List of variables to be included in backprop.
188    non_trainable_weights: List of variables that should not be
189      included in backprop.
190    weights: The concatenation of the lists trainable_weights and
191      non_trainable_weights (in this order).
192
193  Mutable properties:
194    trainable: Whether the layer should be trained (boolean).
195    input_spec: Optional (list of) `InputSpec` object(s) specifying the
196      constraints on inputs that can be accepted by the layer.
197  """
198
199  def __init__(self, trainable=True, name=None, dtype=None,
200               **kwargs):
201    # For backwards compatibility, legacy layers do not use `ResourceVariable`
202    # by default.
203    self._use_resource_variables = False
204    scope = kwargs.pop('_scope', None)
205    self._reuse = kwargs.pop('_reuse', None)
206
207    # Avoid an incorrect lint error
208    self._trainable_weights = []
209    self.built = False
210
211    if dtype is None:
212      # Indicates to infer dtype from inputs. When the V2 dtype behavior is
213      # enabled, Keras layers default their dtype to floatx instead, so we pass
214      # an "_infer" policy to keep the old V1 behavior.
215      dtype = policy.Policy('_infer')
216
217    if 'autocast' not in kwargs:
218      kwargs['autocast'] = False
219
220    # Mark that legacy layers should not be instrumented as Keras usage
221    self._disable_keras_instrumentation = True
222
223    super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
224                                **kwargs)
225
226    if _is_in_keras_style_scope():
227      if scope is not None:
228        raise ValueError(
229            'scope argument not allowed when keras style layers are enabled, '
230            'but saw: {}'.format(scope))
231      if self._reuse is not None:
232        raise ValueError(
233            'reuse argument not allowed when keras style layers are enabled, '
234            'but saw: {}'.format(self._reuse))
235      self._keras_style = True
236    else:
237      self._keras_style = False
238
239    self._call_has_scope_arg = 'scope' in self._call_fn_args
240    if scope:
241      with vs.variable_scope(scope) as captured_scope:
242        self._scope = captured_scope
243    else:
244      self._scope = None
245    self._current_scope = None
246
247  # We no longer track graph in tf.layers layers. This property is only kept to
248  # maintain API backward compatibility.
249  @property
250  def graph(self):
251    warnings.warn('`Layer.graph` is deprecated and '
252                  'will be removed in a future version. '
253                  'Please stop using this property because tf.layers layers no '
254                  'longer track their graph.')
255    if context.executing_eagerly():
256      raise RuntimeError('Layer.graph not supported when executing eagerly.')
257    return None
258
259  def _init_set_name(self, name):
260    # Determine layer name (non-unique).
261    if isinstance(name, vs.VariableScope):
262      base_name = name.name
263      self._name, _ = self._make_unique_name()
264    else:
265      base_name = name
266      self._name = name
267    if not name:
268      self._name, base_name = self._make_unique_name()
269    self._base_name = base_name
270
271  def _make_unique_name(self, name_uid_map=None, avoid_names=None,
272                        namespace='', zero_based=False):
273    base_name = base_layer.to_snake_case(self.__class__.__name__)
274    name = backend.unique_object_name(
275        base_name,
276        name_uid_map=name_uid_map,
277        avoid_names=avoid_names,
278        namespace=namespace,
279        zero_based=zero_based)
280    return (name, base_name)
281
282  @property
283  def scope_name(self):
284    if not self._scope:
285      raise ValueError('No name available for layer scope because the layer "' +
286                       self._name + '" has not been used yet. The scope name ' +
287                       ' is determined the first time the layer instance is ' +
288                       'called. You must therefore call the layer before ' +
289                       'querying `scope_name`.')
290    return self._scope.name
291
292  def add_loss(self, losses, inputs=None):
293    previous_losses_length = len(self._losses)
294    previous_callable_losses_length = len(self._callable_losses)
295    super(Layer, self).add_loss(losses, inputs=inputs)
296    if not context.executing_eagerly():
297      # TODO(fchollet): deprecate collection below.
298      new_losses = self._losses[previous_losses_length:]
299      new_callable_losses = self._callable_losses[
300          previous_callable_losses_length:]
301      for regularizer in new_callable_losses:
302        loss_tensor = regularizer()
303        if loss_tensor is not None:
304          new_losses.append(loss_tensor)
305      _add_elements_to_collection(
306          new_losses,
307          ops.GraphKeys.REGULARIZATION_LOSSES)
308
309  def _name_scope(self):  # pylint: disable=method-hidden
310    """Determines op naming for the Layer."""
311    if self._keras_style:
312      return super(Layer, self)._name_scope()
313    return self._current_scope.original_name_scope
314
315  def _set_scope(self, scope=None):
316    if self._scope is None:
317      # If constructed with _scope=None, lazy setting of scope.
318      if self._reuse:
319        with vs.variable_scope(
320            scope if scope is not None else self._base_name) as captured_scope:
321          self._scope = captured_scope
322      else:
323        with vs.variable_scope(
324            scope, default_name=self._base_name) as captured_scope:
325          self._scope = captured_scope
326
327  def add_weight(self,
328                 name,
329                 shape,
330                 dtype=None,
331                 initializer=None,
332                 regularizer=None,
333                 trainable=None,
334                 constraint=None,
335                 use_resource=None,
336                 synchronization=vs.VariableSynchronization.AUTO,
337                 aggregation=vs.VariableAggregation.NONE,
338                 partitioner=None,
339                 **kwargs):
340    """Adds a new variable to the layer, or gets an existing one; returns it.
341
342    Args:
343      name: variable name.
344      shape: variable shape.
345      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
346      initializer: initializer instance (callable).
347      regularizer: regularizer instance (callable).
348      trainable: whether the variable should be part of the layer's
349        "trainable_variables" (e.g. variables, biases)
350        or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
351        Note, if the current variable scope is marked as non-trainable
352        then this parameter is ignored and any added variables are also
353        marked as non-trainable. `trainable` defaults to `True` unless
354        `synchronization` is set to `ON_READ`.
355      constraint: constraint instance (callable).
356      use_resource: Whether to use `ResourceVariable`.
357      synchronization: Indicates when a distributed a variable will be
358        aggregated. Accepted values are constants defined in the class
359        `tf.VariableSynchronization`. By default the synchronization is set to
360        `AUTO` and the current `DistributionStrategy` chooses
361        when to synchronize. If `synchronization` is set to `ON_READ`,
362        `trainable` must not be set to `True`.
363      aggregation: Indicates how a distributed variable will be aggregated.
364        Accepted values are constants defined in the class
365        `tf.VariableAggregation`.
366      partitioner: (optional) partitioner instance (callable).  If
367        provided, when the requested variable is created it will be split
368        into multiple partitions according to `partitioner`.  In this case,
369        an instance of `PartitionedVariable` is returned.  Available
370        partitioners include `tf.compat.v1.fixed_size_partitioner` and
371        `tf.compat.v1.variable_axis_size_partitioner`.  For more details, see
372        the documentation of `tf.compat.v1.get_variable` and the  "Variable
373        Partitioners and Sharding" section of the API guide.
374      **kwargs: Additional keyword arguments.
375
376    Returns:
377      The created variable.  Usually either a `Variable` or `ResourceVariable`
378      instance.  If `partitioner` is not `None`, a `PartitionedVariable`
379      instance is returned.
380
381    Raises:
382      RuntimeError: If called with partitioned variable regularization and
383        eager execution is enabled.
384      ValueError: When trainable has been set to True with synchronization
385        set as `ON_READ`.
386    """
387    for kwarg in kwargs:
388      if kwarg != 'experimental_autocast':
389        raise TypeError('Unknown keyword argument:', kwarg)
390    if self._keras_style:
391      return super(Layer, self).add_weight(
392          name=name,
393          shape=shape,
394          dtype=dtype,
395          initializer=initializer,
396          regularizer=regularizer,
397          trainable=trainable and self.trainable,
398          constraint=constraint,
399          use_resource=use_resource,
400          synchronization=vs.VariableSynchronization.AUTO,
401          aggregation=vs.VariableAggregation.NONE,
402          partitioner=partitioner,
403          **kwargs)
404
405    if synchronization == vs.VariableSynchronization.ON_READ:
406      if trainable:
407        raise ValueError(
408            'Synchronization value can be set to '
409            'VariableSynchronization.ON_READ only for non-trainable variables. '
410            'You have specified trainable=True and '
411            'synchronization=VariableSynchronization.ON_READ.')
412      else:
413        # Set trainable to be false when variable is to be synced on read.
414        trainable = False
415    elif trainable is None:
416      trainable = True
417
418    def _should_add_regularizer(variable, existing_variable_set):
419      if base_layer_utils.is_split_variable(variable):
420        for var in variable:
421          if var in existing_variable_set:
422            return False
423        return True
424      else:
425        return variable not in existing_variable_set
426
427    init_graph = None
428    if not context.executing_eagerly():
429      default_graph = ops.get_default_graph()
430      if default_graph.building_function:
431        with ops.init_scope():
432          # Retrieve the variables from the graph into which variables
433          # will be lifted; if initialization ops will be lifted into
434          # the eager context, then there is nothing to retrieve, since variable
435          # collections are not supported when eager execution is enabled.
436          if not context.executing_eagerly():
437            init_graph = ops.get_default_graph()
438            existing_variables = set(tf_variables.global_variables())
439      else:
440        # Initialization ops will not be lifted out of the default graph.
441        init_graph = default_graph
442        existing_variables = set(tf_variables.global_variables())
443
444    if dtype is None:
445      dtype = self.dtype or dtypes.float32
446
447    self._set_scope(None)
448    reuse = self.built or self._reuse
449    prev_len_trainable = len(self._trainable_weights)
450    with vs.variable_scope(
451        self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
452      self._current_scope = scope
453      with backend.name_scope(self._name_scope()):  # pylint: disable=not-callable
454        use_resource = (use_resource or
455                        self._use_resource_variables or
456                        scope.use_resource)
457        if initializer is None:
458          initializer = scope.initializer
459        variable = super(Layer, self).add_weight(
460            name,
461            shape,
462            dtype=dtypes.as_dtype(dtype),
463            initializer=initializer,
464            trainable=trainable and self.trainable,
465            constraint=constraint,
466            partitioner=partitioner,
467            use_resource=use_resource,
468            synchronization=synchronization,
469            aggregation=aggregation,
470            getter=vs.get_variable,
471            **kwargs)
472
473        if regularizer:
474          if (ops.executing_eagerly_outside_functions()
475              or _should_add_regularizer(variable, existing_variables)):
476            self._handle_weight_regularization(name, variable, regularizer)
477            var_store = vs._get_default_variable_store()  # pylint: disable=protected-access
478            # When the shim to get variable scope working in TF2 is used,
479            # We need to explicitly make the shim track the regularization
480            # losses as the collections will not be accessible.
481            if hasattr(var_store, 'add_regularizer'):
482              var_store.add_regularizer(variable, regularizer)
483
484        if init_graph is not None:
485          # Handle edge case where a custom getter has overridden `trainable`.
486          # There is one known occurrence of this, in unit test
487          # testBasicRNNCellNotTrainable in
488          # contrib.rnn.python.kernel_tests.core_rnn_cell_test
489          with init_graph.as_default():
490            trainable_variables = tf_variables.trainable_variables()
491          if (trainable and self.trainable and
492              variable not in trainable_variables):
493            # A custom getter / variable scope overrode the trainable flag.
494            extra_trainable_vars = self._trainable_weights[prev_len_trainable:]
495            self._trainable_weights = self._trainable_weights[
496                :prev_len_trainable]
497            self._non_trainable_weights += extra_trainable_vars
498    return variable
499
500  def __call__(self, inputs, *args, **kwargs):
501    """Wraps `call`, applying pre- and post-processing steps.
502
503    Args:
504      inputs: input tensor(s).
505      *args: additional positional arguments to be passed to `self.call`.
506      **kwargs: additional keyword arguments to be passed to `self.call`.
507        **Note**: kwarg `scope` is reserved for use by the layer.
508
509    Returns:
510      Output tensor(s).
511
512    Note:
513      - If the layer's `call` method takes a `scope` keyword argument,
514        this argument will be automatically set to the current variable scope.
515      - If the layer's `call` method takes a `mask` argument (as some Keras
516        layers do), its default value will be set to the mask generated
517        for `inputs` by the previous layer (if `input` did come from
518        a layer that generated a corresponding mask, i.e. if it came from
519        a Keras layer with masking support.
520
521    Raises:
522      ValueError: if the layer's `call` method returns None (an invalid value).
523    """
524    scope = kwargs.pop('scope', None)
525
526    if self._keras_style:
527      if scope is not None:
528        raise ValueError(
529            'scope argument not allowed when keras style layers are enabled, '
530            'but saw: {}'.format(scope))
531      return super(Layer, self).__call__(inputs, *args, **kwargs)
532
533    self._set_scope(scope)
534
535    if self.built:
536      try:
537        # Some classes which inherit from Layer do not use its constructor, so
538        # rather than initializing to None we check for an AttributeError.
539        scope_context_manager = self._always_reuse_variable_scope  # pylint: disable=access-member-before-definition
540      except AttributeError:
541        scope_context_manager = None
542
543      if scope_context_manager is None:
544        # From this point we will always set reuse=True, so create a "final"
545        # variable scope with this setting. We avoid re-creating variable scopes
546        # after this point as an optimization.
547        scope_context_manager = vs.variable_scope(
548            self._scope, reuse=True, auxiliary_name_scope=False)
549
550        # Do not cache variable scopes if Eager mode is enabled. If Eager mode
551        # is enabled then we don't want to reuse scopes because the cached scope
552        # might be from a FuncGraph or Eager scope we are no longer in.
553        if not ops.executing_eagerly_outside_functions():
554          self._always_reuse_variable_scope = scope_context_manager
555    else:
556      scope_context_manager = vs.variable_scope(
557          self._scope, reuse=self._reuse, auxiliary_name_scope=False)
558
559    with scope_context_manager as scope:
560      self._current_scope = scope
561
562      try:
563        call_has_scope_arg = self._call_has_scope_arg
564      except AttributeError:
565        self._call_fn_args = variable_scope_shim.fn_args(self.call)
566        self._call_has_scope_arg = 'scope' in self._call_fn_args
567        call_has_scope_arg = self._call_has_scope_arg
568      if call_has_scope_arg:
569        kwargs['scope'] = scope
570
571      # Actually call layer
572      outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
573
574    if not context.executing_eagerly():
575      # Update global default collections.
576      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
577    return outputs
578
579  def __deepcopy__(self, memo):
580    no_copy = set(['_graph', '_thread_local', '_metrics_lock'])
581    shallow_copy = set(['_scope', '_always_reuse_variable_scope'])
582    cls = self.__class__
583    result = cls.__new__(cls)
584    memo[id(self)] = result
585    for k, v in self.__dict__.items():
586      if k in no_copy:
587        setattr(result, k, v)
588      elif k in shallow_copy:
589        setattr(result, k, copy.copy(v))
590      elif base_layer.is_tensor_or_tensor_list(v):
591        setattr(result, k, v)
592      else:
593        setattr(result, k, copy.deepcopy(v, memo))
594    return result
595
596  def __setattr__(self, value, name):
597    # By-pass the automatic dependency tracking performed by the parent Layer.
598    super(trackable.Trackable, self).__setattr__(value, name)  # pylint: disable=bad-super-call
599
600  @property
601  def _is_legacy_layer(self):
602    """Used by keras to check compatibility. This should not be overridden."""
603    return True
604
605
606def _add_elements_to_collection(elements, collection_list):
607  if context.executing_eagerly():
608    raise RuntimeError('Using collections from Layers not supported in Eager '
609                       'mode. Tried to add %s to %s' % (elements,
610                                                        collection_list))
611  elements = nest.flatten(elements)
612  collection_list = nest.flatten(collection_list)
613  for name in collection_list:
614    collection = ops.get_collection_ref(name)
615    collection_set = {id(e) for e in collection}
616    for element in elements:
617      if id(element) not in collection_set:
618        collection.append(element)
619