1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16# pylint: disable=g-classes-have-attributes 17"""Wrapper layers: layers that augment the functionality of another layer.""" 18 19import copy 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.keras import backend 24from tensorflow.python.keras.engine.base_layer import Layer 25from tensorflow.python.keras.engine.input_spec import InputSpec 26from tensorflow.python.keras.layers.recurrent import _standardize_args 27from tensorflow.python.keras.utils import generic_utils 28from tensorflow.python.keras.utils import layer_utils 29from tensorflow.python.keras.utils import tf_inspect 30from tensorflow.python.keras.utils import tf_utils 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops.ragged import ragged_tensor 33from tensorflow.python.util import nest 34from tensorflow.python.util.tf_export import keras_export 35 36 37@keras_export('keras.layers.Wrapper') 38class Wrapper(Layer): 39 """Abstract wrapper base class. 40 41 Wrappers take another layer and augment it in various ways. 42 Do not use this class as a layer, it is only an abstract base class. 43 Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. 44 45 Args: 46 layer: The layer to be wrapped. 47 """ 48 49 def __init__(self, layer, **kwargs): 50 assert isinstance(layer, Layer) 51 self.layer = layer 52 super(Wrapper, self).__init__(**kwargs) 53 54 def build(self, input_shape=None): 55 if not self.layer.built: 56 self.layer.build(input_shape) 57 self.layer.built = True 58 self.built = True 59 60 @property 61 def activity_regularizer(self): 62 if hasattr(self.layer, 'activity_regularizer'): 63 return self.layer.activity_regularizer 64 else: 65 return None 66 67 def get_config(self): 68 config = {'layer': generic_utils.serialize_keras_object(self.layer)} 69 base_config = super(Wrapper, self).get_config() 70 return dict(list(base_config.items()) + list(config.items())) 71 72 @classmethod 73 def from_config(cls, config, custom_objects=None): 74 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 75 # Avoid mutating the input dict 76 config = copy.deepcopy(config) 77 layer = deserialize_layer( 78 config.pop('layer'), custom_objects=custom_objects) 79 return cls(layer, **config) 80 81 82@keras_export('keras.layers.TimeDistributed') 83class TimeDistributed(Wrapper): 84 """This wrapper allows to apply a layer to every temporal slice of an input. 85 86 Every input should be at least 3D, and the dimension of index one of the 87 first input will be considered to be the temporal dimension. 88 89 Consider a batch of 32 video samples, where each sample is a 128x128 RGB image 90 with `channels_last` data format, across 10 timesteps. 91 The batch input shape is `(32, 10, 128, 128, 3)`. 92 93 You can then use `TimeDistributed` to apply the same `Conv2D` layer to each 94 of the 10 timesteps, independently: 95 96 >>> inputs = tf.keras.Input(shape=(10, 128, 128, 3)) 97 >>> conv_2d_layer = tf.keras.layers.Conv2D(64, (3, 3)) 98 >>> outputs = tf.keras.layers.TimeDistributed(conv_2d_layer)(inputs) 99 >>> outputs.shape 100 TensorShape([None, 10, 126, 126, 64]) 101 102 Because `TimeDistributed` applies the same instance of `Conv2D` to each of the 103 timestamps, the same set of weights are used at each timestamp. 104 105 Args: 106 layer: a `tf.keras.layers.Layer` instance. 107 108 Call arguments: 109 inputs: Input tensor of shape (batch, time, ...) or nested tensors, 110 and each of which has shape (batch, time, ...). 111 training: Python boolean indicating whether the layer should behave in 112 training mode or in inference mode. This argument is passed to the 113 wrapped layer (only if the layer supports this argument). 114 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 115 a given timestep should be masked. This argument is passed to the 116 wrapped layer (only if the layer supports this argument). 117 118 Raises: 119 ValueError: If not initialized with a `tf.keras.layers.Layer` instance. 120 """ 121 122 def __init__(self, layer, **kwargs): 123 if not isinstance(layer, Layer): 124 raise ValueError( 125 'Please initialize `TimeDistributed` layer with a ' 126 '`tf.keras.layers.Layer` instance. You passed: {input}'.format( 127 input=layer)) 128 super(TimeDistributed, self).__init__(layer, **kwargs) 129 self.supports_masking = True 130 131 # It is safe to use the fast, reshape-based approach with all of our 132 # built-in Layers. 133 self._always_use_reshape = ( 134 layer_utils.is_builtin_layer(layer) and 135 not getattr(layer, 'stateful', False)) 136 137 def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None): 138 """Finds non-specific dimensions in the static shapes. 139 140 The static shapes are replaced with the corresponding dynamic shapes of the 141 tensor. 142 Args: 143 init_tuple: a tuple, the first part of the output shape 144 tensor: the tensor from which to get the (static and dynamic) shapes 145 as the last part of the output shape 146 start_idx: int, which indicate the first dimension to take from 147 the static shape of the tensor 148 int_shape: an alternative static shape to take as the last part 149 of the output shape 150 Returns: 151 The new int_shape with the first part from init_tuple 152 and the last part from either `int_shape` (if provided) 153 or `tensor.shape`, where every `None` is replaced by 154 the corresponding dimension from `tf.shape(tensor)`. 155 """ 156 # replace all None in int_shape by backend.shape 157 if int_shape is None: 158 int_shape = backend.int_shape(tensor)[start_idx:] 159 if isinstance(int_shape, tensor_shape.TensorShape): 160 int_shape = int_shape.as_list() 161 if not any(not s for s in int_shape): 162 return init_tuple + tuple(int_shape) 163 shape = backend.shape(tensor) 164 int_shape = list(int_shape) 165 for i, s in enumerate(int_shape): 166 if not s: 167 int_shape[i] = shape[start_idx + i] 168 return init_tuple + tuple(int_shape) 169 170 def _remove_timesteps(self, dims): 171 dims = dims.as_list() 172 return tensor_shape.TensorShape([dims[0]] + dims[2:]) 173 174 def build(self, input_shape): 175 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 176 input_dims = nest.flatten( 177 nest.map_structure(lambda x: x.ndims, input_shape)) 178 if any(dim < 3 for dim in input_dims): 179 raise ValueError( 180 '`TimeDistributed` Layer should be passed an `input_shape ` ' 181 'with at least 3 dimensions, received: ' + str(input_shape)) 182 # Don't enforce the batch or time dimension. 183 self.input_spec = nest.map_structure( 184 lambda x: InputSpec(shape=[None, None] + x.as_list()[2:]), input_shape) 185 child_input_shape = nest.map_structure(self._remove_timesteps, input_shape) 186 child_input_shape = tf_utils.convert_shapes(child_input_shape) 187 super(TimeDistributed, self).build(tuple(child_input_shape)) 188 self.built = True 189 190 def compute_output_shape(self, input_shape): 191 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 192 193 child_input_shape = nest.map_structure(self._remove_timesteps, input_shape) 194 child_output_shape = self.layer.compute_output_shape(child_input_shape) 195 child_output_shape = tf_utils.convert_shapes( 196 child_output_shape, to_tuples=False) 197 timesteps = tf_utils.convert_shapes(input_shape) 198 timesteps = nest.flatten(timesteps)[1] 199 200 def insert_timesteps(dims): 201 dims = dims.as_list() 202 return tensor_shape.TensorShape([dims[0], timesteps] + dims[1:]) 203 204 return nest.map_structure(insert_timesteps, child_output_shape) 205 206 def call(self, inputs, training=None, mask=None): 207 kwargs = {} 208 if generic_utils.has_arg(self.layer.call, 'training'): 209 kwargs['training'] = training 210 211 input_shape = nest.map_structure( 212 lambda x: tensor_shape.TensorShape(backend.int_shape(x)), inputs) 213 batch_size = tf_utils.convert_shapes(input_shape) 214 batch_size = nest.flatten(batch_size)[0] 215 if batch_size and not self._always_use_reshape: 216 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs) 217 is_ragged_input = row_lengths is not None 218 input_length = tf_utils.convert_shapes(input_shape) 219 input_length = nest.flatten(input_length)[1] 220 221 # batch size matters, use rnn-based implementation 222 def step(x, _): 223 output = self.layer(x, **kwargs) 224 return output, [] 225 226 _, outputs, _ = backend.rnn( 227 step, 228 inputs, 229 initial_states=[], 230 input_length=row_lengths[0] if is_ragged_input else input_length, 231 mask=mask, 232 unroll=False) 233 # pylint: disable=g-long-lambda 234 y = nest.map_structure( 235 lambda output: backend.maybe_convert_to_ragged( 236 is_ragged_input, output, row_lengths), outputs) 237 else: 238 # No batch size specified, therefore the layer will be able 239 # to process batches of any size. 240 # We can go with reshape-based implementation for performance. 241 is_ragged_input = nest.map_structure( 242 lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs) 243 is_ragged_input = nest.flatten(is_ragged_input) 244 if all(is_ragged_input): 245 input_values = nest.map_structure(lambda x: x.values, inputs) 246 input_row_lenghts = nest.map_structure( 247 lambda x: x.nested_row_lengths()[0], inputs) 248 y = self.layer(input_values, **kwargs) 249 y = nest.map_structure(ragged_tensor.RaggedTensor.from_row_lengths, y, 250 input_row_lenghts) 251 elif any(is_ragged_input): 252 raise ValueError('All inputs has to be either ragged or not, ' 253 'but not mixed. You passed: {}'.format(inputs)) 254 else: 255 input_length = tf_utils.convert_shapes(input_shape) 256 input_length = nest.flatten(input_length)[1] 257 if not input_length: 258 input_length = nest.map_structure(lambda x: array_ops.shape(x)[1], 259 inputs) 260 input_length = generic_utils.to_list(nest.flatten(input_length))[0] 261 262 inner_input_shape = nest.map_structure( 263 lambda x: self._get_shape_tuple((-1,), x, 2), inputs) 264 # Shape: (num_samples * timesteps, ...). And track the 265 # transformation in self._input_map. 266 inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs, 267 inner_input_shape) 268 # (num_samples * timesteps, ...) 269 if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None: 270 inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) 271 kwargs['mask'] = backend.reshape(mask, inner_mask_shape) 272 273 y = self.layer(inputs, **kwargs) 274 275 # Shape: (num_samples, timesteps, ...) 276 output_shape = self.compute_output_shape(input_shape) 277 # pylint: disable=g-long-lambda 278 output_shape = nest.map_structure( 279 lambda tensor, int_shape: self._get_shape_tuple( 280 (-1, input_length), tensor, 1, int_shape[2:]), y, output_shape) 281 y = nest.map_structure_up_to(y, array_ops.reshape, y, output_shape) 282 if not context.executing_eagerly(): 283 # Set the static shape for the result since it might be lost during 284 # array_ops reshape, eg, some `None` dim in the result could be 285 # inferred. 286 nest.map_structure_up_to( 287 y, lambda tensor, shape: tensor.set_shape(shape), y, 288 self.compute_output_shape(input_shape)) 289 290 return y 291 292 def compute_mask(self, inputs, mask=None): 293 """Computes an output mask tensor for Embedding layer. 294 295 This is based on the inputs, mask, and the inner layer. 296 If batch size is specified: 297 Simply return the input `mask`. (An rnn-based implementation with 298 more than one rnn inputs is required but not supported in tf.keras yet.) 299 Otherwise we call `compute_mask` of the inner layer at each time step. 300 If the output mask at each time step is not `None`: 301 (E.g., inner layer is Masking or RNN) 302 Concatenate all of them and return the concatenation. 303 If the output mask at each time step is `None` and the input mask is not 304 `None`:(E.g., inner layer is Dense) 305 Reduce the input_mask to 2 dimensions and return it. 306 Otherwise (both the output mask and the input mask are `None`): 307 (E.g., `mask` is not used at all) 308 Return `None`. 309 310 Args: 311 inputs: Tensor with shape [batch size, timesteps, ...] indicating the 312 input to TimeDistributed. If static shape information is available for 313 "batch size", `mask` is returned unmodified. 314 mask: Either None (indicating no masking) or a Tensor indicating the 315 input mask for TimeDistributed. The shape can be static or dynamic. 316 317 Returns: 318 Either None (no masking), or a [batch size, timesteps, ...] Tensor with 319 an output mask for the TimeDistributed layer with the shape beyond the 320 second dimension being the value of the input mask shape(if the computed 321 output mask is none), an output mask with the shape beyond the first 322 dimension being the value of the mask shape(if mask is not None) or 323 output mask with the shape beyond the first dimension being the 324 value of the computed output shape. 325 326 """ 327 # cases need to call the layer.compute_mask when input_mask is None: 328 # Masking layer and Embedding layer with mask_zero 329 input_shape = nest.map_structure( 330 lambda x: tensor_shape.TensorShape(backend.int_shape(x)), inputs) 331 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 332 batch_size = tf_utils.convert_shapes(input_shape) 333 batch_size = nest.flatten(batch_size)[0] 334 is_ragged_input = nest.map_structure( 335 lambda x: isinstance(x, ragged_tensor.RaggedTensor), inputs) 336 is_ragged_input = generic_utils.to_list(nest.flatten(is_ragged_input)) 337 if batch_size and not self._always_use_reshape or any(is_ragged_input): 338 # batch size matters, we currently do not handle mask explicitly, or if 339 # the layer always uses reshape approach, or the input is a ragged tensor. 340 return mask 341 inner_mask = mask 342 if inner_mask is not None: 343 inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) 344 inner_mask = backend.reshape(inner_mask, inner_mask_shape) 345 inner_input_shape = nest.map_structure( 346 lambda tensor: self._get_shape_tuple((-1,), tensor, 2), inputs) 347 inner_inputs = nest.map_structure_up_to(inputs, array_ops.reshape, inputs, 348 inner_input_shape) 349 output_mask = self.layer.compute_mask(inner_inputs, inner_mask) 350 if output_mask is None: 351 if mask is None: 352 return None 353 # input_mask is not None, and output_mask is None: 354 # we should return a not-None mask 355 output_mask = mask 356 for _ in range(2, len(backend.int_shape(mask))): 357 output_mask = backend.any(output_mask, axis=-1) 358 else: 359 # output_mask is not None. We need to reshape it 360 input_length = tf_utils.convert_shapes(input_shape) 361 input_length = nest.flatten(input_length)[1] 362 if not input_length: 363 input_length = nest.map_structure(lambda x: backend.shape(x)[1], inputs) 364 input_length = nest.flatten(input_length)[0] 365 output_mask_int_shape = backend.int_shape(output_mask) 366 if output_mask_int_shape is None: 367 # if the output_mask does not have a static shape, 368 # its shape must be the same as mask's 369 if mask is not None: 370 output_mask_int_shape = backend.int_shape(mask) 371 else: 372 input_shape = generic_utils.to_list(nest.flatten(input_shape))[0] 373 output_mask_int_shape = backend.compute_output_shape(input_shape)[:-1] 374 output_mask_shape = self._get_shape_tuple( 375 (-1, input_length), output_mask, 1, output_mask_int_shape[1:]) 376 output_mask = backend.reshape(output_mask, output_mask_shape) 377 return output_mask 378 379 380@keras_export('keras.layers.Bidirectional') 381class Bidirectional(Wrapper): 382 """Bidirectional wrapper for RNNs. 383 384 Args: 385 layer: `keras.layers.RNN` instance, such as `keras.layers.LSTM` or 386 `keras.layers.GRU`. It could also be a `keras.layers.Layer` instance 387 that meets the following criteria: 388 1. Be a sequence-processing layer (accepts 3D+ inputs). 389 2. Have a `go_backwards`, `return_sequences` and `return_state` 390 attribute (with the same semantics as for the `RNN` class). 391 3. Have an `input_spec` attribute. 392 4. Implement serialization via `get_config()` and `from_config()`. 393 Note that the recommended way to create new RNN layers is to write a 394 custom RNN cell and use it with `keras.layers.RNN`, instead of 395 subclassing `keras.layers.Layer` directly. 396 - When the `returns_sequences` is true, the output of the masked timestep 397 will be zero regardless of the layer's original `zero_output_for_mask` 398 value. 399 merge_mode: Mode by which outputs of the forward and backward RNNs will be 400 combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the 401 outputs will not be combined, they will be returned as a list. Default 402 value is 'concat'. 403 backward_layer: Optional `keras.layers.RNN`, or `keras.layers.Layer` 404 instance to be used to handle backwards input processing. 405 If `backward_layer` is not provided, the layer instance passed as the 406 `layer` argument will be used to generate the backward layer 407 automatically. 408 Note that the provided `backward_layer` layer should have properties 409 matching those of the `layer` argument, in particular it should have the 410 same values for `stateful`, `return_states`, `return_sequences`, etc. 411 In addition, `backward_layer` and `layer` should have different 412 `go_backwards` argument values. 413 A `ValueError` will be raised if these requirements are not met. 414 415 Call arguments: 416 The call arguments for this layer are the same as those of the wrapped RNN 417 layer. 418 Beware that when passing the `initial_state` argument during the call of 419 this layer, the first half in the list of elements in the `initial_state` 420 list will be passed to the forward RNN call and the last half in the list 421 of elements will be passed to the backward RNN call. 422 423 Raises: 424 ValueError: 425 1. If `layer` or `backward_layer` is not a `Layer` instance. 426 2. In case of invalid `merge_mode` argument. 427 3. If `backward_layer` has mismatched properties compared to `layer`. 428 429 Examples: 430 431 ```python 432 model = Sequential() 433 model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 10))) 434 model.add(Bidirectional(LSTM(10))) 435 model.add(Dense(5)) 436 model.add(Activation('softmax')) 437 model.compile(loss='categorical_crossentropy', optimizer='rmsprop') 438 439 # With custom backward layer 440 model = Sequential() 441 forward_layer = LSTM(10, return_sequences=True) 442 backward_layer = LSTM(10, activation='relu', return_sequences=True, 443 go_backwards=True) 444 model.add(Bidirectional(forward_layer, backward_layer=backward_layer, 445 input_shape=(5, 10))) 446 model.add(Dense(5)) 447 model.add(Activation('softmax')) 448 model.compile(loss='categorical_crossentropy', optimizer='rmsprop') 449 ``` 450 """ 451 452 def __init__(self, 453 layer, 454 merge_mode='concat', 455 weights=None, 456 backward_layer=None, 457 **kwargs): 458 if not isinstance(layer, Layer): 459 raise ValueError( 460 'Please initialize `Bidirectional` layer with a ' 461 '`Layer` instance. You passed: {input}'.format(input=layer)) 462 if backward_layer is not None and not isinstance(backward_layer, Layer): 463 raise ValueError('`backward_layer` need to be a `Layer` instance. ' 464 'You passed: {input}'.format(input=backward_layer)) 465 if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]: 466 raise ValueError('Invalid merge mode. ' 467 'Merge mode should be one of ' 468 '{"sum", "mul", "ave", "concat", None}') 469 # We don't want to track `layer` since we're already tracking the two copies 470 # of it we actually run. 471 self._setattr_tracking = False 472 super(Bidirectional, self).__init__(layer, **kwargs) 473 self._setattr_tracking = True 474 475 # Recreate the forward layer from the original layer config, so that it will 476 # not carry over any state from the layer. 477 self.forward_layer = self._recreate_layer_from_config(layer) 478 479 if backward_layer is None: 480 self.backward_layer = self._recreate_layer_from_config( 481 layer, go_backwards=True) 482 else: 483 self.backward_layer = backward_layer 484 # Keep the custom backward layer config, so that we can save it later. The 485 # layer's name might be updated below with prefix 'backward_', and we want 486 # to preserve the original config. 487 self._backward_layer_config = generic_utils.serialize_keras_object( 488 backward_layer) 489 490 self.forward_layer._name = 'forward_' + self.forward_layer.name 491 self.backward_layer._name = 'backward_' + self.backward_layer.name 492 493 self._verify_layer_config() 494 495 def force_zero_output_for_mask(layer): 496 # Force the zero_output_for_mask to be True if returning sequences. 497 if getattr(layer, 'zero_output_for_mask', None) is not None: 498 layer.zero_output_for_mask = layer.return_sequences 499 500 force_zero_output_for_mask(self.forward_layer) 501 force_zero_output_for_mask(self.backward_layer) 502 503 self.merge_mode = merge_mode 504 if weights: 505 nw = len(weights) 506 self.forward_layer.initial_weights = weights[:nw // 2] 507 self.backward_layer.initial_weights = weights[nw // 2:] 508 self.stateful = layer.stateful 509 self.return_sequences = layer.return_sequences 510 self.return_state = layer.return_state 511 self.supports_masking = True 512 self._trainable = True 513 self._num_constants = 0 514 self.input_spec = layer.input_spec 515 516 def _verify_layer_config(self): 517 """Ensure the forward and backward layers have valid common property.""" 518 if self.forward_layer.go_backwards == self.backward_layer.go_backwards: 519 raise ValueError('Forward layer and backward layer should have different ' 520 '`go_backwards` value.') 521 522 common_attributes = ('stateful', 'return_sequences', 'return_state') 523 for a in common_attributes: 524 forward_value = getattr(self.forward_layer, a) 525 backward_value = getattr(self.backward_layer, a) 526 if forward_value != backward_value: 527 raise ValueError( 528 'Forward layer and backward layer are expected to have the same ' 529 'value for attribute {attr}, got {forward} and {backward}'.format( 530 attr=a, forward=forward_value, backward=backward_value)) 531 532 def _recreate_layer_from_config(self, layer, go_backwards=False): 533 # When recreating the layer from its config, it is possible that the layer 534 # is a RNN layer that contains custom cells. In this case we inspect the 535 # layer and pass the custom cell class as part of the `custom_objects` 536 # argument when calling `from_config`. 537 # See https://github.com/tensorflow/tensorflow/issues/26581 for more detail. 538 config = layer.get_config() 539 if go_backwards: 540 config['go_backwards'] = not config['go_backwards'] 541 if 'custom_objects' in tf_inspect.getfullargspec( 542 layer.__class__.from_config).args: 543 custom_objects = {} 544 cell = getattr(layer, 'cell', None) 545 if cell is not None: 546 custom_objects[cell.__class__.__name__] = cell.__class__ 547 # For StackedRNNCells 548 stacked_cells = getattr(cell, 'cells', []) 549 for c in stacked_cells: 550 custom_objects[c.__class__.__name__] = c.__class__ 551 return layer.__class__.from_config(config, custom_objects=custom_objects) 552 else: 553 return layer.__class__.from_config(config) 554 555 @tf_utils.shape_type_conversion 556 def compute_output_shape(self, input_shape): 557 output_shape = self.forward_layer.compute_output_shape(input_shape) 558 if self.return_state: 559 state_shape = tf_utils.convert_shapes(output_shape[1:], to_tuples=False) 560 output_shape = tf_utils.convert_shapes(output_shape[0], to_tuples=False) 561 else: 562 output_shape = tf_utils.convert_shapes(output_shape, to_tuples=False) 563 564 if self.merge_mode == 'concat': 565 output_shape = output_shape.as_list() 566 output_shape[-1] *= 2 567 output_shape = tensor_shape.TensorShape(output_shape) 568 elif self.merge_mode is None: 569 output_shape = [output_shape, copy.copy(output_shape)] 570 571 if self.return_state: 572 if self.merge_mode is None: 573 return output_shape + state_shape + copy.copy(state_shape) 574 return [output_shape] + state_shape + copy.copy(state_shape) 575 return output_shape 576 577 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 578 """`Bidirectional.__call__` implements the same API as the wrapped `RNN`.""" 579 inputs, initial_state, constants = _standardize_args( 580 inputs, initial_state, constants, self._num_constants) 581 582 if isinstance(inputs, list): 583 if len(inputs) > 1: 584 initial_state = inputs[1:] 585 inputs = inputs[0] 586 587 if initial_state is None and constants is None: 588 return super(Bidirectional, self).__call__(inputs, **kwargs) 589 590 # Applies the same workaround as in `RNN.__call__` 591 additional_inputs = [] 592 additional_specs = [] 593 if initial_state is not None: 594 # Check if `initial_state` can be splitted into half 595 num_states = len(initial_state) 596 if num_states % 2 > 0: 597 raise ValueError( 598 'When passing `initial_state` to a Bidirectional RNN, ' 599 'the state should be a list containing the states of ' 600 'the underlying RNNs. ' 601 'Found: ' + str(initial_state)) 602 603 kwargs['initial_state'] = initial_state 604 additional_inputs += initial_state 605 state_specs = [InputSpec(shape=backend.int_shape(state)) 606 for state in initial_state] 607 self.forward_layer.state_spec = state_specs[:num_states // 2] 608 self.backward_layer.state_spec = state_specs[num_states // 2:] 609 additional_specs += state_specs 610 if constants is not None: 611 kwargs['constants'] = constants 612 additional_inputs += constants 613 constants_spec = [InputSpec(shape=backend.int_shape(constant)) 614 for constant in constants] 615 self.forward_layer.constants_spec = constants_spec 616 self.backward_layer.constants_spec = constants_spec 617 additional_specs += constants_spec 618 619 self._num_constants = len(constants) 620 self.forward_layer._num_constants = self._num_constants 621 self.backward_layer._num_constants = self._num_constants 622 623 is_keras_tensor = backend.is_keras_tensor(additional_inputs[0]) 624 for tensor in additional_inputs: 625 if backend.is_keras_tensor(tensor) != is_keras_tensor: 626 raise ValueError('The initial state of a Bidirectional' 627 ' layer cannot be specified with a mix of' 628 ' Keras tensors and non-Keras tensors' 629 ' (a "Keras tensor" is a tensor that was' 630 ' returned by a Keras layer, or by `Input`)') 631 632 if is_keras_tensor: 633 # Compute the full input spec, including state 634 full_input = [inputs] + additional_inputs 635 # The original input_spec is None since there could be a nested tensor 636 # input. Update the input_spec to match the inputs. 637 full_input_spec = [None for _ in range(len(nest.flatten(inputs))) 638 ] + additional_specs 639 # Removing kwargs since the value are passed with input list. 640 kwargs['initial_state'] = None 641 kwargs['constants'] = None 642 643 # Perform the call with temporarily replaced input_spec 644 original_input_spec = self.input_spec 645 self.input_spec = full_input_spec 646 output = super(Bidirectional, self).__call__(full_input, **kwargs) 647 self.input_spec = original_input_spec 648 return output 649 else: 650 return super(Bidirectional, self).__call__(inputs, **kwargs) 651 652 def call(self, 653 inputs, 654 training=None, 655 mask=None, 656 initial_state=None, 657 constants=None): 658 """`Bidirectional.call` implements the same API as the wrapped `RNN`.""" 659 kwargs = {} 660 if generic_utils.has_arg(self.layer.call, 'training'): 661 kwargs['training'] = training 662 if generic_utils.has_arg(self.layer.call, 'mask'): 663 kwargs['mask'] = mask 664 if generic_utils.has_arg(self.layer.call, 'constants'): 665 kwargs['constants'] = constants 666 667 if generic_utils.has_arg(self.layer.call, 'initial_state'): 668 if isinstance(inputs, list) and len(inputs) > 1: 669 # initial_states are keras tensors, which means they are passed in 670 # together with inputs as list. The initial_states need to be split into 671 # forward and backward section, and be feed to layers accordingly. 672 forward_inputs = [inputs[0]] 673 backward_inputs = [inputs[0]] 674 pivot = (len(inputs) - self._num_constants) // 2 + 1 675 # add forward initial state 676 forward_inputs += inputs[1:pivot] 677 if not self._num_constants: 678 # add backward initial state 679 backward_inputs += inputs[pivot:] 680 else: 681 # add backward initial state 682 backward_inputs += inputs[pivot:-self._num_constants] 683 # add constants for forward and backward layers 684 forward_inputs += inputs[-self._num_constants:] 685 backward_inputs += inputs[-self._num_constants:] 686 forward_state, backward_state = None, None 687 if 'constants' in kwargs: 688 kwargs['constants'] = None 689 elif initial_state is not None: 690 # initial_states are not keras tensors, eg eager tensor from np array. 691 # They are only passed in from kwarg initial_state, and should be passed 692 # to forward/backward layer via kwarg initial_state as well. 693 forward_inputs, backward_inputs = inputs, inputs 694 half = len(initial_state) // 2 695 forward_state = initial_state[:half] 696 backward_state = initial_state[half:] 697 else: 698 forward_inputs, backward_inputs = inputs, inputs 699 forward_state, backward_state = None, None 700 701 y = self.forward_layer(forward_inputs, 702 initial_state=forward_state, **kwargs) 703 y_rev = self.backward_layer(backward_inputs, 704 initial_state=backward_state, **kwargs) 705 else: 706 y = self.forward_layer(inputs, **kwargs) 707 y_rev = self.backward_layer(inputs, **kwargs) 708 709 if self.return_state: 710 states = y[1:] + y_rev[1:] 711 y = y[0] 712 y_rev = y_rev[0] 713 714 if self.return_sequences: 715 time_dim = 0 if getattr(self.forward_layer, 'time_major', False) else 1 716 y_rev = backend.reverse(y_rev, time_dim) 717 if self.merge_mode == 'concat': 718 output = backend.concatenate([y, y_rev]) 719 elif self.merge_mode == 'sum': 720 output = y + y_rev 721 elif self.merge_mode == 'ave': 722 output = (y + y_rev) / 2 723 elif self.merge_mode == 'mul': 724 output = y * y_rev 725 elif self.merge_mode is None: 726 output = [y, y_rev] 727 else: 728 raise ValueError( 729 'Unrecognized value for `merge_mode`: %s' % (self.merge_mode)) 730 731 if self.return_state: 732 if self.merge_mode is None: 733 return output + states 734 return [output] + states 735 return output 736 737 def reset_states(self): 738 self.forward_layer.reset_states() 739 self.backward_layer.reset_states() 740 741 def build(self, input_shape): 742 with backend.name_scope(self.forward_layer.name): 743 self.forward_layer.build(input_shape) 744 with backend.name_scope(self.backward_layer.name): 745 self.backward_layer.build(input_shape) 746 self.built = True 747 748 def compute_mask(self, inputs, mask): 749 if isinstance(mask, list): 750 mask = mask[0] 751 if self.return_sequences: 752 if not self.merge_mode: 753 output_mask = [mask, mask] 754 else: 755 output_mask = mask 756 else: 757 output_mask = [None, None] if not self.merge_mode else None 758 759 if self.return_state: 760 states = self.forward_layer.states 761 state_mask = [None for _ in states] 762 if isinstance(output_mask, list): 763 return output_mask + state_mask * 2 764 return [output_mask] + state_mask * 2 765 return output_mask 766 767 @property 768 def constraints(self): 769 constraints = {} 770 if hasattr(self.forward_layer, 'constraints'): 771 constraints.update(self.forward_layer.constraints) 772 constraints.update(self.backward_layer.constraints) 773 return constraints 774 775 def get_config(self): 776 config = {'merge_mode': self.merge_mode} 777 if self._num_constants: 778 config['num_constants'] = self._num_constants 779 780 if hasattr(self, '_backward_layer_config'): 781 config['backward_layer'] = self._backward_layer_config 782 base_config = super(Bidirectional, self).get_config() 783 return dict(list(base_config.items()) + list(config.items())) 784 785 @classmethod 786 def from_config(cls, config, custom_objects=None): 787 # Instead of updating the input, create a copy and use that. 788 config = copy.deepcopy(config) 789 num_constants = config.pop('num_constants', 0) 790 # Handle forward layer instantiation (as would parent class). 791 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 792 config['layer'] = deserialize_layer( 793 config['layer'], custom_objects=custom_objects) 794 # Handle (optional) backward layer instantiation. 795 backward_layer_config = config.pop('backward_layer', None) 796 if backward_layer_config is not None: 797 backward_layer = deserialize_layer( 798 backward_layer_config, custom_objects=custom_objects) 799 config['backward_layer'] = backward_layer 800 # Instantiate the wrapper, adjust it and return it. 801 layer = cls(**config) 802 layer._num_constants = num_constants 803 return layer 804