• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Wrapper layers: layers that augment the functionality of another layer.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import copy
23
24from tensorflow.python.eager import context
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.keras import backend as K
27from tensorflow.python.keras.engine.base_layer import Layer
28from tensorflow.python.keras.engine.input_spec import InputSpec
29from tensorflow.python.keras.layers.recurrent import _standardize_args
30from tensorflow.python.keras.utils import generic_utils
31from tensorflow.python.keras.utils import layer_utils
32from tensorflow.python.keras.utils import tf_inspect
33from tensorflow.python.keras.utils import tf_utils
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops.ragged import ragged_tensor
36from tensorflow.python.util import nest
37from tensorflow.python.util.tf_export import keras_export
38
39
40@keras_export('keras.layers.Wrapper')
41class Wrapper(Layer):
42  """Abstract wrapper base class.
43
44  Wrappers take another layer and augment it in various ways.
45  Do not use this class as a layer, it is only an abstract base class.
46  Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
47
48  Args:
49    layer: The layer to be wrapped.
50  """
51
52  def __init__(self, layer, **kwargs):
53    assert isinstance(layer, Layer)
54    self.layer = layer
55    super(Wrapper, self).__init__(**kwargs)
56
57  def build(self, input_shape=None):
58    if not self.layer.built:
59      self.layer.build(input_shape)
60      self.layer.built = True
61    self.built = True
62
63  @property
64  def activity_regularizer(self):
65    if hasattr(self.layer, 'activity_regularizer'):
66      return self.layer.activity_regularizer
67    else:
68      return None
69
70  def get_config(self):
71    config = {'layer': generic_utils.serialize_keras_object(self.layer)}
72    base_config = super(Wrapper, self).get_config()
73    return dict(list(base_config.items()) + list(config.items()))
74
75  @classmethod
76  def from_config(cls, config, custom_objects=None):
77    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
78    # Avoid mutating the input dict
79    config = copy.deepcopy(config)
80    layer = deserialize_layer(
81        config.pop('layer'), custom_objects=custom_objects)
82    return cls(layer, **config)
83
84
85@keras_export('keras.layers.TimeDistributed')
86class TimeDistributed(Wrapper):
87  """This wrapper allows to apply a layer to every temporal slice of an input.
88
89  Every input should be at least 3D, and the dimension of index one of the
90  first input will be considered to be the temporal dimension.
91
92  Consider a batch of 32 video samples, where each sample is a 128x128 RGB image
93  with `channels_last` data format, across 10 timesteps.
94  The batch input shape is `(32, 10, 128, 128, 3)`.
95
96  You can then use `TimeDistributed` to apply the same `Conv2D` layer to each
97  of the 10 timesteps, independently:
98
99  >>> inputs = tf.keras.Input(shape=(10, 128, 128, 3))
100  >>> conv_2d_layer = tf.keras.layers.Conv2D(64, (3, 3))
101  >>> outputs = tf.keras.layers.TimeDistributed(conv_2d_layer)(inputs)
102  >>> outputs.shape
103  TensorShape([None, 10, 126, 126, 64])
104
105  Because `TimeDistributed` applies the same instance of `Conv2D` to each of the
106  timestamps, the same set of weights are used at each timestamp.
107
108  Args:
109    layer: a `tf.keras.layers.Layer` instance.
110
111  Call arguments:
112    inputs: Input tensor of shape (batch, time, ...) or nested tensors,
113      and each of which has shape (batch, time, ...).
114    training: Python boolean indicating whether the layer should behave in
115      training mode or in inference mode. This argument is passed to the
116      wrapped layer (only if the layer supports this argument).
117    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
118      a given timestep should be masked. This argument is passed to the
119      wrapped layer (only if the layer supports this argument).
120
121  Raises:
122    ValueError: If not initialized with a `tf.keras.layers.Layer` instance.
123  """
124
125  def __init__(self, layer, **kwargs):
126    if not isinstance(layer, Layer):
127      raise ValueError(
128          'Please initialize `TimeDistributed` layer with a '
129          '`tf.keras.layers.Layer` instance. You passed: {input}'.format(
130              input=layer))
131    super(TimeDistributed, self).__init__(layer, **kwargs)
132    self.supports_masking = True
133
134    # It is safe to use the fast, reshape-based approach with all of our
135    # built-in Layers.
136    self._always_use_reshape = (
137        layer_utils.is_builtin_layer(layer) and
138        not getattr(layer, 'stateful', False))
139
140  def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
141    """Finds non-specific dimensions in the static shapes.
142
143    The static shapes are replaced with the corresponding dynamic shapes of the
144    tensor.
145    Args:
146      init_tuple: a tuple, the first part of the output shape
147      tensor: the tensor from which to get the (static and dynamic) shapes
148        as the last part of the output shape
149      start_idx: int, which indicate the first dimension to take from
150        the static shape of the tensor
151      int_shape: an alternative static shape to take as the last part
152        of the output shape
153    Returns:
154      The new int_shape with the first part from init_tuple
155      and the last part from either `int_shape` (if provided)
156      or `tensor.shape`, where every `None` is replaced by
157      the corresponding dimension from `tf.shape(tensor)`.
158    """
159    # replace all None in int_shape by K.shape
160    if int_shape is None:
161      int_shape = K.int_shape(tensor)[start_idx:]
162    if isinstance(int_shape, tensor_shape.TensorShape):
163      int_shape = int_shape.as_list()
164    if not any(not s for s in int_shape):
165      return init_tuple + tuple(int_shape)
166    shape = K.shape(tensor)
167    int_shape = list(int_shape)
168    for i, s in enumerate(int_shape):
169      if not s:
170        int_shape[i] = shape[start_idx + i]
171    return init_tuple + tuple(int_shape)
172
173  def _remove_timesteps(self, dims):
174    dims = dims.as_list()
175    return tensor_shape.TensorShape([dims[0]] + dims[2:])
176
177  def build(self, input_shape):
178    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
179    input_dims = nest.flatten(
180        nest.map_structure(lambda x: x.ndims, input_shape))
181    if any(dim < 3 for dim in input_dims):
182      raise ValueError(
183          '`TimeDistributed` Layer should be passed an `input_shape ` '
184          'with at least 3 dimensions, received: ' + str(input_shape))
185    # Don't enforce the batch or time dimension.
186    self.input_spec = nest.map_structure(
187        lambda x: InputSpec(shape=[None, None] + x.as_list()[2:]), input_shape)
188    child_input_shape = nest.map_structure(self._remove_timesteps, input_shape)
189    child_input_shape = tf_utils.convert_shapes(child_input_shape)
190    super(TimeDistributed, self).build(tuple(child_input_shape))
191    self.built = True
192
193  def compute_output_shape(self, input_shape):
194    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
195
196    child_input_shape = nest.map_structure(self._remove_timesteps, input_shape)
197    child_output_shape = self.layer.compute_output_shape(child_input_shape)
198    child_output_shape = tf_utils.convert_shapes(
199        child_output_shape, to_tuples=False)
200    timesteps = tf_utils.convert_shapes(input_shape)
201    timesteps = nest.flatten(timesteps)[1]
202
203    def insert_timesteps(dims):
204      dims = dims.as_list()
205      return tensor_shape.TensorShape([dims[0], timesteps] + dims[1:])
206
207    return nest.map_structure(insert_timesteps, child_output_shape)
208
209  def call(self, inputs, training=None, mask=None):
210    kwargs = {}
211    if generic_utils.has_arg(self.layer.call, 'training'):
212      kwargs['training'] = training
213
214    input_shape = nest.map_structure(
215        lambda x: tensor_shape.TensorShape(K.int_shape(x)), inputs)
216    batch_size = tf_utils.convert_shapes(input_shape)
217    batch_size = nest.flatten(batch_size)[0]
218    if batch_size and not self._always_use_reshape:
219      inputs, row_lengths = K.convert_inputs_if_ragged(inputs)
220      is_ragged_input = row_lengths is not None
221      input_length = tf_utils.convert_shapes(input_shape)
222      input_length = nest.flatten(input_length)[1]
223
224      # batch size matters, use rnn-based implementation
225      def step(x, _):
226        output = self.layer(x, **kwargs)
227        return output, []
228
229      _, outputs, _ = K.rnn(
230          step,
231          inputs,
232          initial_states=[],
233          input_length=row_lengths[0] if is_ragged_input else input_length,
234          mask=mask,
235          unroll=False)
236      # pylint: disable=g-long-lambda
237      y = nest.map_structure(
238          lambda output: K.maybe_convert_to_ragged(is_ragged_input, output,
239                                                   row_lengths), outputs)
240    else:
241      # No batch size specified, therefore the layer will be able
242      # to process batches of any size.
243      # We can go with reshape-based implementation for performance.
244      is_ragged_input = nest.map_structure(
245          lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs)
246      is_ragged_input = nest.flatten(is_ragged_input)
247      if all(is_ragged_input):
248        input_values = nest.map_structure(lambda x: x.values, inputs)
249        input_row_lenghts = nest.map_structure(
250            lambda x: x.nested_row_lengths()[0], inputs)
251        y = self.layer(input_values, **kwargs)
252        y = nest.map_structure(ragged_tensor.RaggedTensor.from_row_lengths, y,
253                               input_row_lenghts)
254      elif any(is_ragged_input):
255        raise ValueError('All inputs has to be either ragged or not, '
256                         'but not mixed. You passed: {}'.format(inputs))
257      else:
258        input_length = tf_utils.convert_shapes(input_shape)
259        input_length = nest.flatten(input_length)[1]
260        if not input_length:
261          input_length = nest.map_structure(lambda x: array_ops.shape(x)[1],
262                                            inputs)
263          input_length = generic_utils.to_list(nest.flatten(input_length))[0]
264
265        inner_input_shape = nest.map_structure(
266            lambda x: self._get_shape_tuple((-1,), x, 2), inputs)
267        # Shape: (num_samples * timesteps, ...). And track the
268        # transformation in self._input_map.
269        inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs,
270                                          inner_input_shape)
271        # (num_samples * timesteps, ...)
272        if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
273          inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
274          kwargs['mask'] = K.reshape(mask, inner_mask_shape)
275
276        y = self.layer(inputs, **kwargs)
277
278        # Shape: (num_samples, timesteps, ...)
279        output_shape = self.compute_output_shape(input_shape)
280        # pylint: disable=g-long-lambda
281        output_shape = nest.map_structure(
282            lambda tensor, int_shape: self._get_shape_tuple(
283                (-1, input_length), tensor, 1, int_shape[2:]), y, output_shape)
284        y = nest.map_structure_up_to(y, array_ops.reshape, y, output_shape)
285        if not context.executing_eagerly():
286          # Set the static shape for the result since it might be lost during
287          # array_ops reshape, eg, some `None` dim in the result could be
288          # inferred.
289          nest.map_structure_up_to(
290              y, lambda tensor, shape: tensor.set_shape(shape), y,
291              self.compute_output_shape(input_shape))
292
293    return y
294
295  def compute_mask(self, inputs, mask=None):
296    """Computes an output mask tensor for Embedding layer.
297
298    This is based on the inputs, mask, and the inner layer.
299    If batch size is specified:
300    Simply return the input `mask`. (An rnn-based implementation with
301    more than one rnn inputs is required but not supported in tf.keras yet.)
302    Otherwise we call `compute_mask` of the inner layer at each time step.
303    If the output mask at each time step is not `None`:
304    (E.g., inner layer is Masking or RNN)
305    Concatenate all of them and return the concatenation.
306    If the output mask at each time step is `None` and the input mask is not
307    `None`:(E.g., inner layer is Dense)
308    Reduce the input_mask to 2 dimensions and return it.
309    Otherwise (both the output mask and the input mask are `None`):
310    (E.g., `mask` is not used at all)
311    Return `None`.
312
313    Args:
314      inputs: Tensor with shape [batch size, timesteps, ...] indicating the
315        input to TimeDistributed. If static shape information is available for
316        "batch size", `mask` is returned unmodified.
317      mask: Either None (indicating no masking) or a Tensor indicating the
318        input mask for TimeDistributed. The shape can be static or dynamic.
319
320    Returns:
321      Either None (no masking), or a [batch size, timesteps, ...] Tensor with
322      an output mask for the TimeDistributed layer with the shape beyond the
323      second dimension being the value of the input mask shape(if the computed
324      output mask is none), an output mask with the shape beyond the first
325      dimension being the value of the mask shape(if mask is not None) or
326      output mask with the shape beyond the first dimension being the
327      value of the computed output shape.
328
329    """
330    # cases need to call the layer.compute_mask when input_mask is None:
331    # Masking layer and Embedding layer with mask_zero
332    input_shape = nest.map_structure(
333        lambda x: tensor_shape.TensorShape(K.int_shape(x)), inputs)
334    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
335    batch_size = tf_utils.convert_shapes(input_shape)
336    batch_size = nest.flatten(batch_size)[0]
337    is_ragged_input = nest.map_structure(
338        lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs)
339    is_ragged_input = generic_utils.to_list(nest.flatten(is_ragged_input))
340    if batch_size and not self._always_use_reshape or any(is_ragged_input):
341      # batch size matters, we currently do not handle mask explicitly, or if
342      # the layer always uses reshape approach, or the input is a ragged tensor.
343      return mask
344    inner_mask = mask
345    if inner_mask is not None:
346      inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
347      inner_mask = K.reshape(inner_mask, inner_mask_shape)
348    inner_input_shape = nest.map_structure(
349        lambda tensor: self._get_shape_tuple((-1,), tensor, 2), inputs)
350    inner_inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs,
351                                            inner_input_shape)
352    output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
353    if output_mask is None:
354      if mask is None:
355        return None
356      # input_mask is not None, and output_mask is None:
357      # we should return a not-None mask
358      output_mask = mask
359      for _ in range(2, len(K.int_shape(mask))):
360        output_mask = K.any(output_mask, axis=-1)
361    else:
362      # output_mask is not None. We need to reshape it
363      input_length = tf_utils.convert_shapes(input_shape)
364      input_length = nest.flatten(input_length)[1]
365      if not input_length:
366        input_length = nest.map_structure(lambda x: K.shape(x)[1], inputs)
367        input_length = nest.flatten(input_length)[0]
368      output_mask_int_shape = K.int_shape(output_mask)
369      if output_mask_int_shape is None:
370        # if the output_mask does not have a static shape,
371        # its shape must be the same as mask's
372        if mask is not None:
373          output_mask_int_shape = K.int_shape(mask)
374        else:
375          input_shape = generic_utils.to_list(nest.flatten(input_shape))[0]
376          output_mask_int_shape = K.compute_output_shape(input_shape)[:-1]
377      output_mask_shape = self._get_shape_tuple(
378          (-1, input_length), output_mask, 1, output_mask_int_shape[1:])
379      output_mask = K.reshape(output_mask, output_mask_shape)
380    return output_mask
381
382
383@keras_export('keras.layers.Bidirectional')
384class Bidirectional(Wrapper):
385  """Bidirectional wrapper for RNNs.
386
387  Args:
388    layer: `keras.layers.RNN` instance, such as `keras.layers.LSTM` or
389      `keras.layers.GRU`. It could also be a `keras.layers.Layer` instance
390      that meets the following criteria:
391      1. Be a sequence-processing layer (accepts 3D+ inputs).
392      2. Have a `go_backwards`, `return_sequences` and `return_state`
393        attribute (with the same semantics as for the `RNN` class).
394      3. Have an `input_spec` attribute.
395      4. Implement serialization via `get_config()` and `from_config()`.
396      Note that the recommended way to create new RNN layers is to write a
397      custom RNN cell and use it with `keras.layers.RNN`, instead of
398      subclassing `keras.layers.Layer` directly.
399    merge_mode: Mode by which outputs of the forward and backward RNNs will be
400      combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the
401      outputs will not be combined, they will be returned as a list. Default
402      value is 'concat'.
403    backward_layer: Optional `keras.layers.RNN`, or `keras.layers.Layer`
404      instance to be used to handle backwards input processing.
405      If `backward_layer` is not provided, the layer instance passed as the
406      `layer` argument will be used to generate the backward layer
407      automatically.
408      Note that the provided `backward_layer` layer should have properties
409      matching those of the `layer` argument, in particular it should have the
410      same values for `stateful`, `return_states`, `return_sequence`, etc.
411      In addition, `backward_layer` and `layer` should have different
412      `go_backwards` argument values.
413      A `ValueError` will be raised if these requirements are not met.
414
415  Call arguments:
416    The call arguments for this layer are the same as those of the wrapped RNN
417      layer.
418    Beware that when passing the `initial_state` argument during the call of
419    this layer, the first half in the list of elements in the `initial_state`
420    list will be passed to the forward RNN call and the last half in the list
421    of elements will be passed to the backward RNN call.
422
423  Raises:
424    ValueError:
425      1. If `layer` or `backward_layer` is not a `Layer` instance.
426      2. In case of invalid `merge_mode` argument.
427      3. If `backward_layer` has mismatched properties compared to `layer`.
428
429  Examples:
430
431  ```python
432  model = Sequential()
433  model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 10)))
434  model.add(Bidirectional(LSTM(10)))
435  model.add(Dense(5))
436  model.add(Activation('softmax'))
437  model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
438
439   # With custom backward layer
440   model = Sequential()
441   forward_layer = LSTM(10, return_sequences=True)
442   backward_layer = LSTM(10, activation='relu', return_sequences=True,
443                         go_backwards=True)
444   model.add(Bidirectional(forward_layer, backward_layer=backward_layer,
445                           input_shape=(5, 10)))
446   model.add(Dense(5))
447   model.add(Activation('softmax'))
448   model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
449  ```
450  """
451
452  def __init__(self,
453               layer,
454               merge_mode='concat',
455               weights=None,
456               backward_layer=None,
457               **kwargs):
458    if not isinstance(layer, Layer):
459      raise ValueError(
460          'Please initialize `Bidirectional` layer with a '
461          '`Layer` instance. You passed: {input}'.format(input=layer))
462    if backward_layer is not None and not isinstance(backward_layer, Layer):
463      raise ValueError('`backward_layer` need to be a `Layer` instance. '
464                       'You passed: {input}'.format(input=backward_layer))
465    if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]:
466      raise ValueError('Invalid merge mode. '
467                       'Merge mode should be one of '
468                       '{"sum", "mul", "ave", "concat", None}')
469    # We don't want to track `layer` since we're already tracking the two copies
470    # of it we actually run.
471    self._setattr_tracking = False
472    super(Bidirectional, self).__init__(layer, **kwargs)
473    self._setattr_tracking = True
474
475    # Recreate the forward layer from the original layer config, so that it will
476    # not carry over any state from the layer.
477    self.forward_layer = self._recreate_layer_from_config(layer)
478
479    if backward_layer is None:
480      self.backward_layer = self._recreate_layer_from_config(
481          layer, go_backwards=True)
482    else:
483      self.backward_layer = backward_layer
484      # Keep the custom backward layer config, so that we can save it later. The
485      # layer's name might be updated below with prefix 'backward_', and we want
486      # to preserve the original config.
487      self._backward_layer_config = generic_utils.serialize_keras_object(
488          backward_layer)
489
490    self.forward_layer._name = 'forward_' + self.forward_layer.name
491    self.backward_layer._name = 'backward_' + self.backward_layer.name
492
493    self._verify_layer_config()
494
495    def force_zero_output_for_mask(layer):
496      # Force the zero_output_for_mask to be True if returning sequences.
497      if getattr(layer, 'zero_output_for_mask', None) is not None:
498        layer.zero_output_for_mask = layer.return_sequences
499
500    force_zero_output_for_mask(self.forward_layer)
501    force_zero_output_for_mask(self.backward_layer)
502
503    self.merge_mode = merge_mode
504    if weights:
505      nw = len(weights)
506      self.forward_layer.initial_weights = weights[:nw // 2]
507      self.backward_layer.initial_weights = weights[nw // 2:]
508    self.stateful = layer.stateful
509    self.return_sequences = layer.return_sequences
510    self.return_state = layer.return_state
511    self.supports_masking = True
512    self._trainable = True
513    self._num_constants = 0
514    self.input_spec = layer.input_spec
515
516  def _verify_layer_config(self):
517    """Ensure the forward and backward layers have valid common property."""
518    if self.forward_layer.go_backwards == self.backward_layer.go_backwards:
519      raise ValueError('Forward layer and backward layer should have different '
520                       '`go_backwards` value.')
521
522    common_attributes = ('stateful', 'return_sequences', 'return_state')
523    for a in common_attributes:
524      forward_value = getattr(self.forward_layer, a)
525      backward_value = getattr(self.backward_layer, a)
526      if forward_value != backward_value:
527        raise ValueError(
528            'Forward layer and backward layer are expected to have the same '
529            'value for attribute {attr}, got {forward} and {backward}'.format(
530                attr=a, forward=forward_value, backward=backward_value))
531
532  def _recreate_layer_from_config(self, layer, go_backwards=False):
533    # When recreating the layer from its config, it is possible that the layer
534    # is a RNN layer that contains custom cells. In this case we inspect the
535    # layer and pass the custom cell class as part of the `custom_objects`
536    # argument when calling `from_config`.
537    # See https://github.com/tensorflow/tensorflow/issues/26581 for more detail.
538    config = layer.get_config()
539    if go_backwards:
540      config['go_backwards'] = not config['go_backwards']
541    if 'custom_objects' in tf_inspect.getfullargspec(
542        layer.__class__.from_config).args:
543      custom_objects = {}
544      cell = getattr(layer, 'cell', None)
545      if cell is not None:
546        custom_objects[cell.__class__.__name__] = cell.__class__
547        # For StackedRNNCells
548        stacked_cells = getattr(cell, 'cells', [])
549        for c in stacked_cells:
550          custom_objects[c.__class__.__name__] = c.__class__
551      return layer.__class__.from_config(config, custom_objects=custom_objects)
552    else:
553      return layer.__class__.from_config(config)
554
555  @tf_utils.shape_type_conversion
556  def compute_output_shape(self, input_shape):
557    output_shape = self.forward_layer.compute_output_shape(input_shape)
558    if self.return_state:
559      state_shape = tf_utils.convert_shapes(output_shape[1:], to_tuples=False)
560      output_shape = tf_utils.convert_shapes(output_shape[0], to_tuples=False)
561    else:
562      output_shape = tf_utils.convert_shapes(output_shape, to_tuples=False)
563
564    if self.merge_mode == 'concat':
565      output_shape = output_shape.as_list()
566      output_shape[-1] *= 2
567      output_shape = tensor_shape.TensorShape(output_shape)
568    elif self.merge_mode is None:
569      output_shape = [output_shape, copy.copy(output_shape)]
570
571    if self.return_state:
572      if self.merge_mode is None:
573        return output_shape + state_shape + copy.copy(state_shape)
574      return [output_shape] + state_shape + copy.copy(state_shape)
575    return output_shape
576
577  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
578    """`Bidirectional.__call__` implements the same API as the wrapped `RNN`."""
579    inputs, initial_state, constants = _standardize_args(
580        inputs, initial_state, constants, self._num_constants)
581
582    if isinstance(inputs, list):
583      if len(inputs) > 1:
584        initial_state = inputs[1:]
585      inputs = inputs[0]
586
587    if initial_state is None and constants is None:
588      return super(Bidirectional, self).__call__(inputs, **kwargs)
589
590    # Applies the same workaround as in `RNN.__call__`
591    additional_inputs = []
592    additional_specs = []
593    if initial_state is not None:
594      # Check if `initial_state` can be splitted into half
595      num_states = len(initial_state)
596      if num_states % 2 > 0:
597        raise ValueError(
598            'When passing `initial_state` to a Bidirectional RNN, '
599            'the state should be a list containing the states of '
600            'the underlying RNNs. '
601            'Found: ' + str(initial_state))
602
603      kwargs['initial_state'] = initial_state
604      additional_inputs += initial_state
605      state_specs = [InputSpec(shape=K.int_shape(state))
606                     for state in initial_state]
607      self.forward_layer.state_spec = state_specs[:num_states // 2]
608      self.backward_layer.state_spec = state_specs[num_states // 2:]
609      additional_specs += state_specs
610    if constants is not None:
611      kwargs['constants'] = constants
612      additional_inputs += constants
613      constants_spec = [InputSpec(shape=K.int_shape(constant))
614                        for constant in constants]
615      self.forward_layer.constants_spec = constants_spec
616      self.backward_layer.constants_spec = constants_spec
617      additional_specs += constants_spec
618
619      self._num_constants = len(constants)
620      self.forward_layer._num_constants = self._num_constants
621      self.backward_layer._num_constants = self._num_constants
622
623    is_keras_tensor = K.is_keras_tensor(additional_inputs[0])
624    for tensor in additional_inputs:
625      if K.is_keras_tensor(tensor) != is_keras_tensor:
626        raise ValueError('The initial state of a Bidirectional'
627                         ' layer cannot be specified with a mix of'
628                         ' Keras tensors and non-Keras tensors'
629                         ' (a "Keras tensor" is a tensor that was'
630                         ' returned by a Keras layer, or by `Input`)')
631
632    if is_keras_tensor:
633      # Compute the full input spec, including state
634      full_input = [inputs] + additional_inputs
635      # The original input_spec is None since there could be a nested tensor
636      # input. Update the input_spec to match the inputs.
637      full_input_spec = [None for _ in range(len(nest.flatten(inputs)))
638                        ] + additional_specs
639      # Removing kwargs since the value are passed with input list.
640      kwargs['initial_state'] = None
641      kwargs['constants'] = None
642
643      # Perform the call with temporarily replaced input_spec
644      original_input_spec = self.input_spec
645      self.input_spec = full_input_spec
646      output = super(Bidirectional, self).__call__(full_input, **kwargs)
647      self.input_spec = original_input_spec
648      return output
649    else:
650      return super(Bidirectional, self).__call__(inputs, **kwargs)
651
652  def call(self,
653           inputs,
654           training=None,
655           mask=None,
656           initial_state=None,
657           constants=None):
658    """`Bidirectional.call` implements the same API as the wrapped `RNN`."""
659    kwargs = {}
660    if generic_utils.has_arg(self.layer.call, 'training'):
661      kwargs['training'] = training
662    if generic_utils.has_arg(self.layer.call, 'mask'):
663      kwargs['mask'] = mask
664    if generic_utils.has_arg(self.layer.call, 'constants'):
665      kwargs['constants'] = constants
666
667    if generic_utils.has_arg(self.layer.call, 'initial_state'):
668      if isinstance(inputs, list) and len(inputs) > 1:
669        # initial_states are keras tensors, which means they are passed in
670        # together with inputs as list. The initial_states need to be split into
671        # forward and backward section, and be feed to layers accordingly.
672        forward_inputs = [inputs[0]]
673        backward_inputs = [inputs[0]]
674        pivot = (len(inputs) - self._num_constants) // 2 + 1
675        # add forward initial state
676        forward_inputs += inputs[1:pivot]
677        if not self._num_constants:
678          # add backward initial state
679          backward_inputs += inputs[pivot:]
680        else:
681          # add backward initial state
682          backward_inputs += inputs[pivot:-self._num_constants]
683          # add constants for forward and backward layers
684          forward_inputs += inputs[-self._num_constants:]
685          backward_inputs += inputs[-self._num_constants:]
686        forward_state, backward_state = None, None
687        if 'constants' in kwargs:
688          kwargs['constants'] = None
689      elif initial_state is not None:
690        # initial_states are not keras tensors, eg eager tensor from np array.
691        # They are only passed in from kwarg initial_state, and should be passed
692        # to forward/backward layer via kwarg initial_state as well.
693        forward_inputs, backward_inputs = inputs, inputs
694        half = len(initial_state) // 2
695        forward_state = initial_state[:half]
696        backward_state = initial_state[half:]
697      else:
698        forward_inputs, backward_inputs = inputs, inputs
699        forward_state, backward_state = None, None
700
701      y = self.forward_layer(forward_inputs,
702                             initial_state=forward_state, **kwargs)
703      y_rev = self.backward_layer(backward_inputs,
704                                  initial_state=backward_state, **kwargs)
705    else:
706      y = self.forward_layer(inputs, **kwargs)
707      y_rev = self.backward_layer(inputs, **kwargs)
708
709    if self.return_state:
710      states = y[1:] + y_rev[1:]
711      y = y[0]
712      y_rev = y_rev[0]
713
714    if self.return_sequences:
715      time_dim = 0 if getattr(self.forward_layer, 'time_major', False) else 1
716      y_rev = K.reverse(y_rev, time_dim)
717    if self.merge_mode == 'concat':
718      output = K.concatenate([y, y_rev])
719    elif self.merge_mode == 'sum':
720      output = y + y_rev
721    elif self.merge_mode == 'ave':
722      output = (y + y_rev) / 2
723    elif self.merge_mode == 'mul':
724      output = y * y_rev
725    elif self.merge_mode is None:
726      output = [y, y_rev]
727    else:
728      raise ValueError(
729          'Unrecognized value for `merge_mode`: %s' % (self.merge_mode))
730
731    if self.return_state:
732      if self.merge_mode is None:
733        return output + states
734      return [output] + states
735    return output
736
737  def reset_states(self):
738    self.forward_layer.reset_states()
739    self.backward_layer.reset_states()
740
741  def build(self, input_shape):
742    with K.name_scope(self.forward_layer.name):
743      self.forward_layer.build(input_shape)
744    with K.name_scope(self.backward_layer.name):
745      self.backward_layer.build(input_shape)
746    self.built = True
747
748  def compute_mask(self, inputs, mask):
749    if isinstance(mask, list):
750      mask = mask[0]
751    if self.return_sequences:
752      if not self.merge_mode:
753        output_mask = [mask, mask]
754      else:
755        output_mask = mask
756    else:
757      output_mask = [None, None] if not self.merge_mode else None
758
759    if self.return_state:
760      states = self.forward_layer.states
761      state_mask = [None for _ in states]
762      if isinstance(output_mask, list):
763        return output_mask + state_mask * 2
764      return [output_mask] + state_mask * 2
765    return output_mask
766
767  @property
768  def constraints(self):
769    constraints = {}
770    if hasattr(self.forward_layer, 'constraints'):
771      constraints.update(self.forward_layer.constraints)
772      constraints.update(self.backward_layer.constraints)
773    return constraints
774
775  def get_config(self):
776    config = {'merge_mode': self.merge_mode}
777    if self._num_constants:
778      config['num_constants'] = self._num_constants
779
780    if hasattr(self, '_backward_layer_config'):
781      config['backward_layer'] = self._backward_layer_config
782    base_config = super(Bidirectional, self).get_config()
783    return dict(list(base_config.items()) + list(config.items()))
784
785  @classmethod
786  def from_config(cls, config, custom_objects=None):
787    # Instead of updating the input, create a copy and use that.
788    config = copy.deepcopy(config)
789    num_constants = config.pop('num_constants', 0)
790    # Handle forward layer instantiation (as would parent class).
791    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
792    config['layer'] = deserialize_layer(
793        config['layer'], custom_objects=custom_objects)
794    # Handle (optional) backward layer instantiation.
795    backward_layer_config = config.pop('backward_layer', None)
796    if backward_layer_config is not None:
797      backward_layer = deserialize_layer(
798          backward_layer_config, custom_objects=custom_objects)
799      config['backward_layer'] = backward_layer
800    # Instantiate the wrapper, adjust it and return it.
801    layer = cls(**config)
802    layer._num_constants = num_constants
803    return layer
804