1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Wrapper layers: layers that augment the functionality of another layer. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import copy 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.keras import backend as K 27from tensorflow.python.keras.engine.base_layer import Layer 28from tensorflow.python.keras.engine.input_spec import InputSpec 29from tensorflow.python.keras.layers.recurrent import _standardize_args 30from tensorflow.python.keras.utils import generic_utils 31from tensorflow.python.keras.utils import layer_utils 32from tensorflow.python.keras.utils import tf_inspect 33from tensorflow.python.keras.utils import tf_utils 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops.ragged import ragged_tensor 36from tensorflow.python.util import nest 37from tensorflow.python.util.tf_export import keras_export 38 39 40@keras_export('keras.layers.Wrapper') 41class Wrapper(Layer): 42 """Abstract wrapper base class. 43 44 Wrappers take another layer and augment it in various ways. 45 Do not use this class as a layer, it is only an abstract base class. 46 Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. 47 48 Args: 49 layer: The layer to be wrapped. 50 """ 51 52 def __init__(self, layer, **kwargs): 53 assert isinstance(layer, Layer) 54 self.layer = layer 55 super(Wrapper, self).__init__(**kwargs) 56 57 def build(self, input_shape=None): 58 if not self.layer.built: 59 self.layer.build(input_shape) 60 self.layer.built = True 61 self.built = True 62 63 @property 64 def activity_regularizer(self): 65 if hasattr(self.layer, 'activity_regularizer'): 66 return self.layer.activity_regularizer 67 else: 68 return None 69 70 def get_config(self): 71 config = {'layer': generic_utils.serialize_keras_object(self.layer)} 72 base_config = super(Wrapper, self).get_config() 73 return dict(list(base_config.items()) + list(config.items())) 74 75 @classmethod 76 def from_config(cls, config, custom_objects=None): 77 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 78 # Avoid mutating the input dict 79 config = copy.deepcopy(config) 80 layer = deserialize_layer( 81 config.pop('layer'), custom_objects=custom_objects) 82 return cls(layer, **config) 83 84 85@keras_export('keras.layers.TimeDistributed') 86class TimeDistributed(Wrapper): 87 """This wrapper allows to apply a layer to every temporal slice of an input. 88 89 Every input should be at least 3D, and the dimension of index one of the 90 first input will be considered to be the temporal dimension. 91 92 Consider a batch of 32 video samples, where each sample is a 128x128 RGB image 93 with `channels_last` data format, across 10 timesteps. 94 The batch input shape is `(32, 10, 128, 128, 3)`. 95 96 You can then use `TimeDistributed` to apply the same `Conv2D` layer to each 97 of the 10 timesteps, independently: 98 99 >>> inputs = tf.keras.Input(shape=(10, 128, 128, 3)) 100 >>> conv_2d_layer = tf.keras.layers.Conv2D(64, (3, 3)) 101 >>> outputs = tf.keras.layers.TimeDistributed(conv_2d_layer)(inputs) 102 >>> outputs.shape 103 TensorShape([None, 10, 126, 126, 64]) 104 105 Because `TimeDistributed` applies the same instance of `Conv2D` to each of the 106 timestamps, the same set of weights are used at each timestamp. 107 108 Args: 109 layer: a `tf.keras.layers.Layer` instance. 110 111 Call arguments: 112 inputs: Input tensor of shape (batch, time, ...) or nested tensors, 113 and each of which has shape (batch, time, ...). 114 training: Python boolean indicating whether the layer should behave in 115 training mode or in inference mode. This argument is passed to the 116 wrapped layer (only if the layer supports this argument). 117 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 118 a given timestep should be masked. This argument is passed to the 119 wrapped layer (only if the layer supports this argument). 120 121 Raises: 122 ValueError: If not initialized with a `tf.keras.layers.Layer` instance. 123 """ 124 125 def __init__(self, layer, **kwargs): 126 if not isinstance(layer, Layer): 127 raise ValueError( 128 'Please initialize `TimeDistributed` layer with a ' 129 '`tf.keras.layers.Layer` instance. You passed: {input}'.format( 130 input=layer)) 131 super(TimeDistributed, self).__init__(layer, **kwargs) 132 self.supports_masking = True 133 134 # It is safe to use the fast, reshape-based approach with all of our 135 # built-in Layers. 136 self._always_use_reshape = ( 137 layer_utils.is_builtin_layer(layer) and 138 not getattr(layer, 'stateful', False)) 139 140 def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None): 141 """Finds non-specific dimensions in the static shapes. 142 143 The static shapes are replaced with the corresponding dynamic shapes of the 144 tensor. 145 Args: 146 init_tuple: a tuple, the first part of the output shape 147 tensor: the tensor from which to get the (static and dynamic) shapes 148 as the last part of the output shape 149 start_idx: int, which indicate the first dimension to take from 150 the static shape of the tensor 151 int_shape: an alternative static shape to take as the last part 152 of the output shape 153 Returns: 154 The new int_shape with the first part from init_tuple 155 and the last part from either `int_shape` (if provided) 156 or `tensor.shape`, where every `None` is replaced by 157 the corresponding dimension from `tf.shape(tensor)`. 158 """ 159 # replace all None in int_shape by K.shape 160 if int_shape is None: 161 int_shape = K.int_shape(tensor)[start_idx:] 162 if isinstance(int_shape, tensor_shape.TensorShape): 163 int_shape = int_shape.as_list() 164 if not any(not s for s in int_shape): 165 return init_tuple + tuple(int_shape) 166 shape = K.shape(tensor) 167 int_shape = list(int_shape) 168 for i, s in enumerate(int_shape): 169 if not s: 170 int_shape[i] = shape[start_idx + i] 171 return init_tuple + tuple(int_shape) 172 173 def _remove_timesteps(self, dims): 174 dims = dims.as_list() 175 return tensor_shape.TensorShape([dims[0]] + dims[2:]) 176 177 def build(self, input_shape): 178 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 179 input_dims = nest.flatten( 180 nest.map_structure(lambda x: x.ndims, input_shape)) 181 if any(dim < 3 for dim in input_dims): 182 raise ValueError( 183 '`TimeDistributed` Layer should be passed an `input_shape ` ' 184 'with at least 3 dimensions, received: ' + str(input_shape)) 185 # Don't enforce the batch or time dimension. 186 self.input_spec = nest.map_structure( 187 lambda x: InputSpec(shape=[None, None] + x.as_list()[2:]), input_shape) 188 child_input_shape = nest.map_structure(self._remove_timesteps, input_shape) 189 child_input_shape = tf_utils.convert_shapes(child_input_shape) 190 super(TimeDistributed, self).build(tuple(child_input_shape)) 191 self.built = True 192 193 def compute_output_shape(self, input_shape): 194 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 195 196 child_input_shape = nest.map_structure(self._remove_timesteps, input_shape) 197 child_output_shape = self.layer.compute_output_shape(child_input_shape) 198 child_output_shape = tf_utils.convert_shapes( 199 child_output_shape, to_tuples=False) 200 timesteps = tf_utils.convert_shapes(input_shape) 201 timesteps = nest.flatten(timesteps)[1] 202 203 def insert_timesteps(dims): 204 dims = dims.as_list() 205 return tensor_shape.TensorShape([dims[0], timesteps] + dims[1:]) 206 207 return nest.map_structure(insert_timesteps, child_output_shape) 208 209 def call(self, inputs, training=None, mask=None): 210 kwargs = {} 211 if generic_utils.has_arg(self.layer.call, 'training'): 212 kwargs['training'] = training 213 214 input_shape = nest.map_structure( 215 lambda x: tensor_shape.TensorShape(K.int_shape(x)), inputs) 216 batch_size = tf_utils.convert_shapes(input_shape) 217 batch_size = nest.flatten(batch_size)[0] 218 if batch_size and not self._always_use_reshape: 219 inputs, row_lengths = K.convert_inputs_if_ragged(inputs) 220 is_ragged_input = row_lengths is not None 221 input_length = tf_utils.convert_shapes(input_shape) 222 input_length = nest.flatten(input_length)[1] 223 224 # batch size matters, use rnn-based implementation 225 def step(x, _): 226 output = self.layer(x, **kwargs) 227 return output, [] 228 229 _, outputs, _ = K.rnn( 230 step, 231 inputs, 232 initial_states=[], 233 input_length=row_lengths[0] if is_ragged_input else input_length, 234 mask=mask, 235 unroll=False) 236 # pylint: disable=g-long-lambda 237 y = nest.map_structure( 238 lambda output: K.maybe_convert_to_ragged(is_ragged_input, output, 239 row_lengths), outputs) 240 else: 241 # No batch size specified, therefore the layer will be able 242 # to process batches of any size. 243 # We can go with reshape-based implementation for performance. 244 is_ragged_input = nest.map_structure( 245 lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs) 246 is_ragged_input = nest.flatten(is_ragged_input) 247 if all(is_ragged_input): 248 input_values = nest.map_structure(lambda x: x.values, inputs) 249 input_row_lenghts = nest.map_structure( 250 lambda x: x.nested_row_lengths()[0], inputs) 251 y = self.layer(input_values, **kwargs) 252 y = nest.map_structure(ragged_tensor.RaggedTensor.from_row_lengths, y, 253 input_row_lenghts) 254 elif any(is_ragged_input): 255 raise ValueError('All inputs has to be either ragged or not, ' 256 'but not mixed. You passed: {}'.format(inputs)) 257 else: 258 input_length = tf_utils.convert_shapes(input_shape) 259 input_length = nest.flatten(input_length)[1] 260 if not input_length: 261 input_length = nest.map_structure(lambda x: array_ops.shape(x)[1], 262 inputs) 263 input_length = generic_utils.to_list(nest.flatten(input_length))[0] 264 265 inner_input_shape = nest.map_structure( 266 lambda x: self._get_shape_tuple((-1,), x, 2), inputs) 267 # Shape: (num_samples * timesteps, ...). And track the 268 # transformation in self._input_map. 269 inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs, 270 inner_input_shape) 271 # (num_samples * timesteps, ...) 272 if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None: 273 inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) 274 kwargs['mask'] = K.reshape(mask, inner_mask_shape) 275 276 y = self.layer(inputs, **kwargs) 277 278 # Shape: (num_samples, timesteps, ...) 279 output_shape = self.compute_output_shape(input_shape) 280 # pylint: disable=g-long-lambda 281 output_shape = nest.map_structure( 282 lambda tensor, int_shape: self._get_shape_tuple( 283 (-1, input_length), tensor, 1, int_shape[2:]), y, output_shape) 284 y = nest.map_structure_up_to(y, array_ops.reshape, y, output_shape) 285 if not context.executing_eagerly(): 286 # Set the static shape for the result since it might be lost during 287 # array_ops reshape, eg, some `None` dim in the result could be 288 # inferred. 289 nest.map_structure_up_to( 290 y, lambda tensor, shape: tensor.set_shape(shape), y, 291 self.compute_output_shape(input_shape)) 292 293 return y 294 295 def compute_mask(self, inputs, mask=None): 296 """Computes an output mask tensor for Embedding layer. 297 298 This is based on the inputs, mask, and the inner layer. 299 If batch size is specified: 300 Simply return the input `mask`. (An rnn-based implementation with 301 more than one rnn inputs is required but not supported in tf.keras yet.) 302 Otherwise we call `compute_mask` of the inner layer at each time step. 303 If the output mask at each time step is not `None`: 304 (E.g., inner layer is Masking or RNN) 305 Concatenate all of them and return the concatenation. 306 If the output mask at each time step is `None` and the input mask is not 307 `None`:(E.g., inner layer is Dense) 308 Reduce the input_mask to 2 dimensions and return it. 309 Otherwise (both the output mask and the input mask are `None`): 310 (E.g., `mask` is not used at all) 311 Return `None`. 312 313 Args: 314 inputs: Tensor with shape [batch size, timesteps, ...] indicating the 315 input to TimeDistributed. If static shape information is available for 316 "batch size", `mask` is returned unmodified. 317 mask: Either None (indicating no masking) or a Tensor indicating the 318 input mask for TimeDistributed. The shape can be static or dynamic. 319 320 Returns: 321 Either None (no masking), or a [batch size, timesteps, ...] Tensor with 322 an output mask for the TimeDistributed layer with the shape beyond the 323 second dimension being the value of the input mask shape(if the computed 324 output mask is none), an output mask with the shape beyond the first 325 dimension being the value of the mask shape(if mask is not None) or 326 output mask with the shape beyond the first dimension being the 327 value of the computed output shape. 328 329 """ 330 # cases need to call the layer.compute_mask when input_mask is None: 331 # Masking layer and Embedding layer with mask_zero 332 input_shape = nest.map_structure( 333 lambda x: tensor_shape.TensorShape(K.int_shape(x)), inputs) 334 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 335 batch_size = tf_utils.convert_shapes(input_shape) 336 batch_size = nest.flatten(batch_size)[0] 337 is_ragged_input = nest.map_structure( 338 lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs) 339 is_ragged_input = generic_utils.to_list(nest.flatten(is_ragged_input)) 340 if batch_size and not self._always_use_reshape or any(is_ragged_input): 341 # batch size matters, we currently do not handle mask explicitly, or if 342 # the layer always uses reshape approach, or the input is a ragged tensor. 343 return mask 344 inner_mask = mask 345 if inner_mask is not None: 346 inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) 347 inner_mask = K.reshape(inner_mask, inner_mask_shape) 348 inner_input_shape = nest.map_structure( 349 lambda tensor: self._get_shape_tuple((-1,), tensor, 2), inputs) 350 inner_inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs, 351 inner_input_shape) 352 output_mask = self.layer.compute_mask(inner_inputs, inner_mask) 353 if output_mask is None: 354 if mask is None: 355 return None 356 # input_mask is not None, and output_mask is None: 357 # we should return a not-None mask 358 output_mask = mask 359 for _ in range(2, len(K.int_shape(mask))): 360 output_mask = K.any(output_mask, axis=-1) 361 else: 362 # output_mask is not None. We need to reshape it 363 input_length = tf_utils.convert_shapes(input_shape) 364 input_length = nest.flatten(input_length)[1] 365 if not input_length: 366 input_length = nest.map_structure(lambda x: K.shape(x)[1], inputs) 367 input_length = nest.flatten(input_length)[0] 368 output_mask_int_shape = K.int_shape(output_mask) 369 if output_mask_int_shape is None: 370 # if the output_mask does not have a static shape, 371 # its shape must be the same as mask's 372 if mask is not None: 373 output_mask_int_shape = K.int_shape(mask) 374 else: 375 input_shape = generic_utils.to_list(nest.flatten(input_shape))[0] 376 output_mask_int_shape = K.compute_output_shape(input_shape)[:-1] 377 output_mask_shape = self._get_shape_tuple( 378 (-1, input_length), output_mask, 1, output_mask_int_shape[1:]) 379 output_mask = K.reshape(output_mask, output_mask_shape) 380 return output_mask 381 382 383@keras_export('keras.layers.Bidirectional') 384class Bidirectional(Wrapper): 385 """Bidirectional wrapper for RNNs. 386 387 Args: 388 layer: `keras.layers.RNN` instance, such as `keras.layers.LSTM` or 389 `keras.layers.GRU`. It could also be a `keras.layers.Layer` instance 390 that meets the following criteria: 391 1. Be a sequence-processing layer (accepts 3D+ inputs). 392 2. Have a `go_backwards`, `return_sequences` and `return_state` 393 attribute (with the same semantics as for the `RNN` class). 394 3. Have an `input_spec` attribute. 395 4. Implement serialization via `get_config()` and `from_config()`. 396 Note that the recommended way to create new RNN layers is to write a 397 custom RNN cell and use it with `keras.layers.RNN`, instead of 398 subclassing `keras.layers.Layer` directly. 399 merge_mode: Mode by which outputs of the forward and backward RNNs will be 400 combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the 401 outputs will not be combined, they will be returned as a list. Default 402 value is 'concat'. 403 backward_layer: Optional `keras.layers.RNN`, or `keras.layers.Layer` 404 instance to be used to handle backwards input processing. 405 If `backward_layer` is not provided, the layer instance passed as the 406 `layer` argument will be used to generate the backward layer 407 automatically. 408 Note that the provided `backward_layer` layer should have properties 409 matching those of the `layer` argument, in particular it should have the 410 same values for `stateful`, `return_states`, `return_sequence`, etc. 411 In addition, `backward_layer` and `layer` should have different 412 `go_backwards` argument values. 413 A `ValueError` will be raised if these requirements are not met. 414 415 Call arguments: 416 The call arguments for this layer are the same as those of the wrapped RNN 417 layer. 418 Beware that when passing the `initial_state` argument during the call of 419 this layer, the first half in the list of elements in the `initial_state` 420 list will be passed to the forward RNN call and the last half in the list 421 of elements will be passed to the backward RNN call. 422 423 Raises: 424 ValueError: 425 1. If `layer` or `backward_layer` is not a `Layer` instance. 426 2. In case of invalid `merge_mode` argument. 427 3. If `backward_layer` has mismatched properties compared to `layer`. 428 429 Examples: 430 431 ```python 432 model = Sequential() 433 model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 10))) 434 model.add(Bidirectional(LSTM(10))) 435 model.add(Dense(5)) 436 model.add(Activation('softmax')) 437 model.compile(loss='categorical_crossentropy', optimizer='rmsprop') 438 439 # With custom backward layer 440 model = Sequential() 441 forward_layer = LSTM(10, return_sequences=True) 442 backward_layer = LSTM(10, activation='relu', return_sequences=True, 443 go_backwards=True) 444 model.add(Bidirectional(forward_layer, backward_layer=backward_layer, 445 input_shape=(5, 10))) 446 model.add(Dense(5)) 447 model.add(Activation('softmax')) 448 model.compile(loss='categorical_crossentropy', optimizer='rmsprop') 449 ``` 450 """ 451 452 def __init__(self, 453 layer, 454 merge_mode='concat', 455 weights=None, 456 backward_layer=None, 457 **kwargs): 458 if not isinstance(layer, Layer): 459 raise ValueError( 460 'Please initialize `Bidirectional` layer with a ' 461 '`Layer` instance. You passed: {input}'.format(input=layer)) 462 if backward_layer is not None and not isinstance(backward_layer, Layer): 463 raise ValueError('`backward_layer` need to be a `Layer` instance. ' 464 'You passed: {input}'.format(input=backward_layer)) 465 if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]: 466 raise ValueError('Invalid merge mode. ' 467 'Merge mode should be one of ' 468 '{"sum", "mul", "ave", "concat", None}') 469 # We don't want to track `layer` since we're already tracking the two copies 470 # of it we actually run. 471 self._setattr_tracking = False 472 super(Bidirectional, self).__init__(layer, **kwargs) 473 self._setattr_tracking = True 474 475 # Recreate the forward layer from the original layer config, so that it will 476 # not carry over any state from the layer. 477 self.forward_layer = self._recreate_layer_from_config(layer) 478 479 if backward_layer is None: 480 self.backward_layer = self._recreate_layer_from_config( 481 layer, go_backwards=True) 482 else: 483 self.backward_layer = backward_layer 484 # Keep the custom backward layer config, so that we can save it later. The 485 # layer's name might be updated below with prefix 'backward_', and we want 486 # to preserve the original config. 487 self._backward_layer_config = generic_utils.serialize_keras_object( 488 backward_layer) 489 490 self.forward_layer._name = 'forward_' + self.forward_layer.name 491 self.backward_layer._name = 'backward_' + self.backward_layer.name 492 493 self._verify_layer_config() 494 495 def force_zero_output_for_mask(layer): 496 # Force the zero_output_for_mask to be True if returning sequences. 497 if getattr(layer, 'zero_output_for_mask', None) is not None: 498 layer.zero_output_for_mask = layer.return_sequences 499 500 force_zero_output_for_mask(self.forward_layer) 501 force_zero_output_for_mask(self.backward_layer) 502 503 self.merge_mode = merge_mode 504 if weights: 505 nw = len(weights) 506 self.forward_layer.initial_weights = weights[:nw // 2] 507 self.backward_layer.initial_weights = weights[nw // 2:] 508 self.stateful = layer.stateful 509 self.return_sequences = layer.return_sequences 510 self.return_state = layer.return_state 511 self.supports_masking = True 512 self._trainable = True 513 self._num_constants = 0 514 self.input_spec = layer.input_spec 515 516 def _verify_layer_config(self): 517 """Ensure the forward and backward layers have valid common property.""" 518 if self.forward_layer.go_backwards == self.backward_layer.go_backwards: 519 raise ValueError('Forward layer and backward layer should have different ' 520 '`go_backwards` value.') 521 522 common_attributes = ('stateful', 'return_sequences', 'return_state') 523 for a in common_attributes: 524 forward_value = getattr(self.forward_layer, a) 525 backward_value = getattr(self.backward_layer, a) 526 if forward_value != backward_value: 527 raise ValueError( 528 'Forward layer and backward layer are expected to have the same ' 529 'value for attribute {attr}, got {forward} and {backward}'.format( 530 attr=a, forward=forward_value, backward=backward_value)) 531 532 def _recreate_layer_from_config(self, layer, go_backwards=False): 533 # When recreating the layer from its config, it is possible that the layer 534 # is a RNN layer that contains custom cells. In this case we inspect the 535 # layer and pass the custom cell class as part of the `custom_objects` 536 # argument when calling `from_config`. 537 # See https://github.com/tensorflow/tensorflow/issues/26581 for more detail. 538 config = layer.get_config() 539 if go_backwards: 540 config['go_backwards'] = not config['go_backwards'] 541 if 'custom_objects' in tf_inspect.getfullargspec( 542 layer.__class__.from_config).args: 543 custom_objects = {} 544 cell = getattr(layer, 'cell', None) 545 if cell is not None: 546 custom_objects[cell.__class__.__name__] = cell.__class__ 547 # For StackedRNNCells 548 stacked_cells = getattr(cell, 'cells', []) 549 for c in stacked_cells: 550 custom_objects[c.__class__.__name__] = c.__class__ 551 return layer.__class__.from_config(config, custom_objects=custom_objects) 552 else: 553 return layer.__class__.from_config(config) 554 555 @tf_utils.shape_type_conversion 556 def compute_output_shape(self, input_shape): 557 output_shape = self.forward_layer.compute_output_shape(input_shape) 558 if self.return_state: 559 state_shape = tf_utils.convert_shapes(output_shape[1:], to_tuples=False) 560 output_shape = tf_utils.convert_shapes(output_shape[0], to_tuples=False) 561 else: 562 output_shape = tf_utils.convert_shapes(output_shape, to_tuples=False) 563 564 if self.merge_mode == 'concat': 565 output_shape = output_shape.as_list() 566 output_shape[-1] *= 2 567 output_shape = tensor_shape.TensorShape(output_shape) 568 elif self.merge_mode is None: 569 output_shape = [output_shape, copy.copy(output_shape)] 570 571 if self.return_state: 572 if self.merge_mode is None: 573 return output_shape + state_shape + copy.copy(state_shape) 574 return [output_shape] + state_shape + copy.copy(state_shape) 575 return output_shape 576 577 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 578 """`Bidirectional.__call__` implements the same API as the wrapped `RNN`.""" 579 inputs, initial_state, constants = _standardize_args( 580 inputs, initial_state, constants, self._num_constants) 581 582 if isinstance(inputs, list): 583 if len(inputs) > 1: 584 initial_state = inputs[1:] 585 inputs = inputs[0] 586 587 if initial_state is None and constants is None: 588 return super(Bidirectional, self).__call__(inputs, **kwargs) 589 590 # Applies the same workaround as in `RNN.__call__` 591 additional_inputs = [] 592 additional_specs = [] 593 if initial_state is not None: 594 # Check if `initial_state` can be splitted into half 595 num_states = len(initial_state) 596 if num_states % 2 > 0: 597 raise ValueError( 598 'When passing `initial_state` to a Bidirectional RNN, ' 599 'the state should be a list containing the states of ' 600 'the underlying RNNs. ' 601 'Found: ' + str(initial_state)) 602 603 kwargs['initial_state'] = initial_state 604 additional_inputs += initial_state 605 state_specs = [InputSpec(shape=K.int_shape(state)) 606 for state in initial_state] 607 self.forward_layer.state_spec = state_specs[:num_states // 2] 608 self.backward_layer.state_spec = state_specs[num_states // 2:] 609 additional_specs += state_specs 610 if constants is not None: 611 kwargs['constants'] = constants 612 additional_inputs += constants 613 constants_spec = [InputSpec(shape=K.int_shape(constant)) 614 for constant in constants] 615 self.forward_layer.constants_spec = constants_spec 616 self.backward_layer.constants_spec = constants_spec 617 additional_specs += constants_spec 618 619 self._num_constants = len(constants) 620 self.forward_layer._num_constants = self._num_constants 621 self.backward_layer._num_constants = self._num_constants 622 623 is_keras_tensor = K.is_keras_tensor(additional_inputs[0]) 624 for tensor in additional_inputs: 625 if K.is_keras_tensor(tensor) != is_keras_tensor: 626 raise ValueError('The initial state of a Bidirectional' 627 ' layer cannot be specified with a mix of' 628 ' Keras tensors and non-Keras tensors' 629 ' (a "Keras tensor" is a tensor that was' 630 ' returned by a Keras layer, or by `Input`)') 631 632 if is_keras_tensor: 633 # Compute the full input spec, including state 634 full_input = [inputs] + additional_inputs 635 # The original input_spec is None since there could be a nested tensor 636 # input. Update the input_spec to match the inputs. 637 full_input_spec = [None for _ in range(len(nest.flatten(inputs))) 638 ] + additional_specs 639 # Removing kwargs since the value are passed with input list. 640 kwargs['initial_state'] = None 641 kwargs['constants'] = None 642 643 # Perform the call with temporarily replaced input_spec 644 original_input_spec = self.input_spec 645 self.input_spec = full_input_spec 646 output = super(Bidirectional, self).__call__(full_input, **kwargs) 647 self.input_spec = original_input_spec 648 return output 649 else: 650 return super(Bidirectional, self).__call__(inputs, **kwargs) 651 652 def call(self, 653 inputs, 654 training=None, 655 mask=None, 656 initial_state=None, 657 constants=None): 658 """`Bidirectional.call` implements the same API as the wrapped `RNN`.""" 659 kwargs = {} 660 if generic_utils.has_arg(self.layer.call, 'training'): 661 kwargs['training'] = training 662 if generic_utils.has_arg(self.layer.call, 'mask'): 663 kwargs['mask'] = mask 664 if generic_utils.has_arg(self.layer.call, 'constants'): 665 kwargs['constants'] = constants 666 667 if generic_utils.has_arg(self.layer.call, 'initial_state'): 668 if isinstance(inputs, list) and len(inputs) > 1: 669 # initial_states are keras tensors, which means they are passed in 670 # together with inputs as list. The initial_states need to be split into 671 # forward and backward section, and be feed to layers accordingly. 672 forward_inputs = [inputs[0]] 673 backward_inputs = [inputs[0]] 674 pivot = (len(inputs) - self._num_constants) // 2 + 1 675 # add forward initial state 676 forward_inputs += inputs[1:pivot] 677 if not self._num_constants: 678 # add backward initial state 679 backward_inputs += inputs[pivot:] 680 else: 681 # add backward initial state 682 backward_inputs += inputs[pivot:-self._num_constants] 683 # add constants for forward and backward layers 684 forward_inputs += inputs[-self._num_constants:] 685 backward_inputs += inputs[-self._num_constants:] 686 forward_state, backward_state = None, None 687 if 'constants' in kwargs: 688 kwargs['constants'] = None 689 elif initial_state is not None: 690 # initial_states are not keras tensors, eg eager tensor from np array. 691 # They are only passed in from kwarg initial_state, and should be passed 692 # to forward/backward layer via kwarg initial_state as well. 693 forward_inputs, backward_inputs = inputs, inputs 694 half = len(initial_state) // 2 695 forward_state = initial_state[:half] 696 backward_state = initial_state[half:] 697 else: 698 forward_inputs, backward_inputs = inputs, inputs 699 forward_state, backward_state = None, None 700 701 y = self.forward_layer(forward_inputs, 702 initial_state=forward_state, **kwargs) 703 y_rev = self.backward_layer(backward_inputs, 704 initial_state=backward_state, **kwargs) 705 else: 706 y = self.forward_layer(inputs, **kwargs) 707 y_rev = self.backward_layer(inputs, **kwargs) 708 709 if self.return_state: 710 states = y[1:] + y_rev[1:] 711 y = y[0] 712 y_rev = y_rev[0] 713 714 if self.return_sequences: 715 time_dim = 0 if getattr(self.forward_layer, 'time_major', False) else 1 716 y_rev = K.reverse(y_rev, time_dim) 717 if self.merge_mode == 'concat': 718 output = K.concatenate([y, y_rev]) 719 elif self.merge_mode == 'sum': 720 output = y + y_rev 721 elif self.merge_mode == 'ave': 722 output = (y + y_rev) / 2 723 elif self.merge_mode == 'mul': 724 output = y * y_rev 725 elif self.merge_mode is None: 726 output = [y, y_rev] 727 else: 728 raise ValueError( 729 'Unrecognized value for `merge_mode`: %s' % (self.merge_mode)) 730 731 if self.return_state: 732 if self.merge_mode is None: 733 return output + states 734 return [output] + states 735 return output 736 737 def reset_states(self): 738 self.forward_layer.reset_states() 739 self.backward_layer.reset_states() 740 741 def build(self, input_shape): 742 with K.name_scope(self.forward_layer.name): 743 self.forward_layer.build(input_shape) 744 with K.name_scope(self.backward_layer.name): 745 self.backward_layer.build(input_shape) 746 self.built = True 747 748 def compute_mask(self, inputs, mask): 749 if isinstance(mask, list): 750 mask = mask[0] 751 if self.return_sequences: 752 if not self.merge_mode: 753 output_mask = [mask, mask] 754 else: 755 output_mask = mask 756 else: 757 output_mask = [None, None] if not self.merge_mode else None 758 759 if self.return_state: 760 states = self.forward_layer.states 761 state_mask = [None for _ in states] 762 if isinstance(output_mask, list): 763 return output_mask + state_mask * 2 764 return [output_mask] + state_mask * 2 765 return output_mask 766 767 @property 768 def constraints(self): 769 constraints = {} 770 if hasattr(self.forward_layer, 'constraints'): 771 constraints.update(self.forward_layer.constraints) 772 constraints.update(self.backward_layer.constraints) 773 return constraints 774 775 def get_config(self): 776 config = {'merge_mode': self.merge_mode} 777 if self._num_constants: 778 config['num_constants'] = self._num_constants 779 780 if hasattr(self, '_backward_layer_config'): 781 config['backward_layer'] = self._backward_layer_config 782 base_config = super(Bidirectional, self).get_config() 783 return dict(list(base_config.items()) + list(config.items())) 784 785 @classmethod 786 def from_config(cls, config, custom_objects=None): 787 # Instead of updating the input, create a copy and use that. 788 config = copy.deepcopy(config) 789 num_constants = config.pop('num_constants', 0) 790 # Handle forward layer instantiation (as would parent class). 791 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 792 config['layer'] = deserialize_layer( 793 config['layer'], custom_objects=custom_objects) 794 # Handle (optional) backward layer instantiation. 795 backward_layer_config = config.pop('backward_layer', None) 796 if backward_layer_config is not None: 797 backward_layer = deserialize_layer( 798 backward_layer_config, custom_objects=custom_objects) 799 config['backward_layer'] = backward_layer 800 # Instantiate the wrapper, adjust it and return it. 801 layer = cls(**config) 802 layer._num_constants = num_constants 803 return layer 804