• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Contains the base Layer class, from which all layers inherit."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import copy
21
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.keras.engine import base_layer
26from tensorflow.python.keras.engine import base_layer_utils
27from tensorflow.python.ops import variable_scope as vs
28from tensorflow.python.ops import variables as tf_variables
29from tensorflow.python.training.tracking import base as trackable
30from tensorflow.python.util import function_utils
31from tensorflow.python.util import nest
32from tensorflow.python.util import tf_contextlib
33from tensorflow.python.util.tf_export import tf_export
34
35# Avoid breaking users who directly import this symbol from this file.
36# TODO(fchollet): remove this.
37InputSpec = base_layer.InputSpec  # pylint: disable=invalid-name
38
39_KERAS_STYLE_SCOPE = False
40
41
42@tf_export(v1=['layers.experimental.keras_style_scope'])
43@tf_contextlib.contextmanager
44def keras_style_scope():
45  """Use Keras-style variable management.
46
47  All tf.layers and tf RNN cells created in this scope use Keras-style
48  variable management.  Creating such layers with a scope= argument is
49  disallowed, and reuse=True is disallowed.
50
51  The purpose of this scope is to allow users of existing layers to
52  slowly transition to a Keras layers API without breaking existing
53  functionality.
54
55  One example of this is when using TensorFlow's RNN classes with Keras
56  Models or Networks.  Because Keras models do not properly set variable
57  scopes, users of RNNs may either accidentally share scopes between two
58  different models, or get errors about variables that already exist.
59
60  Example:
61
62  ```python
63  class RNNModel(tf.keras.Model):
64
65    def __init__(self, name):
66      super(RNNModel, self.).__init__(name=name)
67      self.rnn = tf.nn.rnn_cell.MultiRNNCell(
68        [tf.nn.rnn_cell.LSTMCell(64) for _ in range(2)])
69
70    def call(self, input, state):
71      return self.rnn(input, state)
72
73  model_1 = RNNModel("model_1")
74  model_2 = RNNModel("model_2")
75
76  # OK
77  output_1, next_state_1 = model_1(input, state)
78  # Raises an error about trying to create an already existing variable.
79  output_2, next_state_2 = model_2(input, state)
80  ```
81
82  The solution is to wrap the model construction and execution in a keras-style
83  scope:
84
85  ```python
86  with keras_style_scope():
87    model_1 = RNNModel("model_1")
88    model_2 = RNNModel("model_2")
89
90    # model_1 and model_2 are guaranteed to create their own variables.
91    output_1, next_state_1 = model_1(input, state)
92    output_2, next_state_2 = model_2(input, state)
93
94    assert len(model_1.weights) > 0
95    assert len(model_2.weights) > 0
96    assert(model_1.weights != model_2.weights)
97  ```
98
99  Yields:
100    A keras layer style scope.
101  """
102  global _KERAS_STYLE_SCOPE
103  stack = _KERAS_STYLE_SCOPE
104  _KERAS_STYLE_SCOPE = True
105  try:
106    yield
107  finally:
108    _KERAS_STYLE_SCOPE = stack
109
110
111@tf_export(v1=['layers.experimental.set_keras_style'])
112def set_keras_style():
113  """Use Keras-style variable management.
114
115  All tf.layers and tf RNN cells created after keras style ha been enabled
116  use Keras-style variable management.  Creating such layers with a
117  scope= argument is disallowed, and reuse=True is disallowed.
118
119  The purpose of this function is to allow users of existing layers to
120  slowly transition to Keras layers API without breaking existing
121  functionality.
122
123  For more details, see the documentation for `keras_style_scope`.
124
125  Note, once keras style has been set, it is set globally for the entire
126  program and cannot be unset.
127
128  Example:
129
130  ```python
131  set_keras_style()
132
133  model_1 = RNNModel(name="model_1")
134  model_2 = RNNModel(name="model_2")
135
136  # model_1 and model_2 are guaranteed to create their own variables.
137  output_1, next_state_1 = model_1(input, state)
138  output_2, next_state_2 = model_2(input, state)
139
140  assert len(model_1.weights) > 0
141  assert len(model_2.weights) > 0
142  assert(model_1.weights != model_2.weights)
143  ```
144  """
145  global _KERAS_STYLE_SCOPE
146  _KERAS_STYLE_SCOPE = True
147
148
149def _is_in_keras_style_scope():
150  global _KERAS_STYLE_SCOPE
151  return _KERAS_STYLE_SCOPE
152
153
154@tf_export(v1=['layers.Layer'])
155class Layer(base_layer.Layer):
156  """Base layer class.
157
158  It is considered legacy, and we recommend the use of `tf.keras.layers.Layer`
159  instead.
160
161  Arguments:
162    trainable: Boolean, whether the layer's variables should be trainable.
163    name: String name of the layer.
164    dtype: Default dtype of the layer's weights (default of `None` means use the
165      type of the first input).
166
167  Read-only properties:
168    name: The name of the layer (string).
169    dtype: Default dtype of the layer's weights (default of `None` means use the
170      type of the first input).
171    trainable_variables: List of trainable variables.
172    non_trainable_variables: List of non-trainable variables.
173    variables: List of all variables of this layer, trainable and
174      non-trainable.
175    updates: List of update ops of this layer.
176    losses: List of losses added by this layer.
177    trainable_weights: List of variables to be included in backprop.
178    non_trainable_weights: List of variables that should not be
179      included in backprop.
180    weights: The concatenation of the lists trainable_weights and
181      non_trainable_weights (in this order).
182
183  Mutable properties:
184    trainable: Whether the layer should be trained (boolean).
185    input_spec: Optional (list of) `InputSpec` object(s) specifying the
186      constraints on inputs that can be accepted by the layer.
187  """
188
189  def __init__(self, trainable=True, name=None, dtype=None,
190               **kwargs):
191    # For backwards compatibility, legacy layers do not use `ResourceVariable`
192    # by default.
193    self._use_resource_variables = False
194    scope = kwargs.pop('_scope', None)
195    self._reuse = kwargs.pop('_reuse', None)
196
197    # Avoid an incorrect lint error
198    self._trainable_weights = []
199    self.built = False
200
201    super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
202                                **kwargs)
203
204    if _is_in_keras_style_scope():
205      if scope is not None:
206        raise ValueError(
207            'scope argument not allowed when keras style layers are enabled, '
208            'but saw: {}'.format(scope))
209      if self._reuse is not None:
210        raise ValueError(
211            'reuse argument not allowed when keras style layers are enabled, '
212            'but saw: {}'.format(self._reuse))
213      self._keras_style = True
214    else:
215      self._keras_style = False
216
217    self._graph = None
218    self._call_has_scope_arg = 'scope' in self._call_fn_args
219    if scope:
220      with vs.variable_scope(scope) as captured_scope:
221        self._scope = captured_scope
222    else:
223      self._scope = None
224    self._current_scope = None
225
226  @property
227  def graph(self):
228    if context.executing_eagerly():
229      raise RuntimeError('Layer.graph not supported when executing eagerly.')
230    return self._graph
231
232  def _init_set_name(self, name):
233    # Determine layer name (non-unique).
234    if isinstance(name, vs.VariableScope):
235      base_name = name.name
236      self._name, _ = self._make_unique_name()
237    else:
238      base_name = name
239      self._name = name
240    if not name:
241      self._name, base_name = self._make_unique_name()
242    self._base_name = base_name
243
244  def _make_unique_name(self, name_uid_map=None, avoid_names=None,
245                        namespace='', zero_based=False):
246    base_name = base_layer.to_snake_case(self.__class__.__name__)
247    name = base_layer_utils.unique_layer_name(base_name,
248                                              name_uid_map=name_uid_map,
249                                              avoid_names=avoid_names,
250                                              namespace=namespace,
251                                              zero_based=zero_based)
252    return (name, base_name)
253
254  @property
255  def scope_name(self):
256    if not self._scope:
257      raise ValueError('No name available for layer scope because the layer "' +
258                       self._name + '" has not been used yet. The scope name ' +
259                       ' is determined the first time the layer instance is ' +
260                       'called. You must therefore call the layer before ' +
261                       'querying `scope_name`.')
262    return self._scope.name
263
264  def add_loss(self, losses, inputs=None):
265    previous_losses_length = len(self._losses)
266    previous_callable_losses_length = len(self._callable_losses)
267    super(Layer, self).add_loss(losses, inputs=inputs)
268    if not context.executing_eagerly():
269      # TODO(fchollet): deprecate collection below.
270      new_losses = self._losses[previous_losses_length:]
271      new_callable_losses = self._callable_losses[
272          previous_callable_losses_length:]
273      for regularizer in new_callable_losses:
274        loss_tensor = regularizer()
275        if loss_tensor is not None:
276          new_losses.append(loss_tensor)
277      _add_elements_to_collection(
278          new_losses,
279          ops.GraphKeys.REGULARIZATION_LOSSES)
280
281  def _name_scope(self):
282    """Determines op naming for the Layer."""
283    if self._keras_style:
284      return super(Layer, self)._name_scope()
285    return self._current_scope.original_name_scope
286
287  def _set_scope(self, scope=None):
288    if self._scope is None:
289      # If constructed with _scope=None, lazy setting of scope.
290      if self._reuse:
291        with vs.variable_scope(
292            scope if scope is not None else self._base_name) as captured_scope:
293          self._scope = captured_scope
294      else:
295        with vs.variable_scope(
296            scope, default_name=self._base_name) as captured_scope:
297          self._scope = captured_scope
298
299  def add_weight(self,
300                 name,
301                 shape,
302                 dtype=None,
303                 initializer=None,
304                 regularizer=None,
305                 trainable=None,
306                 constraint=None,
307                 use_resource=None,
308                 synchronization=vs.VariableSynchronization.AUTO,
309                 aggregation=vs.VariableAggregation.NONE,
310                 partitioner=None,
311                 **kwargs):
312    """Adds a new variable to the layer, or gets an existing one; returns it.
313
314    Arguments:
315      name: variable name.
316      shape: variable shape.
317      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
318      initializer: initializer instance (callable).
319      regularizer: regularizer instance (callable).
320      trainable: whether the variable should be part of the layer's
321        "trainable_variables" (e.g. variables, biases)
322        or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
323        Note, if the current variable scope is marked as non-trainable
324        then this parameter is ignored and any added variables are also
325        marked as non-trainable. `trainable` defaults to `True` unless
326        `synchronization` is set to `ON_READ`.
327      constraint: constraint instance (callable).
328      use_resource: Whether to use `ResourceVariable`.
329      synchronization: Indicates when a distributed a variable will be
330        aggregated. Accepted values are constants defined in the class
331        `tf.VariableSynchronization`. By default the synchronization is set to
332        `AUTO` and the current `DistributionStrategy` chooses
333        when to synchronize. If `synchronization` is set to `ON_READ`,
334        `trainable` must not be set to `True`.
335      aggregation: Indicates how a distributed variable will be aggregated.
336        Accepted values are constants defined in the class
337        `tf.VariableAggregation`.
338      partitioner: (optional) partitioner instance (callable).  If
339        provided, when the requested variable is created it will be split
340        into multiple partitions according to `partitioner`.  In this case,
341        an instance of `PartitionedVariable` is returned.  Available
342        partitioners include `tf.fixed_size_partitioner` and
343        `tf.variable_axis_size_partitioner`.  For more details, see the
344        documentation of `tf.get_variable` and the  "Variable Partitioners
345        and Sharding" section of the API guide.
346      **kwargs: Additional keyword arguments.
347
348    Returns:
349      The created variable.  Usually either a `Variable` or `ResourceVariable`
350      instance.  If `partitioner` is not `None`, a `PartitionedVariable`
351      instance is returned.
352
353    Raises:
354      RuntimeError: If called with partioned variable regularization and
355        eager execution is enabled.
356      ValueError: When trainable has been set to True with synchronization
357        set as `ON_READ`.
358    """
359    for kwarg in kwargs:
360      if kwarg != 'experimental_autocast':
361        raise TypeError('Unknown keyword argument:', kwarg)
362    if self._keras_style:
363      return super(Layer, self).add_weight(
364          name=name,
365          shape=shape,
366          dtype=dtype,
367          initializer=initializer,
368          regularizer=regularizer,
369          trainable=trainable,
370          constraint=constraint,
371          use_resource=use_resource,
372          synchronization=vs.VariableSynchronization.AUTO,
373          aggregation=vs.VariableAggregation.NONE,
374          partitioner=partitioner,
375          **kwargs)
376
377    if synchronization == vs.VariableSynchronization.ON_READ:
378      if trainable:
379        raise ValueError(
380            'Synchronization value can be set to '
381            'VariableSynchronization.ON_READ only for non-trainable variables. '
382            'You have specified trainable=True and '
383            'synchronization=VariableSynchronization.ON_READ.')
384      else:
385        # Set trainable to be false when variable is to be synced on read.
386        trainable = False
387    elif trainable is None:
388      trainable = True
389
390    def _should_add_regularizer(variable, existing_variable_set):
391      if isinstance(variable, tf_variables.PartitionedVariable):
392        for var in variable:
393          if var in existing_variable_set:
394            return False
395        return True
396      else:
397        return variable not in existing_variable_set
398
399    init_graph = None
400    if not context.executing_eagerly():
401      default_graph = ops.get_default_graph()
402      if default_graph.building_function:
403        with ops.init_scope():
404          # Retrieve the variables from the graph into which variables
405          # will be lifted; if initialization ops will be lifted into
406          # the eager context, then there is nothing to retrieve, since variable
407          # collections are not supported when eager execution is enabled.
408          if not context.executing_eagerly():
409            init_graph = ops.get_default_graph()
410            existing_variables = set(tf_variables.global_variables())
411      else:
412        # Initialization ops will not be lifted out of the default graph.
413        init_graph = default_graph
414        existing_variables = set(tf_variables.global_variables())
415
416    if dtype is None:
417      dtype = self.dtype or dtypes.float32
418
419    self._set_scope(None)
420    reuse = self.built or self._reuse
421    prev_len_trainable = len(self._trainable_weights)
422    with vs.variable_scope(
423        self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
424      self._current_scope = scope
425      with ops.name_scope(self._name_scope()):
426        use_resource = (use_resource or
427                        self._use_resource_variables or
428                        scope.use_resource)
429        if initializer is None:
430          initializer = scope.initializer
431        variable = super(Layer, self).add_weight(
432            name,
433            shape,
434            dtype=dtypes.as_dtype(dtype),
435            initializer=initializer,
436            trainable=trainable,
437            constraint=constraint,
438            partitioner=partitioner,
439            use_resource=use_resource,
440            synchronization=synchronization,
441            aggregation=aggregation,
442            getter=vs.get_variable,
443            **kwargs)
444
445        if regularizer:
446          if (ops.executing_eagerly_outside_functions()
447              or _should_add_regularizer(variable, existing_variables)):
448            self._handle_weight_regularization(name, variable, regularizer)
449
450        if init_graph is not None:
451          # Handle edge case where a custom getter has overridden `trainable`.
452          # There is one known occurrence of this, in unit test
453          # testBasicRNNCellNotTrainable in
454          # contrib.rnn.python.kernel_tests.core_rnn_cell_test
455          with init_graph.as_default():
456            trainable_variables = tf_variables.trainable_variables()
457          if (trainable and self.trainable and
458              variable not in trainable_variables):
459            # A custom getter / variable scope overrode the trainable flag.
460            extra_trainable_vars = self._trainable_weights[prev_len_trainable:]
461            self._trainable_weights = self._trainable_weights[
462                :prev_len_trainable]
463            self._non_trainable_weights += extra_trainable_vars
464    return variable
465
466  def __call__(self, inputs, *args, **kwargs):
467    """Wraps `call`, applying pre- and post-processing steps.
468
469    Arguments:
470      inputs: input tensor(s).
471      *args: additional positional arguments to be passed to `self.call`.
472      **kwargs: additional keyword arguments to be passed to `self.call`.
473        **Note**: kwarg `scope` is reserved for use by the layer.
474
475    Returns:
476      Output tensor(s).
477
478    Note:
479      - If the layer's `call` method takes a `scope` keyword argument,
480        this argument will be automatically set to the current variable scope.
481      - If the layer's `call` method takes a `mask` argument (as some Keras
482        layers do), its default value will be set to the mask generated
483        for `inputs` by the previous layer (if `input` did come from
484        a layer that generated a corresponding mask, i.e. if it came from
485        a Keras layer with masking support.
486
487    Raises:
488      ValueError: if the layer's `call` method returns None (an invalid value).
489    """
490    scope = kwargs.pop('scope', None)
491
492    if self._keras_style:
493      if scope is not None:
494        raise ValueError(
495            'scope argument not allowed when keras style layers are enabled, '
496            'but saw: {}'.format(scope))
497      return super(Layer, self).__call__(inputs, *args, **kwargs)
498
499    self._set_scope(scope)
500
501    if not context.executing_eagerly():
502      try:
503        # Set layer's "graph" at build time
504        self._graph = ops._get_graph_from_inputs(nest.flatten(inputs),  # pylint: disable=protected-access
505                                                 graph=self._graph)
506      except ValueError as e:
507        raise ValueError('Input graph and Layer graph are not the same: %s' % e)
508
509    if self.built:
510      try:
511        # Some classes which inherit from Layer do not use its constructor, so
512        # rather than initializing to None we check for an AttributeError.
513        scope_context_manager = self._always_reuse_variable_scope
514      except AttributeError:
515        # From this point we will always set reuse=True, so create a "final"
516        # variable scope with this setting. We avoid re-creating variable scopes
517        # after this point as an optimization.
518        self._always_reuse_variable_scope = vs.variable_scope(
519            self._scope, reuse=True, auxiliary_name_scope=False)
520        scope_context_manager = self._always_reuse_variable_scope
521    else:
522      scope_context_manager = vs.variable_scope(
523          self._scope, reuse=self._reuse, auxiliary_name_scope=False)
524
525    with scope_context_manager as scope:
526      self._current_scope = scope
527
528      try:
529        call_has_scope_arg = self._call_has_scope_arg
530      except AttributeError:
531        self._call_fn_args = function_utils.fn_args(self.call)
532        self._call_has_scope_arg = 'scope' in self._call_fn_args
533        call_has_scope_arg = self._call_has_scope_arg
534      if call_has_scope_arg:
535        kwargs['scope'] = scope
536
537      # Actually call layer
538      outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
539
540    if not context.executing_eagerly():
541      # Update global default collections.
542      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
543    return outputs
544
545  def __deepcopy__(self, memo):
546    no_copy = set(['_graph'])
547    shallow_copy = set(['_scope', '_always_reuse_variable_scope'])
548    cls = self.__class__
549    result = cls.__new__(cls)
550    memo[id(self)] = result
551    for k, v in self.__dict__.items():
552      if k in no_copy:
553        setattr(result, k, v)
554      elif k in shallow_copy:
555        setattr(result, k, copy.copy(v))
556      elif base_layer.is_tensor_or_tensor_list(v):
557        setattr(result, k, v)
558      else:
559        setattr(result, k, copy.deepcopy(v, memo))
560    return result
561
562  def __setattr__(self, value, name):
563    # By-pass the automatic dependency tracking performed by the parent Layer.
564    super(trackable.Trackable, self).__setattr__(value, name)
565
566
567def _add_elements_to_collection(elements, collection_list):
568  if context.executing_eagerly():
569    raise RuntimeError('Using collections from Layers not supported in Eager '
570                       'mode. Tried to add %s to %s' % (elements,
571                                                        collection_list))
572  elements = nest.flatten(elements)
573  collection_list = nest.flatten(collection_list)
574  for name in collection_list:
575    collection = ops.get_collection_ref(name)
576    collection_set = set(collection)
577    for element in elements:
578      if element not in collection_set:
579        collection.append(element)
580