1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Recurrent layers and their base classes. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import warnings 24 25import numpy as np 26 27from tensorflow.python.distribute import distribution_strategy_context as ds_context 28from tensorflow.python.eager import context 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.keras import activations 32from tensorflow.python.keras import backend as K 33from tensorflow.python.keras import constraints 34from tensorflow.python.keras import initializers 35from tensorflow.python.keras import regularizers 36from tensorflow.python.keras.engine.base_layer import Layer 37from tensorflow.python.keras.engine.input_spec import InputSpec 38from tensorflow.python.keras.saving.saved_model import layer_serialization 39from tensorflow.python.keras.utils import control_flow_util 40from tensorflow.python.keras.utils import generic_utils 41from tensorflow.python.keras.utils import tf_utils 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import math_ops 45from tensorflow.python.ops import state_ops 46from tensorflow.python.platform import tf_logging as logging 47from tensorflow.python.training.tracking import base as trackable 48from tensorflow.python.training.tracking import data_structures 49from tensorflow.python.util import nest 50from tensorflow.python.util.tf_export import keras_export 51from tensorflow.tools.docs import doc_controls 52 53 54RECURRENT_DROPOUT_WARNING_MSG = ( 55 'RNN `implementation=2` is not supported when `recurrent_dropout` is set. ' 56 'Using `implementation=1`.') 57 58 59@keras_export('keras.layers.StackedRNNCells') 60class StackedRNNCells(Layer): 61 """Wrapper allowing a stack of RNN cells to behave as a single cell. 62 63 Used to implement efficient stacked RNNs. 64 65 Args: 66 cells: List of RNN cell instances. 67 68 Examples: 69 70 ```python 71 batch_size = 3 72 sentence_max_length = 5 73 n_features = 2 74 new_shape = (batch_size, sentence_max_length, n_features) 75 x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32) 76 77 rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)] 78 stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells) 79 lstm_layer = tf.keras.layers.RNN(stacked_lstm) 80 81 result = lstm_layer(x) 82 ``` 83 """ 84 85 def __init__(self, cells, **kwargs): 86 for cell in cells: 87 if not 'call' in dir(cell): 88 raise ValueError('All cells must have a `call` method. ' 89 'received cells:', cells) 90 if not 'state_size' in dir(cell): 91 raise ValueError('All cells must have a ' 92 '`state_size` attribute. ' 93 'received cells:', cells) 94 self.cells = cells 95 # reverse_state_order determines whether the state size will be in a reverse 96 # order of the cells' state. User might want to set this to True to keep the 97 # existing behavior. This is only useful when use RNN(return_state=True) 98 # since the state will be returned as the same order of state_size. 99 self.reverse_state_order = kwargs.pop('reverse_state_order', False) 100 if self.reverse_state_order: 101 logging.warning('reverse_state_order=True in StackedRNNCells will soon ' 102 'be deprecated. Please update the code to work with the ' 103 'natural order of states if you rely on the RNN states, ' 104 'eg RNN(return_state=True).') 105 super(StackedRNNCells, self).__init__(**kwargs) 106 107 @property 108 def state_size(self): 109 return tuple(c.state_size for c in 110 (self.cells[::-1] if self.reverse_state_order else self.cells)) 111 112 @property 113 def output_size(self): 114 if getattr(self.cells[-1], 'output_size', None) is not None: 115 return self.cells[-1].output_size 116 elif _is_multiple_state(self.cells[-1].state_size): 117 return self.cells[-1].state_size[0] 118 else: 119 return self.cells[-1].state_size 120 121 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 122 initial_states = [] 123 for cell in self.cells[::-1] if self.reverse_state_order else self.cells: 124 get_initial_state_fn = getattr(cell, 'get_initial_state', None) 125 if get_initial_state_fn: 126 initial_states.append(get_initial_state_fn( 127 inputs=inputs, batch_size=batch_size, dtype=dtype)) 128 else: 129 initial_states.append(_generate_zero_filled_state_for_cell( 130 cell, inputs, batch_size, dtype)) 131 132 return tuple(initial_states) 133 134 def call(self, inputs, states, constants=None, training=None, **kwargs): 135 # Recover per-cell states. 136 state_size = (self.state_size[::-1] 137 if self.reverse_state_order else self.state_size) 138 nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) 139 140 # Call the cells in order and store the returned states. 141 new_nested_states = [] 142 for cell, states in zip(self.cells, nested_states): 143 states = states if nest.is_nested(states) else [states] 144 # TF cell does not wrap the state into list when there is only one state. 145 is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None 146 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 147 if generic_utils.has_arg(cell.call, 'training'): 148 kwargs['training'] = training 149 else: 150 kwargs.pop('training', None) 151 # Use the __call__ function for callable objects, eg layers, so that it 152 # will have the proper name scopes for the ops, etc. 153 cell_call_fn = cell.__call__ if callable(cell) else cell.call 154 if generic_utils.has_arg(cell.call, 'constants'): 155 inputs, states = cell_call_fn(inputs, states, 156 constants=constants, **kwargs) 157 else: 158 inputs, states = cell_call_fn(inputs, states, **kwargs) 159 new_nested_states.append(states) 160 161 return inputs, nest.pack_sequence_as(state_size, 162 nest.flatten(new_nested_states)) 163 164 @tf_utils.shape_type_conversion 165 def build(self, input_shape): 166 if isinstance(input_shape, list): 167 input_shape = input_shape[0] 168 for cell in self.cells: 169 if isinstance(cell, Layer) and not cell.built: 170 with K.name_scope(cell.name): 171 cell.build(input_shape) 172 cell.built = True 173 if getattr(cell, 'output_size', None) is not None: 174 output_dim = cell.output_size 175 elif _is_multiple_state(cell.state_size): 176 output_dim = cell.state_size[0] 177 else: 178 output_dim = cell.state_size 179 input_shape = tuple([input_shape[0]] + 180 tensor_shape.TensorShape(output_dim).as_list()) 181 self.built = True 182 183 def get_config(self): 184 cells = [] 185 for cell in self.cells: 186 cells.append(generic_utils.serialize_keras_object(cell)) 187 config = {'cells': cells} 188 base_config = super(StackedRNNCells, self).get_config() 189 return dict(list(base_config.items()) + list(config.items())) 190 191 @classmethod 192 def from_config(cls, config, custom_objects=None): 193 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 194 cells = [] 195 for cell_config in config.pop('cells'): 196 cells.append( 197 deserialize_layer(cell_config, custom_objects=custom_objects)) 198 return cls(cells, **config) 199 200 201@keras_export('keras.layers.RNN') 202class RNN(Layer): 203 """Base class for recurrent layers. 204 205 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 206 for details about the usage of RNN API. 207 208 Args: 209 cell: A RNN cell instance or a list of RNN cell instances. 210 A RNN cell is a class that has: 211 - A `call(input_at_t, states_at_t)` method, returning 212 `(output_at_t, states_at_t_plus_1)`. The call method of the 213 cell can also take the optional argument `constants`, see 214 section "Note on passing external constants" below. 215 - A `state_size` attribute. This can be a single integer 216 (single state) in which case it is the size of the recurrent 217 state. This can also be a list/tuple of integers (one size per state). 218 The `state_size` can also be TensorShape or tuple/list of 219 TensorShape, to represent high dimension state. 220 - A `output_size` attribute. This can be a single integer or a 221 TensorShape, which represent the shape of the output. For backward 222 compatible reason, if this attribute is not available for the 223 cell, the value will be inferred by the first element of the 224 `state_size`. 225 - A `get_initial_state(inputs=None, batch_size=None, dtype=None)` 226 method that creates a tensor meant to be fed to `call()` as the 227 initial state, if the user didn't specify any initial state via other 228 means. The returned initial state should have a shape of 229 [batch_size, cell.state_size]. The cell might choose to create a 230 tensor full of zeros, or full of other values based on the cell's 231 implementation. 232 `inputs` is the input tensor to the RNN layer, which should 233 contain the batch size as its shape[0], and also dtype. Note that 234 the shape[0] might be `None` during the graph construction. Either 235 the `inputs` or the pair of `batch_size` and `dtype` are provided. 236 `batch_size` is a scalar tensor that represents the batch size 237 of the inputs. `dtype` is `tf.DType` that represents the dtype of 238 the inputs. 239 For backward compatibility, if this method is not implemented 240 by the cell, the RNN layer will create a zero filled tensor with the 241 size of [batch_size, cell.state_size]. 242 In the case that `cell` is a list of RNN cell instances, the cells 243 will be stacked on top of each other in the RNN, resulting in an 244 efficient stacked RNN. 245 return_sequences: Boolean (default `False`). Whether to return the last 246 output in the output sequence, or the full sequence. 247 return_state: Boolean (default `False`). Whether to return the last state 248 in addition to the output. 249 go_backwards: Boolean (default `False`). 250 If True, process the input sequence backwards and return the 251 reversed sequence. 252 stateful: Boolean (default `False`). If True, the last state 253 for each sample at index i in a batch will be used as initial 254 state for the sample of index i in the following batch. 255 unroll: Boolean (default `False`). 256 If True, the network will be unrolled, else a symbolic loop will be used. 257 Unrolling can speed-up a RNN, although it tends to be more 258 memory-intensive. Unrolling is only suitable for short sequences. 259 time_major: The shape format of the `inputs` and `outputs` tensors. 260 If True, the inputs and outputs will be in shape 261 `(timesteps, batch, ...)`, whereas in the False case, it will be 262 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 263 efficient because it avoids transposes at the beginning and end of the 264 RNN calculation. However, most TensorFlow data is batch-major, so by 265 default this function accepts input and emits output in batch-major 266 form. 267 zero_output_for_mask: Boolean (default `False`). 268 Whether the output should use zeros for the masked timesteps. Note that 269 this field is only used when `return_sequences` is True and mask is 270 provided. It can useful if you want to reuse the raw output sequence of 271 the RNN without interference from the masked timesteps, eg, merging 272 bidirectional RNNs. 273 274 Call arguments: 275 inputs: Input tensor. 276 mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether 277 a given timestep should be masked. An individual `True` entry indicates 278 that the corresponding timestep should be utilized, while a `False` 279 entry indicates that the corresponding timestep should be ignored. 280 training: Python boolean indicating whether the layer should behave in 281 training mode or in inference mode. This argument is passed to the cell 282 when calling it. This is for use with cells that use dropout. 283 initial_state: List of initial state tensors to be passed to the first 284 call of the cell. 285 constants: List of constant tensors to be passed to the cell at each 286 timestep. 287 288 Input shape: 289 N-D tensor with shape `[batch_size, timesteps, ...]` or 290 `[timesteps, batch_size, ...]` when time_major is True. 291 292 Output shape: 293 - If `return_state`: a list of tensors. The first tensor is 294 the output. The remaining tensors are the last states, 295 each with shape `[batch_size, state_size]`, where `state_size` could 296 be a high dimension tensor shape. 297 - If `return_sequences`: N-D tensor with shape 298 `[batch_size, timesteps, output_size]`, where `output_size` could 299 be a high dimension tensor shape, or 300 `[timesteps, batch_size, output_size]` when `time_major` is True. 301 - Else, N-D tensor with shape `[batch_size, output_size]`, where 302 `output_size` could be a high dimension tensor shape. 303 304 Masking: 305 This layer supports masking for input data with a variable number 306 of timesteps. To introduce masks to your data, 307 use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter 308 set to `True`. 309 310 Note on using statefulness in RNNs: 311 You can set RNN layers to be 'stateful', which means that the states 312 computed for the samples in one batch will be reused as initial states 313 for the samples in the next batch. This assumes a one-to-one mapping 314 between samples in different successive batches. 315 316 To enable statefulness: 317 - Specify `stateful=True` in the layer constructor. 318 - Specify a fixed batch size for your model, by passing 319 If sequential model: 320 `batch_input_shape=(...)` to the first layer in your model. 321 Else for functional model with 1 or more Input layers: 322 `batch_shape=(...)` to all the first layers in your model. 323 This is the expected shape of your inputs 324 *including the batch size*. 325 It should be a tuple of integers, e.g. `(32, 10, 100)`. 326 - Specify `shuffle=False` when calling `fit()`. 327 328 To reset the states of your model, call `.reset_states()` on either 329 a specific layer, or on your entire model. 330 331 Note on specifying the initial state of RNNs: 332 You can specify the initial state of RNN layers symbolically by 333 calling them with the keyword argument `initial_state`. The value of 334 `initial_state` should be a tensor or list of tensors representing 335 the initial state of the RNN layer. 336 337 You can specify the initial state of RNN layers numerically by 338 calling `reset_states` with the keyword argument `states`. The value of 339 `states` should be a numpy array or list of numpy arrays representing 340 the initial state of the RNN layer. 341 342 Note on passing external constants to RNNs: 343 You can pass "external" constants to the cell using the `constants` 344 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This 345 requires that the `cell.call` method accepts the same keyword argument 346 `constants`. Such constants can be used to condition the cell 347 transformation on additional static inputs (not changing over time), 348 a.k.a. an attention mechanism. 349 350 Examples: 351 352 ```python 353 # First, let's define a RNN Cell, as a layer subclass. 354 355 class MinimalRNNCell(keras.layers.Layer): 356 357 def __init__(self, units, **kwargs): 358 self.units = units 359 self.state_size = units 360 super(MinimalRNNCell, self).__init__(**kwargs) 361 362 def build(self, input_shape): 363 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 364 initializer='uniform', 365 name='kernel') 366 self.recurrent_kernel = self.add_weight( 367 shape=(self.units, self.units), 368 initializer='uniform', 369 name='recurrent_kernel') 370 self.built = True 371 372 def call(self, inputs, states): 373 prev_output = states[0] 374 h = K.dot(inputs, self.kernel) 375 output = h + K.dot(prev_output, self.recurrent_kernel) 376 return output, [output] 377 378 # Let's use this cell in a RNN layer: 379 380 cell = MinimalRNNCell(32) 381 x = keras.Input((None, 5)) 382 layer = RNN(cell) 383 y = layer(x) 384 385 # Here's how to use the cell to build a stacked RNN: 386 387 cells = [MinimalRNNCell(32), MinimalRNNCell(64)] 388 x = keras.Input((None, 5)) 389 layer = RNN(cells) 390 y = layer(x) 391 ``` 392 """ 393 394 def __init__(self, 395 cell, 396 return_sequences=False, 397 return_state=False, 398 go_backwards=False, 399 stateful=False, 400 unroll=False, 401 time_major=False, 402 **kwargs): 403 if isinstance(cell, (list, tuple)): 404 cell = StackedRNNCells(cell) 405 if not 'call' in dir(cell): 406 raise ValueError('`cell` should have a `call` method. ' 407 'The RNN was passed:', cell) 408 if not 'state_size' in dir(cell): 409 raise ValueError('The RNN cell should have ' 410 'an attribute `state_size` ' 411 '(tuple of integers, ' 412 'one integer per RNN state).') 413 # If True, the output for masked timestep will be zeros, whereas in the 414 # False case, output from previous timestep is returned for masked timestep. 415 self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False) 416 417 if 'input_shape' not in kwargs and ( 418 'input_dim' in kwargs or 'input_length' in kwargs): 419 input_shape = (kwargs.pop('input_length', None), 420 kwargs.pop('input_dim', None)) 421 kwargs['input_shape'] = input_shape 422 423 super(RNN, self).__init__(**kwargs) 424 self.cell = cell 425 self.return_sequences = return_sequences 426 self.return_state = return_state 427 self.go_backwards = go_backwards 428 self.stateful = stateful 429 self.unroll = unroll 430 self.time_major = time_major 431 432 self.supports_masking = True 433 # The input shape is unknown yet, it could have nested tensor inputs, and 434 # the input spec will be the list of specs for nested inputs, the structure 435 # of the input_spec will be the same as the input. 436 self.input_spec = None 437 self.state_spec = None 438 self._states = None 439 self.constants_spec = None 440 self._num_constants = 0 441 442 if stateful: 443 if ds_context.has_strategy(): 444 raise ValueError('RNNs with stateful=True not yet supported with ' 445 'tf.distribute.Strategy.') 446 447 @property 448 def _use_input_spec_as_call_signature(self): 449 if self.unroll: 450 # When the RNN layer is unrolled, the time step shape cannot be unknown. 451 # The input spec does not define the time step (because this layer can be 452 # called with any time step value, as long as it is not None), so it 453 # cannot be used as the call function signature when saving to SavedModel. 454 return False 455 return super(RNN, self)._use_input_spec_as_call_signature 456 457 @property 458 def states(self): 459 if self._states is None: 460 state = nest.map_structure(lambda _: None, self.cell.state_size) 461 return state if nest.is_nested(self.cell.state_size) else [state] 462 return self._states 463 464 @states.setter 465 # Automatic tracking catches "self._states" which adds an extra weight and 466 # breaks HDF5 checkpoints. 467 @trackable.no_automatic_dependency_tracking 468 def states(self, states): 469 self._states = states 470 471 def compute_output_shape(self, input_shape): 472 if isinstance(input_shape, list): 473 input_shape = input_shape[0] 474 # Check whether the input shape contains any nested shapes. It could be 475 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy 476 # inputs. 477 try: 478 input_shape = tensor_shape.TensorShape(input_shape) 479 except (ValueError, TypeError): 480 # A nested tensor input 481 input_shape = nest.flatten(input_shape)[0] 482 483 batch = input_shape[0] 484 time_step = input_shape[1] 485 if self.time_major: 486 batch, time_step = time_step, batch 487 488 if _is_multiple_state(self.cell.state_size): 489 state_size = self.cell.state_size 490 else: 491 state_size = [self.cell.state_size] 492 493 def _get_output_shape(flat_output_size): 494 output_dim = tensor_shape.TensorShape(flat_output_size).as_list() 495 if self.return_sequences: 496 if self.time_major: 497 output_shape = tensor_shape.TensorShape( 498 [time_step, batch] + output_dim) 499 else: 500 output_shape = tensor_shape.TensorShape( 501 [batch, time_step] + output_dim) 502 else: 503 output_shape = tensor_shape.TensorShape([batch] + output_dim) 504 return output_shape 505 506 if getattr(self.cell, 'output_size', None) is not None: 507 # cell.output_size could be nested structure. 508 output_shape = nest.flatten(nest.map_structure( 509 _get_output_shape, self.cell.output_size)) 510 output_shape = output_shape[0] if len(output_shape) == 1 else output_shape 511 else: 512 # Note that state_size[0] could be a tensor_shape or int. 513 output_shape = _get_output_shape(state_size[0]) 514 515 if self.return_state: 516 def _get_state_shape(flat_state): 517 state_shape = [batch] + tensor_shape.TensorShape(flat_state).as_list() 518 return tensor_shape.TensorShape(state_shape) 519 state_shape = nest.map_structure(_get_state_shape, state_size) 520 return generic_utils.to_list(output_shape) + nest.flatten(state_shape) 521 else: 522 return output_shape 523 524 def compute_mask(self, inputs, mask): 525 # Time step masks must be the same for each input. 526 # This is because the mask for an RNN is of size [batch, time_steps, 1], 527 # and specifies which time steps should be skipped, and a time step 528 # must be skipped for all inputs. 529 # TODO(scottzhu): Should we accept multiple different masks? 530 mask = nest.flatten(mask)[0] 531 output_mask = mask if self.return_sequences else None 532 if self.return_state: 533 state_mask = [None for _ in self.states] 534 return [output_mask] + state_mask 535 else: 536 return output_mask 537 538 def build(self, input_shape): 539 if isinstance(input_shape, list): 540 input_shape = input_shape[0] 541 # The input_shape here could be a nest structure. 542 543 # do the tensor_shape to shapes here. The input could be single tensor, or a 544 # nested structure of tensors. 545 def get_input_spec(shape): 546 """Convert input shape to InputSpec.""" 547 if isinstance(shape, tensor_shape.TensorShape): 548 input_spec_shape = shape.as_list() 549 else: 550 input_spec_shape = list(shape) 551 batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) 552 if not self.stateful: 553 input_spec_shape[batch_index] = None 554 input_spec_shape[time_step_index] = None 555 return InputSpec(shape=tuple(input_spec_shape)) 556 557 def get_step_input_shape(shape): 558 if isinstance(shape, tensor_shape.TensorShape): 559 shape = tuple(shape.as_list()) 560 # remove the timestep from the input_shape 561 return shape[1:] if self.time_major else (shape[0],) + shape[2:] 562 563 # Check whether the input shape contains any nested shapes. It could be 564 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy 565 # inputs. 566 try: 567 input_shape = tensor_shape.TensorShape(input_shape) 568 except (ValueError, TypeError): 569 # A nested tensor input 570 pass 571 572 if not nest.is_nested(input_shape): 573 # This indicates the there is only one input. 574 if self.input_spec is not None: 575 self.input_spec[0] = get_input_spec(input_shape) 576 else: 577 self.input_spec = [get_input_spec(input_shape)] 578 step_input_shape = get_step_input_shape(input_shape) 579 else: 580 if self.input_spec is not None: 581 self.input_spec[0] = nest.map_structure(get_input_spec, input_shape) 582 else: 583 self.input_spec = generic_utils.to_list( 584 nest.map_structure(get_input_spec, input_shape)) 585 step_input_shape = nest.map_structure(get_step_input_shape, input_shape) 586 587 # allow cell (if layer) to build before we set or validate state_spec. 588 if isinstance(self.cell, Layer) and not self.cell.built: 589 with K.name_scope(self.cell.name): 590 self.cell.build(step_input_shape) 591 self.cell.built = True 592 593 # set or validate state_spec 594 if _is_multiple_state(self.cell.state_size): 595 state_size = list(self.cell.state_size) 596 else: 597 state_size = [self.cell.state_size] 598 599 if self.state_spec is not None: 600 # initial_state was passed in call, check compatibility 601 self._validate_state_spec(state_size, self.state_spec) 602 else: 603 self.state_spec = [ 604 InputSpec(shape=[None] + tensor_shape.TensorShape(dim).as_list()) 605 for dim in state_size 606 ] 607 if self.stateful: 608 self.reset_states() 609 self.built = True 610 611 @staticmethod 612 def _validate_state_spec(cell_state_sizes, init_state_specs): 613 """Validate the state spec between the initial_state and the state_size. 614 615 Args: 616 cell_state_sizes: list, the `state_size` attribute from the cell. 617 init_state_specs: list, the `state_spec` from the initial_state that is 618 passed in `call()`. 619 620 Raises: 621 ValueError: When initial state spec is not compatible with the state size. 622 """ 623 validation_error = ValueError( 624 'An `initial_state` was passed that is not compatible with ' 625 '`cell.state_size`. Received `state_spec`={}; ' 626 'however `cell.state_size` is ' 627 '{}'.format(init_state_specs, cell_state_sizes)) 628 flat_cell_state_sizes = nest.flatten(cell_state_sizes) 629 flat_state_specs = nest.flatten(init_state_specs) 630 631 if len(flat_cell_state_sizes) != len(flat_state_specs): 632 raise validation_error 633 for cell_state_spec, cell_state_size in zip(flat_state_specs, 634 flat_cell_state_sizes): 635 if not tensor_shape.TensorShape( 636 # Ignore the first axis for init_state which is for batch 637 cell_state_spec.shape[1:]).is_compatible_with( 638 tensor_shape.TensorShape(cell_state_size)): 639 raise validation_error 640 641 @doc_controls.do_not_doc_inheritable 642 def get_initial_state(self, inputs): 643 get_initial_state_fn = getattr(self.cell, 'get_initial_state', None) 644 645 if nest.is_nested(inputs): 646 # The input are nested sequences. Use the first element in the seq to get 647 # batch size and dtype. 648 inputs = nest.flatten(inputs)[0] 649 650 input_shape = array_ops.shape(inputs) 651 batch_size = input_shape[1] if self.time_major else input_shape[0] 652 dtype = inputs.dtype 653 if get_initial_state_fn: 654 init_state = get_initial_state_fn( 655 inputs=None, batch_size=batch_size, dtype=dtype) 656 else: 657 init_state = _generate_zero_filled_state(batch_size, self.cell.state_size, 658 dtype) 659 # Keras RNN expect the states in a list, even if it's a single state tensor. 660 if not nest.is_nested(init_state): 661 init_state = [init_state] 662 # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple. 663 return list(init_state) 664 665 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 666 inputs, initial_state, constants = _standardize_args(inputs, 667 initial_state, 668 constants, 669 self._num_constants) 670 671 if initial_state is None and constants is None: 672 return super(RNN, self).__call__(inputs, **kwargs) 673 674 # If any of `initial_state` or `constants` are specified and are Keras 675 # tensors, then add them to the inputs and temporarily modify the 676 # input_spec to include them. 677 678 additional_inputs = [] 679 additional_specs = [] 680 if initial_state is not None: 681 additional_inputs += initial_state 682 self.state_spec = nest.map_structure( 683 lambda s: InputSpec(shape=K.int_shape(s)), initial_state) 684 additional_specs += self.state_spec 685 if constants is not None: 686 additional_inputs += constants 687 self.constants_spec = [ 688 InputSpec(shape=K.int_shape(constant)) for constant in constants 689 ] 690 self._num_constants = len(constants) 691 additional_specs += self.constants_spec 692 # additional_inputs can be empty if initial_state or constants are provided 693 # but empty (e.g. the cell is stateless). 694 flat_additional_inputs = nest.flatten(additional_inputs) 695 is_keras_tensor = K.is_keras_tensor( 696 flat_additional_inputs[0]) if flat_additional_inputs else True 697 for tensor in flat_additional_inputs: 698 if K.is_keras_tensor(tensor) != is_keras_tensor: 699 raise ValueError('The initial state or constants of an RNN' 700 ' layer cannot be specified with a mix of' 701 ' Keras tensors and non-Keras tensors' 702 ' (a "Keras tensor" is a tensor that was' 703 ' returned by a Keras layer, or by `Input`)') 704 705 if is_keras_tensor: 706 # Compute the full input spec, including state and constants 707 full_input = [inputs] + additional_inputs 708 if self.built: 709 # Keep the input_spec since it has been populated in build() method. 710 full_input_spec = self.input_spec + additional_specs 711 else: 712 # The original input_spec is None since there could be a nested tensor 713 # input. Update the input_spec to match the inputs. 714 full_input_spec = generic_utils.to_list( 715 nest.map_structure(lambda _: None, inputs)) + additional_specs 716 # Perform the call with temporarily replaced input_spec 717 self.input_spec = full_input_spec 718 output = super(RNN, self).__call__(full_input, **kwargs) 719 # Remove the additional_specs from input spec and keep the rest. It is 720 # important to keep since the input spec was populated by build(), and 721 # will be reused in the stateful=True. 722 self.input_spec = self.input_spec[:-len(additional_specs)] 723 return output 724 else: 725 if initial_state is not None: 726 kwargs['initial_state'] = initial_state 727 if constants is not None: 728 kwargs['constants'] = constants 729 return super(RNN, self).__call__(inputs, **kwargs) 730 731 def call(self, 732 inputs, 733 mask=None, 734 training=None, 735 initial_state=None, 736 constants=None): 737 # The input should be dense, padded with zeros. If a ragged input is fed 738 # into the layer, it is padded and the row lengths are used for masking. 739 inputs, row_lengths = K.convert_inputs_if_ragged(inputs) 740 is_ragged_input = (row_lengths is not None) 741 self._validate_args_if_ragged(is_ragged_input, mask) 742 743 inputs, initial_state, constants = self._process_inputs( 744 inputs, initial_state, constants) 745 746 self._maybe_reset_cell_dropout_mask(self.cell) 747 if isinstance(self.cell, StackedRNNCells): 748 for cell in self.cell.cells: 749 self._maybe_reset_cell_dropout_mask(cell) 750 751 if mask is not None: 752 # Time step masks must be the same for each input. 753 # TODO(scottzhu): Should we accept multiple different masks? 754 mask = nest.flatten(mask)[0] 755 756 if nest.is_nested(inputs): 757 # In the case of nested input, use the first element for shape check. 758 input_shape = K.int_shape(nest.flatten(inputs)[0]) 759 else: 760 input_shape = K.int_shape(inputs) 761 timesteps = input_shape[0] if self.time_major else input_shape[1] 762 if self.unroll and timesteps is None: 763 raise ValueError('Cannot unroll a RNN if the ' 764 'time dimension is undefined. \n' 765 '- If using a Sequential model, ' 766 'specify the time dimension by passing ' 767 'an `input_shape` or `batch_input_shape` ' 768 'argument to your first layer. If your ' 769 'first layer is an Embedding, you can ' 770 'also use the `input_length` argument.\n' 771 '- If using the functional API, specify ' 772 'the time dimension by passing a `shape` ' 773 'or `batch_shape` argument to your Input layer.') 774 775 kwargs = {} 776 if generic_utils.has_arg(self.cell.call, 'training'): 777 kwargs['training'] = training 778 779 # TF RNN cells expect single tensor as state instead of list wrapped tensor. 780 is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None 781 # Use the __call__ function for callable objects, eg layers, so that it 782 # will have the proper name scopes for the ops, etc. 783 cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call 784 if constants: 785 if not generic_utils.has_arg(self.cell.call, 'constants'): 786 raise ValueError('RNN cell does not support constants') 787 788 def step(inputs, states): 789 constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type 790 states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type 791 792 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 793 output, new_states = cell_call_fn( 794 inputs, states, constants=constants, **kwargs) 795 if not nest.is_nested(new_states): 796 new_states = [new_states] 797 return output, new_states 798 else: 799 800 def step(inputs, states): 801 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 802 output, new_states = cell_call_fn(inputs, states, **kwargs) 803 if not nest.is_nested(new_states): 804 new_states = [new_states] 805 return output, new_states 806 last_output, outputs, states = K.rnn( 807 step, 808 inputs, 809 initial_state, 810 constants=constants, 811 go_backwards=self.go_backwards, 812 mask=mask, 813 unroll=self.unroll, 814 input_length=row_lengths if row_lengths is not None else timesteps, 815 time_major=self.time_major, 816 zero_output_for_mask=self.zero_output_for_mask) 817 818 if self.stateful: 819 updates = [ 820 state_ops.assign(self_state, state) for self_state, state in zip( 821 nest.flatten(self.states), nest.flatten(states)) 822 ] 823 self.add_update(updates) 824 825 if self.return_sequences: 826 output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths) 827 else: 828 output = last_output 829 830 if self.return_state: 831 if not isinstance(states, (list, tuple)): 832 states = [states] 833 else: 834 states = list(states) 835 return generic_utils.to_list(output) + states 836 else: 837 return output 838 839 def _process_inputs(self, inputs, initial_state, constants): 840 # input shape: `(samples, time (padded with zeros), input_dim)` 841 # note that the .build() method of subclasses MUST define 842 # self.input_spec and self.state_spec with complete input shapes. 843 if (isinstance(inputs, collections.abc.Sequence) 844 and not isinstance(inputs, tuple)): 845 # get initial_state from full input spec 846 # as they could be copied to multiple GPU. 847 if not self._num_constants: 848 initial_state = inputs[1:] 849 else: 850 initial_state = inputs[1:-self._num_constants] 851 constants = inputs[-self._num_constants:] 852 if len(initial_state) == 0: 853 initial_state = None 854 inputs = inputs[0] 855 856 if self.stateful: 857 if initial_state is not None: 858 # When layer is stateful and initial_state is provided, check if the 859 # recorded state is same as the default value (zeros). Use the recorded 860 # state if it is not same as the default. 861 non_zero_count = math_ops.add_n([math_ops.count_nonzero_v2(s) 862 for s in nest.flatten(self.states)]) 863 # Set strict = True to keep the original structure of the state. 864 initial_state = control_flow_ops.cond(non_zero_count > 0, 865 true_fn=lambda: self.states, 866 false_fn=lambda: initial_state, 867 strict=True) 868 else: 869 initial_state = self.states 870 elif initial_state is None: 871 initial_state = self.get_initial_state(inputs) 872 873 if len(initial_state) != len(self.states): 874 raise ValueError('Layer has ' + str(len(self.states)) + 875 ' states but was passed ' + str(len(initial_state)) + 876 ' initial states.') 877 return inputs, initial_state, constants 878 879 def _validate_args_if_ragged(self, is_ragged_input, mask): 880 if not is_ragged_input: 881 return 882 883 if mask is not None: 884 raise ValueError('The mask that was passed in was ' + str(mask) + 885 ' and cannot be applied to RaggedTensor inputs. Please ' 886 'make sure that there is no mask passed in by upstream ' 887 'layers.') 888 if self.unroll: 889 raise ValueError('The input received contains RaggedTensors and does ' 890 'not support unrolling. Disable unrolling by passing ' 891 '`unroll=False` in the RNN Layer constructor.') 892 893 def _maybe_reset_cell_dropout_mask(self, cell): 894 if isinstance(cell, DropoutRNNCellMixin): 895 cell.reset_dropout_mask() 896 cell.reset_recurrent_dropout_mask() 897 898 def reset_states(self, states=None): 899 """Reset the recorded states for the stateful RNN layer. 900 901 Can only be used when RNN layer is constructed with `stateful` = `True`. 902 Args: 903 states: Numpy arrays that contains the value for the initial state, which 904 will be feed to cell at the first time step. When the value is None, 905 zero filled numpy array will be created based on the cell state size. 906 907 Raises: 908 AttributeError: When the RNN layer is not stateful. 909 ValueError: When the batch size of the RNN layer is unknown. 910 ValueError: When the input numpy array is not compatible with the RNN 911 layer state, either size wise or dtype wise. 912 """ 913 if not self.stateful: 914 raise AttributeError('Layer must be stateful.') 915 spec_shape = None 916 if self.input_spec is not None: 917 spec_shape = nest.flatten(self.input_spec[0])[0].shape 918 if spec_shape is None: 919 # It is possible to have spec shape to be None, eg when construct a RNN 920 # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know 921 # it has 3 dim input, but not its full shape spec before build(). 922 batch_size = None 923 else: 924 batch_size = spec_shape[1] if self.time_major else spec_shape[0] 925 if not batch_size: 926 raise ValueError('If a RNN is stateful, it needs to know ' 927 'its batch size. Specify the batch size ' 928 'of your input tensors: \n' 929 '- If using a Sequential model, ' 930 'specify the batch size by passing ' 931 'a `batch_input_shape` ' 932 'argument to your first layer.\n' 933 '- If using the functional API, specify ' 934 'the batch size by passing a ' 935 '`batch_shape` argument to your Input layer.') 936 # initialize state if None 937 if nest.flatten(self.states)[0] is None: 938 if getattr(self.cell, 'get_initial_state', None): 939 flat_init_state_values = nest.flatten(self.cell.get_initial_state( 940 inputs=None, batch_size=batch_size, 941 dtype=self.dtype or K.floatx())) 942 else: 943 flat_init_state_values = nest.flatten(_generate_zero_filled_state( 944 batch_size, self.cell.state_size, self.dtype or K.floatx())) 945 flat_states_variables = nest.map_structure( 946 K.variable, flat_init_state_values) 947 self.states = nest.pack_sequence_as(self.cell.state_size, 948 flat_states_variables) 949 if not nest.is_nested(self.states): 950 self.states = [self.states] 951 elif states is None: 952 for state, size in zip(nest.flatten(self.states), 953 nest.flatten(self.cell.state_size)): 954 K.set_value(state, np.zeros([batch_size] + 955 tensor_shape.TensorShape(size).as_list())) 956 else: 957 flat_states = nest.flatten(self.states) 958 flat_input_states = nest.flatten(states) 959 if len(flat_input_states) != len(flat_states): 960 raise ValueError('Layer ' + self.name + ' expects ' + 961 str(len(flat_states)) + ' states, ' 962 'but it received ' + str(len(flat_input_states)) + 963 ' state values. Input received: ' + str(states)) 964 set_value_tuples = [] 965 for i, (value, state) in enumerate(zip(flat_input_states, 966 flat_states)): 967 if value.shape != state.shape: 968 raise ValueError( 969 'State ' + str(i) + ' is incompatible with layer ' + 970 self.name + ': expected shape=' + str( 971 (batch_size, state)) + ', found shape=' + str(value.shape)) 972 set_value_tuples.append((state, value)) 973 K.batch_set_value(set_value_tuples) 974 975 def get_config(self): 976 config = { 977 'return_sequences': self.return_sequences, 978 'return_state': self.return_state, 979 'go_backwards': self.go_backwards, 980 'stateful': self.stateful, 981 'unroll': self.unroll, 982 'time_major': self.time_major 983 } 984 if self._num_constants: 985 config['num_constants'] = self._num_constants 986 if self.zero_output_for_mask: 987 config['zero_output_for_mask'] = self.zero_output_for_mask 988 989 config['cell'] = generic_utils.serialize_keras_object(self.cell) 990 base_config = super(RNN, self).get_config() 991 return dict(list(base_config.items()) + list(config.items())) 992 993 @classmethod 994 def from_config(cls, config, custom_objects=None): 995 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 996 cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) 997 num_constants = config.pop('num_constants', 0) 998 layer = cls(cell, **config) 999 layer._num_constants = num_constants 1000 return layer 1001 1002 @property 1003 def _trackable_saved_model_saver(self): 1004 return layer_serialization.RNNSavedModelSaver(self) 1005 1006 1007@keras_export('keras.layers.AbstractRNNCell') 1008class AbstractRNNCell(Layer): 1009 """Abstract object representing an RNN cell. 1010 1011 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 1012 for details about the usage of RNN API. 1013 1014 This is the base class for implementing RNN cells with custom behavior. 1015 1016 Every `RNNCell` must have the properties below and implement `call` with 1017 the signature `(output, next_state) = call(input, state)`. 1018 1019 Examples: 1020 1021 ```python 1022 class MinimalRNNCell(AbstractRNNCell): 1023 1024 def __init__(self, units, **kwargs): 1025 self.units = units 1026 super(MinimalRNNCell, self).__init__(**kwargs) 1027 1028 @property 1029 def state_size(self): 1030 return self.units 1031 1032 def build(self, input_shape): 1033 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 1034 initializer='uniform', 1035 name='kernel') 1036 self.recurrent_kernel = self.add_weight( 1037 shape=(self.units, self.units), 1038 initializer='uniform', 1039 name='recurrent_kernel') 1040 self.built = True 1041 1042 def call(self, inputs, states): 1043 prev_output = states[0] 1044 h = K.dot(inputs, self.kernel) 1045 output = h + K.dot(prev_output, self.recurrent_kernel) 1046 return output, output 1047 ``` 1048 1049 This definition of cell differs from the definition used in the literature. 1050 In the literature, 'cell' refers to an object with a single scalar output. 1051 This definition refers to a horizontal array of such units. 1052 1053 An RNN cell, in the most abstract setting, is anything that has 1054 a state and performs some operation that takes a matrix of inputs. 1055 This operation results in an output matrix with `self.output_size` columns. 1056 If `self.state_size` is an integer, this operation also results in a new 1057 state matrix with `self.state_size` columns. If `self.state_size` is a 1058 (possibly nested tuple of) TensorShape object(s), then it should return a 1059 matching structure of Tensors having shape `[batch_size].concatenate(s)` 1060 for each `s` in `self.batch_size`. 1061 """ 1062 1063 def call(self, inputs, states): 1064 """The function that contains the logic for one RNN step calculation. 1065 1066 Args: 1067 inputs: the input tensor, which is a slide from the overall RNN input by 1068 the time dimension (usually the second dimension). 1069 states: the state tensor from previous step, which has the same shape 1070 as `(batch, state_size)`. In the case of timestep 0, it will be the 1071 initial state user specified, or zero filled tensor otherwise. 1072 1073 Returns: 1074 A tuple of two tensors: 1075 1. output tensor for the current timestep, with size `output_size`. 1076 2. state tensor for next step, which has the shape of `state_size`. 1077 """ 1078 raise NotImplementedError('Abstract method') 1079 1080 @property 1081 def state_size(self): 1082 """size(s) of state(s) used by this cell. 1083 1084 It can be represented by an Integer, a TensorShape or a tuple of Integers 1085 or TensorShapes. 1086 """ 1087 raise NotImplementedError('Abstract method') 1088 1089 @property 1090 def output_size(self): 1091 """Integer or TensorShape: size of outputs produced by this cell.""" 1092 raise NotImplementedError('Abstract method') 1093 1094 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 1095 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 1096 1097 1098@doc_controls.do_not_generate_docs 1099class DropoutRNNCellMixin(object): 1100 """Object that hold dropout related fields for RNN Cell. 1101 1102 This class is not a standalone RNN cell. It suppose to be used with a RNN cell 1103 by multiple inheritance. Any cell that mix with class should have following 1104 fields: 1105 dropout: a float number within range [0, 1). The ratio that the input 1106 tensor need to dropout. 1107 recurrent_dropout: a float number within range [0, 1). The ratio that the 1108 recurrent state weights need to dropout. 1109 This object will create and cache created dropout masks, and reuse them for 1110 the incoming data, so that the same mask is used for every batch input. 1111 """ 1112 1113 def __init__(self, *args, **kwargs): 1114 self._create_non_trackable_mask_cache() 1115 super(DropoutRNNCellMixin, self).__init__(*args, **kwargs) 1116 1117 @trackable.no_automatic_dependency_tracking 1118 def _create_non_trackable_mask_cache(self): 1119 """Create the cache for dropout and recurrent dropout mask. 1120 1121 Note that the following two masks will be used in "graph function" mode, 1122 e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask` 1123 tensors will be generated differently than in the "graph function" case, 1124 and they will be cached. 1125 1126 Also note that in graph mode, we still cache those masks only because the 1127 RNN could be created with `unroll=True`. In that case, the `cell.call()` 1128 function will be invoked multiple times, and we want to ensure same mask 1129 is used every time. 1130 1131 Also the caches are created without tracking. Since they are not picklable 1132 by python when deepcopy, we don't want `layer._obj_reference_counts_dict` 1133 to track it by default. 1134 """ 1135 self._dropout_mask_cache = K.ContextValueCache(self._create_dropout_mask) 1136 self._recurrent_dropout_mask_cache = K.ContextValueCache( 1137 self._create_recurrent_dropout_mask) 1138 1139 def reset_dropout_mask(self): 1140 """Reset the cached dropout masks if any. 1141 1142 This is important for the RNN layer to invoke this in it `call()` method so 1143 that the cached mask is cleared before calling the `cell.call()`. The mask 1144 should be cached across the timestep within the same batch, but shouldn't 1145 be cached between batches. Otherwise it will introduce unreasonable bias 1146 against certain index of data within the batch. 1147 """ 1148 self._dropout_mask_cache.clear() 1149 1150 def reset_recurrent_dropout_mask(self): 1151 """Reset the cached recurrent dropout masks if any. 1152 1153 This is important for the RNN layer to invoke this in it call() method so 1154 that the cached mask is cleared before calling the cell.call(). The mask 1155 should be cached across the timestep within the same batch, but shouldn't 1156 be cached between batches. Otherwise it will introduce unreasonable bias 1157 against certain index of data within the batch. 1158 """ 1159 self._recurrent_dropout_mask_cache.clear() 1160 1161 def _create_dropout_mask(self, inputs, training, count=1): 1162 return _generate_dropout_mask( 1163 array_ops.ones_like(inputs), 1164 self.dropout, 1165 training=training, 1166 count=count) 1167 1168 def _create_recurrent_dropout_mask(self, inputs, training, count=1): 1169 return _generate_dropout_mask( 1170 array_ops.ones_like(inputs), 1171 self.recurrent_dropout, 1172 training=training, 1173 count=count) 1174 1175 def get_dropout_mask_for_cell(self, inputs, training, count=1): 1176 """Get the dropout mask for RNN cell's input. 1177 1178 It will create mask based on context if there isn't any existing cached 1179 mask. If a new mask is generated, it will update the cache in the cell. 1180 1181 Args: 1182 inputs: The input tensor whose shape will be used to generate dropout 1183 mask. 1184 training: Boolean tensor, whether its in training mode, dropout will be 1185 ignored in non-training mode. 1186 count: Int, how many dropout mask will be generated. It is useful for cell 1187 that has internal weights fused together. 1188 Returns: 1189 List of mask tensor, generated or cached mask based on context. 1190 """ 1191 if self.dropout == 0: 1192 return None 1193 init_kwargs = dict(inputs=inputs, training=training, count=count) 1194 return self._dropout_mask_cache.setdefault(kwargs=init_kwargs) 1195 1196 def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1): 1197 """Get the recurrent dropout mask for RNN cell. 1198 1199 It will create mask based on context if there isn't any existing cached 1200 mask. If a new mask is generated, it will update the cache in the cell. 1201 1202 Args: 1203 inputs: The input tensor whose shape will be used to generate dropout 1204 mask. 1205 training: Boolean tensor, whether its in training mode, dropout will be 1206 ignored in non-training mode. 1207 count: Int, how many dropout mask will be generated. It is useful for cell 1208 that has internal weights fused together. 1209 Returns: 1210 List of mask tensor, generated or cached mask based on context. 1211 """ 1212 if self.recurrent_dropout == 0: 1213 return None 1214 init_kwargs = dict(inputs=inputs, training=training, count=count) 1215 return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs) 1216 1217 def __getstate__(self): 1218 # Used for deepcopy. The caching can't be pickled by python, since it will 1219 # contain tensor and graph. 1220 state = super(DropoutRNNCellMixin, self).__getstate__() 1221 state.pop('_dropout_mask_cache', None) 1222 state.pop('_recurrent_dropout_mask_cache', None) 1223 return state 1224 1225 def __setstate__(self, state): 1226 state['_dropout_mask_cache'] = K.ContextValueCache( 1227 self._create_dropout_mask) 1228 state['_recurrent_dropout_mask_cache'] = K.ContextValueCache( 1229 self._create_recurrent_dropout_mask) 1230 super(DropoutRNNCellMixin, self).__setstate__(state) 1231 1232 1233@keras_export('keras.layers.SimpleRNNCell') 1234class SimpleRNNCell(DropoutRNNCellMixin, Layer): 1235 """Cell class for SimpleRNN. 1236 1237 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 1238 for details about the usage of RNN API. 1239 1240 This class processes one step within the whole time sequence input, whereas 1241 `tf.keras.layer.SimpleRNN` processes the whole sequence. 1242 1243 Args: 1244 units: Positive integer, dimensionality of the output space. 1245 activation: Activation function to use. 1246 Default: hyperbolic tangent (`tanh`). 1247 If you pass `None`, no activation is applied 1248 (ie. "linear" activation: `a(x) = x`). 1249 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 1250 kernel_initializer: Initializer for the `kernel` weights matrix, 1251 used for the linear transformation of the inputs. Default: 1252 `glorot_uniform`. 1253 recurrent_initializer: Initializer for the `recurrent_kernel` 1254 weights matrix, used for the linear transformation of the recurrent state. 1255 Default: `orthogonal`. 1256 bias_initializer: Initializer for the bias vector. Default: `zeros`. 1257 kernel_regularizer: Regularizer function applied to the `kernel` weights 1258 matrix. Default: `None`. 1259 recurrent_regularizer: Regularizer function applied to the 1260 `recurrent_kernel` weights matrix. Default: `None`. 1261 bias_regularizer: Regularizer function applied to the bias vector. Default: 1262 `None`. 1263 kernel_constraint: Constraint function applied to the `kernel` weights 1264 matrix. Default: `None`. 1265 recurrent_constraint: Constraint function applied to the `recurrent_kernel` 1266 weights matrix. Default: `None`. 1267 bias_constraint: Constraint function applied to the bias vector. Default: 1268 `None`. 1269 dropout: Float between 0 and 1. Fraction of the units to drop for the linear 1270 transformation of the inputs. Default: 0. 1271 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for 1272 the linear transformation of the recurrent state. Default: 0. 1273 1274 Call arguments: 1275 inputs: A 2D tensor, with shape of `[batch, feature]`. 1276 states: A 2D tensor with shape of `[batch, units]`, which is the state from 1277 the previous time step. For timestep 0, the initial state provided by user 1278 will be feed to cell. 1279 training: Python boolean indicating whether the layer should behave in 1280 training mode or in inference mode. Only relevant when `dropout` or 1281 `recurrent_dropout` is used. 1282 1283 Examples: 1284 1285 ```python 1286 inputs = np.random.random([32, 10, 8]).astype(np.float32) 1287 rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4)) 1288 1289 output = rnn(inputs) # The output has shape `[32, 4]`. 1290 1291 rnn = tf.keras.layers.RNN( 1292 tf.keras.layers.SimpleRNNCell(4), 1293 return_sequences=True, 1294 return_state=True) 1295 1296 # whole_sequence_output has shape `[32, 10, 4]`. 1297 # final_state has shape `[32, 4]`. 1298 whole_sequence_output, final_state = rnn(inputs) 1299 ``` 1300 """ 1301 1302 def __init__(self, 1303 units, 1304 activation='tanh', 1305 use_bias=True, 1306 kernel_initializer='glorot_uniform', 1307 recurrent_initializer='orthogonal', 1308 bias_initializer='zeros', 1309 kernel_regularizer=None, 1310 recurrent_regularizer=None, 1311 bias_regularizer=None, 1312 kernel_constraint=None, 1313 recurrent_constraint=None, 1314 bias_constraint=None, 1315 dropout=0., 1316 recurrent_dropout=0., 1317 **kwargs): 1318 # By default use cached variable under v2 mode, see b/143699808. 1319 if ops.executing_eagerly_outside_functions(): 1320 self._enable_caching_device = kwargs.pop('enable_caching_device', True) 1321 else: 1322 self._enable_caching_device = kwargs.pop('enable_caching_device', False) 1323 super(SimpleRNNCell, self).__init__(**kwargs) 1324 self.units = units 1325 self.activation = activations.get(activation) 1326 self.use_bias = use_bias 1327 1328 self.kernel_initializer = initializers.get(kernel_initializer) 1329 self.recurrent_initializer = initializers.get(recurrent_initializer) 1330 self.bias_initializer = initializers.get(bias_initializer) 1331 1332 self.kernel_regularizer = regularizers.get(kernel_regularizer) 1333 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 1334 self.bias_regularizer = regularizers.get(bias_regularizer) 1335 1336 self.kernel_constraint = constraints.get(kernel_constraint) 1337 self.recurrent_constraint = constraints.get(recurrent_constraint) 1338 self.bias_constraint = constraints.get(bias_constraint) 1339 1340 self.dropout = min(1., max(0., dropout)) 1341 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 1342 self.state_size = self.units 1343 self.output_size = self.units 1344 1345 @tf_utils.shape_type_conversion 1346 def build(self, input_shape): 1347 default_caching_device = _caching_device(self) 1348 self.kernel = self.add_weight( 1349 shape=(input_shape[-1], self.units), 1350 name='kernel', 1351 initializer=self.kernel_initializer, 1352 regularizer=self.kernel_regularizer, 1353 constraint=self.kernel_constraint, 1354 caching_device=default_caching_device) 1355 self.recurrent_kernel = self.add_weight( 1356 shape=(self.units, self.units), 1357 name='recurrent_kernel', 1358 initializer=self.recurrent_initializer, 1359 regularizer=self.recurrent_regularizer, 1360 constraint=self.recurrent_constraint, 1361 caching_device=default_caching_device) 1362 if self.use_bias: 1363 self.bias = self.add_weight( 1364 shape=(self.units,), 1365 name='bias', 1366 initializer=self.bias_initializer, 1367 regularizer=self.bias_regularizer, 1368 constraint=self.bias_constraint, 1369 caching_device=default_caching_device) 1370 else: 1371 self.bias = None 1372 self.built = True 1373 1374 def call(self, inputs, states, training=None): 1375 prev_output = states[0] if nest.is_nested(states) else states 1376 dp_mask = self.get_dropout_mask_for_cell(inputs, training) 1377 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 1378 prev_output, training) 1379 1380 if dp_mask is not None: 1381 h = K.dot(inputs * dp_mask, self.kernel) 1382 else: 1383 h = K.dot(inputs, self.kernel) 1384 if self.bias is not None: 1385 h = K.bias_add(h, self.bias) 1386 1387 if rec_dp_mask is not None: 1388 prev_output = prev_output * rec_dp_mask 1389 output = h + K.dot(prev_output, self.recurrent_kernel) 1390 if self.activation is not None: 1391 output = self.activation(output) 1392 1393 new_state = [output] if nest.is_nested(states) else output 1394 return output, new_state 1395 1396 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 1397 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 1398 1399 def get_config(self): 1400 config = { 1401 'units': 1402 self.units, 1403 'activation': 1404 activations.serialize(self.activation), 1405 'use_bias': 1406 self.use_bias, 1407 'kernel_initializer': 1408 initializers.serialize(self.kernel_initializer), 1409 'recurrent_initializer': 1410 initializers.serialize(self.recurrent_initializer), 1411 'bias_initializer': 1412 initializers.serialize(self.bias_initializer), 1413 'kernel_regularizer': 1414 regularizers.serialize(self.kernel_regularizer), 1415 'recurrent_regularizer': 1416 regularizers.serialize(self.recurrent_regularizer), 1417 'bias_regularizer': 1418 regularizers.serialize(self.bias_regularizer), 1419 'kernel_constraint': 1420 constraints.serialize(self.kernel_constraint), 1421 'recurrent_constraint': 1422 constraints.serialize(self.recurrent_constraint), 1423 'bias_constraint': 1424 constraints.serialize(self.bias_constraint), 1425 'dropout': 1426 self.dropout, 1427 'recurrent_dropout': 1428 self.recurrent_dropout 1429 } 1430 config.update(_config_for_enable_caching_device(self)) 1431 base_config = super(SimpleRNNCell, self).get_config() 1432 return dict(list(base_config.items()) + list(config.items())) 1433 1434 1435@keras_export('keras.layers.SimpleRNN') 1436class SimpleRNN(RNN): 1437 """Fully-connected RNN where the output is to be fed back to input. 1438 1439 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 1440 for details about the usage of RNN API. 1441 1442 Args: 1443 units: Positive integer, dimensionality of the output space. 1444 activation: Activation function to use. 1445 Default: hyperbolic tangent (`tanh`). 1446 If you pass None, no activation is applied 1447 (ie. "linear" activation: `a(x) = x`). 1448 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 1449 kernel_initializer: Initializer for the `kernel` weights matrix, 1450 used for the linear transformation of the inputs. Default: 1451 `glorot_uniform`. 1452 recurrent_initializer: Initializer for the `recurrent_kernel` 1453 weights matrix, used for the linear transformation of the recurrent state. 1454 Default: `orthogonal`. 1455 bias_initializer: Initializer for the bias vector. Default: `zeros`. 1456 kernel_regularizer: Regularizer function applied to the `kernel` weights 1457 matrix. Default: `None`. 1458 recurrent_regularizer: Regularizer function applied to the 1459 `recurrent_kernel` weights matrix. Default: `None`. 1460 bias_regularizer: Regularizer function applied to the bias vector. Default: 1461 `None`. 1462 activity_regularizer: Regularizer function applied to the output of the 1463 layer (its "activation"). Default: `None`. 1464 kernel_constraint: Constraint function applied to the `kernel` weights 1465 matrix. Default: `None`. 1466 recurrent_constraint: Constraint function applied to the `recurrent_kernel` 1467 weights matrix. Default: `None`. 1468 bias_constraint: Constraint function applied to the bias vector. Default: 1469 `None`. 1470 dropout: Float between 0 and 1. 1471 Fraction of the units to drop for the linear transformation of the inputs. 1472 Default: 0. 1473 recurrent_dropout: Float between 0 and 1. 1474 Fraction of the units to drop for the linear transformation of the 1475 recurrent state. Default: 0. 1476 return_sequences: Boolean. Whether to return the last output 1477 in the output sequence, or the full sequence. Default: `False`. 1478 return_state: Boolean. Whether to return the last state 1479 in addition to the output. Default: `False` 1480 go_backwards: Boolean (default False). 1481 If True, process the input sequence backwards and return the 1482 reversed sequence. 1483 stateful: Boolean (default False). If True, the last state 1484 for each sample at index i in a batch will be used as initial 1485 state for the sample of index i in the following batch. 1486 unroll: Boolean (default False). 1487 If True, the network will be unrolled, 1488 else a symbolic loop will be used. 1489 Unrolling can speed-up a RNN, 1490 although it tends to be more memory-intensive. 1491 Unrolling is only suitable for short sequences. 1492 1493 Call arguments: 1494 inputs: A 3D tensor, with shape `[batch, timesteps, feature]`. 1495 mask: Binary tensor of shape `[batch, timesteps]` indicating whether 1496 a given timestep should be masked. An individual `True` entry indicates 1497 that the corresponding timestep should be utilized, while a `False` entry 1498 indicates that the corresponding timestep should be ignored. 1499 training: Python boolean indicating whether the layer should behave in 1500 training mode or in inference mode. This argument is passed to the cell 1501 when calling it. This is only relevant if `dropout` or 1502 `recurrent_dropout` is used. 1503 initial_state: List of initial state tensors to be passed to the first 1504 call of the cell. 1505 1506 Examples: 1507 1508 ```python 1509 inputs = np.random.random([32, 10, 8]).astype(np.float32) 1510 simple_rnn = tf.keras.layers.SimpleRNN(4) 1511 1512 output = simple_rnn(inputs) # The output has shape `[32, 4]`. 1513 1514 simple_rnn = tf.keras.layers.SimpleRNN( 1515 4, return_sequences=True, return_state=True) 1516 1517 # whole_sequence_output has shape `[32, 10, 4]`. 1518 # final_state has shape `[32, 4]`. 1519 whole_sequence_output, final_state = simple_rnn(inputs) 1520 ``` 1521 """ 1522 1523 def __init__(self, 1524 units, 1525 activation='tanh', 1526 use_bias=True, 1527 kernel_initializer='glorot_uniform', 1528 recurrent_initializer='orthogonal', 1529 bias_initializer='zeros', 1530 kernel_regularizer=None, 1531 recurrent_regularizer=None, 1532 bias_regularizer=None, 1533 activity_regularizer=None, 1534 kernel_constraint=None, 1535 recurrent_constraint=None, 1536 bias_constraint=None, 1537 dropout=0., 1538 recurrent_dropout=0., 1539 return_sequences=False, 1540 return_state=False, 1541 go_backwards=False, 1542 stateful=False, 1543 unroll=False, 1544 **kwargs): 1545 if 'implementation' in kwargs: 1546 kwargs.pop('implementation') 1547 logging.warning('The `implementation` argument ' 1548 'in `SimpleRNN` has been deprecated. ' 1549 'Please remove it from your layer call.') 1550 if 'enable_caching_device' in kwargs: 1551 cell_kwargs = {'enable_caching_device': 1552 kwargs.pop('enable_caching_device')} 1553 else: 1554 cell_kwargs = {} 1555 cell = SimpleRNNCell( 1556 units, 1557 activation=activation, 1558 use_bias=use_bias, 1559 kernel_initializer=kernel_initializer, 1560 recurrent_initializer=recurrent_initializer, 1561 bias_initializer=bias_initializer, 1562 kernel_regularizer=kernel_regularizer, 1563 recurrent_regularizer=recurrent_regularizer, 1564 bias_regularizer=bias_regularizer, 1565 kernel_constraint=kernel_constraint, 1566 recurrent_constraint=recurrent_constraint, 1567 bias_constraint=bias_constraint, 1568 dropout=dropout, 1569 recurrent_dropout=recurrent_dropout, 1570 dtype=kwargs.get('dtype'), 1571 trainable=kwargs.get('trainable', True), 1572 **cell_kwargs) 1573 super(SimpleRNN, self).__init__( 1574 cell, 1575 return_sequences=return_sequences, 1576 return_state=return_state, 1577 go_backwards=go_backwards, 1578 stateful=stateful, 1579 unroll=unroll, 1580 **kwargs) 1581 self.activity_regularizer = regularizers.get(activity_regularizer) 1582 self.input_spec = [InputSpec(ndim=3)] 1583 1584 def call(self, inputs, mask=None, training=None, initial_state=None): 1585 return super(SimpleRNN, self).call( 1586 inputs, mask=mask, training=training, initial_state=initial_state) 1587 1588 @property 1589 def units(self): 1590 return self.cell.units 1591 1592 @property 1593 def activation(self): 1594 return self.cell.activation 1595 1596 @property 1597 def use_bias(self): 1598 return self.cell.use_bias 1599 1600 @property 1601 def kernel_initializer(self): 1602 return self.cell.kernel_initializer 1603 1604 @property 1605 def recurrent_initializer(self): 1606 return self.cell.recurrent_initializer 1607 1608 @property 1609 def bias_initializer(self): 1610 return self.cell.bias_initializer 1611 1612 @property 1613 def kernel_regularizer(self): 1614 return self.cell.kernel_regularizer 1615 1616 @property 1617 def recurrent_regularizer(self): 1618 return self.cell.recurrent_regularizer 1619 1620 @property 1621 def bias_regularizer(self): 1622 return self.cell.bias_regularizer 1623 1624 @property 1625 def kernel_constraint(self): 1626 return self.cell.kernel_constraint 1627 1628 @property 1629 def recurrent_constraint(self): 1630 return self.cell.recurrent_constraint 1631 1632 @property 1633 def bias_constraint(self): 1634 return self.cell.bias_constraint 1635 1636 @property 1637 def dropout(self): 1638 return self.cell.dropout 1639 1640 @property 1641 def recurrent_dropout(self): 1642 return self.cell.recurrent_dropout 1643 1644 def get_config(self): 1645 config = { 1646 'units': 1647 self.units, 1648 'activation': 1649 activations.serialize(self.activation), 1650 'use_bias': 1651 self.use_bias, 1652 'kernel_initializer': 1653 initializers.serialize(self.kernel_initializer), 1654 'recurrent_initializer': 1655 initializers.serialize(self.recurrent_initializer), 1656 'bias_initializer': 1657 initializers.serialize(self.bias_initializer), 1658 'kernel_regularizer': 1659 regularizers.serialize(self.kernel_regularizer), 1660 'recurrent_regularizer': 1661 regularizers.serialize(self.recurrent_regularizer), 1662 'bias_regularizer': 1663 regularizers.serialize(self.bias_regularizer), 1664 'activity_regularizer': 1665 regularizers.serialize(self.activity_regularizer), 1666 'kernel_constraint': 1667 constraints.serialize(self.kernel_constraint), 1668 'recurrent_constraint': 1669 constraints.serialize(self.recurrent_constraint), 1670 'bias_constraint': 1671 constraints.serialize(self.bias_constraint), 1672 'dropout': 1673 self.dropout, 1674 'recurrent_dropout': 1675 self.recurrent_dropout 1676 } 1677 base_config = super(SimpleRNN, self).get_config() 1678 config.update(_config_for_enable_caching_device(self.cell)) 1679 del base_config['cell'] 1680 return dict(list(base_config.items()) + list(config.items())) 1681 1682 @classmethod 1683 def from_config(cls, config): 1684 if 'implementation' in config: 1685 config.pop('implementation') 1686 return cls(**config) 1687 1688 1689@keras_export(v1=['keras.layers.GRUCell']) 1690class GRUCell(DropoutRNNCellMixin, Layer): 1691 """Cell class for the GRU layer. 1692 1693 Args: 1694 units: Positive integer, dimensionality of the output space. 1695 activation: Activation function to use. 1696 Default: hyperbolic tangent (`tanh`). 1697 If you pass None, no activation is applied 1698 (ie. "linear" activation: `a(x) = x`). 1699 recurrent_activation: Activation function to use 1700 for the recurrent step. 1701 Default: hard sigmoid (`hard_sigmoid`). 1702 If you pass `None`, no activation is applied 1703 (ie. "linear" activation: `a(x) = x`). 1704 use_bias: Boolean, whether the layer uses a bias vector. 1705 kernel_initializer: Initializer for the `kernel` weights matrix, 1706 used for the linear transformation of the inputs. 1707 recurrent_initializer: Initializer for the `recurrent_kernel` 1708 weights matrix, 1709 used for the linear transformation of the recurrent state. 1710 bias_initializer: Initializer for the bias vector. 1711 kernel_regularizer: Regularizer function applied to 1712 the `kernel` weights matrix. 1713 recurrent_regularizer: Regularizer function applied to 1714 the `recurrent_kernel` weights matrix. 1715 bias_regularizer: Regularizer function applied to the bias vector. 1716 kernel_constraint: Constraint function applied to 1717 the `kernel` weights matrix. 1718 recurrent_constraint: Constraint function applied to 1719 the `recurrent_kernel` weights matrix. 1720 bias_constraint: Constraint function applied to the bias vector. 1721 dropout: Float between 0 and 1. 1722 Fraction of the units to drop for the linear transformation of the inputs. 1723 recurrent_dropout: Float between 0 and 1. 1724 Fraction of the units to drop for 1725 the linear transformation of the recurrent state. 1726 reset_after: GRU convention (whether to apply reset gate after or 1727 before matrix multiplication). False = "before" (default), 1728 True = "after" (CuDNN compatible). 1729 1730 Call arguments: 1731 inputs: A 2D tensor. 1732 states: List of state tensors corresponding to the previous timestep. 1733 training: Python boolean indicating whether the layer should behave in 1734 training mode or in inference mode. Only relevant when `dropout` or 1735 `recurrent_dropout` is used. 1736 """ 1737 1738 def __init__(self, 1739 units, 1740 activation='tanh', 1741 recurrent_activation='hard_sigmoid', 1742 use_bias=True, 1743 kernel_initializer='glorot_uniform', 1744 recurrent_initializer='orthogonal', 1745 bias_initializer='zeros', 1746 kernel_regularizer=None, 1747 recurrent_regularizer=None, 1748 bias_regularizer=None, 1749 kernel_constraint=None, 1750 recurrent_constraint=None, 1751 bias_constraint=None, 1752 dropout=0., 1753 recurrent_dropout=0., 1754 reset_after=False, 1755 **kwargs): 1756 # By default use cached variable under v2 mode, see b/143699808. 1757 if ops.executing_eagerly_outside_functions(): 1758 self._enable_caching_device = kwargs.pop('enable_caching_device', True) 1759 else: 1760 self._enable_caching_device = kwargs.pop('enable_caching_device', False) 1761 super(GRUCell, self).__init__(**kwargs) 1762 self.units = units 1763 self.activation = activations.get(activation) 1764 self.recurrent_activation = activations.get(recurrent_activation) 1765 self.use_bias = use_bias 1766 1767 self.kernel_initializer = initializers.get(kernel_initializer) 1768 self.recurrent_initializer = initializers.get(recurrent_initializer) 1769 self.bias_initializer = initializers.get(bias_initializer) 1770 1771 self.kernel_regularizer = regularizers.get(kernel_regularizer) 1772 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 1773 self.bias_regularizer = regularizers.get(bias_regularizer) 1774 1775 self.kernel_constraint = constraints.get(kernel_constraint) 1776 self.recurrent_constraint = constraints.get(recurrent_constraint) 1777 self.bias_constraint = constraints.get(bias_constraint) 1778 1779 self.dropout = min(1., max(0., dropout)) 1780 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 1781 1782 implementation = kwargs.pop('implementation', 1) 1783 if self.recurrent_dropout != 0 and implementation != 1: 1784 logging.debug(RECURRENT_DROPOUT_WARNING_MSG) 1785 self.implementation = 1 1786 else: 1787 self.implementation = implementation 1788 self.reset_after = reset_after 1789 self.state_size = self.units 1790 self.output_size = self.units 1791 1792 @tf_utils.shape_type_conversion 1793 def build(self, input_shape): 1794 input_dim = input_shape[-1] 1795 default_caching_device = _caching_device(self) 1796 self.kernel = self.add_weight( 1797 shape=(input_dim, self.units * 3), 1798 name='kernel', 1799 initializer=self.kernel_initializer, 1800 regularizer=self.kernel_regularizer, 1801 constraint=self.kernel_constraint, 1802 caching_device=default_caching_device) 1803 self.recurrent_kernel = self.add_weight( 1804 shape=(self.units, self.units * 3), 1805 name='recurrent_kernel', 1806 initializer=self.recurrent_initializer, 1807 regularizer=self.recurrent_regularizer, 1808 constraint=self.recurrent_constraint, 1809 caching_device=default_caching_device) 1810 1811 if self.use_bias: 1812 if not self.reset_after: 1813 bias_shape = (3 * self.units,) 1814 else: 1815 # separate biases for input and recurrent kernels 1816 # Note: the shape is intentionally different from CuDNNGRU biases 1817 # `(2 * 3 * self.units,)`, so that we can distinguish the classes 1818 # when loading and converting saved weights. 1819 bias_shape = (2, 3 * self.units) 1820 self.bias = self.add_weight(shape=bias_shape, 1821 name='bias', 1822 initializer=self.bias_initializer, 1823 regularizer=self.bias_regularizer, 1824 constraint=self.bias_constraint, 1825 caching_device=default_caching_device) 1826 else: 1827 self.bias = None 1828 self.built = True 1829 1830 def call(self, inputs, states, training=None): 1831 h_tm1 = states[0] if nest.is_nested(states) else states # previous memory 1832 1833 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) 1834 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 1835 h_tm1, training, count=3) 1836 1837 if self.use_bias: 1838 if not self.reset_after: 1839 input_bias, recurrent_bias = self.bias, None 1840 else: 1841 input_bias, recurrent_bias = array_ops.unstack(self.bias) 1842 1843 if self.implementation == 1: 1844 if 0. < self.dropout < 1.: 1845 inputs_z = inputs * dp_mask[0] 1846 inputs_r = inputs * dp_mask[1] 1847 inputs_h = inputs * dp_mask[2] 1848 else: 1849 inputs_z = inputs 1850 inputs_r = inputs 1851 inputs_h = inputs 1852 1853 x_z = K.dot(inputs_z, self.kernel[:, :self.units]) 1854 x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2]) 1855 x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:]) 1856 1857 if self.use_bias: 1858 x_z = K.bias_add(x_z, input_bias[:self.units]) 1859 x_r = K.bias_add(x_r, input_bias[self.units: self.units * 2]) 1860 x_h = K.bias_add(x_h, input_bias[self.units * 2:]) 1861 1862 if 0. < self.recurrent_dropout < 1.: 1863 h_tm1_z = h_tm1 * rec_dp_mask[0] 1864 h_tm1_r = h_tm1 * rec_dp_mask[1] 1865 h_tm1_h = h_tm1 * rec_dp_mask[2] 1866 else: 1867 h_tm1_z = h_tm1 1868 h_tm1_r = h_tm1 1869 h_tm1_h = h_tm1 1870 1871 recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]) 1872 recurrent_r = K.dot(h_tm1_r, 1873 self.recurrent_kernel[:, self.units:self.units * 2]) 1874 if self.reset_after and self.use_bias: 1875 recurrent_z = K.bias_add(recurrent_z, recurrent_bias[:self.units]) 1876 recurrent_r = K.bias_add(recurrent_r, 1877 recurrent_bias[self.units:self.units * 2]) 1878 1879 z = self.recurrent_activation(x_z + recurrent_z) 1880 r = self.recurrent_activation(x_r + recurrent_r) 1881 1882 # reset gate applied after/before matrix multiplication 1883 if self.reset_after: 1884 recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) 1885 if self.use_bias: 1886 recurrent_h = K.bias_add(recurrent_h, recurrent_bias[self.units * 2:]) 1887 recurrent_h = r * recurrent_h 1888 else: 1889 recurrent_h = K.dot(r * h_tm1_h, 1890 self.recurrent_kernel[:, self.units * 2:]) 1891 1892 hh = self.activation(x_h + recurrent_h) 1893 else: 1894 if 0. < self.dropout < 1.: 1895 inputs = inputs * dp_mask[0] 1896 1897 # inputs projected by all gate matrices at once 1898 matrix_x = K.dot(inputs, self.kernel) 1899 if self.use_bias: 1900 # biases: bias_z_i, bias_r_i, bias_h_i 1901 matrix_x = K.bias_add(matrix_x, input_bias) 1902 1903 x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=-1) 1904 1905 if self.reset_after: 1906 # hidden state projected by all gate matrices at once 1907 matrix_inner = K.dot(h_tm1, self.recurrent_kernel) 1908 if self.use_bias: 1909 matrix_inner = K.bias_add(matrix_inner, recurrent_bias) 1910 else: 1911 # hidden state projected separately for update/reset and new 1912 matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units]) 1913 1914 recurrent_z, recurrent_r, recurrent_h = array_ops.split( 1915 matrix_inner, [self.units, self.units, -1], axis=-1) 1916 1917 z = self.recurrent_activation(x_z + recurrent_z) 1918 r = self.recurrent_activation(x_r + recurrent_r) 1919 1920 if self.reset_after: 1921 recurrent_h = r * recurrent_h 1922 else: 1923 recurrent_h = K.dot(r * h_tm1, 1924 self.recurrent_kernel[:, 2 * self.units:]) 1925 1926 hh = self.activation(x_h + recurrent_h) 1927 # previous and candidate state mixed by update gate 1928 h = z * h_tm1 + (1 - z) * hh 1929 new_state = [h] if nest.is_nested(states) else h 1930 return h, new_state 1931 1932 def get_config(self): 1933 config = { 1934 'units': self.units, 1935 'activation': activations.serialize(self.activation), 1936 'recurrent_activation': 1937 activations.serialize(self.recurrent_activation), 1938 'use_bias': self.use_bias, 1939 'kernel_initializer': initializers.serialize(self.kernel_initializer), 1940 'recurrent_initializer': 1941 initializers.serialize(self.recurrent_initializer), 1942 'bias_initializer': initializers.serialize(self.bias_initializer), 1943 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 1944 'recurrent_regularizer': 1945 regularizers.serialize(self.recurrent_regularizer), 1946 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 1947 'kernel_constraint': constraints.serialize(self.kernel_constraint), 1948 'recurrent_constraint': 1949 constraints.serialize(self.recurrent_constraint), 1950 'bias_constraint': constraints.serialize(self.bias_constraint), 1951 'dropout': self.dropout, 1952 'recurrent_dropout': self.recurrent_dropout, 1953 'implementation': self.implementation, 1954 'reset_after': self.reset_after 1955 } 1956 config.update(_config_for_enable_caching_device(self)) 1957 base_config = super(GRUCell, self).get_config() 1958 return dict(list(base_config.items()) + list(config.items())) 1959 1960 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 1961 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 1962 1963 1964@keras_export(v1=['keras.layers.GRU']) 1965class GRU(RNN): 1966 """Gated Recurrent Unit - Cho et al. 2014. 1967 1968 There are two variants. The default one is based on 1406.1078v3 and 1969 has reset gate applied to hidden state before matrix multiplication. The 1970 other one is based on original 1406.1078v1 and has the order reversed. 1971 1972 The second variant is compatible with CuDNNGRU (GPU-only) and allows 1973 inference on CPU. Thus it has separate biases for `kernel` and 1974 `recurrent_kernel`. Use `'reset_after'=True` and 1975 `recurrent_activation='sigmoid'`. 1976 1977 Args: 1978 units: Positive integer, dimensionality of the output space. 1979 activation: Activation function to use. 1980 Default: hyperbolic tangent (`tanh`). 1981 If you pass `None`, no activation is applied 1982 (ie. "linear" activation: `a(x) = x`). 1983 recurrent_activation: Activation function to use 1984 for the recurrent step. 1985 Default: hard sigmoid (`hard_sigmoid`). 1986 If you pass `None`, no activation is applied 1987 (ie. "linear" activation: `a(x) = x`). 1988 use_bias: Boolean, whether the layer uses a bias vector. 1989 kernel_initializer: Initializer for the `kernel` weights matrix, 1990 used for the linear transformation of the inputs. 1991 recurrent_initializer: Initializer for the `recurrent_kernel` 1992 weights matrix, used for the linear transformation of the recurrent state. 1993 bias_initializer: Initializer for the bias vector. 1994 kernel_regularizer: Regularizer function applied to 1995 the `kernel` weights matrix. 1996 recurrent_regularizer: Regularizer function applied to 1997 the `recurrent_kernel` weights matrix. 1998 bias_regularizer: Regularizer function applied to the bias vector. 1999 activity_regularizer: Regularizer function applied to 2000 the output of the layer (its "activation").. 2001 kernel_constraint: Constraint function applied to 2002 the `kernel` weights matrix. 2003 recurrent_constraint: Constraint function applied to 2004 the `recurrent_kernel` weights matrix. 2005 bias_constraint: Constraint function applied to the bias vector. 2006 dropout: Float between 0 and 1. 2007 Fraction of the units to drop for 2008 the linear transformation of the inputs. 2009 recurrent_dropout: Float between 0 and 1. 2010 Fraction of the units to drop for 2011 the linear transformation of the recurrent state. 2012 return_sequences: Boolean. Whether to return the last output 2013 in the output sequence, or the full sequence. 2014 return_state: Boolean. Whether to return the last state 2015 in addition to the output. 2016 go_backwards: Boolean (default False). 2017 If True, process the input sequence backwards and return the 2018 reversed sequence. 2019 stateful: Boolean (default False). If True, the last state 2020 for each sample at index i in a batch will be used as initial 2021 state for the sample of index i in the following batch. 2022 unroll: Boolean (default False). 2023 If True, the network will be unrolled, 2024 else a symbolic loop will be used. 2025 Unrolling can speed-up a RNN, 2026 although it tends to be more memory-intensive. 2027 Unrolling is only suitable for short sequences. 2028 time_major: The shape format of the `inputs` and `outputs` tensors. 2029 If True, the inputs and outputs will be in shape 2030 `(timesteps, batch, ...)`, whereas in the False case, it will be 2031 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 2032 efficient because it avoids transposes at the beginning and end of the 2033 RNN calculation. However, most TensorFlow data is batch-major, so by 2034 default this function accepts input and emits output in batch-major 2035 form. 2036 reset_after: GRU convention (whether to apply reset gate after or 2037 before matrix multiplication). False = "before" (default), 2038 True = "after" (CuDNN compatible). 2039 2040 Call arguments: 2041 inputs: A 3D tensor. 2042 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 2043 a given timestep should be masked. An individual `True` entry indicates 2044 that the corresponding timestep should be utilized, while a `False` 2045 entry indicates that the corresponding timestep should be ignored. 2046 training: Python boolean indicating whether the layer should behave in 2047 training mode or in inference mode. This argument is passed to the cell 2048 when calling it. This is only relevant if `dropout` or 2049 `recurrent_dropout` is used. 2050 initial_state: List of initial state tensors to be passed to the first 2051 call of the cell. 2052 """ 2053 2054 def __init__(self, 2055 units, 2056 activation='tanh', 2057 recurrent_activation='hard_sigmoid', 2058 use_bias=True, 2059 kernel_initializer='glorot_uniform', 2060 recurrent_initializer='orthogonal', 2061 bias_initializer='zeros', 2062 kernel_regularizer=None, 2063 recurrent_regularizer=None, 2064 bias_regularizer=None, 2065 activity_regularizer=None, 2066 kernel_constraint=None, 2067 recurrent_constraint=None, 2068 bias_constraint=None, 2069 dropout=0., 2070 recurrent_dropout=0., 2071 return_sequences=False, 2072 return_state=False, 2073 go_backwards=False, 2074 stateful=False, 2075 unroll=False, 2076 reset_after=False, 2077 **kwargs): 2078 implementation = kwargs.pop('implementation', 1) 2079 if implementation == 0: 2080 logging.warning('`implementation=0` has been deprecated, ' 2081 'and now defaults to `implementation=1`.' 2082 'Please update your layer call.') 2083 if 'enable_caching_device' in kwargs: 2084 cell_kwargs = {'enable_caching_device': 2085 kwargs.pop('enable_caching_device')} 2086 else: 2087 cell_kwargs = {} 2088 cell = GRUCell( 2089 units, 2090 activation=activation, 2091 recurrent_activation=recurrent_activation, 2092 use_bias=use_bias, 2093 kernel_initializer=kernel_initializer, 2094 recurrent_initializer=recurrent_initializer, 2095 bias_initializer=bias_initializer, 2096 kernel_regularizer=kernel_regularizer, 2097 recurrent_regularizer=recurrent_regularizer, 2098 bias_regularizer=bias_regularizer, 2099 kernel_constraint=kernel_constraint, 2100 recurrent_constraint=recurrent_constraint, 2101 bias_constraint=bias_constraint, 2102 dropout=dropout, 2103 recurrent_dropout=recurrent_dropout, 2104 implementation=implementation, 2105 reset_after=reset_after, 2106 dtype=kwargs.get('dtype'), 2107 trainable=kwargs.get('trainable', True), 2108 **cell_kwargs) 2109 super(GRU, self).__init__( 2110 cell, 2111 return_sequences=return_sequences, 2112 return_state=return_state, 2113 go_backwards=go_backwards, 2114 stateful=stateful, 2115 unroll=unroll, 2116 **kwargs) 2117 self.activity_regularizer = regularizers.get(activity_regularizer) 2118 self.input_spec = [InputSpec(ndim=3)] 2119 2120 def call(self, inputs, mask=None, training=None, initial_state=None): 2121 return super(GRU, self).call( 2122 inputs, mask=mask, training=training, initial_state=initial_state) 2123 2124 @property 2125 def units(self): 2126 return self.cell.units 2127 2128 @property 2129 def activation(self): 2130 return self.cell.activation 2131 2132 @property 2133 def recurrent_activation(self): 2134 return self.cell.recurrent_activation 2135 2136 @property 2137 def use_bias(self): 2138 return self.cell.use_bias 2139 2140 @property 2141 def kernel_initializer(self): 2142 return self.cell.kernel_initializer 2143 2144 @property 2145 def recurrent_initializer(self): 2146 return self.cell.recurrent_initializer 2147 2148 @property 2149 def bias_initializer(self): 2150 return self.cell.bias_initializer 2151 2152 @property 2153 def kernel_regularizer(self): 2154 return self.cell.kernel_regularizer 2155 2156 @property 2157 def recurrent_regularizer(self): 2158 return self.cell.recurrent_regularizer 2159 2160 @property 2161 def bias_regularizer(self): 2162 return self.cell.bias_regularizer 2163 2164 @property 2165 def kernel_constraint(self): 2166 return self.cell.kernel_constraint 2167 2168 @property 2169 def recurrent_constraint(self): 2170 return self.cell.recurrent_constraint 2171 2172 @property 2173 def bias_constraint(self): 2174 return self.cell.bias_constraint 2175 2176 @property 2177 def dropout(self): 2178 return self.cell.dropout 2179 2180 @property 2181 def recurrent_dropout(self): 2182 return self.cell.recurrent_dropout 2183 2184 @property 2185 def implementation(self): 2186 return self.cell.implementation 2187 2188 @property 2189 def reset_after(self): 2190 return self.cell.reset_after 2191 2192 def get_config(self): 2193 config = { 2194 'units': 2195 self.units, 2196 'activation': 2197 activations.serialize(self.activation), 2198 'recurrent_activation': 2199 activations.serialize(self.recurrent_activation), 2200 'use_bias': 2201 self.use_bias, 2202 'kernel_initializer': 2203 initializers.serialize(self.kernel_initializer), 2204 'recurrent_initializer': 2205 initializers.serialize(self.recurrent_initializer), 2206 'bias_initializer': 2207 initializers.serialize(self.bias_initializer), 2208 'kernel_regularizer': 2209 regularizers.serialize(self.kernel_regularizer), 2210 'recurrent_regularizer': 2211 regularizers.serialize(self.recurrent_regularizer), 2212 'bias_regularizer': 2213 regularizers.serialize(self.bias_regularizer), 2214 'activity_regularizer': 2215 regularizers.serialize(self.activity_regularizer), 2216 'kernel_constraint': 2217 constraints.serialize(self.kernel_constraint), 2218 'recurrent_constraint': 2219 constraints.serialize(self.recurrent_constraint), 2220 'bias_constraint': 2221 constraints.serialize(self.bias_constraint), 2222 'dropout': 2223 self.dropout, 2224 'recurrent_dropout': 2225 self.recurrent_dropout, 2226 'implementation': 2227 self.implementation, 2228 'reset_after': 2229 self.reset_after 2230 } 2231 config.update(_config_for_enable_caching_device(self.cell)) 2232 base_config = super(GRU, self).get_config() 2233 del base_config['cell'] 2234 return dict(list(base_config.items()) + list(config.items())) 2235 2236 @classmethod 2237 def from_config(cls, config): 2238 if 'implementation' in config and config['implementation'] == 0: 2239 config['implementation'] = 1 2240 return cls(**config) 2241 2242 2243@keras_export(v1=['keras.layers.LSTMCell']) 2244class LSTMCell(DropoutRNNCellMixin, Layer): 2245 """Cell class for the LSTM layer. 2246 2247 Args: 2248 units: Positive integer, dimensionality of the output space. 2249 activation: Activation function to use. 2250 Default: hyperbolic tangent (`tanh`). 2251 If you pass `None`, no activation is applied 2252 (ie. "linear" activation: `a(x) = x`). 2253 recurrent_activation: Activation function to use 2254 for the recurrent step. 2255 Default: hard sigmoid (`hard_sigmoid`). 2256 If you pass `None`, no activation is applied 2257 (ie. "linear" activation: `a(x) = x`). 2258 use_bias: Boolean, whether the layer uses a bias vector. 2259 kernel_initializer: Initializer for the `kernel` weights matrix, 2260 used for the linear transformation of the inputs. 2261 recurrent_initializer: Initializer for the `recurrent_kernel` 2262 weights matrix, 2263 used for the linear transformation of the recurrent state. 2264 bias_initializer: Initializer for the bias vector. 2265 unit_forget_bias: Boolean. 2266 If True, add 1 to the bias of the forget gate at initialization. 2267 Setting it to true will also force `bias_initializer="zeros"`. 2268 This is recommended in [Jozefowicz et al., 2015]( 2269 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 2270 kernel_regularizer: Regularizer function applied to 2271 the `kernel` weights matrix. 2272 recurrent_regularizer: Regularizer function applied to 2273 the `recurrent_kernel` weights matrix. 2274 bias_regularizer: Regularizer function applied to the bias vector. 2275 kernel_constraint: Constraint function applied to 2276 the `kernel` weights matrix. 2277 recurrent_constraint: Constraint function applied to 2278 the `recurrent_kernel` weights matrix. 2279 bias_constraint: Constraint function applied to the bias vector. 2280 dropout: Float between 0 and 1. 2281 Fraction of the units to drop for 2282 the linear transformation of the inputs. 2283 recurrent_dropout: Float between 0 and 1. 2284 Fraction of the units to drop for 2285 the linear transformation of the recurrent state. 2286 2287 Call arguments: 2288 inputs: A 2D tensor. 2289 states: List of state tensors corresponding to the previous timestep. 2290 training: Python boolean indicating whether the layer should behave in 2291 training mode or in inference mode. Only relevant when `dropout` or 2292 `recurrent_dropout` is used. 2293 """ 2294 2295 def __init__(self, 2296 units, 2297 activation='tanh', 2298 recurrent_activation='hard_sigmoid', 2299 use_bias=True, 2300 kernel_initializer='glorot_uniform', 2301 recurrent_initializer='orthogonal', 2302 bias_initializer='zeros', 2303 unit_forget_bias=True, 2304 kernel_regularizer=None, 2305 recurrent_regularizer=None, 2306 bias_regularizer=None, 2307 kernel_constraint=None, 2308 recurrent_constraint=None, 2309 bias_constraint=None, 2310 dropout=0., 2311 recurrent_dropout=0., 2312 **kwargs): 2313 # By default use cached variable under v2 mode, see b/143699808. 2314 if ops.executing_eagerly_outside_functions(): 2315 self._enable_caching_device = kwargs.pop('enable_caching_device', True) 2316 else: 2317 self._enable_caching_device = kwargs.pop('enable_caching_device', False) 2318 super(LSTMCell, self).__init__(**kwargs) 2319 self.units = units 2320 self.activation = activations.get(activation) 2321 self.recurrent_activation = activations.get(recurrent_activation) 2322 self.use_bias = use_bias 2323 2324 self.kernel_initializer = initializers.get(kernel_initializer) 2325 self.recurrent_initializer = initializers.get(recurrent_initializer) 2326 self.bias_initializer = initializers.get(bias_initializer) 2327 self.unit_forget_bias = unit_forget_bias 2328 2329 self.kernel_regularizer = regularizers.get(kernel_regularizer) 2330 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 2331 self.bias_regularizer = regularizers.get(bias_regularizer) 2332 2333 self.kernel_constraint = constraints.get(kernel_constraint) 2334 self.recurrent_constraint = constraints.get(recurrent_constraint) 2335 self.bias_constraint = constraints.get(bias_constraint) 2336 2337 self.dropout = min(1., max(0., dropout)) 2338 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 2339 implementation = kwargs.pop('implementation', 1) 2340 if self.recurrent_dropout != 0 and implementation != 1: 2341 logging.debug(RECURRENT_DROPOUT_WARNING_MSG) 2342 self.implementation = 1 2343 else: 2344 self.implementation = implementation 2345 # tuple(_ListWrapper) was silently dropping list content in at least 2.7.10, 2346 # and fixed after 2.7.16. Converting the state_size to wrapper around 2347 # NoDependency(), so that the base_layer.__setattr__ will not convert it to 2348 # ListWrapper. Down the stream, self.states will be a list since it is 2349 # generated from nest.map_structure with list, and tuple(list) will work 2350 # properly. 2351 self.state_size = data_structures.NoDependency([self.units, self.units]) 2352 self.output_size = self.units 2353 2354 @tf_utils.shape_type_conversion 2355 def build(self, input_shape): 2356 default_caching_device = _caching_device(self) 2357 input_dim = input_shape[-1] 2358 self.kernel = self.add_weight( 2359 shape=(input_dim, self.units * 4), 2360 name='kernel', 2361 initializer=self.kernel_initializer, 2362 regularizer=self.kernel_regularizer, 2363 constraint=self.kernel_constraint, 2364 caching_device=default_caching_device) 2365 self.recurrent_kernel = self.add_weight( 2366 shape=(self.units, self.units * 4), 2367 name='recurrent_kernel', 2368 initializer=self.recurrent_initializer, 2369 regularizer=self.recurrent_regularizer, 2370 constraint=self.recurrent_constraint, 2371 caching_device=default_caching_device) 2372 2373 if self.use_bias: 2374 if self.unit_forget_bias: 2375 2376 def bias_initializer(_, *args, **kwargs): 2377 return K.concatenate([ 2378 self.bias_initializer((self.units,), *args, **kwargs), 2379 initializers.get('ones')((self.units,), *args, **kwargs), 2380 self.bias_initializer((self.units * 2,), *args, **kwargs), 2381 ]) 2382 else: 2383 bias_initializer = self.bias_initializer 2384 self.bias = self.add_weight( 2385 shape=(self.units * 4,), 2386 name='bias', 2387 initializer=bias_initializer, 2388 regularizer=self.bias_regularizer, 2389 constraint=self.bias_constraint, 2390 caching_device=default_caching_device) 2391 else: 2392 self.bias = None 2393 self.built = True 2394 2395 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 2396 """Computes carry and output using split kernels.""" 2397 x_i, x_f, x_c, x_o = x 2398 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 2399 i = self.recurrent_activation( 2400 x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) 2401 f = self.recurrent_activation(x_f + K.dot( 2402 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])) 2403 c = f * c_tm1 + i * self.activation(x_c + K.dot( 2404 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) 2405 o = self.recurrent_activation( 2406 x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) 2407 return c, o 2408 2409 def _compute_carry_and_output_fused(self, z, c_tm1): 2410 """Computes carry and output using fused kernels.""" 2411 z0, z1, z2, z3 = z 2412 i = self.recurrent_activation(z0) 2413 f = self.recurrent_activation(z1) 2414 c = f * c_tm1 + i * self.activation(z2) 2415 o = self.recurrent_activation(z3) 2416 return c, o 2417 2418 def call(self, inputs, states, training=None): 2419 h_tm1 = states[0] # previous memory state 2420 c_tm1 = states[1] # previous carry state 2421 2422 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 2423 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 2424 h_tm1, training, count=4) 2425 2426 if self.implementation == 1: 2427 if 0 < self.dropout < 1.: 2428 inputs_i = inputs * dp_mask[0] 2429 inputs_f = inputs * dp_mask[1] 2430 inputs_c = inputs * dp_mask[2] 2431 inputs_o = inputs * dp_mask[3] 2432 else: 2433 inputs_i = inputs 2434 inputs_f = inputs 2435 inputs_c = inputs 2436 inputs_o = inputs 2437 k_i, k_f, k_c, k_o = array_ops.split( 2438 self.kernel, num_or_size_splits=4, axis=1) 2439 x_i = K.dot(inputs_i, k_i) 2440 x_f = K.dot(inputs_f, k_f) 2441 x_c = K.dot(inputs_c, k_c) 2442 x_o = K.dot(inputs_o, k_o) 2443 if self.use_bias: 2444 b_i, b_f, b_c, b_o = array_ops.split( 2445 self.bias, num_or_size_splits=4, axis=0) 2446 x_i = K.bias_add(x_i, b_i) 2447 x_f = K.bias_add(x_f, b_f) 2448 x_c = K.bias_add(x_c, b_c) 2449 x_o = K.bias_add(x_o, b_o) 2450 2451 if 0 < self.recurrent_dropout < 1.: 2452 h_tm1_i = h_tm1 * rec_dp_mask[0] 2453 h_tm1_f = h_tm1 * rec_dp_mask[1] 2454 h_tm1_c = h_tm1 * rec_dp_mask[2] 2455 h_tm1_o = h_tm1 * rec_dp_mask[3] 2456 else: 2457 h_tm1_i = h_tm1 2458 h_tm1_f = h_tm1 2459 h_tm1_c = h_tm1 2460 h_tm1_o = h_tm1 2461 x = (x_i, x_f, x_c, x_o) 2462 h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) 2463 c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) 2464 else: 2465 if 0. < self.dropout < 1.: 2466 inputs = inputs * dp_mask[0] 2467 z = K.dot(inputs, self.kernel) 2468 z += K.dot(h_tm1, self.recurrent_kernel) 2469 if self.use_bias: 2470 z = K.bias_add(z, self.bias) 2471 2472 z = array_ops.split(z, num_or_size_splits=4, axis=1) 2473 c, o = self._compute_carry_and_output_fused(z, c_tm1) 2474 2475 h = o * self.activation(c) 2476 return h, [h, c] 2477 2478 def get_config(self): 2479 config = { 2480 'units': 2481 self.units, 2482 'activation': 2483 activations.serialize(self.activation), 2484 'recurrent_activation': 2485 activations.serialize(self.recurrent_activation), 2486 'use_bias': 2487 self.use_bias, 2488 'kernel_initializer': 2489 initializers.serialize(self.kernel_initializer), 2490 'recurrent_initializer': 2491 initializers.serialize(self.recurrent_initializer), 2492 'bias_initializer': 2493 initializers.serialize(self.bias_initializer), 2494 'unit_forget_bias': 2495 self.unit_forget_bias, 2496 'kernel_regularizer': 2497 regularizers.serialize(self.kernel_regularizer), 2498 'recurrent_regularizer': 2499 regularizers.serialize(self.recurrent_regularizer), 2500 'bias_regularizer': 2501 regularizers.serialize(self.bias_regularizer), 2502 'kernel_constraint': 2503 constraints.serialize(self.kernel_constraint), 2504 'recurrent_constraint': 2505 constraints.serialize(self.recurrent_constraint), 2506 'bias_constraint': 2507 constraints.serialize(self.bias_constraint), 2508 'dropout': 2509 self.dropout, 2510 'recurrent_dropout': 2511 self.recurrent_dropout, 2512 'implementation': 2513 self.implementation 2514 } 2515 config.update(_config_for_enable_caching_device(self)) 2516 base_config = super(LSTMCell, self).get_config() 2517 return dict(list(base_config.items()) + list(config.items())) 2518 2519 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 2520 return list(_generate_zero_filled_state_for_cell( 2521 self, inputs, batch_size, dtype)) 2522 2523 2524@keras_export('keras.experimental.PeepholeLSTMCell') 2525class PeepholeLSTMCell(LSTMCell): 2526 """Equivalent to LSTMCell class but adds peephole connections. 2527 2528 Peephole connections allow the gates to utilize the previous internal state as 2529 well as the previous hidden state (which is what LSTMCell is limited to). 2530 This allows PeepholeLSTMCell to better learn precise timings over LSTMCell. 2531 2532 From [Gers et al., 2002]( 2533 http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf): 2534 2535 "We find that LSTM augmented by 'peephole connections' from its internal 2536 cells to its multiplicative gates can learn the fine distinction between 2537 sequences of spikes spaced either 50 or 49 time steps apart without the help 2538 of any short training exemplars." 2539 2540 The peephole implementation is based on: 2541 2542 [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf) 2543 2544 Example: 2545 2546 ```python 2547 # Create 2 PeepholeLSTMCells 2548 peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]] 2549 # Create a layer composed sequentially of the peephole LSTM cells. 2550 layer = RNN(peephole_lstm_cells) 2551 input = keras.Input((timesteps, input_dim)) 2552 output = layer(input) 2553 ``` 2554 """ 2555 2556 def __init__(self, 2557 units, 2558 activation='tanh', 2559 recurrent_activation='hard_sigmoid', 2560 use_bias=True, 2561 kernel_initializer='glorot_uniform', 2562 recurrent_initializer='orthogonal', 2563 bias_initializer='zeros', 2564 unit_forget_bias=True, 2565 kernel_regularizer=None, 2566 recurrent_regularizer=None, 2567 bias_regularizer=None, 2568 kernel_constraint=None, 2569 recurrent_constraint=None, 2570 bias_constraint=None, 2571 dropout=0., 2572 recurrent_dropout=0., 2573 **kwargs): 2574 warnings.warn('`tf.keras.experimental.PeepholeLSTMCell` is deprecated ' 2575 'and will be removed in a future version. ' 2576 'Please use tensorflow_addons.rnn.PeepholeLSTMCell ' 2577 'instead.') 2578 super(PeepholeLSTMCell, self).__init__( 2579 units=units, 2580 activation=activation, 2581 recurrent_activation=recurrent_activation, 2582 use_bias=use_bias, 2583 kernel_initializer=kernel_initializer, 2584 recurrent_initializer=recurrent_initializer, 2585 bias_initializer=bias_initializer, 2586 unit_forget_bias=unit_forget_bias, 2587 kernel_regularizer=kernel_regularizer, 2588 recurrent_regularizer=recurrent_regularizer, 2589 bias_regularizer=bias_regularizer, 2590 kernel_constraint=kernel_constraint, 2591 recurrent_constraint=recurrent_constraint, 2592 bias_constraint=bias_constraint, 2593 dropout=dropout, 2594 recurrent_dropout=recurrent_dropout, 2595 implementation=kwargs.pop('implementation', 1), 2596 **kwargs) 2597 2598 def build(self, input_shape): 2599 super(PeepholeLSTMCell, self).build(input_shape) 2600 # The following are the weight matrices for the peephole connections. These 2601 # are multiplied with the previous internal state during the computation of 2602 # carry and output. 2603 self.input_gate_peephole_weights = self.add_weight( 2604 shape=(self.units,), 2605 name='input_gate_peephole_weights', 2606 initializer=self.kernel_initializer) 2607 self.forget_gate_peephole_weights = self.add_weight( 2608 shape=(self.units,), 2609 name='forget_gate_peephole_weights', 2610 initializer=self.kernel_initializer) 2611 self.output_gate_peephole_weights = self.add_weight( 2612 shape=(self.units,), 2613 name='output_gate_peephole_weights', 2614 initializer=self.kernel_initializer) 2615 2616 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 2617 x_i, x_f, x_c, x_o = x 2618 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 2619 i = self.recurrent_activation( 2620 x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) + 2621 self.input_gate_peephole_weights * c_tm1) 2622 f = self.recurrent_activation(x_f + K.dot( 2623 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) + 2624 self.forget_gate_peephole_weights * c_tm1) 2625 c = f * c_tm1 + i * self.activation(x_c + K.dot( 2626 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) 2627 o = self.recurrent_activation( 2628 x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) + 2629 self.output_gate_peephole_weights * c) 2630 return c, o 2631 2632 def _compute_carry_and_output_fused(self, z, c_tm1): 2633 z0, z1, z2, z3 = z 2634 i = self.recurrent_activation(z0 + 2635 self.input_gate_peephole_weights * c_tm1) 2636 f = self.recurrent_activation(z1 + 2637 self.forget_gate_peephole_weights * c_tm1) 2638 c = f * c_tm1 + i * self.activation(z2) 2639 o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c) 2640 return c, o 2641 2642 2643@keras_export(v1=['keras.layers.LSTM']) 2644class LSTM(RNN): 2645 """Long Short-Term Memory layer - Hochreiter 1997. 2646 2647 Note that this cell is not optimized for performance on GPU. Please use 2648 `tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU. 2649 2650 Args: 2651 units: Positive integer, dimensionality of the output space. 2652 activation: Activation function to use. 2653 Default: hyperbolic tangent (`tanh`). 2654 If you pass `None`, no activation is applied 2655 (ie. "linear" activation: `a(x) = x`). 2656 recurrent_activation: Activation function to use 2657 for the recurrent step. 2658 Default: hard sigmoid (`hard_sigmoid`). 2659 If you pass `None`, no activation is applied 2660 (ie. "linear" activation: `a(x) = x`). 2661 use_bias: Boolean, whether the layer uses a bias vector. 2662 kernel_initializer: Initializer for the `kernel` weights matrix, 2663 used for the linear transformation of the inputs.. 2664 recurrent_initializer: Initializer for the `recurrent_kernel` 2665 weights matrix, 2666 used for the linear transformation of the recurrent state. 2667 bias_initializer: Initializer for the bias vector. 2668 unit_forget_bias: Boolean. 2669 If True, add 1 to the bias of the forget gate at initialization. 2670 Setting it to true will also force `bias_initializer="zeros"`. 2671 This is recommended in [Jozefowicz et al., 2015]( 2672 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf). 2673 kernel_regularizer: Regularizer function applied to 2674 the `kernel` weights matrix. 2675 recurrent_regularizer: Regularizer function applied to 2676 the `recurrent_kernel` weights matrix. 2677 bias_regularizer: Regularizer function applied to the bias vector. 2678 activity_regularizer: Regularizer function applied to 2679 the output of the layer (its "activation"). 2680 kernel_constraint: Constraint function applied to 2681 the `kernel` weights matrix. 2682 recurrent_constraint: Constraint function applied to 2683 the `recurrent_kernel` weights matrix. 2684 bias_constraint: Constraint function applied to the bias vector. 2685 dropout: Float between 0 and 1. 2686 Fraction of the units to drop for 2687 the linear transformation of the inputs. 2688 recurrent_dropout: Float between 0 and 1. 2689 Fraction of the units to drop for 2690 the linear transformation of the recurrent state. 2691 return_sequences: Boolean. Whether to return the last output. 2692 in the output sequence, or the full sequence. 2693 return_state: Boolean. Whether to return the last state 2694 in addition to the output. 2695 go_backwards: Boolean (default False). 2696 If True, process the input sequence backwards and return the 2697 reversed sequence. 2698 stateful: Boolean (default False). If True, the last state 2699 for each sample at index i in a batch will be used as initial 2700 state for the sample of index i in the following batch. 2701 unroll: Boolean (default False). 2702 If True, the network will be unrolled, 2703 else a symbolic loop will be used. 2704 Unrolling can speed-up a RNN, 2705 although it tends to be more memory-intensive. 2706 Unrolling is only suitable for short sequences. 2707 time_major: The shape format of the `inputs` and `outputs` tensors. 2708 If True, the inputs and outputs will be in shape 2709 `(timesteps, batch, ...)`, whereas in the False case, it will be 2710 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 2711 efficient because it avoids transposes at the beginning and end of the 2712 RNN calculation. However, most TensorFlow data is batch-major, so by 2713 default this function accepts input and emits output in batch-major 2714 form. 2715 2716 Call arguments: 2717 inputs: A 3D tensor. 2718 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 2719 a given timestep should be masked. An individual `True` entry indicates 2720 that the corresponding timestep should be utilized, while a `False` 2721 entry indicates that the corresponding timestep should be ignored. 2722 training: Python boolean indicating whether the layer should behave in 2723 training mode or in inference mode. This argument is passed to the cell 2724 when calling it. This is only relevant if `dropout` or 2725 `recurrent_dropout` is used. 2726 initial_state: List of initial state tensors to be passed to the first 2727 call of the cell. 2728 """ 2729 2730 def __init__(self, 2731 units, 2732 activation='tanh', 2733 recurrent_activation='hard_sigmoid', 2734 use_bias=True, 2735 kernel_initializer='glorot_uniform', 2736 recurrent_initializer='orthogonal', 2737 bias_initializer='zeros', 2738 unit_forget_bias=True, 2739 kernel_regularizer=None, 2740 recurrent_regularizer=None, 2741 bias_regularizer=None, 2742 activity_regularizer=None, 2743 kernel_constraint=None, 2744 recurrent_constraint=None, 2745 bias_constraint=None, 2746 dropout=0., 2747 recurrent_dropout=0., 2748 return_sequences=False, 2749 return_state=False, 2750 go_backwards=False, 2751 stateful=False, 2752 unroll=False, 2753 **kwargs): 2754 implementation = kwargs.pop('implementation', 1) 2755 if implementation == 0: 2756 logging.warning('`implementation=0` has been deprecated, ' 2757 'and now defaults to `implementation=1`.' 2758 'Please update your layer call.') 2759 if 'enable_caching_device' in kwargs: 2760 cell_kwargs = {'enable_caching_device': 2761 kwargs.pop('enable_caching_device')} 2762 else: 2763 cell_kwargs = {} 2764 cell = LSTMCell( 2765 units, 2766 activation=activation, 2767 recurrent_activation=recurrent_activation, 2768 use_bias=use_bias, 2769 kernel_initializer=kernel_initializer, 2770 recurrent_initializer=recurrent_initializer, 2771 unit_forget_bias=unit_forget_bias, 2772 bias_initializer=bias_initializer, 2773 kernel_regularizer=kernel_regularizer, 2774 recurrent_regularizer=recurrent_regularizer, 2775 bias_regularizer=bias_regularizer, 2776 kernel_constraint=kernel_constraint, 2777 recurrent_constraint=recurrent_constraint, 2778 bias_constraint=bias_constraint, 2779 dropout=dropout, 2780 recurrent_dropout=recurrent_dropout, 2781 implementation=implementation, 2782 dtype=kwargs.get('dtype'), 2783 trainable=kwargs.get('trainable', True), 2784 **cell_kwargs) 2785 super(LSTM, self).__init__( 2786 cell, 2787 return_sequences=return_sequences, 2788 return_state=return_state, 2789 go_backwards=go_backwards, 2790 stateful=stateful, 2791 unroll=unroll, 2792 **kwargs) 2793 self.activity_regularizer = regularizers.get(activity_regularizer) 2794 self.input_spec = [InputSpec(ndim=3)] 2795 2796 def call(self, inputs, mask=None, training=None, initial_state=None): 2797 return super(LSTM, self).call( 2798 inputs, mask=mask, training=training, initial_state=initial_state) 2799 2800 @property 2801 def units(self): 2802 return self.cell.units 2803 2804 @property 2805 def activation(self): 2806 return self.cell.activation 2807 2808 @property 2809 def recurrent_activation(self): 2810 return self.cell.recurrent_activation 2811 2812 @property 2813 def use_bias(self): 2814 return self.cell.use_bias 2815 2816 @property 2817 def kernel_initializer(self): 2818 return self.cell.kernel_initializer 2819 2820 @property 2821 def recurrent_initializer(self): 2822 return self.cell.recurrent_initializer 2823 2824 @property 2825 def bias_initializer(self): 2826 return self.cell.bias_initializer 2827 2828 @property 2829 def unit_forget_bias(self): 2830 return self.cell.unit_forget_bias 2831 2832 @property 2833 def kernel_regularizer(self): 2834 return self.cell.kernel_regularizer 2835 2836 @property 2837 def recurrent_regularizer(self): 2838 return self.cell.recurrent_regularizer 2839 2840 @property 2841 def bias_regularizer(self): 2842 return self.cell.bias_regularizer 2843 2844 @property 2845 def kernel_constraint(self): 2846 return self.cell.kernel_constraint 2847 2848 @property 2849 def recurrent_constraint(self): 2850 return self.cell.recurrent_constraint 2851 2852 @property 2853 def bias_constraint(self): 2854 return self.cell.bias_constraint 2855 2856 @property 2857 def dropout(self): 2858 return self.cell.dropout 2859 2860 @property 2861 def recurrent_dropout(self): 2862 return self.cell.recurrent_dropout 2863 2864 @property 2865 def implementation(self): 2866 return self.cell.implementation 2867 2868 def get_config(self): 2869 config = { 2870 'units': 2871 self.units, 2872 'activation': 2873 activations.serialize(self.activation), 2874 'recurrent_activation': 2875 activations.serialize(self.recurrent_activation), 2876 'use_bias': 2877 self.use_bias, 2878 'kernel_initializer': 2879 initializers.serialize(self.kernel_initializer), 2880 'recurrent_initializer': 2881 initializers.serialize(self.recurrent_initializer), 2882 'bias_initializer': 2883 initializers.serialize(self.bias_initializer), 2884 'unit_forget_bias': 2885 self.unit_forget_bias, 2886 'kernel_regularizer': 2887 regularizers.serialize(self.kernel_regularizer), 2888 'recurrent_regularizer': 2889 regularizers.serialize(self.recurrent_regularizer), 2890 'bias_regularizer': 2891 regularizers.serialize(self.bias_regularizer), 2892 'activity_regularizer': 2893 regularizers.serialize(self.activity_regularizer), 2894 'kernel_constraint': 2895 constraints.serialize(self.kernel_constraint), 2896 'recurrent_constraint': 2897 constraints.serialize(self.recurrent_constraint), 2898 'bias_constraint': 2899 constraints.serialize(self.bias_constraint), 2900 'dropout': 2901 self.dropout, 2902 'recurrent_dropout': 2903 self.recurrent_dropout, 2904 'implementation': 2905 self.implementation 2906 } 2907 config.update(_config_for_enable_caching_device(self.cell)) 2908 base_config = super(LSTM, self).get_config() 2909 del base_config['cell'] 2910 return dict(list(base_config.items()) + list(config.items())) 2911 2912 @classmethod 2913 def from_config(cls, config): 2914 if 'implementation' in config and config['implementation'] == 0: 2915 config['implementation'] = 1 2916 return cls(**config) 2917 2918 2919def _generate_dropout_mask(ones, rate, training=None, count=1): 2920 def dropped_inputs(): 2921 return K.dropout(ones, rate) 2922 2923 if count > 1: 2924 return [ 2925 K.in_train_phase(dropped_inputs, ones, training=training) 2926 for _ in range(count) 2927 ] 2928 return K.in_train_phase(dropped_inputs, ones, training=training) 2929 2930 2931def _standardize_args(inputs, initial_state, constants, num_constants): 2932 """Standardizes `__call__` to a single list of tensor inputs. 2933 2934 When running a model loaded from a file, the input tensors 2935 `initial_state` and `constants` can be passed to `RNN.__call__()` as part 2936 of `inputs` instead of by the dedicated keyword arguments. This method 2937 makes sure the arguments are separated and that `initial_state` and 2938 `constants` are lists of tensors (or None). 2939 2940 Args: 2941 inputs: Tensor or list/tuple of tensors. which may include constants 2942 and initial states. In that case `num_constant` must be specified. 2943 initial_state: Tensor or list of tensors or None, initial states. 2944 constants: Tensor or list of tensors or None, constant tensors. 2945 num_constants: Expected number of constants (if constants are passed as 2946 part of the `inputs` list. 2947 2948 Returns: 2949 inputs: Single tensor or tuple of tensors. 2950 initial_state: List of tensors or None. 2951 constants: List of tensors or None. 2952 """ 2953 if isinstance(inputs, list): 2954 # There are several situations here: 2955 # In the graph mode, __call__ will be only called once. The initial_state 2956 # and constants could be in inputs (from file loading). 2957 # In the eager mode, __call__ will be called twice, once during 2958 # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be 2959 # model.fit/train_on_batch/predict with real np data. In the second case, 2960 # the inputs will contain initial_state and constants as eager tensor. 2961 # 2962 # For either case, the real input is the first item in the list, which 2963 # could be a nested structure itself. Then followed by initial_states, which 2964 # could be a list of items, or list of list if the initial_state is complex 2965 # structure, and finally followed by constants which is a flat list. 2966 assert initial_state is None and constants is None 2967 if num_constants: 2968 constants = inputs[-num_constants:] 2969 inputs = inputs[:-num_constants] 2970 if len(inputs) > 1: 2971 initial_state = inputs[1:] 2972 inputs = inputs[:1] 2973 2974 if len(inputs) > 1: 2975 inputs = tuple(inputs) 2976 else: 2977 inputs = inputs[0] 2978 2979 def to_list_or_none(x): 2980 if x is None or isinstance(x, list): 2981 return x 2982 if isinstance(x, tuple): 2983 return list(x) 2984 return [x] 2985 2986 initial_state = to_list_or_none(initial_state) 2987 constants = to_list_or_none(constants) 2988 2989 return inputs, initial_state, constants 2990 2991 2992def _is_multiple_state(state_size): 2993 """Check whether the state_size contains multiple states.""" 2994 return (hasattr(state_size, '__len__') and 2995 not isinstance(state_size, tensor_shape.TensorShape)) 2996 2997 2998def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): 2999 if inputs is not None: 3000 batch_size = array_ops.shape(inputs)[0] 3001 dtype = inputs.dtype 3002 return _generate_zero_filled_state(batch_size, cell.state_size, dtype) 3003 3004 3005def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): 3006 """Generate a zero filled tensor with shape [batch_size, state_size].""" 3007 if batch_size_tensor is None or dtype is None: 3008 raise ValueError( 3009 'batch_size and dtype cannot be None while constructing initial state: ' 3010 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype)) 3011 3012 def create_zeros(unnested_state_size): 3013 flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list() 3014 init_state_size = [batch_size_tensor] + flat_dims 3015 return array_ops.zeros(init_state_size, dtype=dtype) 3016 3017 if nest.is_nested(state_size): 3018 return nest.map_structure(create_zeros, state_size) 3019 else: 3020 return create_zeros(state_size) 3021 3022 3023def _caching_device(rnn_cell): 3024 """Returns the caching device for the RNN variable. 3025 3026 This is useful for distributed training, when variable is not located as same 3027 device as the training worker. By enabling the device cache, this allows 3028 worker to read the variable once and cache locally, rather than read it every 3029 time step from remote when it is needed. 3030 3031 Note that this is assuming the variable that cell needs for each time step is 3032 having the same value in the forward path, and only gets updated in the 3033 backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the 3034 cell body relies on any variable that gets updated every time step, then 3035 caching device will cause it to read the stall value. 3036 3037 Args: 3038 rnn_cell: the rnn cell instance. 3039 """ 3040 if context.executing_eagerly(): 3041 # caching_device is not supported in eager mode. 3042 return None 3043 if not getattr(rnn_cell, '_enable_caching_device', False): 3044 return None 3045 # Don't set a caching device when running in a loop, since it is possible that 3046 # train steps could be wrapped in a tf.while_loop. In that scenario caching 3047 # prevents forward computations in loop iterations from re-reading the 3048 # updated weights. 3049 if control_flow_util.IsInWhileLoop(ops.get_default_graph()): 3050 logging.warn('Variable read device caching has been disabled because the ' 3051 'RNN is in tf.while_loop loop context, which will cause ' 3052 'reading stalled value in forward path. This could slow down ' 3053 'the training due to duplicated variable reads. Please ' 3054 'consider updating your code to remove tf.while_loop if ' 3055 'possible.') 3056 return None 3057 if (rnn_cell._dtype_policy.compute_dtype != 3058 rnn_cell._dtype_policy.variable_dtype): 3059 logging.warn('Variable read device caching has been disabled since it ' 3060 'doesn\'t work with the mixed precision API. This is ' 3061 'likely to cause a slowdown for RNN training due to ' 3062 'duplicated read of variable for each timestep, which ' 3063 'will be significant in a multi remote worker setting. ' 3064 'Please consider disabling mixed precision API if ' 3065 'the performance has been affected.') 3066 return None 3067 # Cache the value on the device that access the variable. 3068 return lambda op: op.device 3069 3070 3071def _config_for_enable_caching_device(rnn_cell): 3072 """Return the dict config for RNN cell wrt to enable_caching_device field. 3073 3074 Since enable_caching_device is a internal implementation detail for speed up 3075 the RNN variable read when running on the multi remote worker setting, we 3076 don't want this config to be serialized constantly in the JSON. We will only 3077 serialize this field when a none default value is used to create the cell. 3078 Args: 3079 rnn_cell: the RNN cell for serialize. 3080 3081 Returns: 3082 A dict which contains the JSON config for enable_caching_device value or 3083 empty dict if the enable_caching_device value is same as the default value. 3084 """ 3085 default_enable_caching_device = ops.executing_eagerly_outside_functions() 3086 if rnn_cell._enable_caching_device != default_enable_caching_device: 3087 return {'enable_caching_device': rnn_cell._enable_caching_device} 3088 return {} 3089