1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Recurrent layers and their base classes. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23 24import numpy as np 25 26from tensorflow.python.eager import context 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.keras import activations 29from tensorflow.python.keras import backend as K 30from tensorflow.python.keras import constraints 31from tensorflow.python.keras import initializers 32from tensorflow.python.keras import regularizers 33from tensorflow.python.keras.engine.base_layer import Layer 34from tensorflow.python.keras.engine.input_spec import InputSpec 35from tensorflow.python.keras.utils import generic_utils 36from tensorflow.python.keras.utils import tf_utils 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import state_ops 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.training.tracking import base as trackable 41from tensorflow.python.util import nest 42from tensorflow.python.util.tf_export import keras_export 43 44 45@keras_export('keras.layers.StackedRNNCells') 46class StackedRNNCells(Layer): 47 """Wrapper allowing a stack of RNN cells to behave as a single cell. 48 49 Used to implement efficient stacked RNNs. 50 51 Arguments: 52 cells: List of RNN cell instances. 53 54 Examples: 55 56 ```python 57 cells = [ 58 keras.layers.LSTMCell(output_dim), 59 keras.layers.LSTMCell(output_dim), 60 keras.layers.LSTMCell(output_dim), 61 ] 62 63 inputs = keras.Input((timesteps, input_dim)) 64 x = keras.layers.RNN(cells)(inputs) 65 ``` 66 """ 67 68 def __init__(self, cells, **kwargs): 69 for cell in cells: 70 if not hasattr(cell, 'call'): 71 raise ValueError('All cells must have a `call` method. ' 72 'received cells:', cells) 73 if not hasattr(cell, 'state_size'): 74 raise ValueError('All cells must have a ' 75 '`state_size` attribute. ' 76 'received cells:', cells) 77 self.cells = cells 78 # reverse_state_order determines whether the state size will be in a reverse 79 # order of the cells' state. User might want to set this to True to keep the 80 # existing behavior. This is only useful when use RNN(return_state=True) 81 # since the state will be returned as the same order of state_size. 82 self.reverse_state_order = kwargs.pop('reverse_state_order', False) 83 if self.reverse_state_order: 84 logging.warning('reverse_state_order=True in StackedRNNCells will soon ' 85 'be deprecated. Please update the code to work with the ' 86 'natural order of states if you reply on the RNN states, ' 87 'eg RNN(return_state=True).') 88 super(StackedRNNCells, self).__init__(**kwargs) 89 90 @property 91 def state_size(self): 92 return tuple(c.state_size for c in 93 (self.cells[::-1] if self.reverse_state_order else self.cells)) 94 95 @property 96 def output_size(self): 97 if getattr(self.cells[-1], 'output_size', None) is not None: 98 return self.cells[-1].output_size 99 elif _is_multiple_state(self.cells[-1].state_size): 100 return self.cells[-1].state_size[0] 101 else: 102 return self.cells[-1].state_size 103 104 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 105 initial_states = [] 106 for cell in self.cells[::-1] if self.reverse_state_order else self.cells: 107 get_initial_state_fn = getattr(cell, 'get_initial_state', None) 108 if get_initial_state_fn: 109 initial_states.append(get_initial_state_fn( 110 inputs=inputs, batch_size=batch_size, dtype=dtype)) 111 else: 112 initial_states.append(_generate_zero_filled_state_for_cell( 113 cell, inputs, batch_size, dtype)) 114 115 return tuple(initial_states) 116 117 def call(self, inputs, states, constants=None, **kwargs): 118 # Recover per-cell states. 119 state_size = (self.state_size[::-1] 120 if self.reverse_state_order else self.state_size) 121 nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) 122 123 # Call the cells in order and store the returned states. 124 new_nested_states = [] 125 for cell, states in zip(self.cells, nested_states): 126 states = states if nest.is_sequence(states) else [states] 127 # TF cell does not wrap the state into list when there is only one state. 128 is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None 129 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 130 if generic_utils.has_arg(cell.call, 'constants'): 131 inputs, states = cell.call(inputs, states, constants=constants, 132 **kwargs) 133 else: 134 inputs, states = cell.call(inputs, states, **kwargs) 135 new_nested_states.append(states) 136 137 return inputs, nest.pack_sequence_as(state_size, 138 nest.flatten(new_nested_states)) 139 140 @tf_utils.shape_type_conversion 141 def build(self, input_shape): 142 if isinstance(input_shape, list): 143 constants_shape = input_shape[1:] 144 input_shape = input_shape[0] 145 for cell in self.cells: 146 if isinstance(cell, Layer): 147 if generic_utils.has_arg(cell.call, 'constants'): 148 cell.build([input_shape] + constants_shape) 149 else: 150 cell.build(input_shape) 151 if getattr(cell, 'output_size', None) is not None: 152 output_dim = cell.output_size 153 elif _is_multiple_state(cell.state_size): 154 output_dim = cell.state_size[0] 155 else: 156 output_dim = cell.state_size 157 input_shape = tuple([input_shape[0]] + 158 tensor_shape.as_shape(output_dim).as_list()) 159 self.built = True 160 161 def get_config(self): 162 cells = [] 163 for cell in self.cells: 164 cells.append({ 165 'class_name': cell.__class__.__name__, 166 'config': cell.get_config() 167 }) 168 config = {'cells': cells} 169 base_config = super(StackedRNNCells, self).get_config() 170 return dict(list(base_config.items()) + list(config.items())) 171 172 @classmethod 173 def from_config(cls, config, custom_objects=None): 174 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 175 cells = [] 176 for cell_config in config.pop('cells'): 177 cells.append( 178 deserialize_layer(cell_config, custom_objects=custom_objects)) 179 return cls(cells, **config) 180 181 182@keras_export('keras.layers.RNN') 183class RNN(Layer): 184 """Base class for recurrent layers. 185 186 Arguments: 187 cell: A RNN cell instance or a list of RNN cell instances. 188 A RNN cell is a class that has: 189 - A `call(input_at_t, states_at_t)` method, returning 190 `(output_at_t, states_at_t_plus_1)`. The call method of the 191 cell can also take the optional argument `constants`, see 192 section "Note on passing external constants" below. 193 - A `state_size` attribute. This can be a single integer 194 (single state) in which case it is the size of the recurrent 195 state. This can also be a list/tuple of integers (one size per 196 state). 197 The `state_size` can also be TensorShape or tuple/list of 198 TensorShape, to represent high dimension state. 199 - A `output_size` attribute. This can be a single integer or a 200 TensorShape, which represent the shape of the output. For backward 201 compatible reason, if this attribute is not available for the 202 cell, the value will be inferred by the first element of the 203 `state_size`. 204 - A `get_initial_state(inputs=None, batch_size=None, dtype=None)` 205 method that creates a tensor meant to be fed to `call()` as the 206 initial state, if user didn't specify any initial state via other 207 means. The returned initial state should be in shape of 208 [batch, cell.state_size]. Cell might choose to create zero filled 209 tensor, or with other values based on the cell implementations. 210 `inputs` is the input tensor to the RNN layer, which should 211 contain the batch size as its shape[0], and also dtype. Note that 212 the shape[0] might be None during the graph construction. Either 213 the `inputs` or the pair of `batch` and `dtype `are provided. 214 `batch` is a scalar tensor that represent the batch size 215 of the input. `dtype` is `tf.dtype` that represent the dtype of 216 the input. 217 For backward compatible reason, if this method is not implemented 218 by the cell, RNN layer will create a zero filled tensors with the 219 size of [batch, cell.state_size]. 220 In the case that `cell` is a list of RNN cell instances, the cells 221 will be stacked on after the other in the RNN, implementing an 222 efficient stacked RNN. 223 return_sequences: Boolean. Whether to return the last output 224 in the output sequence, or the full sequence. 225 return_state: Boolean. Whether to return the last state 226 in addition to the output. 227 go_backwards: Boolean (default False). 228 If True, process the input sequence backwards and return the 229 reversed sequence. 230 stateful: Boolean (default False). If True, the last state 231 for each sample at index i in a batch will be used as initial 232 state for the sample of index i in the following batch. 233 unroll: Boolean (default False). 234 If True, the network will be unrolled, else a symbolic loop will be used. 235 Unrolling can speed-up a RNN, 236 although it tends to be more memory-intensive. 237 Unrolling is only suitable for short sequences. 238 time_major: The shape format of the `inputs` and `outputs` tensors. 239 If True, the inputs and outputs will be in shape 240 `(timesteps, batch, ...)`, whereas in the False case, it will be 241 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 242 efficient because it avoids transposes at the beginning and end of the 243 RNN calculation. However, most TensorFlow data is batch-major, so by 244 default this function accepts input and emits output in batch-major 245 form. 246 247 Call arguments: 248 inputs: Input tensor. 249 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 250 a given timestep should be masked. 251 training: Python boolean indicating whether the layer should behave in 252 training mode or in inference mode. This argument is passed to the cell 253 when calling it. This is for use with cells that use dropout. 254 initial_state: List of initial state tensors to be passed to the first 255 call of the cell. 256 constants: List of constant tensors to be passed to the cell at each 257 timestep. 258 259 Input shape: 260 N-D tensor with shape `(batch_size, timesteps, ...)` or 261 `(timesteps, batch_size, ...)` when time_major is True. 262 263 Output shape: 264 - If `return_state`: a list of tensors. The first tensor is 265 the output. The remaining tensors are the last states, 266 each with shape `(batch_size, state_size)`, where `state_size` could 267 be a high dimension tensor shape. 268 - If `return_sequences`: N-D tensor with shape 269 `(batch_size, timesteps, output_size)`, where `output_size` could 270 be a high dimension tensor shape, or 271 `(timesteps, batch_size, output_size)` when `time_major` is True. 272 - Else, N-D tensor with shape `(batch_size, output_size)`, where 273 `output_size` could be a high dimension tensor shape. 274 275 Masking: 276 This layer supports masking for input data with a variable number 277 of timesteps. To introduce masks to your data, 278 use an [Embedding](embeddings.md) layer with the `mask_zero` parameter 279 set to `True`. 280 281 Note on using statefulness in RNNs: 282 You can set RNN layers to be 'stateful', which means that the states 283 computed for the samples in one batch will be reused as initial states 284 for the samples in the next batch. This assumes a one-to-one mapping 285 between samples in different successive batches. 286 287 To enable statefulness: 288 - Specify `stateful=True` in the layer constructor. 289 - Specify a fixed batch size for your model, by passing 290 If sequential model: 291 `batch_input_shape=(...)` to the first layer in your model. 292 Else for functional model with 1 or more Input layers: 293 `batch_shape=(...)` to all the first layers in your model. 294 This is the expected shape of your inputs 295 *including the batch size*. 296 It should be a tuple of integers, e.g. `(32, 10, 100)`. 297 - Specify `shuffle=False` when calling fit(). 298 299 To reset the states of your model, call `.reset_states()` on either 300 a specific layer, or on your entire model. 301 302 Note on specifying the initial state of RNNs: 303 You can specify the initial state of RNN layers symbolically by 304 calling them with the keyword argument `initial_state`. The value of 305 `initial_state` should be a tensor or list of tensors representing 306 the initial state of the RNN layer. 307 308 You can specify the initial state of RNN layers numerically by 309 calling `reset_states` with the keyword argument `states`. The value of 310 `states` should be a numpy array or list of numpy arrays representing 311 the initial state of the RNN layer. 312 313 Note on passing external constants to RNNs: 314 You can pass "external" constants to the cell using the `constants` 315 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This 316 requires that the `cell.call` method accepts the same keyword argument 317 `constants`. Such constants can be used to condition the cell 318 transformation on additional static inputs (not changing over time), 319 a.k.a. an attention mechanism. 320 321 Examples: 322 323 ```python 324 # First, let's define a RNN Cell, as a layer subclass. 325 326 class MinimalRNNCell(keras.layers.Layer): 327 328 def __init__(self, units, **kwargs): 329 self.units = units 330 self.state_size = units 331 super(MinimalRNNCell, self).__init__(**kwargs) 332 333 def build(self, input_shape): 334 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 335 initializer='uniform', 336 name='kernel') 337 self.recurrent_kernel = self.add_weight( 338 shape=(self.units, self.units), 339 initializer='uniform', 340 name='recurrent_kernel') 341 self.built = True 342 343 def call(self, inputs, states): 344 prev_output = states[0] 345 h = K.dot(inputs, self.kernel) 346 output = h + K.dot(prev_output, self.recurrent_kernel) 347 return output, [output] 348 349 # Let's use this cell in a RNN layer: 350 351 cell = MinimalRNNCell(32) 352 x = keras.Input((None, 5)) 353 layer = RNN(cell) 354 y = layer(x) 355 356 # Here's how to use the cell to build a stacked RNN: 357 358 cells = [MinimalRNNCell(32), MinimalRNNCell(64)] 359 x = keras.Input((None, 5)) 360 layer = RNN(cells) 361 y = layer(x) 362 ``` 363 """ 364 365 def __init__(self, 366 cell, 367 return_sequences=False, 368 return_state=False, 369 go_backwards=False, 370 stateful=False, 371 unroll=False, 372 time_major=False, 373 **kwargs): 374 if isinstance(cell, (list, tuple)): 375 cell = StackedRNNCells(cell) 376 if not hasattr(cell, 'call'): 377 raise ValueError('`cell` should have a `call` method. ' 378 'The RNN was passed:', cell) 379 if not hasattr(cell, 'state_size'): 380 raise ValueError('The RNN cell should have ' 381 'an attribute `state_size` ' 382 '(tuple of integers, ' 383 'one integer per RNN state).') 384 # If True, the output for masked timestep will be zeros, whereas in the 385 # False case, output from previous timestep is returned for masked timestep. 386 self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False) 387 388 if 'input_shape' not in kwargs and ( 389 'input_dim' in kwargs or 'input_length' in kwargs): 390 input_shape = (kwargs.pop('input_length', None), 391 kwargs.pop('input_dim', None)) 392 kwargs['input_shape'] = input_shape 393 394 super(RNN, self).__init__(**kwargs) 395 self.cell = cell 396 self.return_sequences = return_sequences 397 self.return_state = return_state 398 self.go_backwards = go_backwards 399 self.stateful = stateful 400 self.unroll = unroll 401 self.time_major = time_major 402 403 self.supports_masking = True 404 # The input shape is unknown yet, it could have nested tensor inputs, and 405 # the input spec will be the list of specs for flattened inputs. 406 self.input_spec = None 407 self.state_spec = None 408 self._states = None 409 self.constants_spec = None 410 self._num_constants = None 411 self._num_inputs = None 412 413 @property 414 def states(self): 415 if self._states is None: 416 state = nest.map_structure(lambda _: None, self.cell.state_size) 417 return state if nest.is_sequence(self.cell.state_size) else [state] 418 return self._states 419 420 @states.setter 421 # Automatic tracking catches "self._states" which adds an extra weight and 422 # breaks HDF5 checkpoints. 423 @trackable.no_automatic_dependency_tracking 424 def states(self, states): 425 self._states = states 426 427 def compute_output_shape(self, input_shape): 428 if isinstance(input_shape, list): 429 input_shape = input_shape[0] 430 # Check whether the input shape contains any nested shapes. It could be 431 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy 432 # inputs. 433 try: 434 input_shape = tensor_shape.as_shape(input_shape) 435 except (ValueError, TypeError): 436 # A nested tensor input 437 input_shape = nest.flatten(input_shape)[0] 438 439 batch = input_shape[0] 440 time_step = input_shape[1] 441 if self.time_major: 442 batch, time_step = time_step, batch 443 444 if _is_multiple_state(self.cell.state_size): 445 state_size = self.cell.state_size 446 else: 447 state_size = [self.cell.state_size] 448 449 def _get_output_shape(flat_output_size): 450 output_dim = tensor_shape.as_shape(flat_output_size).as_list() 451 if self.return_sequences: 452 if self.time_major: 453 output_shape = tensor_shape.as_shape([time_step, batch] + output_dim) 454 else: 455 output_shape = tensor_shape.as_shape([batch, time_step] + output_dim) 456 else: 457 output_shape = tensor_shape.as_shape([batch] + output_dim) 458 return output_shape 459 460 if getattr(self.cell, 'output_size', None) is not None: 461 # cell.output_size could be nested structure. 462 output_shape = nest.flatten(nest.map_structure( 463 _get_output_shape, self.cell.output_size)) 464 output_shape = output_shape[0] if len(output_shape) == 1 else output_shape 465 else: 466 # Note that state_size[0] could be a tensor_shape or int. 467 output_shape = _get_output_shape(state_size[0]) 468 469 if self.return_state: 470 def _get_state_shape(flat_state): 471 state_shape = [batch] + tensor_shape.as_shape(flat_state).as_list() 472 return tensor_shape.as_shape(state_shape) 473 state_shape = nest.map_structure(_get_state_shape, state_size) 474 return generic_utils.to_list(output_shape) + nest.flatten(state_shape) 475 else: 476 return output_shape 477 478 def compute_mask(self, inputs, mask): 479 # Time step masks must be the same for each input. 480 # This is because the mask for an RNN is of size [batch, time_steps, 1], 481 # and specifies which time steps should be skipped, and a time step 482 # must be skipped for all inputs. 483 # TODO(scottzhu): Should we accept multiple different masks? 484 mask = nest.flatten(mask)[0] 485 output_mask = mask if self.return_sequences else None 486 if self.return_state: 487 state_mask = [None for _ in self.states] 488 return [output_mask] + state_mask 489 else: 490 return output_mask 491 492 def build(self, input_shape): 493 # Note input_shape will be list of shapes of initial states and 494 # constants if these are passed in __call__. 495 if self._num_constants is not None: 496 constants_shape = input_shape[-self._num_constants:] # pylint: disable=invalid-unary-operand-type 497 constants_shape = nest.map_structure( 498 lambda s: tuple(tensor_shape.TensorShape(s).as_list()), 499 constants_shape) 500 else: 501 constants_shape = None 502 503 if isinstance(input_shape, list): 504 input_shape = input_shape[0] 505 # The input_shape here could be a nest structure. 506 507 # do the tensor_shape to shapes here. The input could be single tensor, or a 508 # nested structure of tensors. 509 def get_input_spec(shape): 510 if isinstance(shape, tensor_shape.TensorShape): 511 input_spec_shape = shape.as_list() 512 else: 513 input_spec_shape = list(shape) 514 batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) 515 if not self.stateful: 516 input_spec_shape[batch_index] = None 517 input_spec_shape[time_step_index] = None 518 return InputSpec(shape=tuple(input_spec_shape)) 519 520 def get_step_input_shape(shape): 521 if isinstance(shape, tensor_shape.TensorShape): 522 shape = tuple(shape.as_list()) 523 # remove the timestep from the input_shape 524 return shape[1:] if self.time_major else (shape[0],) + shape[2:] 525 526 # Check whether the input shape contains any nested shapes. It could be 527 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy 528 # inputs. 529 try: 530 input_shape = tensor_shape.as_shape(input_shape) 531 except (ValueError, TypeError): 532 # A nested tensor input 533 pass 534 535 if not nest.is_sequence(input_shape): 536 # This indicates the there is only one input. 537 if self.input_spec is not None: 538 self.input_spec[0] = get_input_spec(input_shape) 539 else: 540 self.input_spec = [get_input_spec(input_shape)] 541 step_input_shape = get_step_input_shape(input_shape) 542 else: 543 flat_input_shapes = nest.flatten(input_shape) 544 flat_input_shapes = nest.map_structure(get_input_spec, flat_input_shapes) 545 assert len(flat_input_shapes) == self._num_inputs 546 if self.input_spec is not None: 547 self.input_spec[:self._num_inputs] = flat_input_shapes 548 else: 549 self.input_spec = flat_input_shapes 550 step_input_shape = nest.map_structure(get_step_input_shape, input_shape) 551 552 # allow cell (if layer) to build before we set or validate state_spec 553 if isinstance(self.cell, Layer): 554 if constants_shape is not None: 555 self.cell.build([step_input_shape] + constants_shape) 556 else: 557 self.cell.build(step_input_shape) 558 559 # set or validate state_spec 560 if _is_multiple_state(self.cell.state_size): 561 state_size = list(self.cell.state_size) 562 else: 563 state_size = [self.cell.state_size] 564 565 if self.state_spec is not None: 566 # initial_state was passed in call, check compatibility 567 self._validate_state_spec(state_size, self.state_spec) 568 else: 569 self.state_spec = [ 570 InputSpec(shape=[None] + tensor_shape.as_shape(dim).as_list()) 571 for dim in state_size 572 ] 573 if self.stateful: 574 self.reset_states() 575 self.built = True 576 577 @staticmethod 578 def _validate_state_spec(cell_state_sizes, init_state_specs): 579 """Validate the state spec between the initial_state and the state_size. 580 581 Args: 582 cell_state_sizes: list, the `state_size` attribute from the cell. 583 init_state_specs: list, the `state_spec` from the initial_state that is 584 passed in `call()`. 585 586 Raises: 587 ValueError: When initial state spec is not compatible with the state size. 588 """ 589 validation_error = ValueError( 590 'An `initial_state` was passed that is not compatible with ' 591 '`cell.state_size`. Received `state_spec`={}; ' 592 'however `cell.state_size` is ' 593 '{}'.format(init_state_specs, cell_state_sizes)) 594 if len(cell_state_sizes) == len(init_state_specs): 595 for i in range(len(cell_state_sizes)): 596 if not tensor_shape.TensorShape( 597 # Ignore the first axis for init_state which is for batch 598 init_state_specs[i].shape[1:]).is_compatible_with( 599 tensor_shape.TensorShape(cell_state_sizes[i])): 600 raise validation_error 601 else: 602 raise validation_error 603 604 def get_initial_state(self, inputs): 605 get_initial_state_fn = getattr(self.cell, 'get_initial_state', None) 606 607 if nest.is_sequence(inputs): 608 # The input are nested sequences. Use the first element in the seq to get 609 # batch size and dtype. 610 inputs = nest.flatten(inputs)[0] 611 612 input_shape = array_ops.shape(inputs) 613 batch_size = input_shape[1] if self.time_major else input_shape[0] 614 dtype = inputs.dtype 615 if get_initial_state_fn: 616 init_state = get_initial_state_fn( 617 inputs=None, batch_size=batch_size, dtype=dtype) 618 else: 619 init_state = _generate_zero_filled_state(batch_size, self.cell.state_size, 620 dtype) 621 # Keras RNN expect the states in a list, even if it's a single state tensor. 622 if not nest.is_sequence(init_state): 623 init_state = [init_state] 624 # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple. 625 return list(init_state) 626 627 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 628 inputs, initial_state, constants = _standardize_args(inputs, 629 initial_state, 630 constants, 631 self._num_constants, 632 self._num_inputs) 633 # in case the real inputs is a nested structure, set the size of flatten 634 # input so that we can distinguish between real inputs, initial_state and 635 # constants. 636 self._num_inputs = len(nest.flatten(inputs)) 637 638 if initial_state is None and constants is None: 639 return super(RNN, self).__call__(inputs, **kwargs) 640 641 # If any of `initial_state` or `constants` are specified and are Keras 642 # tensors, then add them to the inputs and temporarily modify the 643 # input_spec to include them. 644 645 additional_inputs = [] 646 additional_specs = [] 647 if initial_state is not None: 648 additional_inputs += initial_state 649 self.state_spec = [ 650 InputSpec(shape=K.int_shape(state)) for state in initial_state 651 ] 652 additional_specs += self.state_spec 653 if constants is not None: 654 additional_inputs += constants 655 self.constants_spec = [ 656 InputSpec(shape=K.int_shape(constant)) for constant in constants 657 ] 658 self._num_constants = len(constants) 659 additional_specs += self.constants_spec 660 # at this point additional_inputs cannot be empty 661 is_keras_tensor = K.is_keras_tensor(additional_inputs[0]) 662 for tensor in additional_inputs: 663 if K.is_keras_tensor(tensor) != is_keras_tensor: 664 raise ValueError('The initial state or constants of an RNN' 665 ' layer cannot be specified with a mix of' 666 ' Keras tensors and non-Keras tensors' 667 ' (a "Keras tensor" is a tensor that was' 668 ' returned by a Keras layer, or by `Input`)') 669 670 if is_keras_tensor: 671 # Compute the full input spec, including state and constants 672 full_input = [inputs] + additional_inputs 673 # The original input_spec is None since there could be a nested tensor 674 # input. Update the input_spec to match the inputs. 675 full_input_spec = [None for _ in range(len(nest.flatten(inputs))) 676 ] + additional_specs 677 # Perform the call with temporarily replaced input_spec 678 self.input_spec = full_input_spec 679 output = super(RNN, self).__call__(full_input, **kwargs) 680 # Remove the additional_specs from input spec and keep the rest. It is 681 # important to keep since the input spec was populated by build(), and 682 # will be reused in the stateful=True. 683 self.input_spec = self.input_spec[:-len(additional_specs)] 684 return output 685 else: 686 if initial_state is not None: 687 kwargs['initial_state'] = initial_state 688 if constants is not None: 689 kwargs['constants'] = constants 690 return super(RNN, self).__call__(inputs, **kwargs) 691 692 def call(self, 693 inputs, 694 mask=None, 695 training=None, 696 initial_state=None, 697 constants=None): 698 inputs, initial_state, constants = self._process_inputs( 699 inputs, initial_state, constants) 700 701 if mask is not None: 702 # Time step masks must be the same for each input. 703 # TODO(scottzhu): Should we accept multiple different masks? 704 mask = nest.flatten(mask)[0] 705 706 if nest.is_sequence(inputs): 707 # In the case of nested input, use the first element for shape check. 708 input_shape = K.int_shape(nest.flatten(inputs)[0]) 709 else: 710 input_shape = K.int_shape(inputs) 711 timesteps = input_shape[0] if self.time_major else input_shape[1] 712 if self.unroll and timesteps is None: 713 raise ValueError('Cannot unroll a RNN if the ' 714 'time dimension is undefined. \n' 715 '- If using a Sequential model, ' 716 'specify the time dimension by passing ' 717 'an `input_shape` or `batch_input_shape` ' 718 'argument to your first layer. If your ' 719 'first layer is an Embedding, you can ' 720 'also use the `input_length` argument.\n' 721 '- If using the functional API, specify ' 722 'the time dimension by passing a `shape` ' 723 'or `batch_shape` argument to your Input layer.') 724 725 kwargs = {} 726 if generic_utils.has_arg(self.cell.call, 'training'): 727 kwargs['training'] = training 728 729 # TF RNN cells expect single tensor as state instead of list wrapped tensor. 730 is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None 731 if constants: 732 if not generic_utils.has_arg(self.cell.call, 'constants'): 733 raise ValueError('RNN cell does not support constants') 734 735 def step(inputs, states): 736 constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type 737 states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type 738 739 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 740 output, new_states = self.cell.call( 741 inputs, states, constants=constants, **kwargs) 742 if not nest.is_sequence(new_states): 743 new_states = [new_states] 744 return output, new_states 745 else: 746 747 def step(inputs, states): 748 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 749 output, new_states = self.cell.call(inputs, states, **kwargs) 750 if not nest.is_sequence(new_states): 751 new_states = [new_states] 752 return output, new_states 753 754 last_output, outputs, states = K.rnn( 755 step, 756 inputs, 757 initial_state, 758 constants=constants, 759 go_backwards=self.go_backwards, 760 mask=mask, 761 unroll=self.unroll, 762 input_length=timesteps, 763 time_major=self.time_major, 764 zero_output_for_mask=self.zero_output_for_mask) 765 if self.stateful: 766 updates = [] 767 for i in range(len(states)): 768 updates.append(state_ops.assign(self.states[i], states[i])) 769 self.add_update(updates, inputs) 770 771 if self.return_sequences: 772 output = outputs 773 else: 774 output = last_output 775 776 if self.return_state: 777 if not isinstance(states, (list, tuple)): 778 states = [states] 779 else: 780 states = list(states) 781 return generic_utils.to_list(output) + states 782 else: 783 return output 784 785 def _process_inputs(self, inputs, initial_state, constants): 786 # input shape: `(samples, time (padded with zeros), input_dim)` 787 # note that the .build() method of subclasses MUST define 788 # self.input_spec and self.state_spec with complete input shapes. 789 if (isinstance(inputs, collections.Sequence) 790 and not isinstance(inputs, tuple)): 791 # get initial_state from full input spec 792 # as they could be copied to multiple GPU. 793 if self._num_constants is None: 794 initial_state = inputs[1:] 795 else: 796 initial_state = inputs[1:-self._num_constants] 797 constants = inputs[-self._num_constants:] 798 if len(initial_state) == 0: 799 initial_state = None 800 inputs = inputs[0] 801 if initial_state is not None: 802 pass 803 elif self.stateful: 804 initial_state = self.states 805 else: 806 initial_state = self.get_initial_state(inputs) 807 808 if len(initial_state) != len(self.states): 809 raise ValueError('Layer has ' + str(len(self.states)) + 810 ' states but was passed ' + str(len(initial_state)) + 811 ' initial states.') 812 return inputs, initial_state, constants 813 814 def reset_states(self, states=None): 815 if not self.stateful: 816 raise AttributeError('Layer must be stateful.') 817 spec_shape = None if self.input_spec is None else self.input_spec[0].shape 818 if spec_shape is None: 819 # It is possible to have spec shape to be None, eg when construct a RNN 820 # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know 821 # it has 3 dim input, but not its full shape spec before build(). 822 batch_size = None 823 else: 824 batch_size = spec_shape[1] if self.time_major else spec_shape[0] 825 if not batch_size: 826 raise ValueError('If a RNN is stateful, it needs to know ' 827 'its batch size. Specify the batch size ' 828 'of your input tensors: \n' 829 '- If using a Sequential model, ' 830 'specify the batch size by passing ' 831 'a `batch_input_shape` ' 832 'argument to your first layer.\n' 833 '- If using the functional API, specify ' 834 'the batch size by passing a ' 835 '`batch_shape` argument to your Input layer.') 836 # initialize state if None 837 if self.states[0] is None: 838 if _is_multiple_state(self.cell.state_size): 839 self.states = [ 840 K.zeros([batch_size] + tensor_shape.as_shape(dim).as_list()) 841 for dim in self.cell.state_size 842 ] 843 else: 844 self.states = [ 845 K.zeros([batch_size] + 846 tensor_shape.as_shape(self.cell.state_size).as_list()) 847 ] 848 elif states is None: 849 if _is_multiple_state(self.cell.state_size): 850 for state, dim in zip(self.states, self.cell.state_size): 851 K.set_value(state, 852 np.zeros([batch_size] + 853 tensor_shape.as_shape(dim).as_list())) 854 else: 855 K.set_value(self.states[0], np.zeros( 856 [batch_size] + 857 tensor_shape.as_shape(self.cell.state_size).as_list())) 858 else: 859 if not isinstance(states, (list, tuple)): 860 states = [states] 861 if len(states) != len(self.states): 862 raise ValueError('Layer ' + self.name + ' expects ' + 863 str(len(self.states)) + ' states, ' 864 'but it received ' + str(len(states)) + 865 ' state values. Input received: ' + str(states)) 866 for index, (value, state) in enumerate(zip(states, self.states)): 867 if _is_multiple_state(self.cell.state_size): 868 dim = self.cell.state_size[index] 869 else: 870 dim = self.cell.state_size 871 if value.shape != tuple([batch_size] + 872 tensor_shape.as_shape(dim).as_list()): 873 raise ValueError( 874 'State ' + str(index) + ' is incompatible with layer ' + 875 self.name + ': expected shape=' + str( 876 (batch_size, dim)) + ', found shape=' + str(value.shape)) 877 # TODO(fchollet): consider batch calls to `set_value`. 878 K.set_value(state, value) 879 880 def get_config(self): 881 config = { 882 'return_sequences': self.return_sequences, 883 'return_state': self.return_state, 884 'go_backwards': self.go_backwards, 885 'stateful': self.stateful, 886 'unroll': self.unroll, 887 'time_major': self.time_major 888 } 889 if self._num_constants is not None: 890 config['num_constants'] = self._num_constants 891 if self.zero_output_for_mask: 892 config['zero_output_for_mask'] = self.zero_output_for_mask 893 894 cell_config = self.cell.get_config() 895 config['cell'] = { 896 'class_name': self.cell.__class__.__name__, 897 'config': cell_config 898 } 899 base_config = super(RNN, self).get_config() 900 return dict(list(base_config.items()) + list(config.items())) 901 902 @classmethod 903 def from_config(cls, config, custom_objects=None): 904 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 905 cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) 906 num_constants = config.pop('num_constants', None) 907 layer = cls(cell, **config) 908 layer._num_constants = num_constants 909 return layer 910 911 912@keras_export('keras.layers.AbstractRNNCell') 913class AbstractRNNCell(Layer): 914 """Abstract object representing an RNN cell. 915 916 This is the base class for implementing RNN cells with custom behavior. 917 918 Every `RNNCell` must have the properties below and implement `call` with 919 the signature `(output, next_state) = call(input, state)`. 920 921 Examples: 922 923 ```python 924 class MinimalRNNCell(AbstractRNNCell): 925 926 def __init__(self, units, **kwargs): 927 self.units = units 928 super(MinimalRNNCell, self).__init__(**kwargs) 929 930 @property 931 def state_size(self): 932 return self.units 933 934 def build(self, input_shape): 935 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 936 initializer='uniform', 937 name='kernel') 938 self.recurrent_kernel = self.add_weight( 939 shape=(self.units, self.units), 940 initializer='uniform', 941 name='recurrent_kernel') 942 self.built = True 943 944 def call(self, inputs, states): 945 prev_output = states[0] 946 h = K.dot(inputs, self.kernel) 947 output = h + K.dot(prev_output, self.recurrent_kernel) 948 return output, output 949 ``` 950 951 This definition of cell differs from the definition used in the literature. 952 In the literature, 'cell' refers to an object with a single scalar output. 953 This definition refers to a horizontal array of such units. 954 955 An RNN cell, in the most abstract setting, is anything that has 956 a state and performs some operation that takes a matrix of inputs. 957 This operation results in an output matrix with `self.output_size` columns. 958 If `self.state_size` is an integer, this operation also results in a new 959 state matrix with `self.state_size` columns. If `self.state_size` is a 960 (possibly nested tuple of) TensorShape object(s), then it should return a 961 matching structure of Tensors having shape `[batch_size].concatenate(s)` 962 for each `s` in `self.batch_size`. 963 """ 964 965 def call(self, inputs, states): 966 """The function that contains the logic for one RNN step calculation. 967 968 Args: 969 inputs: the input tensor, which is a slide from the overall RNN input by 970 the time dimension (usually the second dimension). 971 states: the state tensor from previous step, which has the same shape 972 as `(batch, state_size)`. In the case of timestep 0, it will be the 973 initial state user specified, or zero filled tensor otherwise. 974 975 Returns: 976 A tuple of two tensors: 977 1. output tensor for the current timestep, with size `output_size`. 978 2. state tensor for next step, which has the shape of `state_size`. 979 """ 980 raise NotImplementedError('Abstract method') 981 982 @property 983 def state_size(self): 984 """size(s) of state(s) used by this cell. 985 986 It can be represented by an Integer, a TensorShape or a tuple of Integers 987 or TensorShapes. 988 """ 989 raise NotImplementedError('Abstract method') 990 991 @property 992 def output_size(self): 993 """Integer or TensorShape: size of outputs produced by this cell.""" 994 raise NotImplementedError('Abstract method') 995 996 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 997 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 998 999 1000class DropoutRNNCellMixin(object): 1001 """Object that hold dropout related fields for RNN Cell. 1002 1003 This class is not a standalone RNN cell. It suppose to be used with a RNN cell 1004 by multiple inheritance. Any cell that mix with class should have following 1005 fields: 1006 dropout: a float number within range [0, 1). The ratio that the input 1007 tensor need to dropout. 1008 recurrent_dropout: a float number within range [0, 1). The ratio that the 1009 recurrent state weights need to dropout. 1010 This object will create and cache created dropout masks, and reuse them for 1011 the incoming data, so that the same mask is used for every batch input. 1012 """ 1013 1014 def __init__(self, *args, **kwargs): 1015 # Note that the following two masks will be used in "graph function" mode, 1016 # e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask` 1017 # tensors will be generated differently than in the "graph function" case, 1018 # and they will be cached. 1019 # Also note that in graph mode, we still cache those masks only because the 1020 # RNN could be created with `unroll=True`. In that case, the `cell.call()` 1021 # function will be invoked multiple times, and we want to ensure same mask 1022 # is used every time. 1023 self._dropout_mask = None 1024 self._recurrent_dropout_mask = None 1025 self._eager_dropout_mask = None 1026 self._eager_recurrent_dropout_mask = None 1027 super(DropoutRNNCellMixin, self).__init__(*args, **kwargs) 1028 1029 def reset_dropout_mask(self): 1030 """Reset the cached dropout masks if any. 1031 1032 This is important for the RNN layer to invoke this in it call() method so 1033 that the cached mask is cleared before calling the cell.call(). The mask 1034 should be cached across the timestep within the same batch, but shouldn't 1035 be cached between batches. Otherwise it will introduce unreasonable bias 1036 against certain index of data within the batch. 1037 """ 1038 self._dropout_mask = None 1039 self._eager_dropout_mask = None 1040 1041 def reset_recurrent_dropout_mask(self): 1042 """Reset the cached recurrent dropout masks if any. 1043 1044 This is important for the RNN layer to invoke this in it call() method so 1045 that the cached mask is cleared before calling the cell.call(). The mask 1046 should be cached across the timestep within the same batch, but shouldn't 1047 be cached between batches. Otherwise it will introduce unreasonable bias 1048 against certain index of data within the batch. 1049 """ 1050 self._recurrent_dropout_mask = None 1051 self._eager_recurrent_dropout_mask = None 1052 1053 def get_dropout_mask_for_cell(self, inputs, training, count=1): 1054 """Get the dropout mask for RNN cell's input. 1055 1056 It will create mask based on context if there isn't any existing cached 1057 mask. If a new mask is generated, it will update the cache in the cell. 1058 1059 Args: 1060 inputs: the input tensor whose shape will be used to generate dropout 1061 mask. 1062 training: boolean tensor, whether its in training mode, dropout will be 1063 ignored in non-training mode. 1064 count: int, how many dropout mask will be generated. It is useful for cell 1065 that has internal weights fused together. 1066 Returns: 1067 List of mask tensor, generated or cached mask based on context. 1068 """ 1069 if self.dropout == 0: 1070 return None 1071 if (not context.executing_eagerly() and self._dropout_mask is None 1072 or context.executing_eagerly() and self._eager_dropout_mask is None): 1073 # Generate new mask and cache it based on context. 1074 dp_mask = _generate_dropout_mask( 1075 array_ops.ones_like(inputs), 1076 self.dropout, 1077 training=training, 1078 count=count) 1079 if context.executing_eagerly(): 1080 self._eager_dropout_mask = dp_mask 1081 else: 1082 self._dropout_mask = dp_mask 1083 else: 1084 # Reuse the existing mask. 1085 dp_mask = (self._eager_dropout_mask 1086 if context.executing_eagerly() else self._dropout_mask) 1087 return dp_mask 1088 1089 def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1): 1090 """Get the recurrent dropout mask for RNN cell. 1091 1092 It will create mask based on context if there isn't any existing cached 1093 mask. If a new mask is generated, it will update the cache in the cell. 1094 1095 Args: 1096 inputs: the input tensor whose shape will be used to generate dropout 1097 mask. 1098 training: boolean tensor, whether its in training mode, dropout will be 1099 ignored in non-training mode. 1100 count: int, how many dropout mask will be generated. It is useful for cell 1101 that has internal weights fused together. 1102 Returns: 1103 List of mask tensor, generated or cached mask based on context. 1104 """ 1105 if self.recurrent_dropout == 0: 1106 return None 1107 if (not context.executing_eagerly() and self._recurrent_dropout_mask is None 1108 or context.executing_eagerly() 1109 and self._eager_recurrent_dropout_mask is None): 1110 # Generate new mask and cache it based on context. 1111 rec_dp_mask = _generate_dropout_mask( 1112 array_ops.ones_like(inputs), 1113 self.recurrent_dropout, 1114 training=training, 1115 count=count) 1116 if context.executing_eagerly(): 1117 self._eager_recurrent_dropout_mask = rec_dp_mask 1118 else: 1119 self._recurrent_dropout_mask = rec_dp_mask 1120 else: 1121 # Reuse the existing mask. 1122 rec_dp_mask = (self._eager_recurrent_dropout_mask 1123 if context.executing_eagerly() 1124 else self._recurrent_dropout_mask) 1125 return rec_dp_mask 1126 1127 1128@keras_export('keras.layers.SimpleRNNCell') 1129class SimpleRNNCell(DropoutRNNCellMixin, Layer): 1130 """Cell class for SimpleRNN. 1131 1132 Arguments: 1133 units: Positive integer, dimensionality of the output space. 1134 activation: Activation function to use. 1135 Default: hyperbolic tangent (`tanh`). 1136 If you pass `None`, no activation is applied 1137 (ie. "linear" activation: `a(x) = x`). 1138 use_bias: Boolean, whether the layer uses a bias vector. 1139 kernel_initializer: Initializer for the `kernel` weights matrix, 1140 used for the linear transformation of the inputs. 1141 recurrent_initializer: Initializer for the `recurrent_kernel` 1142 weights matrix, used for the linear transformation of the recurrent state. 1143 bias_initializer: Initializer for the bias vector. 1144 kernel_regularizer: Regularizer function applied to 1145 the `kernel` weights matrix. 1146 recurrent_regularizer: Regularizer function applied to 1147 the `recurrent_kernel` weights matrix. 1148 bias_regularizer: Regularizer function applied to the bias vector. 1149 kernel_constraint: Constraint function applied to 1150 the `kernel` weights matrix. 1151 recurrent_constraint: Constraint function applied to 1152 the `recurrent_kernel` weights matrix. 1153 bias_constraint: Constraint function applied to the bias vector. 1154 dropout: Float between 0 and 1. 1155 Fraction of the units to drop for 1156 the linear transformation of the inputs. 1157 recurrent_dropout: Float between 0 and 1. 1158 Fraction of the units to drop for 1159 the linear transformation of the recurrent state. 1160 1161 Call arguments: 1162 inputs: A 2D tensor. 1163 states: List of state tensors corresponding to the previous timestep. 1164 training: Python boolean indicating whether the layer should behave in 1165 training mode or in inference mode. Only relevant when `dropout` or 1166 `recurrent_dropout` is used. 1167 """ 1168 1169 def __init__(self, 1170 units, 1171 activation='tanh', 1172 use_bias=True, 1173 kernel_initializer='glorot_uniform', 1174 recurrent_initializer='orthogonal', 1175 bias_initializer='zeros', 1176 kernel_regularizer=None, 1177 recurrent_regularizer=None, 1178 bias_regularizer=None, 1179 kernel_constraint=None, 1180 recurrent_constraint=None, 1181 bias_constraint=None, 1182 dropout=0., 1183 recurrent_dropout=0., 1184 **kwargs): 1185 super(SimpleRNNCell, self).__init__(**kwargs) 1186 self.units = units 1187 self.activation = activations.get(activation) 1188 self.use_bias = use_bias 1189 1190 self.kernel_initializer = initializers.get(kernel_initializer) 1191 self.recurrent_initializer = initializers.get(recurrent_initializer) 1192 self.bias_initializer = initializers.get(bias_initializer) 1193 1194 self.kernel_regularizer = regularizers.get(kernel_regularizer) 1195 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 1196 self.bias_regularizer = regularizers.get(bias_regularizer) 1197 1198 self.kernel_constraint = constraints.get(kernel_constraint) 1199 self.recurrent_constraint = constraints.get(recurrent_constraint) 1200 self.bias_constraint = constraints.get(bias_constraint) 1201 1202 self.dropout = min(1., max(0., dropout)) 1203 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 1204 self.state_size = self.units 1205 self.output_size = self.units 1206 1207 @tf_utils.shape_type_conversion 1208 def build(self, input_shape): 1209 self.kernel = self.add_weight( 1210 shape=(input_shape[-1], self.units), 1211 name='kernel', 1212 initializer=self.kernel_initializer, 1213 regularizer=self.kernel_regularizer, 1214 constraint=self.kernel_constraint) 1215 self.recurrent_kernel = self.add_weight( 1216 shape=(self.units, self.units), 1217 name='recurrent_kernel', 1218 initializer=self.recurrent_initializer, 1219 regularizer=self.recurrent_regularizer, 1220 constraint=self.recurrent_constraint) 1221 if self.use_bias: 1222 self.bias = self.add_weight( 1223 shape=(self.units,), 1224 name='bias', 1225 initializer=self.bias_initializer, 1226 regularizer=self.bias_regularizer, 1227 constraint=self.bias_constraint) 1228 else: 1229 self.bias = None 1230 self.built = True 1231 1232 def call(self, inputs, states, training=None): 1233 prev_output = states[0] 1234 dp_mask = self.get_dropout_mask_for_cell(inputs, training) 1235 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 1236 prev_output, training) 1237 1238 if dp_mask is not None: 1239 h = K.dot(inputs * dp_mask, self.kernel) 1240 else: 1241 h = K.dot(inputs, self.kernel) 1242 if self.bias is not None: 1243 h = K.bias_add(h, self.bias) 1244 1245 if rec_dp_mask is not None: 1246 prev_output *= rec_dp_mask 1247 output = h + K.dot(prev_output, self.recurrent_kernel) 1248 if self.activation is not None: 1249 output = self.activation(output) 1250 1251 return output, [output] 1252 1253 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 1254 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 1255 1256 def get_config(self): 1257 config = { 1258 'units': 1259 self.units, 1260 'activation': 1261 activations.serialize(self.activation), 1262 'use_bias': 1263 self.use_bias, 1264 'kernel_initializer': 1265 initializers.serialize(self.kernel_initializer), 1266 'recurrent_initializer': 1267 initializers.serialize(self.recurrent_initializer), 1268 'bias_initializer': 1269 initializers.serialize(self.bias_initializer), 1270 'kernel_regularizer': 1271 regularizers.serialize(self.kernel_regularizer), 1272 'recurrent_regularizer': 1273 regularizers.serialize(self.recurrent_regularizer), 1274 'bias_regularizer': 1275 regularizers.serialize(self.bias_regularizer), 1276 'kernel_constraint': 1277 constraints.serialize(self.kernel_constraint), 1278 'recurrent_constraint': 1279 constraints.serialize(self.recurrent_constraint), 1280 'bias_constraint': 1281 constraints.serialize(self.bias_constraint), 1282 'dropout': 1283 self.dropout, 1284 'recurrent_dropout': 1285 self.recurrent_dropout 1286 } 1287 base_config = super(SimpleRNNCell, self).get_config() 1288 return dict(list(base_config.items()) + list(config.items())) 1289 1290 1291@keras_export('keras.layers.SimpleRNN') 1292class SimpleRNN(RNN): 1293 """Fully-connected RNN where the output is to be fed back to input. 1294 1295 Arguments: 1296 units: Positive integer, dimensionality of the output space. 1297 activation: Activation function to use. 1298 Default: hyperbolic tangent (`tanh`). 1299 If you pass None, no activation is applied 1300 (ie. "linear" activation: `a(x) = x`). 1301 use_bias: Boolean, whether the layer uses a bias vector. 1302 kernel_initializer: Initializer for the `kernel` weights matrix, 1303 used for the linear transformation of the inputs. 1304 recurrent_initializer: Initializer for the `recurrent_kernel` 1305 weights matrix, 1306 used for the linear transformation of the recurrent state. 1307 bias_initializer: Initializer for the bias vector. 1308 kernel_regularizer: Regularizer function applied to 1309 the `kernel` weights matrix. 1310 recurrent_regularizer: Regularizer function applied to 1311 the `recurrent_kernel` weights matrix. 1312 bias_regularizer: Regularizer function applied to the bias vector. 1313 activity_regularizer: Regularizer function applied to 1314 the output of the layer (its "activation").. 1315 kernel_constraint: Constraint function applied to 1316 the `kernel` weights matrix. 1317 recurrent_constraint: Constraint function applied to 1318 the `recurrent_kernel` weights matrix. 1319 bias_constraint: Constraint function applied to the bias vector. 1320 dropout: Float between 0 and 1. 1321 Fraction of the units to drop for 1322 the linear transformation of the inputs. 1323 recurrent_dropout: Float between 0 and 1. 1324 Fraction of the units to drop for 1325 the linear transformation of the recurrent state. 1326 return_sequences: Boolean. Whether to return the last output 1327 in the output sequence, or the full sequence. 1328 return_state: Boolean. Whether to return the last state 1329 in addition to the output. 1330 go_backwards: Boolean (default False). 1331 If True, process the input sequence backwards and return the 1332 reversed sequence. 1333 stateful: Boolean (default False). If True, the last state 1334 for each sample at index i in a batch will be used as initial 1335 state for the sample of index i in the following batch. 1336 unroll: Boolean (default False). 1337 If True, the network will be unrolled, 1338 else a symbolic loop will be used. 1339 Unrolling can speed-up a RNN, 1340 although it tends to be more memory-intensive. 1341 Unrolling is only suitable for short sequences. 1342 1343 Call arguments: 1344 inputs: A 3D tensor. 1345 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 1346 a given timestep should be masked. 1347 training: Python boolean indicating whether the layer should behave in 1348 training mode or in inference mode. This argument is passed to the cell 1349 when calling it. This is only relevant if `dropout` or 1350 `recurrent_dropout` is used. 1351 initial_state: List of initial state tensors to be passed to the first 1352 call of the cell. 1353 """ 1354 1355 def __init__(self, 1356 units, 1357 activation='tanh', 1358 use_bias=True, 1359 kernel_initializer='glorot_uniform', 1360 recurrent_initializer='orthogonal', 1361 bias_initializer='zeros', 1362 kernel_regularizer=None, 1363 recurrent_regularizer=None, 1364 bias_regularizer=None, 1365 activity_regularizer=None, 1366 kernel_constraint=None, 1367 recurrent_constraint=None, 1368 bias_constraint=None, 1369 dropout=0., 1370 recurrent_dropout=0., 1371 return_sequences=False, 1372 return_state=False, 1373 go_backwards=False, 1374 stateful=False, 1375 unroll=False, 1376 **kwargs): 1377 if 'implementation' in kwargs: 1378 kwargs.pop('implementation') 1379 logging.warning('The `implementation` argument ' 1380 'in `SimpleRNN` has been deprecated. ' 1381 'Please remove it from your layer call.') 1382 cell = SimpleRNNCell( 1383 units, 1384 activation=activation, 1385 use_bias=use_bias, 1386 kernel_initializer=kernel_initializer, 1387 recurrent_initializer=recurrent_initializer, 1388 bias_initializer=bias_initializer, 1389 kernel_regularizer=kernel_regularizer, 1390 recurrent_regularizer=recurrent_regularizer, 1391 bias_regularizer=bias_regularizer, 1392 kernel_constraint=kernel_constraint, 1393 recurrent_constraint=recurrent_constraint, 1394 bias_constraint=bias_constraint, 1395 dropout=dropout, 1396 recurrent_dropout=recurrent_dropout) 1397 super(SimpleRNN, self).__init__( 1398 cell, 1399 return_sequences=return_sequences, 1400 return_state=return_state, 1401 go_backwards=go_backwards, 1402 stateful=stateful, 1403 unroll=unroll, 1404 **kwargs) 1405 self.activity_regularizer = regularizers.get(activity_regularizer) 1406 self.input_spec = [InputSpec(ndim=3)] 1407 1408 def call(self, inputs, mask=None, training=None, initial_state=None): 1409 self.cell.reset_dropout_mask() 1410 self.cell.reset_recurrent_dropout_mask() 1411 return super(SimpleRNN, self).call( 1412 inputs, mask=mask, training=training, initial_state=initial_state) 1413 1414 @property 1415 def units(self): 1416 return self.cell.units 1417 1418 @property 1419 def activation(self): 1420 return self.cell.activation 1421 1422 @property 1423 def use_bias(self): 1424 return self.cell.use_bias 1425 1426 @property 1427 def kernel_initializer(self): 1428 return self.cell.kernel_initializer 1429 1430 @property 1431 def recurrent_initializer(self): 1432 return self.cell.recurrent_initializer 1433 1434 @property 1435 def bias_initializer(self): 1436 return self.cell.bias_initializer 1437 1438 @property 1439 def kernel_regularizer(self): 1440 return self.cell.kernel_regularizer 1441 1442 @property 1443 def recurrent_regularizer(self): 1444 return self.cell.recurrent_regularizer 1445 1446 @property 1447 def bias_regularizer(self): 1448 return self.cell.bias_regularizer 1449 1450 @property 1451 def kernel_constraint(self): 1452 return self.cell.kernel_constraint 1453 1454 @property 1455 def recurrent_constraint(self): 1456 return self.cell.recurrent_constraint 1457 1458 @property 1459 def bias_constraint(self): 1460 return self.cell.bias_constraint 1461 1462 @property 1463 def dropout(self): 1464 return self.cell.dropout 1465 1466 @property 1467 def recurrent_dropout(self): 1468 return self.cell.recurrent_dropout 1469 1470 def get_config(self): 1471 config = { 1472 'units': 1473 self.units, 1474 'activation': 1475 activations.serialize(self.activation), 1476 'use_bias': 1477 self.use_bias, 1478 'kernel_initializer': 1479 initializers.serialize(self.kernel_initializer), 1480 'recurrent_initializer': 1481 initializers.serialize(self.recurrent_initializer), 1482 'bias_initializer': 1483 initializers.serialize(self.bias_initializer), 1484 'kernel_regularizer': 1485 regularizers.serialize(self.kernel_regularizer), 1486 'recurrent_regularizer': 1487 regularizers.serialize(self.recurrent_regularizer), 1488 'bias_regularizer': 1489 regularizers.serialize(self.bias_regularizer), 1490 'activity_regularizer': 1491 regularizers.serialize(self.activity_regularizer), 1492 'kernel_constraint': 1493 constraints.serialize(self.kernel_constraint), 1494 'recurrent_constraint': 1495 constraints.serialize(self.recurrent_constraint), 1496 'bias_constraint': 1497 constraints.serialize(self.bias_constraint), 1498 'dropout': 1499 self.dropout, 1500 'recurrent_dropout': 1501 self.recurrent_dropout 1502 } 1503 base_config = super(SimpleRNN, self).get_config() 1504 del base_config['cell'] 1505 return dict(list(base_config.items()) + list(config.items())) 1506 1507 @classmethod 1508 def from_config(cls, config): 1509 if 'implementation' in config: 1510 config.pop('implementation') 1511 return cls(**config) 1512 1513 1514@keras_export('keras.layers.GRUCell') 1515class GRUCell(DropoutRNNCellMixin, Layer): 1516 """Cell class for the GRU layer. 1517 1518 Arguments: 1519 units: Positive integer, dimensionality of the output space. 1520 activation: Activation function to use. 1521 Default: hyperbolic tangent (`tanh`). 1522 If you pass None, no activation is applied 1523 (ie. "linear" activation: `a(x) = x`). 1524 recurrent_activation: Activation function to use 1525 for the recurrent step. 1526 Default: hard sigmoid (`hard_sigmoid`). 1527 If you pass `None`, no activation is applied 1528 (ie. "linear" activation: `a(x) = x`). 1529 use_bias: Boolean, whether the layer uses a bias vector. 1530 kernel_initializer: Initializer for the `kernel` weights matrix, 1531 used for the linear transformation of the inputs. 1532 recurrent_initializer: Initializer for the `recurrent_kernel` 1533 weights matrix, 1534 used for the linear transformation of the recurrent state. 1535 bias_initializer: Initializer for the bias vector. 1536 kernel_regularizer: Regularizer function applied to 1537 the `kernel` weights matrix. 1538 recurrent_regularizer: Regularizer function applied to 1539 the `recurrent_kernel` weights matrix. 1540 bias_regularizer: Regularizer function applied to the bias vector. 1541 kernel_constraint: Constraint function applied to 1542 the `kernel` weights matrix. 1543 recurrent_constraint: Constraint function applied to 1544 the `recurrent_kernel` weights matrix. 1545 bias_constraint: Constraint function applied to the bias vector. 1546 dropout: Float between 0 and 1. 1547 Fraction of the units to drop for the linear transformation of the inputs. 1548 recurrent_dropout: Float between 0 and 1. 1549 Fraction of the units to drop for 1550 the linear transformation of the recurrent state. 1551 implementation: Implementation mode, either 1 or 2. 1552 Mode 1 will structure its operations as a larger number of 1553 smaller dot products and additions, whereas mode 2 will 1554 batch them into fewer, larger operations. These modes will 1555 have different performance profiles on different hardware and 1556 for different applications. 1557 reset_after: GRU convention (whether to apply reset gate after or 1558 before matrix multiplication). False = "before" (default), 1559 True = "after" (CuDNN compatible). 1560 1561 Call arguments: 1562 inputs: A 2D tensor. 1563 states: List of state tensors corresponding to the previous timestep. 1564 training: Python boolean indicating whether the layer should behave in 1565 training mode or in inference mode. Only relevant when `dropout` or 1566 `recurrent_dropout` is used. 1567 """ 1568 1569 def __init__(self, 1570 units, 1571 activation='tanh', 1572 recurrent_activation='hard_sigmoid', 1573 use_bias=True, 1574 kernel_initializer='glorot_uniform', 1575 recurrent_initializer='orthogonal', 1576 bias_initializer='zeros', 1577 kernel_regularizer=None, 1578 recurrent_regularizer=None, 1579 bias_regularizer=None, 1580 kernel_constraint=None, 1581 recurrent_constraint=None, 1582 bias_constraint=None, 1583 dropout=0., 1584 recurrent_dropout=0., 1585 implementation=1, 1586 reset_after=False, 1587 **kwargs): 1588 super(GRUCell, self).__init__(**kwargs) 1589 self.units = units 1590 self.activation = activations.get(activation) 1591 self.recurrent_activation = activations.get(recurrent_activation) 1592 self.use_bias = use_bias 1593 1594 self.kernel_initializer = initializers.get(kernel_initializer) 1595 self.recurrent_initializer = initializers.get(recurrent_initializer) 1596 self.bias_initializer = initializers.get(bias_initializer) 1597 1598 self.kernel_regularizer = regularizers.get(kernel_regularizer) 1599 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 1600 self.bias_regularizer = regularizers.get(bias_regularizer) 1601 1602 self.kernel_constraint = constraints.get(kernel_constraint) 1603 self.recurrent_constraint = constraints.get(recurrent_constraint) 1604 self.bias_constraint = constraints.get(bias_constraint) 1605 1606 self.dropout = min(1., max(0., dropout)) 1607 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 1608 self.implementation = implementation 1609 self.reset_after = reset_after 1610 self.state_size = self.units 1611 self.output_size = self.units 1612 1613 @tf_utils.shape_type_conversion 1614 def build(self, input_shape): 1615 input_dim = input_shape[-1] 1616 self.kernel = self.add_weight( 1617 shape=(input_dim, self.units * 3), 1618 name='kernel', 1619 initializer=self.kernel_initializer, 1620 regularizer=self.kernel_regularizer, 1621 constraint=self.kernel_constraint) 1622 self.recurrent_kernel = self.add_weight( 1623 shape=(self.units, self.units * 3), 1624 name='recurrent_kernel', 1625 initializer=self.recurrent_initializer, 1626 regularizer=self.recurrent_regularizer, 1627 constraint=self.recurrent_constraint) 1628 1629 if self.use_bias: 1630 if not self.reset_after: 1631 bias_shape = (3 * self.units,) 1632 else: 1633 # separate biases for input and recurrent kernels 1634 # Note: the shape is intentionally different from CuDNNGRU biases 1635 # `(2 * 3 * self.units,)`, so that we can distinguish the classes 1636 # when loading and converting saved weights. 1637 bias_shape = (2, 3 * self.units) 1638 self.bias = self.add_weight(shape=bias_shape, 1639 name='bias', 1640 initializer=self.bias_initializer, 1641 regularizer=self.bias_regularizer, 1642 constraint=self.bias_constraint) 1643 else: 1644 self.bias = None 1645 self.built = True 1646 1647 def call(self, inputs, states, training=None): 1648 h_tm1 = states[0] # previous memory 1649 1650 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) 1651 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 1652 h_tm1, training, count=3) 1653 1654 if self.use_bias: 1655 if not self.reset_after: 1656 input_bias, recurrent_bias = self.bias, None 1657 else: 1658 input_bias, recurrent_bias = array_ops.unstack(self.bias) 1659 1660 if self.implementation == 1: 1661 if 0. < self.dropout < 1.: 1662 inputs_z = inputs * dp_mask[0] 1663 inputs_r = inputs * dp_mask[1] 1664 inputs_h = inputs * dp_mask[2] 1665 else: 1666 inputs_z = inputs 1667 inputs_r = inputs 1668 inputs_h = inputs 1669 1670 x_z = K.dot(inputs_z, self.kernel[:, :self.units]) 1671 x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2]) 1672 x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:]) 1673 1674 if self.use_bias: 1675 x_z = K.bias_add(x_z, input_bias[:self.units]) 1676 x_r = K.bias_add(x_r, input_bias[self.units: self.units * 2]) 1677 x_h = K.bias_add(x_h, input_bias[self.units * 2:]) 1678 1679 if 0. < self.recurrent_dropout < 1.: 1680 h_tm1_z = h_tm1 * rec_dp_mask[0] 1681 h_tm1_r = h_tm1 * rec_dp_mask[1] 1682 h_tm1_h = h_tm1 * rec_dp_mask[2] 1683 else: 1684 h_tm1_z = h_tm1 1685 h_tm1_r = h_tm1 1686 h_tm1_h = h_tm1 1687 1688 recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]) 1689 recurrent_r = K.dot(h_tm1_r, 1690 self.recurrent_kernel[:, self.units:self.units * 2]) 1691 if self.reset_after and self.use_bias: 1692 recurrent_z = K.bias_add(recurrent_z, recurrent_bias[:self.units]) 1693 recurrent_r = K.bias_add(recurrent_r, 1694 recurrent_bias[self.units:self.units * 2]) 1695 1696 z = self.recurrent_activation(x_z + recurrent_z) 1697 r = self.recurrent_activation(x_r + recurrent_r) 1698 1699 # reset gate applied after/before matrix multiplication 1700 if self.reset_after: 1701 recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) 1702 if self.use_bias: 1703 recurrent_h = K.bias_add(recurrent_h, recurrent_bias[self.units * 2:]) 1704 recurrent_h = r * recurrent_h 1705 else: 1706 recurrent_h = K.dot(r * h_tm1_h, 1707 self.recurrent_kernel[:, self.units * 2:]) 1708 1709 hh = self.activation(x_h + recurrent_h) 1710 else: 1711 if 0. < self.dropout < 1.: 1712 inputs *= dp_mask[0] 1713 1714 # inputs projected by all gate matrices at once 1715 matrix_x = K.dot(inputs, self.kernel) 1716 if self.use_bias: 1717 # biases: bias_z_i, bias_r_i, bias_h_i 1718 matrix_x = K.bias_add(matrix_x, input_bias) 1719 1720 x_z = matrix_x[:, :self.units] 1721 x_r = matrix_x[:, self.units: 2 * self.units] 1722 x_h = matrix_x[:, 2 * self.units:] 1723 1724 if 0. < self.recurrent_dropout < 1.: 1725 h_tm1 *= rec_dp_mask[0] 1726 1727 if self.reset_after: 1728 # hidden state projected by all gate matrices at once 1729 matrix_inner = K.dot(h_tm1, self.recurrent_kernel) 1730 if self.use_bias: 1731 matrix_inner = K.bias_add(matrix_inner, recurrent_bias) 1732 else: 1733 # hidden state projected separately for update/reset and new 1734 matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units]) 1735 1736 recurrent_z = matrix_inner[:, :self.units] 1737 recurrent_r = matrix_inner[:, self.units:2 * self.units] 1738 1739 z = self.recurrent_activation(x_z + recurrent_z) 1740 r = self.recurrent_activation(x_r + recurrent_r) 1741 1742 if self.reset_after: 1743 recurrent_h = r * matrix_inner[:, 2 * self.units:] 1744 else: 1745 recurrent_h = K.dot(r * h_tm1, 1746 self.recurrent_kernel[:, 2 * self.units:]) 1747 1748 hh = self.activation(x_h + recurrent_h) 1749 # previous and candidate state mixed by update gate 1750 h = z * h_tm1 + (1 - z) * hh 1751 return h, [h] 1752 1753 def get_config(self): 1754 config = { 1755 'units': self.units, 1756 'activation': activations.serialize(self.activation), 1757 'recurrent_activation': 1758 activations.serialize(self.recurrent_activation), 1759 'use_bias': self.use_bias, 1760 'kernel_initializer': initializers.serialize(self.kernel_initializer), 1761 'recurrent_initializer': 1762 initializers.serialize(self.recurrent_initializer), 1763 'bias_initializer': initializers.serialize(self.bias_initializer), 1764 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 1765 'recurrent_regularizer': 1766 regularizers.serialize(self.recurrent_regularizer), 1767 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 1768 'kernel_constraint': constraints.serialize(self.kernel_constraint), 1769 'recurrent_constraint': 1770 constraints.serialize(self.recurrent_constraint), 1771 'bias_constraint': constraints.serialize(self.bias_constraint), 1772 'dropout': self.dropout, 1773 'recurrent_dropout': self.recurrent_dropout, 1774 'implementation': self.implementation, 1775 'reset_after': self.reset_after 1776 } 1777 base_config = super(GRUCell, self).get_config() 1778 return dict(list(base_config.items()) + list(config.items())) 1779 1780 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 1781 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 1782 1783 1784@keras_export(v1=['keras.layers.GRU']) 1785class GRU(RNN): 1786 """Gated Recurrent Unit - Cho et al. 2014. 1787 1788 There are two variants. The default one is based on 1406.1078v3 and 1789 has reset gate applied to hidden state before matrix multiplication. The 1790 other one is based on original 1406.1078v1 and has the order reversed. 1791 1792 The second variant is compatible with CuDNNGRU (GPU-only) and allows 1793 inference on CPU. Thus it has separate biases for `kernel` and 1794 `recurrent_kernel`. Use `'reset_after'=True` and 1795 `recurrent_activation='sigmoid'`. 1796 1797 Arguments: 1798 units: Positive integer, dimensionality of the output space. 1799 activation: Activation function to use. 1800 Default: hyperbolic tangent (`tanh`). 1801 If you pass `None`, no activation is applied 1802 (ie. "linear" activation: `a(x) = x`). 1803 recurrent_activation: Activation function to use 1804 for the recurrent step. 1805 Default: hard sigmoid (`hard_sigmoid`). 1806 If you pass `None`, no activation is applied 1807 (ie. "linear" activation: `a(x) = x`). 1808 use_bias: Boolean, whether the layer uses a bias vector. 1809 kernel_initializer: Initializer for the `kernel` weights matrix, 1810 used for the linear transformation of the inputs. 1811 recurrent_initializer: Initializer for the `recurrent_kernel` 1812 weights matrix, used for the linear transformation of the recurrent state. 1813 bias_initializer: Initializer for the bias vector. 1814 kernel_regularizer: Regularizer function applied to 1815 the `kernel` weights matrix. 1816 recurrent_regularizer: Regularizer function applied to 1817 the `recurrent_kernel` weights matrix. 1818 bias_regularizer: Regularizer function applied to the bias vector. 1819 activity_regularizer: Regularizer function applied to 1820 the output of the layer (its "activation").. 1821 kernel_constraint: Constraint function applied to 1822 the `kernel` weights matrix. 1823 recurrent_constraint: Constraint function applied to 1824 the `recurrent_kernel` weights matrix. 1825 bias_constraint: Constraint function applied to the bias vector. 1826 dropout: Float between 0 and 1. 1827 Fraction of the units to drop for 1828 the linear transformation of the inputs. 1829 recurrent_dropout: Float between 0 and 1. 1830 Fraction of the units to drop for 1831 the linear transformation of the recurrent state. 1832 implementation: Implementation mode, either 1 or 2. 1833 Mode 1 will structure its operations as a larger number of 1834 smaller dot products and additions, whereas mode 2 will 1835 batch them into fewer, larger operations. These modes will 1836 have different performance profiles on different hardware and 1837 for different applications. 1838 return_sequences: Boolean. Whether to return the last output 1839 in the output sequence, or the full sequence. 1840 return_state: Boolean. Whether to return the last state 1841 in addition to the output. 1842 go_backwards: Boolean (default False). 1843 If True, process the input sequence backwards and return the 1844 reversed sequence. 1845 stateful: Boolean (default False). If True, the last state 1846 for each sample at index i in a batch will be used as initial 1847 state for the sample of index i in the following batch. 1848 unroll: Boolean (default False). 1849 If True, the network will be unrolled, 1850 else a symbolic loop will be used. 1851 Unrolling can speed-up a RNN, 1852 although it tends to be more memory-intensive. 1853 Unrolling is only suitable for short sequences. 1854 reset_after: GRU convention (whether to apply reset gate after or 1855 before matrix multiplication). False = "before" (default), 1856 True = "after" (CuDNN compatible). 1857 1858 Call arguments: 1859 inputs: A 3D tensor. 1860 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 1861 a given timestep should be masked. 1862 training: Python boolean indicating whether the layer should behave in 1863 training mode or in inference mode. This argument is passed to the cell 1864 when calling it. This is only relevant if `dropout` or 1865 `recurrent_dropout` is used. 1866 initial_state: List of initial state tensors to be passed to the first 1867 call of the cell. 1868 """ 1869 1870 def __init__(self, 1871 units, 1872 activation='tanh', 1873 recurrent_activation='hard_sigmoid', 1874 use_bias=True, 1875 kernel_initializer='glorot_uniform', 1876 recurrent_initializer='orthogonal', 1877 bias_initializer='zeros', 1878 kernel_regularizer=None, 1879 recurrent_regularizer=None, 1880 bias_regularizer=None, 1881 activity_regularizer=None, 1882 kernel_constraint=None, 1883 recurrent_constraint=None, 1884 bias_constraint=None, 1885 dropout=0., 1886 recurrent_dropout=0., 1887 implementation=1, 1888 return_sequences=False, 1889 return_state=False, 1890 go_backwards=False, 1891 stateful=False, 1892 unroll=False, 1893 reset_after=False, 1894 **kwargs): 1895 if implementation == 0: 1896 logging.warning('`implementation=0` has been deprecated, ' 1897 'and now defaults to `implementation=1`.' 1898 'Please update your layer call.') 1899 cell = GRUCell( 1900 units, 1901 activation=activation, 1902 recurrent_activation=recurrent_activation, 1903 use_bias=use_bias, 1904 kernel_initializer=kernel_initializer, 1905 recurrent_initializer=recurrent_initializer, 1906 bias_initializer=bias_initializer, 1907 kernel_regularizer=kernel_regularizer, 1908 recurrent_regularizer=recurrent_regularizer, 1909 bias_regularizer=bias_regularizer, 1910 kernel_constraint=kernel_constraint, 1911 recurrent_constraint=recurrent_constraint, 1912 bias_constraint=bias_constraint, 1913 dropout=dropout, 1914 recurrent_dropout=recurrent_dropout, 1915 implementation=implementation, 1916 reset_after=reset_after) 1917 super(GRU, self).__init__( 1918 cell, 1919 return_sequences=return_sequences, 1920 return_state=return_state, 1921 go_backwards=go_backwards, 1922 stateful=stateful, 1923 unroll=unroll, 1924 **kwargs) 1925 self.activity_regularizer = regularizers.get(activity_regularizer) 1926 self.input_spec = [InputSpec(ndim=3)] 1927 1928 def call(self, inputs, mask=None, training=None, initial_state=None): 1929 self.cell.reset_dropout_mask() 1930 self.cell.reset_recurrent_dropout_mask() 1931 return super(GRU, self).call( 1932 inputs, mask=mask, training=training, initial_state=initial_state) 1933 1934 @property 1935 def units(self): 1936 return self.cell.units 1937 1938 @property 1939 def activation(self): 1940 return self.cell.activation 1941 1942 @property 1943 def recurrent_activation(self): 1944 return self.cell.recurrent_activation 1945 1946 @property 1947 def use_bias(self): 1948 return self.cell.use_bias 1949 1950 @property 1951 def kernel_initializer(self): 1952 return self.cell.kernel_initializer 1953 1954 @property 1955 def recurrent_initializer(self): 1956 return self.cell.recurrent_initializer 1957 1958 @property 1959 def bias_initializer(self): 1960 return self.cell.bias_initializer 1961 1962 @property 1963 def kernel_regularizer(self): 1964 return self.cell.kernel_regularizer 1965 1966 @property 1967 def recurrent_regularizer(self): 1968 return self.cell.recurrent_regularizer 1969 1970 @property 1971 def bias_regularizer(self): 1972 return self.cell.bias_regularizer 1973 1974 @property 1975 def kernel_constraint(self): 1976 return self.cell.kernel_constraint 1977 1978 @property 1979 def recurrent_constraint(self): 1980 return self.cell.recurrent_constraint 1981 1982 @property 1983 def bias_constraint(self): 1984 return self.cell.bias_constraint 1985 1986 @property 1987 def dropout(self): 1988 return self.cell.dropout 1989 1990 @property 1991 def recurrent_dropout(self): 1992 return self.cell.recurrent_dropout 1993 1994 @property 1995 def implementation(self): 1996 return self.cell.implementation 1997 1998 @property 1999 def reset_after(self): 2000 return self.cell.reset_after 2001 2002 def get_config(self): 2003 config = { 2004 'units': 2005 self.units, 2006 'activation': 2007 activations.serialize(self.activation), 2008 'recurrent_activation': 2009 activations.serialize(self.recurrent_activation), 2010 'use_bias': 2011 self.use_bias, 2012 'kernel_initializer': 2013 initializers.serialize(self.kernel_initializer), 2014 'recurrent_initializer': 2015 initializers.serialize(self.recurrent_initializer), 2016 'bias_initializer': 2017 initializers.serialize(self.bias_initializer), 2018 'kernel_regularizer': 2019 regularizers.serialize(self.kernel_regularizer), 2020 'recurrent_regularizer': 2021 regularizers.serialize(self.recurrent_regularizer), 2022 'bias_regularizer': 2023 regularizers.serialize(self.bias_regularizer), 2024 'activity_regularizer': 2025 regularizers.serialize(self.activity_regularizer), 2026 'kernel_constraint': 2027 constraints.serialize(self.kernel_constraint), 2028 'recurrent_constraint': 2029 constraints.serialize(self.recurrent_constraint), 2030 'bias_constraint': 2031 constraints.serialize(self.bias_constraint), 2032 'dropout': 2033 self.dropout, 2034 'recurrent_dropout': 2035 self.recurrent_dropout, 2036 'implementation': 2037 self.implementation, 2038 'reset_after': 2039 self.reset_after 2040 } 2041 base_config = super(GRU, self).get_config() 2042 del base_config['cell'] 2043 return dict(list(base_config.items()) + list(config.items())) 2044 2045 @classmethod 2046 def from_config(cls, config): 2047 if 'implementation' in config and config['implementation'] == 0: 2048 config['implementation'] = 1 2049 return cls(**config) 2050 2051 2052@keras_export('keras.layers.LSTMCell') 2053class LSTMCell(DropoutRNNCellMixin, Layer): 2054 """Cell class for the LSTM layer. 2055 2056 Arguments: 2057 units: Positive integer, dimensionality of the output space. 2058 activation: Activation function to use. 2059 Default: hyperbolic tangent (`tanh`). 2060 If you pass `None`, no activation is applied 2061 (ie. "linear" activation: `a(x) = x`). 2062 recurrent_activation: Activation function to use 2063 for the recurrent step. 2064 Default: hard sigmoid (`hard_sigmoid`). 2065 If you pass `None`, no activation is applied 2066 (ie. "linear" activation: `a(x) = x`). 2067 use_bias: Boolean, whether the layer uses a bias vector. 2068 kernel_initializer: Initializer for the `kernel` weights matrix, 2069 used for the linear transformation of the inputs. 2070 recurrent_initializer: Initializer for the `recurrent_kernel` 2071 weights matrix, 2072 used for the linear transformation of the recurrent state. 2073 bias_initializer: Initializer for the bias vector. 2074 unit_forget_bias: Boolean. 2075 If True, add 1 to the bias of the forget gate at initialization. 2076 Setting it to true will also force `bias_initializer="zeros"`. 2077 This is recommended in [Jozefowicz et 2078 al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 2079 kernel_regularizer: Regularizer function applied to 2080 the `kernel` weights matrix. 2081 recurrent_regularizer: Regularizer function applied to 2082 the `recurrent_kernel` weights matrix. 2083 bias_regularizer: Regularizer function applied to the bias vector. 2084 kernel_constraint: Constraint function applied to 2085 the `kernel` weights matrix. 2086 recurrent_constraint: Constraint function applied to 2087 the `recurrent_kernel` weights matrix. 2088 bias_constraint: Constraint function applied to the bias vector. 2089 dropout: Float between 0 and 1. 2090 Fraction of the units to drop for 2091 the linear transformation of the inputs. 2092 recurrent_dropout: Float between 0 and 1. 2093 Fraction of the units to drop for 2094 the linear transformation of the recurrent state. 2095 implementation: Implementation mode, either 1 or 2. 2096 Mode 1 will structure its operations as a larger number of 2097 smaller dot products and additions, whereas mode 2 will 2098 batch them into fewer, larger operations. These modes will 2099 have different performance profiles on different hardware and 2100 for different applications. 2101 2102 Call arguments: 2103 inputs: A 2D tensor. 2104 states: List of state tensors corresponding to the previous timestep. 2105 training: Python boolean indicating whether the layer should behave in 2106 training mode or in inference mode. Only relevant when `dropout` or 2107 `recurrent_dropout` is used. 2108 """ 2109 2110 def __init__(self, 2111 units, 2112 activation='tanh', 2113 recurrent_activation='hard_sigmoid', 2114 use_bias=True, 2115 kernel_initializer='glorot_uniform', 2116 recurrent_initializer='orthogonal', 2117 bias_initializer='zeros', 2118 unit_forget_bias=True, 2119 kernel_regularizer=None, 2120 recurrent_regularizer=None, 2121 bias_regularizer=None, 2122 kernel_constraint=None, 2123 recurrent_constraint=None, 2124 bias_constraint=None, 2125 dropout=0., 2126 recurrent_dropout=0., 2127 implementation=1, 2128 **kwargs): 2129 super(LSTMCell, self).__init__(**kwargs) 2130 self.units = units 2131 self.activation = activations.get(activation) 2132 self.recurrent_activation = activations.get(recurrent_activation) 2133 self.use_bias = use_bias 2134 2135 self.kernel_initializer = initializers.get(kernel_initializer) 2136 self.recurrent_initializer = initializers.get(recurrent_initializer) 2137 self.bias_initializer = initializers.get(bias_initializer) 2138 self.unit_forget_bias = unit_forget_bias 2139 2140 self.kernel_regularizer = regularizers.get(kernel_regularizer) 2141 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 2142 self.bias_regularizer = regularizers.get(bias_regularizer) 2143 2144 self.kernel_constraint = constraints.get(kernel_constraint) 2145 self.recurrent_constraint = constraints.get(recurrent_constraint) 2146 self.bias_constraint = constraints.get(bias_constraint) 2147 2148 self.dropout = min(1., max(0., dropout)) 2149 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 2150 self.implementation = implementation 2151 self.state_size = [self.units, self.units] 2152 self.output_size = self.units 2153 2154 @tf_utils.shape_type_conversion 2155 def build(self, input_shape): 2156 input_dim = input_shape[-1] 2157 self.kernel = self.add_weight( 2158 shape=(input_dim, self.units * 4), 2159 name='kernel', 2160 initializer=self.kernel_initializer, 2161 regularizer=self.kernel_regularizer, 2162 constraint=self.kernel_constraint) 2163 self.recurrent_kernel = self.add_weight( 2164 shape=(self.units, self.units * 4), 2165 name='recurrent_kernel', 2166 initializer=self.recurrent_initializer, 2167 regularizer=self.recurrent_regularizer, 2168 constraint=self.recurrent_constraint) 2169 2170 if self.use_bias: 2171 if self.unit_forget_bias: 2172 2173 def bias_initializer(_, *args, **kwargs): 2174 return K.concatenate([ 2175 self.bias_initializer((self.units,), *args, **kwargs), 2176 initializers.Ones()((self.units,), *args, **kwargs), 2177 self.bias_initializer((self.units * 2,), *args, **kwargs), 2178 ]) 2179 else: 2180 bias_initializer = self.bias_initializer 2181 self.bias = self.add_weight( 2182 shape=(self.units * 4,), 2183 name='bias', 2184 initializer=bias_initializer, 2185 regularizer=self.bias_regularizer, 2186 constraint=self.bias_constraint) 2187 else: 2188 self.bias = None 2189 self.built = True 2190 2191 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 2192 """Computes carry and output using split kernels.""" 2193 x_i, x_f, x_c, x_o = x 2194 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 2195 i = self.recurrent_activation( 2196 x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) 2197 f = self.recurrent_activation(x_f + K.dot( 2198 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])) 2199 c = f * c_tm1 + i * self.activation(x_c + K.dot( 2200 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) 2201 o = self.recurrent_activation( 2202 x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) 2203 return c, o 2204 2205 def _compute_carry_and_output_fused(self, z, c_tm1): 2206 """Computes carry and output using fused kernels.""" 2207 z0, z1, z2, z3 = z 2208 i = self.recurrent_activation(z0) 2209 f = self.recurrent_activation(z1) 2210 c = f * c_tm1 + i * self.activation(z2) 2211 o = self.recurrent_activation(z3) 2212 return c, o 2213 2214 def call(self, inputs, states, training=None): 2215 h_tm1 = states[0] # previous memory state 2216 c_tm1 = states[1] # previous carry state 2217 2218 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 2219 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 2220 h_tm1, training, count=4) 2221 2222 if self.implementation == 1: 2223 if 0 < self.dropout < 1.: 2224 inputs_i = inputs * dp_mask[0] 2225 inputs_f = inputs * dp_mask[1] 2226 inputs_c = inputs * dp_mask[2] 2227 inputs_o = inputs * dp_mask[3] 2228 else: 2229 inputs_i = inputs 2230 inputs_f = inputs 2231 inputs_c = inputs 2232 inputs_o = inputs 2233 k_i, k_f, k_c, k_o = array_ops.split( 2234 self.kernel, num_or_size_splits=4, axis=1) 2235 x_i = K.dot(inputs_i, k_i) 2236 x_f = K.dot(inputs_f, k_f) 2237 x_c = K.dot(inputs_c, k_c) 2238 x_o = K.dot(inputs_o, k_o) 2239 if self.use_bias: 2240 b_i, b_f, b_c, b_o = array_ops.split( 2241 self.bias, num_or_size_splits=4, axis=0) 2242 x_i = K.bias_add(x_i, b_i) 2243 x_f = K.bias_add(x_f, b_f) 2244 x_c = K.bias_add(x_c, b_c) 2245 x_o = K.bias_add(x_o, b_o) 2246 2247 if 0 < self.recurrent_dropout < 1.: 2248 h_tm1_i = h_tm1 * rec_dp_mask[0] 2249 h_tm1_f = h_tm1 * rec_dp_mask[1] 2250 h_tm1_c = h_tm1 * rec_dp_mask[2] 2251 h_tm1_o = h_tm1 * rec_dp_mask[3] 2252 else: 2253 h_tm1_i = h_tm1 2254 h_tm1_f = h_tm1 2255 h_tm1_c = h_tm1 2256 h_tm1_o = h_tm1 2257 x = (x_i, x_f, x_c, x_o) 2258 h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) 2259 c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) 2260 else: 2261 if 0. < self.dropout < 1.: 2262 inputs *= dp_mask[0] 2263 z = K.dot(inputs, self.kernel) 2264 if 0. < self.recurrent_dropout < 1.: 2265 h_tm1 *= rec_dp_mask[0] 2266 z += K.dot(h_tm1, self.recurrent_kernel) 2267 if self.use_bias: 2268 z = K.bias_add(z, self.bias) 2269 2270 z = array_ops.split(z, num_or_size_splits=4, axis=1) 2271 c, o = self._compute_carry_and_output_fused(z, c_tm1) 2272 2273 h = o * self.activation(c) 2274 return h, [h, c] 2275 2276 def get_config(self): 2277 config = { 2278 'units': 2279 self.units, 2280 'activation': 2281 activations.serialize(self.activation), 2282 'recurrent_activation': 2283 activations.serialize(self.recurrent_activation), 2284 'use_bias': 2285 self.use_bias, 2286 'kernel_initializer': 2287 initializers.serialize(self.kernel_initializer), 2288 'recurrent_initializer': 2289 initializers.serialize(self.recurrent_initializer), 2290 'bias_initializer': 2291 initializers.serialize(self.bias_initializer), 2292 'unit_forget_bias': 2293 self.unit_forget_bias, 2294 'kernel_regularizer': 2295 regularizers.serialize(self.kernel_regularizer), 2296 'recurrent_regularizer': 2297 regularizers.serialize(self.recurrent_regularizer), 2298 'bias_regularizer': 2299 regularizers.serialize(self.bias_regularizer), 2300 'kernel_constraint': 2301 constraints.serialize(self.kernel_constraint), 2302 'recurrent_constraint': 2303 constraints.serialize(self.recurrent_constraint), 2304 'bias_constraint': 2305 constraints.serialize(self.bias_constraint), 2306 'dropout': 2307 self.dropout, 2308 'recurrent_dropout': 2309 self.recurrent_dropout, 2310 'implementation': 2311 self.implementation 2312 } 2313 base_config = super(LSTMCell, self).get_config() 2314 return dict(list(base_config.items()) + list(config.items())) 2315 2316 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 2317 return list(_generate_zero_filled_state_for_cell( 2318 self, inputs, batch_size, dtype)) 2319 2320 2321@keras_export('keras.experimental.PeepholeLSTMCell') 2322class PeepholeLSTMCell(LSTMCell): 2323 """Equivalent to LSTMCell class but adds peephole connections. 2324 2325 Peephole connections allow the gates to utilize the previous internal state as 2326 well as the previous hidden state (which is what LSTMCell is limited to). 2327 This allows PeepholeLSTMCell to better learn precise timings over LSTMCell. 2328 2329 From [Gers et al.](http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf): 2330 2331 "We find that LSTM augmented by 'peephole connections' from its internal 2332 cells to its multiplicative gates can learn the fine distinction between 2333 sequences of spikes spaced either 50 or 49 time steps apart without the help 2334 of any short training exemplars." 2335 2336 The peephole implementation is based on: 2337 2338 [Long short-term memory recurrent neural network architectures for 2339 large scale acoustic modeling. 2340 ](https://research.google.com/pubs/archive/43905.pdf) 2341 2342 Example: 2343 2344 ```python 2345 # Create 2 PeepholeLSTMCells 2346 peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]] 2347 # Create a layer composed sequentially of the peephole LSTM cells. 2348 layer = RNN(peephole_lstm_cells) 2349 input = keras.Input((timesteps, input_dim)) 2350 output = layer(input) 2351 ``` 2352 """ 2353 2354 def build(self, input_shape): 2355 super(PeepholeLSTMCell, self).build(input_shape) 2356 # The following are the weight matrices for the peephole connections. These 2357 # are multiplied with the previous internal state during the computation of 2358 # carry and output. 2359 self.input_gate_peephole_weights = self.add_weight( 2360 shape=(self.units,), 2361 name='input_gate_peephole_weights', 2362 initializer=self.kernel_initializer) 2363 self.forget_gate_peephole_weights = self.add_weight( 2364 shape=(self.units,), 2365 name='forget_gate_peephole_weights', 2366 initializer=self.kernel_initializer) 2367 self.output_gate_peephole_weights = self.add_weight( 2368 shape=(self.units,), 2369 name='output_gate_peephole_weights', 2370 initializer=self.kernel_initializer) 2371 2372 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 2373 x_i, x_f, x_c, x_o = x 2374 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 2375 i = self.recurrent_activation( 2376 x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) + 2377 self.input_gate_peephole_weights * c_tm1) 2378 f = self.recurrent_activation(x_f + K.dot( 2379 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) + 2380 self.forget_gate_peephole_weights * c_tm1) 2381 c = f * c_tm1 + i * self.activation(x_c + K.dot( 2382 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) 2383 o = self.recurrent_activation( 2384 x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) + 2385 self.output_gate_peephole_weights * c) 2386 return c, o 2387 2388 def _compute_carry_and_output_fused(self, z, c_tm1): 2389 z0, z1, z2, z3 = z 2390 i = self.recurrent_activation(z0 + 2391 self.input_gate_peephole_weights * c_tm1) 2392 f = self.recurrent_activation(z1 + 2393 self.forget_gate_peephole_weights * c_tm1) 2394 c = f * c_tm1 + i * self.activation(z2) 2395 o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c) 2396 return c, o 2397 2398 2399@keras_export(v1=['keras.layers.LSTM']) 2400class LSTM(RNN): 2401 """Long Short-Term Memory layer - Hochreiter 1997. 2402 2403 Note that this cell is not optimized for performance on GPU. Please use 2404 `tf.keras.layers.CuDNNLSTM` for better performance on GPU. 2405 2406 Arguments: 2407 units: Positive integer, dimensionality of the output space. 2408 activation: Activation function to use. 2409 Default: hyperbolic tangent (`tanh`). 2410 If you pass `None`, no activation is applied 2411 (ie. "linear" activation: `a(x) = x`). 2412 recurrent_activation: Activation function to use 2413 for the recurrent step. 2414 Default: hard sigmoid (`hard_sigmoid`). 2415 If you pass `None`, no activation is applied 2416 (ie. "linear" activation: `a(x) = x`). 2417 use_bias: Boolean, whether the layer uses a bias vector. 2418 kernel_initializer: Initializer for the `kernel` weights matrix, 2419 used for the linear transformation of the inputs.. 2420 recurrent_initializer: Initializer for the `recurrent_kernel` 2421 weights matrix, 2422 used for the linear transformation of the recurrent state. 2423 bias_initializer: Initializer for the bias vector. 2424 unit_forget_bias: Boolean. 2425 If True, add 1 to the bias of the forget gate at initialization. 2426 Setting it to true will also force `bias_initializer="zeros"`. 2427 This is recommended in [Jozefowicz et 2428 al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf). 2429 kernel_regularizer: Regularizer function applied to 2430 the `kernel` weights matrix. 2431 recurrent_regularizer: Regularizer function applied to 2432 the `recurrent_kernel` weights matrix. 2433 bias_regularizer: Regularizer function applied to the bias vector. 2434 activity_regularizer: Regularizer function applied to 2435 the output of the layer (its "activation").. 2436 kernel_constraint: Constraint function applied to 2437 the `kernel` weights matrix. 2438 recurrent_constraint: Constraint function applied to 2439 the `recurrent_kernel` weights matrix. 2440 bias_constraint: Constraint function applied to the bias vector. 2441 dropout: Float between 0 and 1. 2442 Fraction of the units to drop for 2443 the linear transformation of the inputs. 2444 recurrent_dropout: Float between 0 and 1. 2445 Fraction of the units to drop for 2446 the linear transformation of the recurrent state. 2447 implementation: Implementation mode, either 1 or 2. 2448 Mode 1 will structure its operations as a larger number of 2449 smaller dot products and additions, whereas mode 2 will 2450 batch them into fewer, larger operations. These modes will 2451 have different performance profiles on different hardware and 2452 for different applications. 2453 return_sequences: Boolean. Whether to return the last output. 2454 in the output sequence, or the full sequence. 2455 return_state: Boolean. Whether to return the last state 2456 in addition to the output. 2457 go_backwards: Boolean (default False). 2458 If True, process the input sequence backwards and return the 2459 reversed sequence. 2460 stateful: Boolean (default False). If True, the last state 2461 for each sample at index i in a batch will be used as initial 2462 state for the sample of index i in the following batch. 2463 unroll: Boolean (default False). 2464 If True, the network will be unrolled, 2465 else a symbolic loop will be used. 2466 Unrolling can speed-up a RNN, 2467 although it tends to be more memory-intensive. 2468 Unrolling is only suitable for short sequences. 2469 2470 Call arguments: 2471 inputs: A 3D tensor. 2472 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 2473 a given timestep should be masked. 2474 training: Python boolean indicating whether the layer should behave in 2475 training mode or in inference mode. This argument is passed to the cell 2476 when calling it. This is only relevant if `dropout` or 2477 `recurrent_dropout` is used. 2478 initial_state: List of initial state tensors to be passed to the first 2479 call of the cell. 2480 """ 2481 2482 def __init__(self, 2483 units, 2484 activation='tanh', 2485 recurrent_activation='hard_sigmoid', 2486 use_bias=True, 2487 kernel_initializer='glorot_uniform', 2488 recurrent_initializer='orthogonal', 2489 bias_initializer='zeros', 2490 unit_forget_bias=True, 2491 kernel_regularizer=None, 2492 recurrent_regularizer=None, 2493 bias_regularizer=None, 2494 activity_regularizer=None, 2495 kernel_constraint=None, 2496 recurrent_constraint=None, 2497 bias_constraint=None, 2498 dropout=0., 2499 recurrent_dropout=0., 2500 implementation=1, 2501 return_sequences=False, 2502 return_state=False, 2503 go_backwards=False, 2504 stateful=False, 2505 unroll=False, 2506 **kwargs): 2507 if implementation == 0: 2508 logging.warning('`implementation=0` has been deprecated, ' 2509 'and now defaults to `implementation=1`.' 2510 'Please update your layer call.') 2511 cell = LSTMCell( 2512 units, 2513 activation=activation, 2514 recurrent_activation=recurrent_activation, 2515 use_bias=use_bias, 2516 kernel_initializer=kernel_initializer, 2517 recurrent_initializer=recurrent_initializer, 2518 unit_forget_bias=unit_forget_bias, 2519 bias_initializer=bias_initializer, 2520 kernel_regularizer=kernel_regularizer, 2521 recurrent_regularizer=recurrent_regularizer, 2522 bias_regularizer=bias_regularizer, 2523 kernel_constraint=kernel_constraint, 2524 recurrent_constraint=recurrent_constraint, 2525 bias_constraint=bias_constraint, 2526 dropout=dropout, 2527 recurrent_dropout=recurrent_dropout, 2528 implementation=implementation) 2529 super(LSTM, self).__init__( 2530 cell, 2531 return_sequences=return_sequences, 2532 return_state=return_state, 2533 go_backwards=go_backwards, 2534 stateful=stateful, 2535 unroll=unroll, 2536 **kwargs) 2537 self.activity_regularizer = regularizers.get(activity_regularizer) 2538 self.input_spec = [InputSpec(ndim=3)] 2539 2540 def call(self, inputs, mask=None, training=None, initial_state=None): 2541 self.cell.reset_dropout_mask() 2542 self.cell.reset_recurrent_dropout_mask() 2543 return super(LSTM, self).call( 2544 inputs, mask=mask, training=training, initial_state=initial_state) 2545 2546 @property 2547 def units(self): 2548 return self.cell.units 2549 2550 @property 2551 def activation(self): 2552 return self.cell.activation 2553 2554 @property 2555 def recurrent_activation(self): 2556 return self.cell.recurrent_activation 2557 2558 @property 2559 def use_bias(self): 2560 return self.cell.use_bias 2561 2562 @property 2563 def kernel_initializer(self): 2564 return self.cell.kernel_initializer 2565 2566 @property 2567 def recurrent_initializer(self): 2568 return self.cell.recurrent_initializer 2569 2570 @property 2571 def bias_initializer(self): 2572 return self.cell.bias_initializer 2573 2574 @property 2575 def unit_forget_bias(self): 2576 return self.cell.unit_forget_bias 2577 2578 @property 2579 def kernel_regularizer(self): 2580 return self.cell.kernel_regularizer 2581 2582 @property 2583 def recurrent_regularizer(self): 2584 return self.cell.recurrent_regularizer 2585 2586 @property 2587 def bias_regularizer(self): 2588 return self.cell.bias_regularizer 2589 2590 @property 2591 def kernel_constraint(self): 2592 return self.cell.kernel_constraint 2593 2594 @property 2595 def recurrent_constraint(self): 2596 return self.cell.recurrent_constraint 2597 2598 @property 2599 def bias_constraint(self): 2600 return self.cell.bias_constraint 2601 2602 @property 2603 def dropout(self): 2604 return self.cell.dropout 2605 2606 @property 2607 def recurrent_dropout(self): 2608 return self.cell.recurrent_dropout 2609 2610 @property 2611 def implementation(self): 2612 return self.cell.implementation 2613 2614 def get_config(self): 2615 config = { 2616 'units': 2617 self.units, 2618 'activation': 2619 activations.serialize(self.activation), 2620 'recurrent_activation': 2621 activations.serialize(self.recurrent_activation), 2622 'use_bias': 2623 self.use_bias, 2624 'kernel_initializer': 2625 initializers.serialize(self.kernel_initializer), 2626 'recurrent_initializer': 2627 initializers.serialize(self.recurrent_initializer), 2628 'bias_initializer': 2629 initializers.serialize(self.bias_initializer), 2630 'unit_forget_bias': 2631 self.unit_forget_bias, 2632 'kernel_regularizer': 2633 regularizers.serialize(self.kernel_regularizer), 2634 'recurrent_regularizer': 2635 regularizers.serialize(self.recurrent_regularizer), 2636 'bias_regularizer': 2637 regularizers.serialize(self.bias_regularizer), 2638 'activity_regularizer': 2639 regularizers.serialize(self.activity_regularizer), 2640 'kernel_constraint': 2641 constraints.serialize(self.kernel_constraint), 2642 'recurrent_constraint': 2643 constraints.serialize(self.recurrent_constraint), 2644 'bias_constraint': 2645 constraints.serialize(self.bias_constraint), 2646 'dropout': 2647 self.dropout, 2648 'recurrent_dropout': 2649 self.recurrent_dropout, 2650 'implementation': 2651 self.implementation 2652 } 2653 base_config = super(LSTM, self).get_config() 2654 del base_config['cell'] 2655 return dict(list(base_config.items()) + list(config.items())) 2656 2657 @classmethod 2658 def from_config(cls, config): 2659 if 'implementation' in config and config['implementation'] == 0: 2660 config['implementation'] = 1 2661 return cls(**config) 2662 2663 2664def _generate_dropout_mask(ones, rate, training=None, count=1): 2665 def dropped_inputs(): 2666 return K.dropout(ones, rate) 2667 2668 if count > 1: 2669 return [ 2670 K.in_train_phase(dropped_inputs, ones, training=training) 2671 for _ in range(count) 2672 ] 2673 return K.in_train_phase(dropped_inputs, ones, training=training) 2674 2675 2676def _standardize_args( 2677 inputs, initial_state, constants, num_constants, num_inputs=1): 2678 """Standardizes `__call__` to a single list of tensor inputs. 2679 2680 When running a model loaded from a file, the input tensors 2681 `initial_state` and `constants` can be passed to `RNN.__call__()` as part 2682 of `inputs` instead of by the dedicated keyword arguments. This method 2683 makes sure the arguments are separated and that `initial_state` and 2684 `constants` are lists of tensors (or None). 2685 2686 Arguments: 2687 inputs: Tensor or list/tuple of tensors. which may include constants 2688 and initial states. In that case `num_constant` must be specified. 2689 initial_state: Tensor or list of tensors or None, initial states. 2690 constants: Tensor or list of tensors or None, constant tensors. 2691 num_constants: Expected number of constants (if constants are passed as 2692 part of the `inputs` list. 2693 num_inputs: Expected number of real input tensors (exclude initial_states 2694 and constants). 2695 2696 Returns: 2697 inputs: Single tensor or tuple of tensors. 2698 initial_state: List of tensors or None. 2699 constants: List of tensors or None. 2700 """ 2701 if isinstance(inputs, list): 2702 # There are several situations here: 2703 # In the graph mode, __call__ will be only called once. The initial_state 2704 # and constants could be in inputs (from file loading). 2705 # In the eager mode, __call__ will be called twice, once during 2706 # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be 2707 # model.fit/train_on_batch/predict with real np data. In the second case, 2708 # the inputs will contain initial_state and constants, and more importantly, 2709 # the real inputs will be in a flat list, instead of nested tuple. 2710 # 2711 # For either case, we will use num_inputs to split the input list, and 2712 # restructure the real input into tuple. 2713 assert initial_state is None and constants is None 2714 inputs = nest.flatten(inputs) 2715 if num_constants is not None: 2716 constants = inputs[-num_constants:] 2717 inputs = inputs[:-num_constants] 2718 if num_inputs is None: 2719 num_inputs = 1 2720 if len(inputs) > num_inputs: 2721 initial_state = inputs[num_inputs:] 2722 inputs = inputs[:num_inputs] 2723 2724 if len(inputs) > 1: 2725 inputs = tuple(inputs) 2726 else: 2727 inputs = inputs[0] 2728 2729 def to_list_or_none(x): 2730 if x is None or isinstance(x, list): 2731 return x 2732 if isinstance(x, tuple): 2733 return list(x) 2734 return [x] 2735 2736 initial_state = to_list_or_none(initial_state) 2737 constants = to_list_or_none(constants) 2738 2739 return inputs, initial_state, constants 2740 2741 2742def _is_multiple_state(state_size): 2743 """Check whether the state_size contains multiple states.""" 2744 return (hasattr(state_size, '__len__') and 2745 not isinstance(state_size, tensor_shape.TensorShape)) 2746 2747 2748def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): 2749 if inputs is not None: 2750 batch_size = array_ops.shape(inputs)[0] 2751 dtype = inputs.dtype 2752 return _generate_zero_filled_state(batch_size, cell.state_size, dtype) 2753 2754 2755def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): 2756 """Generate a zero filled tensor with shape [batch_size, state_size].""" 2757 if batch_size_tensor is None or dtype is None: 2758 raise ValueError( 2759 'batch_size and dtype cannot be None while constructing initial state: ' 2760 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype)) 2761 2762 def create_zeros(unnested_state_size): 2763 flat_dims = tensor_shape.as_shape(unnested_state_size).as_list() 2764 init_state_size = [batch_size_tensor] + flat_dims 2765 return array_ops.zeros(init_state_size, dtype=dtype) 2766 2767 if nest.is_sequence(state_size): 2768 return nest.map_structure(create_zeros, state_size) 2769 else: 2770 return create_zeros(state_size) 2771