• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-classes-have-attributes
17"""Wrapper layers: layers that augment the functionality of another layer."""
18
19import copy
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.keras import backend
24from tensorflow.python.keras.engine.base_layer import Layer
25from tensorflow.python.keras.engine.input_spec import InputSpec
26from tensorflow.python.keras.layers.recurrent import _standardize_args
27from tensorflow.python.keras.utils import generic_utils
28from tensorflow.python.keras.utils import layer_utils
29from tensorflow.python.keras.utils import tf_inspect
30from tensorflow.python.keras.utils import tf_utils
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops.ragged import ragged_tensor
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import keras_export
35
36
37@keras_export('keras.layers.Wrapper')
38class Wrapper(Layer):
39  """Abstract wrapper base class.
40
41  Wrappers take another layer and augment it in various ways.
42  Do not use this class as a layer, it is only an abstract base class.
43  Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
44
45  Args:
46    layer: The layer to be wrapped.
47  """
48
49  def __init__(self, layer, **kwargs):
50    assert isinstance(layer, Layer)
51    self.layer = layer
52    super(Wrapper, self).__init__(**kwargs)
53
54  def build(self, input_shape=None):
55    if not self.layer.built:
56      self.layer.build(input_shape)
57      self.layer.built = True
58    self.built = True
59
60  @property
61  def activity_regularizer(self):
62    if hasattr(self.layer, 'activity_regularizer'):
63      return self.layer.activity_regularizer
64    else:
65      return None
66
67  def get_config(self):
68    config = {'layer': generic_utils.serialize_keras_object(self.layer)}
69    base_config = super(Wrapper, self).get_config()
70    return dict(list(base_config.items()) + list(config.items()))
71
72  @classmethod
73  def from_config(cls, config, custom_objects=None):
74    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
75    # Avoid mutating the input dict
76    config = copy.deepcopy(config)
77    layer = deserialize_layer(
78        config.pop('layer'), custom_objects=custom_objects)
79    return cls(layer, **config)
80
81
82@keras_export('keras.layers.TimeDistributed')
83class TimeDistributed(Wrapper):
84  """This wrapper allows to apply a layer to every temporal slice of an input.
85
86  Every input should be at least 3D, and the dimension of index one of the
87  first input will be considered to be the temporal dimension.
88
89  Consider a batch of 32 video samples, where each sample is a 128x128 RGB image
90  with `channels_last` data format, across 10 timesteps.
91  The batch input shape is `(32, 10, 128, 128, 3)`.
92
93  You can then use `TimeDistributed` to apply the same `Conv2D` layer to each
94  of the 10 timesteps, independently:
95
96  >>> inputs = tf.keras.Input(shape=(10, 128, 128, 3))
97  >>> conv_2d_layer = tf.keras.layers.Conv2D(64, (3, 3))
98  >>> outputs = tf.keras.layers.TimeDistributed(conv_2d_layer)(inputs)
99  >>> outputs.shape
100  TensorShape([None, 10, 126, 126, 64])
101
102  Because `TimeDistributed` applies the same instance of `Conv2D` to each of the
103  timestamps, the same set of weights are used at each timestamp.
104
105  Args:
106    layer: a `tf.keras.layers.Layer` instance.
107
108  Call arguments:
109    inputs: Input tensor of shape (batch, time, ...) or nested tensors,
110      and each of which has shape (batch, time, ...).
111    training: Python boolean indicating whether the layer should behave in
112      training mode or in inference mode. This argument is passed to the
113      wrapped layer (only if the layer supports this argument).
114    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
115      a given timestep should be masked. This argument is passed to the
116      wrapped layer (only if the layer supports this argument).
117
118  Raises:
119    ValueError: If not initialized with a `tf.keras.layers.Layer` instance.
120  """
121
122  def __init__(self, layer, **kwargs):
123    if not isinstance(layer, Layer):
124      raise ValueError(
125          'Please initialize `TimeDistributed` layer with a '
126          '`tf.keras.layers.Layer` instance. You passed: {input}'.format(
127              input=layer))
128    super(TimeDistributed, self).__init__(layer, **kwargs)
129    self.supports_masking = True
130
131    # It is safe to use the fast, reshape-based approach with all of our
132    # built-in Layers.
133    self._always_use_reshape = (
134        layer_utils.is_builtin_layer(layer) and
135        not getattr(layer, 'stateful', False))
136
137  def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
138    """Finds non-specific dimensions in the static shapes.
139
140    The static shapes are replaced with the corresponding dynamic shapes of the
141    tensor.
142    Args:
143      init_tuple: a tuple, the first part of the output shape
144      tensor: the tensor from which to get the (static and dynamic) shapes
145        as the last part of the output shape
146      start_idx: int, which indicate the first dimension to take from
147        the static shape of the tensor
148      int_shape: an alternative static shape to take as the last part
149        of the output shape
150    Returns:
151      The new int_shape with the first part from init_tuple
152      and the last part from either `int_shape` (if provided)
153      or `tensor.shape`, where every `None` is replaced by
154      the corresponding dimension from `tf.shape(tensor)`.
155    """
156    # replace all None in int_shape by backend.shape
157    if int_shape is None:
158      int_shape = backend.int_shape(tensor)[start_idx:]
159    if isinstance(int_shape, tensor_shape.TensorShape):
160      int_shape = int_shape.as_list()
161    if not any(not s for s in int_shape):
162      return init_tuple + tuple(int_shape)
163    shape = backend.shape(tensor)
164    int_shape = list(int_shape)
165    for i, s in enumerate(int_shape):
166      if not s:
167        int_shape[i] = shape[start_idx + i]
168    return init_tuple + tuple(int_shape)
169
170  def _remove_timesteps(self, dims):
171    dims = dims.as_list()
172    return tensor_shape.TensorShape([dims[0]] + dims[2:])
173
174  def build(self, input_shape):
175    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
176    input_dims = nest.flatten(
177        nest.map_structure(lambda x: x.ndims, input_shape))
178    if any(dim < 3 for dim in input_dims):
179      raise ValueError(
180          '`TimeDistributed` Layer should be passed an `input_shape ` '
181          'with at least 3 dimensions, received: ' + str(input_shape))
182    # Don't enforce the batch or time dimension.
183    self.input_spec = nest.map_structure(
184        lambda x: InputSpec(shape=[None, None] + x.as_list()[2:]), input_shape)
185    child_input_shape = nest.map_structure(self._remove_timesteps, input_shape)
186    child_input_shape = tf_utils.convert_shapes(child_input_shape)
187    super(TimeDistributed, self).build(tuple(child_input_shape))
188    self.built = True
189
190  def compute_output_shape(self, input_shape):
191    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
192
193    child_input_shape = nest.map_structure(self._remove_timesteps, input_shape)
194    child_output_shape = self.layer.compute_output_shape(child_input_shape)
195    child_output_shape = tf_utils.convert_shapes(
196        child_output_shape, to_tuples=False)
197    timesteps = tf_utils.convert_shapes(input_shape)
198    timesteps = nest.flatten(timesteps)[1]
199
200    def insert_timesteps(dims):
201      dims = dims.as_list()
202      return tensor_shape.TensorShape([dims[0], timesteps] + dims[1:])
203
204    return nest.map_structure(insert_timesteps, child_output_shape)
205
206  def call(self, inputs, training=None, mask=None):
207    kwargs = {}
208    if generic_utils.has_arg(self.layer.call, 'training'):
209      kwargs['training'] = training
210
211    input_shape = nest.map_structure(
212        lambda x: tensor_shape.TensorShape(backend.int_shape(x)), inputs)
213    batch_size = tf_utils.convert_shapes(input_shape)
214    batch_size = nest.flatten(batch_size)[0]
215    if batch_size and not self._always_use_reshape:
216      inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
217      is_ragged_input = row_lengths is not None
218      input_length = tf_utils.convert_shapes(input_shape)
219      input_length = nest.flatten(input_length)[1]
220
221      # batch size matters, use rnn-based implementation
222      def step(x, _):
223        output = self.layer(x, **kwargs)
224        return output, []
225
226      _, outputs, _ = backend.rnn(
227          step,
228          inputs,
229          initial_states=[],
230          input_length=row_lengths[0] if is_ragged_input else input_length,
231          mask=mask,
232          unroll=False)
233      # pylint: disable=g-long-lambda
234      y = nest.map_structure(
235          lambda output: backend.maybe_convert_to_ragged(
236              is_ragged_input, output, row_lengths), outputs)
237    else:
238      # No batch size specified, therefore the layer will be able
239      # to process batches of any size.
240      # We can go with reshape-based implementation for performance.
241      is_ragged_input = nest.map_structure(
242          lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs)
243      is_ragged_input = nest.flatten(is_ragged_input)
244      if all(is_ragged_input):
245        input_values = nest.map_structure(lambda x: x.values, inputs)
246        input_row_lenghts = nest.map_structure(
247            lambda x: x.nested_row_lengths()[0], inputs)
248        y = self.layer(input_values, **kwargs)
249        y = nest.map_structure(ragged_tensor.RaggedTensor.from_row_lengths, y,
250                               input_row_lenghts)
251      elif any(is_ragged_input):
252        raise ValueError('All inputs has to be either ragged or not, '
253                         'but not mixed. You passed: {}'.format(inputs))
254      else:
255        input_length = tf_utils.convert_shapes(input_shape)
256        input_length = nest.flatten(input_length)[1]
257        if not input_length:
258          input_length = nest.map_structure(lambda x: array_ops.shape(x)[1],
259                                            inputs)
260          input_length = generic_utils.to_list(nest.flatten(input_length))[0]
261
262        inner_input_shape = nest.map_structure(
263            lambda x: self._get_shape_tuple((-1,), x, 2), inputs)
264        # Shape: (num_samples * timesteps, ...). And track the
265        # transformation in self._input_map.
266        inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs,
267                                          inner_input_shape)
268        # (num_samples * timesteps, ...)
269        if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
270          inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
271          kwargs['mask'] = backend.reshape(mask, inner_mask_shape)
272
273        y = self.layer(inputs, **kwargs)
274
275        # Shape: (num_samples, timesteps, ...)
276        output_shape = self.compute_output_shape(input_shape)
277        # pylint: disable=g-long-lambda
278        output_shape = nest.map_structure(
279            lambda tensor, int_shape: self._get_shape_tuple(
280                (-1, input_length), tensor, 1, int_shape[2:]), y, output_shape)
281        y = nest.map_structure_up_to(y, array_ops.reshape, y, output_shape)
282        if not context.executing_eagerly():
283          # Set the static shape for the result since it might be lost during
284          # array_ops reshape, eg, some `None` dim in the result could be
285          # inferred.
286          nest.map_structure_up_to(
287              y, lambda tensor, shape: tensor.set_shape(shape), y,
288              self.compute_output_shape(input_shape))
289
290    return y
291
292  def compute_mask(self, inputs, mask=None):
293    """Computes an output mask tensor for Embedding layer.
294
295    This is based on the inputs, mask, and the inner layer.
296    If batch size is specified:
297    Simply return the input `mask`. (An rnn-based implementation with
298    more than one rnn inputs is required but not supported in tf.keras yet.)
299    Otherwise we call `compute_mask` of the inner layer at each time step.
300    If the output mask at each time step is not `None`:
301    (E.g., inner layer is Masking or RNN)
302    Concatenate all of them and return the concatenation.
303    If the output mask at each time step is `None` and the input mask is not
304    `None`:(E.g., inner layer is Dense)
305    Reduce the input_mask to 2 dimensions and return it.
306    Otherwise (both the output mask and the input mask are `None`):
307    (E.g., `mask` is not used at all)
308    Return `None`.
309
310    Args:
311      inputs: Tensor with shape [batch size, timesteps, ...] indicating the
312        input to TimeDistributed. If static shape information is available for
313        "batch size", `mask` is returned unmodified.
314      mask: Either None (indicating no masking) or a Tensor indicating the
315        input mask for TimeDistributed. The shape can be static or dynamic.
316
317    Returns:
318      Either None (no masking), or a [batch size, timesteps, ...] Tensor with
319      an output mask for the TimeDistributed layer with the shape beyond the
320      second dimension being the value of the input mask shape(if the computed
321      output mask is none), an output mask with the shape beyond the first
322      dimension being the value of the mask shape(if mask is not None) or
323      output mask with the shape beyond the first dimension being the
324      value of the computed output shape.
325
326    """
327    # cases need to call the layer.compute_mask when input_mask is None:
328    # Masking layer and Embedding layer with mask_zero
329    input_shape = nest.map_structure(
330        lambda x: tensor_shape.TensorShape(backend.int_shape(x)), inputs)
331    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
332    batch_size = tf_utils.convert_shapes(input_shape)
333    batch_size = nest.flatten(batch_size)[0]
334    is_ragged_input = nest.map_structure(
335        lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs)
336    is_ragged_input = generic_utils.to_list(nest.flatten(is_ragged_input))
337    if batch_size and not self._always_use_reshape or any(is_ragged_input):
338      # batch size matters, we currently do not handle mask explicitly, or if
339      # the layer always uses reshape approach, or the input is a ragged tensor.
340      return mask
341    inner_mask = mask
342    if inner_mask is not None:
343      inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
344      inner_mask = backend.reshape(inner_mask, inner_mask_shape)
345    inner_input_shape = nest.map_structure(
346        lambda tensor: self._get_shape_tuple((-1,), tensor, 2), inputs)
347    inner_inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs,
348                                            inner_input_shape)
349    output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
350    if output_mask is None:
351      if mask is None:
352        return None
353      # input_mask is not None, and output_mask is None:
354      # we should return a not-None mask
355      output_mask = mask
356      for _ in range(2, len(backend.int_shape(mask))):
357        output_mask = backend.any(output_mask, axis=-1)
358    else:
359      # output_mask is not None. We need to reshape it
360      input_length = tf_utils.convert_shapes(input_shape)
361      input_length = nest.flatten(input_length)[1]
362      if not input_length:
363        input_length = nest.map_structure(lambda x: backend.shape(x)[1], inputs)
364        input_length = nest.flatten(input_length)[0]
365      output_mask_int_shape = backend.int_shape(output_mask)
366      if output_mask_int_shape is None:
367        # if the output_mask does not have a static shape,
368        # its shape must be the same as mask's
369        if mask is not None:
370          output_mask_int_shape = backend.int_shape(mask)
371        else:
372          input_shape = generic_utils.to_list(nest.flatten(input_shape))[0]
373          output_mask_int_shape = backend.compute_output_shape(input_shape)[:-1]
374      output_mask_shape = self._get_shape_tuple(
375          (-1, input_length), output_mask, 1, output_mask_int_shape[1:])
376      output_mask = backend.reshape(output_mask, output_mask_shape)
377    return output_mask
378
379
380@keras_export('keras.layers.Bidirectional')
381class Bidirectional(Wrapper):
382  """Bidirectional wrapper for RNNs.
383
384  Args:
385    layer: `keras.layers.RNN` instance, such as `keras.layers.LSTM` or
386      `keras.layers.GRU`. It could also be a `keras.layers.Layer` instance
387      that meets the following criteria:
388      1. Be a sequence-processing layer (accepts 3D+ inputs).
389      2. Have a `go_backwards`, `return_sequences` and `return_state`
390        attribute (with the same semantics as for the `RNN` class).
391      3. Have an `input_spec` attribute.
392      4. Implement serialization via `get_config()` and `from_config()`.
393      Note that the recommended way to create new RNN layers is to write a
394      custom RNN cell and use it with `keras.layers.RNN`, instead of
395      subclassing `keras.layers.Layer` directly.
396      - When the `returns_sequences` is true, the output of the masked timestep
397      will be zero regardless of the layer's original `zero_output_for_mask`
398      value.
399    merge_mode: Mode by which outputs of the forward and backward RNNs will be
400      combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the
401      outputs will not be combined, they will be returned as a list. Default
402      value is 'concat'.
403    backward_layer: Optional `keras.layers.RNN`, or `keras.layers.Layer`
404      instance to be used to handle backwards input processing.
405      If `backward_layer` is not provided, the layer instance passed as the
406      `layer` argument will be used to generate the backward layer
407      automatically.
408      Note that the provided `backward_layer` layer should have properties
409      matching those of the `layer` argument, in particular it should have the
410      same values for `stateful`, `return_states`, `return_sequences`, etc.
411      In addition, `backward_layer` and `layer` should have different
412      `go_backwards` argument values.
413      A `ValueError` will be raised if these requirements are not met.
414
415  Call arguments:
416    The call arguments for this layer are the same as those of the wrapped RNN
417      layer.
418    Beware that when passing the `initial_state` argument during the call of
419    this layer, the first half in the list of elements in the `initial_state`
420    list will be passed to the forward RNN call and the last half in the list
421    of elements will be passed to the backward RNN call.
422
423  Raises:
424    ValueError:
425      1. If `layer` or `backward_layer` is not a `Layer` instance.
426      2. In case of invalid `merge_mode` argument.
427      3. If `backward_layer` has mismatched properties compared to `layer`.
428
429  Examples:
430
431  ```python
432  model = Sequential()
433  model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 10)))
434  model.add(Bidirectional(LSTM(10)))
435  model.add(Dense(5))
436  model.add(Activation('softmax'))
437  model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
438
439   # With custom backward layer
440   model = Sequential()
441   forward_layer = LSTM(10, return_sequences=True)
442   backward_layer = LSTM(10, activation='relu', return_sequences=True,
443                         go_backwards=True)
444   model.add(Bidirectional(forward_layer, backward_layer=backward_layer,
445                           input_shape=(5, 10)))
446   model.add(Dense(5))
447   model.add(Activation('softmax'))
448   model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
449  ```
450  """
451
452  def __init__(self,
453               layer,
454               merge_mode='concat',
455               weights=None,
456               backward_layer=None,
457               **kwargs):
458    if not isinstance(layer, Layer):
459      raise ValueError(
460          'Please initialize `Bidirectional` layer with a '
461          '`Layer` instance. You passed: {input}'.format(input=layer))
462    if backward_layer is not None and not isinstance(backward_layer, Layer):
463      raise ValueError('`backward_layer` need to be a `Layer` instance. '
464                       'You passed: {input}'.format(input=backward_layer))
465    if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]:
466      raise ValueError('Invalid merge mode. '
467                       'Merge mode should be one of '
468                       '{"sum", "mul", "ave", "concat", None}')
469    # We don't want to track `layer` since we're already tracking the two copies
470    # of it we actually run.
471    self._setattr_tracking = False
472    super(Bidirectional, self).__init__(layer, **kwargs)
473    self._setattr_tracking = True
474
475    # Recreate the forward layer from the original layer config, so that it will
476    # not carry over any state from the layer.
477    self.forward_layer = self._recreate_layer_from_config(layer)
478
479    if backward_layer is None:
480      self.backward_layer = self._recreate_layer_from_config(
481          layer, go_backwards=True)
482    else:
483      self.backward_layer = backward_layer
484      # Keep the custom backward layer config, so that we can save it later. The
485      # layer's name might be updated below with prefix 'backward_', and we want
486      # to preserve the original config.
487      self._backward_layer_config = generic_utils.serialize_keras_object(
488          backward_layer)
489
490    self.forward_layer._name = 'forward_' + self.forward_layer.name
491    self.backward_layer._name = 'backward_' + self.backward_layer.name
492
493    self._verify_layer_config()
494
495    def force_zero_output_for_mask(layer):
496      # Force the zero_output_for_mask to be True if returning sequences.
497      if getattr(layer, 'zero_output_for_mask', None) is not None:
498        layer.zero_output_for_mask = layer.return_sequences
499
500    force_zero_output_for_mask(self.forward_layer)
501    force_zero_output_for_mask(self.backward_layer)
502
503    self.merge_mode = merge_mode
504    if weights:
505      nw = len(weights)
506      self.forward_layer.initial_weights = weights[:nw // 2]
507      self.backward_layer.initial_weights = weights[nw // 2:]
508    self.stateful = layer.stateful
509    self.return_sequences = layer.return_sequences
510    self.return_state = layer.return_state
511    self.supports_masking = True
512    self._trainable = True
513    self._num_constants = 0
514    self.input_spec = layer.input_spec
515
516  def _verify_layer_config(self):
517    """Ensure the forward and backward layers have valid common property."""
518    if self.forward_layer.go_backwards == self.backward_layer.go_backwards:
519      raise ValueError('Forward layer and backward layer should have different '
520                       '`go_backwards` value.')
521
522    common_attributes = ('stateful', 'return_sequences', 'return_state')
523    for a in common_attributes:
524      forward_value = getattr(self.forward_layer, a)
525      backward_value = getattr(self.backward_layer, a)
526      if forward_value != backward_value:
527        raise ValueError(
528            'Forward layer and backward layer are expected to have the same '
529            'value for attribute {attr}, got {forward} and {backward}'.format(
530                attr=a, forward=forward_value, backward=backward_value))
531
532  def _recreate_layer_from_config(self, layer, go_backwards=False):
533    # When recreating the layer from its config, it is possible that the layer
534    # is a RNN layer that contains custom cells. In this case we inspect the
535    # layer and pass the custom cell class as part of the `custom_objects`
536    # argument when calling `from_config`.
537    # See https://github.com/tensorflow/tensorflow/issues/26581 for more detail.
538    config = layer.get_config()
539    if go_backwards:
540      config['go_backwards'] = not config['go_backwards']
541    if 'custom_objects' in tf_inspect.getfullargspec(
542        layer.__class__.from_config).args:
543      custom_objects = {}
544      cell = getattr(layer, 'cell', None)
545      if cell is not None:
546        custom_objects[cell.__class__.__name__] = cell.__class__
547        # For StackedRNNCells
548        stacked_cells = getattr(cell, 'cells', [])
549        for c in stacked_cells:
550          custom_objects[c.__class__.__name__] = c.__class__
551      return layer.__class__.from_config(config, custom_objects=custom_objects)
552    else:
553      return layer.__class__.from_config(config)
554
555  @tf_utils.shape_type_conversion
556  def compute_output_shape(self, input_shape):
557    output_shape = self.forward_layer.compute_output_shape(input_shape)
558    if self.return_state:
559      state_shape = tf_utils.convert_shapes(output_shape[1:], to_tuples=False)
560      output_shape = tf_utils.convert_shapes(output_shape[0], to_tuples=False)
561    else:
562      output_shape = tf_utils.convert_shapes(output_shape, to_tuples=False)
563
564    if self.merge_mode == 'concat':
565      output_shape = output_shape.as_list()
566      output_shape[-1] *= 2
567      output_shape = tensor_shape.TensorShape(output_shape)
568    elif self.merge_mode is None:
569      output_shape = [output_shape, copy.copy(output_shape)]
570
571    if self.return_state:
572      if self.merge_mode is None:
573        return output_shape + state_shape + copy.copy(state_shape)
574      return [output_shape] + state_shape + copy.copy(state_shape)
575    return output_shape
576
577  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
578    """`Bidirectional.__call__` implements the same API as the wrapped `RNN`."""
579    inputs, initial_state, constants = _standardize_args(
580        inputs, initial_state, constants, self._num_constants)
581
582    if isinstance(inputs, list):
583      if len(inputs) > 1:
584        initial_state = inputs[1:]
585      inputs = inputs[0]
586
587    if initial_state is None and constants is None:
588      return super(Bidirectional, self).__call__(inputs, **kwargs)
589
590    # Applies the same workaround as in `RNN.__call__`
591    additional_inputs = []
592    additional_specs = []
593    if initial_state is not None:
594      # Check if `initial_state` can be splitted into half
595      num_states = len(initial_state)
596      if num_states % 2 > 0:
597        raise ValueError(
598            'When passing `initial_state` to a Bidirectional RNN, '
599            'the state should be a list containing the states of '
600            'the underlying RNNs. '
601            'Found: ' + str(initial_state))
602
603      kwargs['initial_state'] = initial_state
604      additional_inputs += initial_state
605      state_specs = [InputSpec(shape=backend.int_shape(state))
606                     for state in initial_state]
607      self.forward_layer.state_spec = state_specs[:num_states // 2]
608      self.backward_layer.state_spec = state_specs[num_states // 2:]
609      additional_specs += state_specs
610    if constants is not None:
611      kwargs['constants'] = constants
612      additional_inputs += constants
613      constants_spec = [InputSpec(shape=backend.int_shape(constant))
614                        for constant in constants]
615      self.forward_layer.constants_spec = constants_spec
616      self.backward_layer.constants_spec = constants_spec
617      additional_specs += constants_spec
618
619      self._num_constants = len(constants)
620      self.forward_layer._num_constants = self._num_constants
621      self.backward_layer._num_constants = self._num_constants
622
623    is_keras_tensor = backend.is_keras_tensor(additional_inputs[0])
624    for tensor in additional_inputs:
625      if backend.is_keras_tensor(tensor) != is_keras_tensor:
626        raise ValueError('The initial state of a Bidirectional'
627                         ' layer cannot be specified with a mix of'
628                         ' Keras tensors and non-Keras tensors'
629                         ' (a "Keras tensor" is a tensor that was'
630                         ' returned by a Keras layer, or by `Input`)')
631
632    if is_keras_tensor:
633      # Compute the full input spec, including state
634      full_input = [inputs] + additional_inputs
635      # The original input_spec is None since there could be a nested tensor
636      # input. Update the input_spec to match the inputs.
637      full_input_spec = [None for _ in range(len(nest.flatten(inputs)))
638                        ] + additional_specs
639      # Removing kwargs since the value are passed with input list.
640      kwargs['initial_state'] = None
641      kwargs['constants'] = None
642
643      # Perform the call with temporarily replaced input_spec
644      original_input_spec = self.input_spec
645      self.input_spec = full_input_spec
646      output = super(Bidirectional, self).__call__(full_input, **kwargs)
647      self.input_spec = original_input_spec
648      return output
649    else:
650      return super(Bidirectional, self).__call__(inputs, **kwargs)
651
652  def call(self,
653           inputs,
654           training=None,
655           mask=None,
656           initial_state=None,
657           constants=None):
658    """`Bidirectional.call` implements the same API as the wrapped `RNN`."""
659    kwargs = {}
660    if generic_utils.has_arg(self.layer.call, 'training'):
661      kwargs['training'] = training
662    if generic_utils.has_arg(self.layer.call, 'mask'):
663      kwargs['mask'] = mask
664    if generic_utils.has_arg(self.layer.call, 'constants'):
665      kwargs['constants'] = constants
666
667    if generic_utils.has_arg(self.layer.call, 'initial_state'):
668      if isinstance(inputs, list) and len(inputs) > 1:
669        # initial_states are keras tensors, which means they are passed in
670        # together with inputs as list. The initial_states need to be split into
671        # forward and backward section, and be feed to layers accordingly.
672        forward_inputs = [inputs[0]]
673        backward_inputs = [inputs[0]]
674        pivot = (len(inputs) - self._num_constants) // 2 + 1
675        # add forward initial state
676        forward_inputs += inputs[1:pivot]
677        if not self._num_constants:
678          # add backward initial state
679          backward_inputs += inputs[pivot:]
680        else:
681          # add backward initial state
682          backward_inputs += inputs[pivot:-self._num_constants]
683          # add constants for forward and backward layers
684          forward_inputs += inputs[-self._num_constants:]
685          backward_inputs += inputs[-self._num_constants:]
686        forward_state, backward_state = None, None
687        if 'constants' in kwargs:
688          kwargs['constants'] = None
689      elif initial_state is not None:
690        # initial_states are not keras tensors, eg eager tensor from np array.
691        # They are only passed in from kwarg initial_state, and should be passed
692        # to forward/backward layer via kwarg initial_state as well.
693        forward_inputs, backward_inputs = inputs, inputs
694        half = len(initial_state) // 2
695        forward_state = initial_state[:half]
696        backward_state = initial_state[half:]
697      else:
698        forward_inputs, backward_inputs = inputs, inputs
699        forward_state, backward_state = None, None
700
701      y = self.forward_layer(forward_inputs,
702                             initial_state=forward_state, **kwargs)
703      y_rev = self.backward_layer(backward_inputs,
704                                  initial_state=backward_state, **kwargs)
705    else:
706      y = self.forward_layer(inputs, **kwargs)
707      y_rev = self.backward_layer(inputs, **kwargs)
708
709    if self.return_state:
710      states = y[1:] + y_rev[1:]
711      y = y[0]
712      y_rev = y_rev[0]
713
714    if self.return_sequences:
715      time_dim = 0 if getattr(self.forward_layer, 'time_major', False) else 1
716      y_rev = backend.reverse(y_rev, time_dim)
717    if self.merge_mode == 'concat':
718      output = backend.concatenate([y, y_rev])
719    elif self.merge_mode == 'sum':
720      output = y + y_rev
721    elif self.merge_mode == 'ave':
722      output = (y + y_rev) / 2
723    elif self.merge_mode == 'mul':
724      output = y * y_rev
725    elif self.merge_mode is None:
726      output = [y, y_rev]
727    else:
728      raise ValueError(
729          'Unrecognized value for `merge_mode`: %s' % (self.merge_mode))
730
731    if self.return_state:
732      if self.merge_mode is None:
733        return output + states
734      return [output] + states
735    return output
736
737  def reset_states(self):
738    self.forward_layer.reset_states()
739    self.backward_layer.reset_states()
740
741  def build(self, input_shape):
742    with backend.name_scope(self.forward_layer.name):
743      self.forward_layer.build(input_shape)
744    with backend.name_scope(self.backward_layer.name):
745      self.backward_layer.build(input_shape)
746    self.built = True
747
748  def compute_mask(self, inputs, mask):
749    if isinstance(mask, list):
750      mask = mask[0]
751    if self.return_sequences:
752      if not self.merge_mode:
753        output_mask = [mask, mask]
754      else:
755        output_mask = mask
756    else:
757      output_mask = [None, None] if not self.merge_mode else None
758
759    if self.return_state:
760      states = self.forward_layer.states
761      state_mask = [None for _ in states]
762      if isinstance(output_mask, list):
763        return output_mask + state_mask * 2
764      return [output_mask] + state_mask * 2
765    return output_mask
766
767  @property
768  def constraints(self):
769    constraints = {}
770    if hasattr(self.forward_layer, 'constraints'):
771      constraints.update(self.forward_layer.constraints)
772      constraints.update(self.backward_layer.constraints)
773    return constraints
774
775  def get_config(self):
776    config = {'merge_mode': self.merge_mode}
777    if self._num_constants:
778      config['num_constants'] = self._num_constants
779
780    if hasattr(self, '_backward_layer_config'):
781      config['backward_layer'] = self._backward_layer_config
782    base_config = super(Bidirectional, self).get_config()
783    return dict(list(base_config.items()) + list(config.items()))
784
785  @classmethod
786  def from_config(cls, config, custom_objects=None):
787    # Instead of updating the input, create a copy and use that.
788    config = copy.deepcopy(config)
789    num_constants = config.pop('num_constants', 0)
790    # Handle forward layer instantiation (as would parent class).
791    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
792    config['layer'] = deserialize_layer(
793        config['layer'], custom_objects=custom_objects)
794    # Handle (optional) backward layer instantiation.
795    backward_layer_config = config.pop('backward_layer', None)
796    if backward_layer_config is not None:
797      backward_layer = deserialize_layer(
798          backward_layer_config, custom_objects=custom_objects)
799      config['backward_layer'] = backward_layer
800    # Instantiate the wrapper, adjust it and return it.
801    layer = cls(**config)
802    layer._num_constants = num_constants
803    return layer
804