1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module implementing RNN Cells. 16 17This module provides a number of basic commonly used RNN cells, such as LSTM 18(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of 19operators that allow adding dropouts, projections, or embeddings for inputs. 20Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by 21calling the `rnn` ops several times. 22""" 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import collections 28import hashlib 29import numbers 30 31from tensorflow.python.eager import context 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_shape 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.keras import activations 38from tensorflow.python.keras import initializers 39from tensorflow.python.keras import layers as keras_layer 40from tensorflow.python.keras.engine import input_spec 41from tensorflow.python.keras.utils import tf_utils 42from tensorflow.python.layers import base as base_layer 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import clip_ops 45from tensorflow.python.ops import init_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import nn_ops 48from tensorflow.python.ops import partitioned_variables 49from tensorflow.python.ops import random_ops 50from tensorflow.python.ops import tensor_array_ops 51from tensorflow.python.ops import variable_scope as vs 52from tensorflow.python.ops import variables as tf_variables 53from tensorflow.python.platform import tf_logging as logging 54from tensorflow.python.training.tracking import base as trackable 55from tensorflow.python.util import nest 56from tensorflow.python.util.deprecation import deprecated 57from tensorflow.python.util.tf_export import tf_export 58 59 60_BIAS_VARIABLE_NAME = "bias" 61_WEIGHTS_VARIABLE_NAME = "kernel" 62 63# This can be used with self.assertRaisesRegexp for assert_like_rnncell. 64ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" 65 66 67def _hasattr(obj, attr_name): 68 try: 69 getattr(obj, attr_name) 70 except AttributeError: 71 return False 72 else: 73 return True 74 75 76def assert_like_rnncell(cell_name, cell): 77 """Raises a TypeError if cell is not like an RNNCell. 78 79 NOTE: Do not rely on the error message (in particular in tests) which can be 80 subject to change to increase readability. Use 81 ASSERT_LIKE_RNNCELL_ERROR_REGEXP. 82 83 Args: 84 cell_name: A string to give a meaningful error referencing to the name 85 of the functionargument. 86 cell: The object which should behave like an RNNCell. 87 88 Raises: 89 TypeError: A human-friendly exception. 90 """ 91 conditions = [ 92 _hasattr(cell, "output_size"), 93 _hasattr(cell, "state_size"), 94 _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"), 95 callable(cell), 96 ] 97 errors = [ 98 "'output_size' property is missing", 99 "'state_size' property is missing", 100 "either 'zero_state' or 'get_initial_state' method is required", 101 "is not callable" 102 ] 103 104 if not all(conditions): 105 106 errors = [error for error, cond in zip(errors, conditions) if not cond] 107 raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format( 108 cell_name, cell, ", ".join(errors))) 109 110 111def _concat(prefix, suffix, static=False): 112 """Concat that enables int, Tensor, or TensorShape values. 113 114 This function takes a size specification, which can be an integer, a 115 TensorShape, or a Tensor, and converts it into a concatenated Tensor 116 (if static = False) or a list of integers (if static = True). 117 118 Args: 119 prefix: The prefix; usually the batch size (and/or time step size). 120 (TensorShape, int, or Tensor.) 121 suffix: TensorShape, int, or Tensor. 122 static: If `True`, return a python list with possibly unknown dimensions. 123 Otherwise return a `Tensor`. 124 125 Returns: 126 shape: the concatenation of prefix and suffix. 127 128 Raises: 129 ValueError: if `suffix` is not a scalar or vector (or TensorShape). 130 ValueError: if prefix or suffix was `None` and asked for dynamic 131 Tensors out. 132 """ 133 if isinstance(prefix, ops.Tensor): 134 p = prefix 135 p_static = tensor_util.constant_value(prefix) 136 if p.shape.ndims == 0: 137 p = array_ops.expand_dims(p, 0) 138 elif p.shape.ndims != 1: 139 raise ValueError("prefix tensor must be either a scalar or vector, " 140 "but saw tensor: %s" % p) 141 else: 142 p = tensor_shape.as_shape(prefix) 143 p_static = p.as_list() if p.ndims is not None else None 144 p = (constant_op.constant(p.as_list(), dtype=dtypes.int32) 145 if p.is_fully_defined() else None) 146 if isinstance(suffix, ops.Tensor): 147 s = suffix 148 s_static = tensor_util.constant_value(suffix) 149 if s.shape.ndims == 0: 150 s = array_ops.expand_dims(s, 0) 151 elif s.shape.ndims != 1: 152 raise ValueError("suffix tensor must be either a scalar or vector, " 153 "but saw tensor: %s" % s) 154 else: 155 s = tensor_shape.as_shape(suffix) 156 s_static = s.as_list() if s.ndims is not None else None 157 s = (constant_op.constant(s.as_list(), dtype=dtypes.int32) 158 if s.is_fully_defined() else None) 159 160 if static: 161 shape = tensor_shape.as_shape(p_static).concatenate(s_static) 162 shape = shape.as_list() if shape.ndims is not None else None 163 else: 164 if p is None or s is None: 165 raise ValueError("Provided a prefix or suffix of None: %s and %s" 166 % (prefix, suffix)) 167 shape = array_ops.concat((p, s), 0) 168 return shape 169 170 171def _zero_state_tensors(state_size, batch_size, dtype): 172 """Create tensors of zeros based on state_size, batch_size, and dtype.""" 173 def get_state_shape(s): 174 """Combine s with batch_size to get a proper tensor shape.""" 175 c = _concat(batch_size, s) 176 size = array_ops.zeros(c, dtype=dtype) 177 if not context.executing_eagerly(): 178 c_static = _concat(batch_size, s, static=True) 179 size.set_shape(c_static) 180 return size 181 return nest.map_structure(get_state_shape, state_size) 182 183 184@tf_export(v1=["nn.rnn_cell.RNNCell"]) 185class RNNCell(base_layer.Layer): 186 """Abstract object representing an RNN cell. 187 188 Every `RNNCell` must have the properties below and implement `call` with 189 the signature `(output, next_state) = call(input, state)`. The optional 190 third input argument, `scope`, is allowed for backwards compatibility 191 purposes; but should be left off for new subclasses. 192 193 This definition of cell differs from the definition used in the literature. 194 In the literature, 'cell' refers to an object with a single scalar output. 195 This definition refers to a horizontal array of such units. 196 197 An RNN cell, in the most abstract setting, is anything that has 198 a state and performs some operation that takes a matrix of inputs. 199 This operation results in an output matrix with `self.output_size` columns. 200 If `self.state_size` is an integer, this operation also results in a new 201 state matrix with `self.state_size` columns. If `self.state_size` is a 202 (possibly nested tuple of) TensorShape object(s), then it should return a 203 matching structure of Tensors having shape `[batch_size].concatenate(s)` 204 for each `s` in `self.batch_size`. 205 """ 206 207 def __init__(self, trainable=True, name=None, dtype=None, **kwargs): 208 super(RNNCell, self).__init__( 209 trainable=trainable, name=name, dtype=dtype, **kwargs) 210 # Attribute that indicates whether the cell is a TF RNN cell, due the slight 211 # difference between TF and Keras RNN cell. Notably the state is not wrapped 212 # in a list for TF cell where they are single tensor state, whereas keras 213 # cell will wrap the state into a list, and call() will have to unwrap them. 214 self._is_tf_rnn_cell = True 215 216 def __call__(self, inputs, state, scope=None): 217 """Run this RNN cell on inputs, starting from the given state. 218 219 Args: 220 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 221 state: if `self.state_size` is an integer, this should be a `2-D Tensor` 222 with shape `[batch_size, self.state_size]`. Otherwise, if 223 `self.state_size` is a tuple of integers, this should be a tuple 224 with shapes `[batch_size, s] for s in self.state_size`. 225 scope: VariableScope for the created subgraph; defaults to class name. 226 227 Returns: 228 A pair containing: 229 230 - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. 231 - New state: Either a single `2-D` tensor, or a tuple of tensors matching 232 the arity and shapes of `state`. 233 """ 234 if scope is not None: 235 with vs.variable_scope(scope, 236 custom_getter=self._rnn_get_variable) as scope: 237 return super(RNNCell, self).__call__(inputs, state, scope=scope) 238 else: 239 scope_attrname = "rnncell_scope" 240 scope = getattr(self, scope_attrname, None) 241 if scope is None: 242 scope = vs.variable_scope(vs.get_variable_scope(), 243 custom_getter=self._rnn_get_variable) 244 setattr(self, scope_attrname, scope) 245 with scope: 246 return super(RNNCell, self).__call__(inputs, state) 247 248 def _rnn_get_variable(self, getter, *args, **kwargs): 249 variable = getter(*args, **kwargs) 250 if context.executing_eagerly(): 251 trainable = variable._trainable # pylint: disable=protected-access 252 else: 253 trainable = ( 254 variable in tf_variables.trainable_variables() or 255 (isinstance(variable, tf_variables.PartitionedVariable) and 256 list(variable)[0] in tf_variables.trainable_variables())) 257 if trainable and variable not in self._trainable_weights: 258 self._trainable_weights.append(variable) 259 elif not trainable and variable not in self._non_trainable_weights: 260 self._non_trainable_weights.append(variable) 261 return variable 262 263 @property 264 def state_size(self): 265 """size(s) of state(s) used by this cell. 266 267 It can be represented by an Integer, a TensorShape or a tuple of Integers 268 or TensorShapes. 269 """ 270 raise NotImplementedError("Abstract method") 271 272 @property 273 def output_size(self): 274 """Integer or TensorShape: size of outputs produced by this cell.""" 275 raise NotImplementedError("Abstract method") 276 277 def build(self, _): 278 # This tells the parent Layer object that it's OK to call 279 # self.add_variable() inside the call() method. 280 pass 281 282 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 283 if inputs is not None: 284 # Validate the given batch_size and dtype against inputs if provided. 285 inputs = ops.convert_to_tensor(inputs, name="inputs") 286 if batch_size is not None: 287 if tensor_util.is_tensor(batch_size): 288 static_batch_size = tensor_util.constant_value( 289 batch_size, partial=True) 290 else: 291 static_batch_size = batch_size 292 if inputs.shape.dims[0].value != static_batch_size: 293 raise ValueError( 294 "batch size from input tensor is different from the " 295 "input param. Input tensor batch: {}, batch_size: {}".format( 296 inputs.shape.dims[0].value, batch_size)) 297 298 if dtype is not None and inputs.dtype != dtype: 299 raise ValueError( 300 "dtype from input tensor is different from the " 301 "input param. Input tensor dtype: {}, dtype: {}".format( 302 inputs.dtype, dtype)) 303 304 batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0] 305 dtype = inputs.dtype 306 if None in [batch_size, dtype]: 307 raise ValueError( 308 "batch_size and dtype cannot be None while constructing initial " 309 "state: batch_size={}, dtype={}".format(batch_size, dtype)) 310 return self.zero_state(batch_size, dtype) 311 312 def zero_state(self, batch_size, dtype): 313 """Return zero-filled state tensor(s). 314 315 Args: 316 batch_size: int, float, or unit Tensor representing the batch size. 317 dtype: the data type to use for the state. 318 319 Returns: 320 If `state_size` is an int or TensorShape, then the return value is a 321 `N-D` tensor of shape `[batch_size, state_size]` filled with zeros. 322 323 If `state_size` is a nested list or tuple, then the return value is 324 a nested list or tuple (of the same structure) of `2-D` tensors with 325 the shapes `[batch_size, s]` for each s in `state_size`. 326 """ 327 # Try to use the last cached zero_state. This is done to avoid recreating 328 # zeros, especially when eager execution is enabled. 329 state_size = self.state_size 330 is_eager = context.executing_eagerly() 331 if is_eager and _hasattr(self, "_last_zero_state"): 332 (last_state_size, last_batch_size, last_dtype, 333 last_output) = getattr(self, "_last_zero_state") 334 if (last_batch_size == batch_size and 335 last_dtype == dtype and 336 last_state_size == state_size): 337 return last_output 338 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 339 output = _zero_state_tensors(state_size, batch_size, dtype) 340 if is_eager: 341 self._last_zero_state = (state_size, batch_size, dtype, output) 342 return output 343 344 345class LayerRNNCell(RNNCell): 346 """Subclass of RNNCells that act like proper `tf.Layer` objects. 347 348 For backwards compatibility purposes, most `RNNCell` instances allow their 349 `call` methods to instantiate variables via `tf.get_variable`. The underlying 350 variable scope thus keeps track of any variables, and returning cached 351 versions. This is atypical of `tf.layer` objects, which separate this 352 part of layer building into a `build` method that is only called once. 353 354 Here we provide a subclass for `RNNCell` objects that act exactly as 355 `Layer` objects do. They must provide a `build` method and their 356 `call` methods do not access Variables `tf.get_variable`. 357 """ 358 359 def __call__(self, inputs, state, scope=None, *args, **kwargs): 360 """Run this RNN cell on inputs, starting from the given state. 361 362 Args: 363 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 364 state: if `self.state_size` is an integer, this should be a `2-D Tensor` 365 with shape `[batch_size, self.state_size]`. Otherwise, if 366 `self.state_size` is a tuple of integers, this should be a tuple 367 with shapes `[batch_size, s] for s in self.state_size`. 368 scope: optional cell scope. 369 *args: Additional positional arguments. 370 **kwargs: Additional keyword arguments. 371 372 Returns: 373 A pair containing: 374 375 - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. 376 - New state: Either a single `2-D` tensor, or a tuple of tensors matching 377 the arity and shapes of `state`. 378 """ 379 # Bypass RNNCell's variable capturing semantics for LayerRNNCell. 380 # Instead, it is up to subclasses to provide a proper build 381 # method. See the class docstring for more details. 382 return base_layer.Layer.__call__(self, inputs, state, scope=scope, 383 *args, **kwargs) 384 385 386@tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) 387class BasicRNNCell(LayerRNNCell): 388 """The most basic RNN cell. 389 390 Note that this cell is not optimized for performance. Please use 391 `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU. 392 393 Args: 394 num_units: int, The number of units in the RNN cell. 395 activation: Nonlinearity to use. Default: `tanh`. It could also be string 396 that is within Keras activation function names. 397 reuse: (optional) Python boolean describing whether to reuse variables 398 in an existing scope. If not `True`, and the existing scope already has 399 the given variables, an error is raised. 400 name: String, the name of the layer. Layers with the same name will 401 share weights, but to avoid mistakes we require reuse=True in such 402 cases. 403 dtype: Default dtype of the layer (default of `None` means use the type 404 of the first input). Required when `build` is called before `call`. 405 **kwargs: Dict, keyword named properties for common layer attributes, like 406 `trainable` etc when constructing the cell from configs of get_config(). 407 """ 408 409 @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell," 410 " and will be replaced by that in Tensorflow 2.0.") 411 def __init__(self, 412 num_units, 413 activation=None, 414 reuse=None, 415 name=None, 416 dtype=None, 417 **kwargs): 418 super(BasicRNNCell, self).__init__( 419 _reuse=reuse, name=name, dtype=dtype, **kwargs) 420 _check_supported_dtypes(self.dtype) 421 if context.executing_eagerly() and context.num_gpus() > 0: 422 logging.warn("%s: Note that this cell is not optimized for performance. " 423 "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better " 424 "performance on GPU.", self) 425 426 # Inputs must be 2-dimensional. 427 self.input_spec = input_spec.InputSpec(ndim=2) 428 429 self._num_units = num_units 430 if activation: 431 self._activation = activations.get(activation) 432 else: 433 self._activation = math_ops.tanh 434 435 @property 436 def state_size(self): 437 return self._num_units 438 439 @property 440 def output_size(self): 441 return self._num_units 442 443 @tf_utils.shape_type_conversion 444 def build(self, inputs_shape): 445 if inputs_shape[-1] is None: 446 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 447 % str(inputs_shape)) 448 _check_supported_dtypes(self.dtype) 449 450 input_depth = inputs_shape[-1] 451 self._kernel = self.add_variable( 452 _WEIGHTS_VARIABLE_NAME, 453 shape=[input_depth + self._num_units, self._num_units]) 454 self._bias = self.add_variable( 455 _BIAS_VARIABLE_NAME, 456 shape=[self._num_units], 457 initializer=init_ops.zeros_initializer(dtype=self.dtype)) 458 459 self.built = True 460 461 def call(self, inputs, state): 462 """Most basic RNN: output = new_state = act(W * input + U * state + B).""" 463 _check_rnn_cell_input_dtypes([inputs, state]) 464 gate_inputs = math_ops.matmul( 465 array_ops.concat([inputs, state], 1), self._kernel) 466 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 467 output = self._activation(gate_inputs) 468 return output, output 469 470 def get_config(self): 471 config = { 472 "num_units": self._num_units, 473 "activation": activations.serialize(self._activation), 474 "reuse": self._reuse, 475 } 476 base_config = super(BasicRNNCell, self).get_config() 477 return dict(list(base_config.items()) + list(config.items())) 478 479 480@tf_export(v1=["nn.rnn_cell.GRUCell"]) 481class GRUCell(LayerRNNCell): 482 """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). 483 484 Note that this cell is not optimized for performance. Please use 485 `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or 486 `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU. 487 488 Args: 489 num_units: int, The number of units in the GRU cell. 490 activation: Nonlinearity to use. Default: `tanh`. 491 reuse: (optional) Python boolean describing whether to reuse variables 492 in an existing scope. If not `True`, and the existing scope already has 493 the given variables, an error is raised. 494 kernel_initializer: (optional) The initializer to use for the weight and 495 projection matrices. 496 bias_initializer: (optional) The initializer to use for the bias. 497 name: String, the name of the layer. Layers with the same name will 498 share weights, but to avoid mistakes we require reuse=True in such 499 cases. 500 dtype: Default dtype of the layer (default of `None` means use the type 501 of the first input). Required when `build` is called before `call`. 502 **kwargs: Dict, keyword named properties for common layer attributes, like 503 `trainable` etc when constructing the cell from configs of get_config(). 504 """ 505 506 @deprecated(None, "This class is equivalent as tf.keras.layers.GRUCell," 507 " and will be replaced by that in Tensorflow 2.0.") 508 def __init__(self, 509 num_units, 510 activation=None, 511 reuse=None, 512 kernel_initializer=None, 513 bias_initializer=None, 514 name=None, 515 dtype=None, 516 **kwargs): 517 super(GRUCell, self).__init__( 518 _reuse=reuse, name=name, dtype=dtype, **kwargs) 519 _check_supported_dtypes(self.dtype) 520 521 if context.executing_eagerly() and context.num_gpus() > 0: 522 logging.warn("%s: Note that this cell is not optimized for performance. " 523 "Please use tf.contrib.cudnn_rnn.CudnnGRU for better " 524 "performance on GPU.", self) 525 # Inputs must be 2-dimensional. 526 self.input_spec = input_spec.InputSpec(ndim=2) 527 528 self._num_units = num_units 529 if activation: 530 self._activation = activations.get(activation) 531 else: 532 self._activation = math_ops.tanh 533 self._kernel_initializer = initializers.get(kernel_initializer) 534 self._bias_initializer = initializers.get(bias_initializer) 535 536 @property 537 def state_size(self): 538 return self._num_units 539 540 @property 541 def output_size(self): 542 return self._num_units 543 544 @tf_utils.shape_type_conversion 545 def build(self, inputs_shape): 546 if inputs_shape[-1] is None: 547 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 548 % str(inputs_shape)) 549 _check_supported_dtypes(self.dtype) 550 input_depth = inputs_shape[-1] 551 self._gate_kernel = self.add_variable( 552 "gates/%s" % _WEIGHTS_VARIABLE_NAME, 553 shape=[input_depth + self._num_units, 2 * self._num_units], 554 initializer=self._kernel_initializer) 555 self._gate_bias = self.add_variable( 556 "gates/%s" % _BIAS_VARIABLE_NAME, 557 shape=[2 * self._num_units], 558 initializer=( 559 self._bias_initializer 560 if self._bias_initializer is not None 561 else init_ops.constant_initializer(1.0, dtype=self.dtype))) 562 self._candidate_kernel = self.add_variable( 563 "candidate/%s" % _WEIGHTS_VARIABLE_NAME, 564 shape=[input_depth + self._num_units, self._num_units], 565 initializer=self._kernel_initializer) 566 self._candidate_bias = self.add_variable( 567 "candidate/%s" % _BIAS_VARIABLE_NAME, 568 shape=[self._num_units], 569 initializer=( 570 self._bias_initializer 571 if self._bias_initializer is not None 572 else init_ops.zeros_initializer(dtype=self.dtype))) 573 574 self.built = True 575 576 def call(self, inputs, state): 577 """Gated recurrent unit (GRU) with nunits cells.""" 578 _check_rnn_cell_input_dtypes([inputs, state]) 579 580 gate_inputs = math_ops.matmul( 581 array_ops.concat([inputs, state], 1), self._gate_kernel) 582 gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) 583 584 value = math_ops.sigmoid(gate_inputs) 585 r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) 586 587 r_state = r * state 588 589 candidate = math_ops.matmul( 590 array_ops.concat([inputs, r_state], 1), self._candidate_kernel) 591 candidate = nn_ops.bias_add(candidate, self._candidate_bias) 592 593 c = self._activation(candidate) 594 new_h = u * state + (1 - u) * c 595 return new_h, new_h 596 597 def get_config(self): 598 config = { 599 "num_units": self._num_units, 600 "kernel_initializer": initializers.serialize(self._kernel_initializer), 601 "bias_initializer": initializers.serialize(self._bias_initializer), 602 "activation": activations.serialize(self._activation), 603 "reuse": self._reuse, 604 } 605 base_config = super(GRUCell, self).get_config() 606 return dict(list(base_config.items()) + list(config.items())) 607 608 609_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) 610 611 612@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"]) 613class LSTMStateTuple(_LSTMStateTuple): 614 """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. 615 616 Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state 617 and `h` is the output. 618 619 Only used when `state_is_tuple=True`. 620 """ 621 __slots__ = () 622 623 @property 624 def dtype(self): 625 (c, h) = self 626 if c.dtype != h.dtype: 627 raise TypeError("Inconsistent internal state: %s vs %s" % 628 (str(c.dtype), str(h.dtype))) 629 return c.dtype 630 631 632@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"]) 633class BasicLSTMCell(LayerRNNCell): 634 """DEPRECATED: Please use `tf.nn.rnn_cell.LSTMCell` instead. 635 636 Basic LSTM recurrent network cell. 637 638 The implementation is based on: http://arxiv.org/abs/1409.2329. 639 640 We add forget_bias (default: 1) to the biases of the forget gate in order to 641 reduce the scale of forgetting in the beginning of the training. 642 643 It does not allow cell clipping, a projection layer, and does not 644 use peep-hole connections: it is the basic baseline. 645 646 For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` 647 that follows. 648 649 Note that this cell is not optimized for performance. Please use 650 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or 651 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for 652 better performance on CPU. 653 """ 654 655 @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," 656 " and will be replaced by that in Tensorflow 2.0.") 657 def __init__(self, 658 num_units, 659 forget_bias=1.0, 660 state_is_tuple=True, 661 activation=None, 662 reuse=None, 663 name=None, 664 dtype=None, 665 **kwargs): 666 """Initialize the basic LSTM cell. 667 668 Args: 669 num_units: int, The number of units in the LSTM cell. 670 forget_bias: float, The bias added to forget gates (see above). 671 Must set to `0.0` manually when restoring from CudnnLSTM-trained 672 checkpoints. 673 state_is_tuple: If True, accepted and returned states are 2-tuples of 674 the `c_state` and `m_state`. If False, they are concatenated 675 along the column axis. The latter behavior will soon be deprecated. 676 activation: Activation function of the inner states. Default: `tanh`. It 677 could also be string that is within Keras activation function names. 678 reuse: (optional) Python boolean describing whether to reuse variables 679 in an existing scope. If not `True`, and the existing scope already has 680 the given variables, an error is raised. 681 name: String, the name of the layer. Layers with the same name will 682 share weights, but to avoid mistakes we require reuse=True in such 683 cases. 684 dtype: Default dtype of the layer (default of `None` means use the type 685 of the first input). Required when `build` is called before `call`. 686 **kwargs: Dict, keyword named properties for common layer attributes, like 687 `trainable` etc when constructing the cell from configs of get_config(). 688 689 When restoring from CudnnLSTM-trained checkpoints, must use 690 `CudnnCompatibleLSTMCell` instead. 691 """ 692 super(BasicLSTMCell, self).__init__( 693 _reuse=reuse, name=name, dtype=dtype, **kwargs) 694 _check_supported_dtypes(self.dtype) 695 if not state_is_tuple: 696 logging.warn("%s: Using a concatenated state is slower and will soon be " 697 "deprecated. Use state_is_tuple=True.", self) 698 if context.executing_eagerly() and context.num_gpus() > 0: 699 logging.warn("%s: Note that this cell is not optimized for performance. " 700 "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " 701 "performance on GPU.", self) 702 703 # Inputs must be 2-dimensional. 704 self.input_spec = input_spec.InputSpec(ndim=2) 705 706 self._num_units = num_units 707 self._forget_bias = forget_bias 708 self._state_is_tuple = state_is_tuple 709 if activation: 710 self._activation = activations.get(activation) 711 else: 712 self._activation = math_ops.tanh 713 714 @property 715 def state_size(self): 716 return (LSTMStateTuple(self._num_units, self._num_units) 717 if self._state_is_tuple else 2 * self._num_units) 718 719 @property 720 def output_size(self): 721 return self._num_units 722 723 @tf_utils.shape_type_conversion 724 def build(self, inputs_shape): 725 if inputs_shape[-1] is None: 726 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 727 % str(inputs_shape)) 728 _check_supported_dtypes(self.dtype) 729 input_depth = inputs_shape[-1] 730 h_depth = self._num_units 731 self._kernel = self.add_variable( 732 _WEIGHTS_VARIABLE_NAME, 733 shape=[input_depth + h_depth, 4 * self._num_units]) 734 self._bias = self.add_variable( 735 _BIAS_VARIABLE_NAME, 736 shape=[4 * self._num_units], 737 initializer=init_ops.zeros_initializer(dtype=self.dtype)) 738 739 self.built = True 740 741 def call(self, inputs, state): 742 """Long short-term memory cell (LSTM). 743 744 Args: 745 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 746 state: An `LSTMStateTuple` of state tensors, each shaped 747 `[batch_size, num_units]`, if `state_is_tuple` has been set to 748 `True`. Otherwise, a `Tensor` shaped 749 `[batch_size, 2 * num_units]`. 750 751 Returns: 752 A pair containing the new hidden state, and the new state (either a 753 `LSTMStateTuple` or a concatenated state, depending on 754 `state_is_tuple`). 755 """ 756 _check_rnn_cell_input_dtypes([inputs, state]) 757 758 sigmoid = math_ops.sigmoid 759 one = constant_op.constant(1, dtype=dtypes.int32) 760 # Parameters of gates are concatenated into one multiply for efficiency. 761 if self._state_is_tuple: 762 c, h = state 763 else: 764 c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) 765 766 gate_inputs = math_ops.matmul( 767 array_ops.concat([inputs, h], 1), self._kernel) 768 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 769 770 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 771 i, j, f, o = array_ops.split( 772 value=gate_inputs, num_or_size_splits=4, axis=one) 773 774 forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) 775 # Note that using `add` and `multiply` instead of `+` and `*` gives a 776 # performance improvement. So using those at the cost of readability. 777 add = math_ops.add 778 multiply = math_ops.multiply 779 new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))), 780 multiply(sigmoid(i), self._activation(j))) 781 new_h = multiply(self._activation(new_c), sigmoid(o)) 782 783 if self._state_is_tuple: 784 new_state = LSTMStateTuple(new_c, new_h) 785 else: 786 new_state = array_ops.concat([new_c, new_h], 1) 787 return new_h, new_state 788 789 def get_config(self): 790 config = { 791 "num_units": self._num_units, 792 "forget_bias": self._forget_bias, 793 "state_is_tuple": self._state_is_tuple, 794 "activation": activations.serialize(self._activation), 795 "reuse": self._reuse, 796 } 797 base_config = super(BasicLSTMCell, self).get_config() 798 return dict(list(base_config.items()) + list(config.items())) 799 800 801@tf_export(v1=["nn.rnn_cell.LSTMCell"]) 802class LSTMCell(LayerRNNCell): 803 """Long short-term memory unit (LSTM) recurrent network cell. 804 805 The default non-peephole implementation is based on: 806 807 https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf 808 809 Felix Gers, Jurgen Schmidhuber, and Fred Cummins. 810 "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. 811 812 The peephole implementation is based on: 813 814 https://research.google.com/pubs/archive/43905.pdf 815 816 Hasim Sak, Andrew Senior, and Francoise Beaufays. 817 "Long short-term memory recurrent neural network architectures for 818 large scale acoustic modeling." INTERSPEECH, 2014. 819 820 The class uses optional peep-hole connections, optional cell clipping, and 821 an optional projection layer. 822 823 Note that this cell is not optimized for performance. Please use 824 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or 825 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for 826 better performance on CPU. 827 """ 828 829 @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell," 830 " and will be replaced by that in Tensorflow 2.0.") 831 def __init__(self, num_units, 832 use_peepholes=False, cell_clip=None, 833 initializer=None, num_proj=None, proj_clip=None, 834 num_unit_shards=None, num_proj_shards=None, 835 forget_bias=1.0, state_is_tuple=True, 836 activation=None, reuse=None, name=None, dtype=None, **kwargs): 837 """Initialize the parameters for an LSTM cell. 838 839 Args: 840 num_units: int, The number of units in the LSTM cell. 841 use_peepholes: bool, set True to enable diagonal/peephole connections. 842 cell_clip: (optional) A float value, if provided the cell state is clipped 843 by this value prior to the cell output activation. 844 initializer: (optional) The initializer to use for the weight and 845 projection matrices. 846 num_proj: (optional) int, The output dimensionality for the projection 847 matrices. If None, no projection is performed. 848 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 849 provided, then the projected values are clipped elementwise to within 850 `[-proj_clip, proj_clip]`. 851 num_unit_shards: Deprecated, will be removed by Jan. 2017. 852 Use a variable_scope partitioner instead. 853 num_proj_shards: Deprecated, will be removed by Jan. 2017. 854 Use a variable_scope partitioner instead. 855 forget_bias: Biases of the forget gate are initialized by default to 1 856 in order to reduce the scale of forgetting at the beginning of 857 the training. Must set it manually to `0.0` when restoring from 858 CudnnLSTM trained checkpoints. 859 state_is_tuple: If True, accepted and returned states are 2-tuples of 860 the `c_state` and `m_state`. If False, they are concatenated 861 along the column axis. This latter behavior will soon be deprecated. 862 activation: Activation function of the inner states. Default: `tanh`. It 863 could also be string that is within Keras activation function names. 864 reuse: (optional) Python boolean describing whether to reuse variables 865 in an existing scope. If not `True`, and the existing scope already has 866 the given variables, an error is raised. 867 name: String, the name of the layer. Layers with the same name will 868 share weights, but to avoid mistakes we require reuse=True in such 869 cases. 870 dtype: Default dtype of the layer (default of `None` means use the type 871 of the first input). Required when `build` is called before `call`. 872 **kwargs: Dict, keyword named properties for common layer attributes, like 873 `trainable` etc when constructing the cell from configs of get_config(). 874 875 When restoring from CudnnLSTM-trained checkpoints, use 876 `CudnnCompatibleLSTMCell` instead. 877 """ 878 super(LSTMCell, self).__init__( 879 _reuse=reuse, name=name, dtype=dtype, **kwargs) 880 _check_supported_dtypes(self.dtype) 881 if not state_is_tuple: 882 logging.warn("%s: Using a concatenated state is slower and will soon be " 883 "deprecated. Use state_is_tuple=True.", self) 884 if num_unit_shards is not None or num_proj_shards is not None: 885 logging.warn( 886 "%s: The num_unit_shards and proj_unit_shards parameters are " 887 "deprecated and will be removed in Jan 2017. " 888 "Use a variable scope with a partitioner instead.", self) 889 if context.executing_eagerly() and context.num_gpus() > 0: 890 logging.warn("%s: Note that this cell is not optimized for performance. " 891 "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " 892 "performance on GPU.", self) 893 894 # Inputs must be 2-dimensional. 895 self.input_spec = input_spec.InputSpec(ndim=2) 896 897 self._num_units = num_units 898 self._use_peepholes = use_peepholes 899 self._cell_clip = cell_clip 900 self._initializer = initializers.get(initializer) 901 self._num_proj = num_proj 902 self._proj_clip = proj_clip 903 self._num_unit_shards = num_unit_shards 904 self._num_proj_shards = num_proj_shards 905 self._forget_bias = forget_bias 906 self._state_is_tuple = state_is_tuple 907 if activation: 908 self._activation = activations.get(activation) 909 else: 910 self._activation = math_ops.tanh 911 912 if num_proj: 913 self._state_size = ( 914 LSTMStateTuple(num_units, num_proj) 915 if state_is_tuple else num_units + num_proj) 916 self._output_size = num_proj 917 else: 918 self._state_size = ( 919 LSTMStateTuple(num_units, num_units) 920 if state_is_tuple else 2 * num_units) 921 self._output_size = num_units 922 923 @property 924 def state_size(self): 925 return self._state_size 926 927 @property 928 def output_size(self): 929 return self._output_size 930 931 @tf_utils.shape_type_conversion 932 def build(self, inputs_shape): 933 if inputs_shape[-1] is None: 934 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 935 % str(inputs_shape)) 936 _check_supported_dtypes(self.dtype) 937 input_depth = inputs_shape[-1] 938 h_depth = self._num_units if self._num_proj is None else self._num_proj 939 maybe_partitioner = ( 940 partitioned_variables.fixed_size_partitioner(self._num_unit_shards) 941 if self._num_unit_shards is not None 942 else None) 943 self._kernel = self.add_variable( 944 _WEIGHTS_VARIABLE_NAME, 945 shape=[input_depth + h_depth, 4 * self._num_units], 946 initializer=self._initializer, 947 partitioner=maybe_partitioner) 948 if self.dtype is None: 949 initializer = init_ops.zeros_initializer 950 else: 951 initializer = init_ops.zeros_initializer(dtype=self.dtype) 952 self._bias = self.add_variable( 953 _BIAS_VARIABLE_NAME, 954 shape=[4 * self._num_units], 955 initializer=initializer) 956 if self._use_peepholes: 957 self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units], 958 initializer=self._initializer) 959 self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units], 960 initializer=self._initializer) 961 self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units], 962 initializer=self._initializer) 963 964 if self._num_proj is not None: 965 maybe_proj_partitioner = ( 966 partitioned_variables.fixed_size_partitioner(self._num_proj_shards) 967 if self._num_proj_shards is not None 968 else None) 969 self._proj_kernel = self.add_variable( 970 "projection/%s" % _WEIGHTS_VARIABLE_NAME, 971 shape=[self._num_units, self._num_proj], 972 initializer=self._initializer, 973 partitioner=maybe_proj_partitioner) 974 975 self.built = True 976 977 def call(self, inputs, state): 978 """Run one step of LSTM. 979 980 Args: 981 inputs: input Tensor, must be 2-D, `[batch, input_size]`. 982 state: if `state_is_tuple` is False, this must be a state Tensor, 983 `2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a 984 tuple of state Tensors, both `2-D`, with column sizes `c_state` and 985 `m_state`. 986 987 Returns: 988 A tuple containing: 989 990 - A `2-D, [batch, output_dim]`, Tensor representing the output of the 991 LSTM after reading `inputs` when previous state was `state`. 992 Here output_dim is: 993 num_proj if num_proj was set, 994 num_units otherwise. 995 - Tensor(s) representing the new state of LSTM after reading `inputs` when 996 the previous state was `state`. Same type and shape(s) as `state`. 997 998 Raises: 999 ValueError: If input size cannot be inferred from inputs via 1000 static shape inference. 1001 """ 1002 _check_rnn_cell_input_dtypes([inputs, state]) 1003 1004 num_proj = self._num_units if self._num_proj is None else self._num_proj 1005 sigmoid = math_ops.sigmoid 1006 1007 if self._state_is_tuple: 1008 (c_prev, m_prev) = state 1009 else: 1010 c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 1011 m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 1012 1013 input_size = inputs.get_shape().with_rank(2).dims[1].value 1014 if input_size is None: 1015 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1016 1017 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 1018 lstm_matrix = math_ops.matmul( 1019 array_ops.concat([inputs, m_prev], 1), self._kernel) 1020 lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) 1021 1022 i, j, f, o = array_ops.split( 1023 value=lstm_matrix, num_or_size_splits=4, axis=1) 1024 # Diagonal connections 1025 if self._use_peepholes: 1026 c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + 1027 sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) 1028 else: 1029 c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * 1030 self._activation(j)) 1031 1032 if self._cell_clip is not None: 1033 # pylint: disable=invalid-unary-operand-type 1034 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 1035 # pylint: enable=invalid-unary-operand-type 1036 if self._use_peepholes: 1037 m = sigmoid(o + self._w_o_diag * c) * self._activation(c) 1038 else: 1039 m = sigmoid(o) * self._activation(c) 1040 1041 if self._num_proj is not None: 1042 m = math_ops.matmul(m, self._proj_kernel) 1043 1044 if self._proj_clip is not None: 1045 # pylint: disable=invalid-unary-operand-type 1046 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 1047 # pylint: enable=invalid-unary-operand-type 1048 1049 new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else 1050 array_ops.concat([c, m], 1)) 1051 return m, new_state 1052 1053 def get_config(self): 1054 config = { 1055 "num_units": self._num_units, 1056 "use_peepholes": self._use_peepholes, 1057 "cell_clip": self._cell_clip, 1058 "initializer": initializers.serialize(self._initializer), 1059 "num_proj": self._num_proj, 1060 "proj_clip": self._proj_clip, 1061 "num_unit_shards": self._num_unit_shards, 1062 "num_proj_shards": self._num_proj_shards, 1063 "forget_bias": self._forget_bias, 1064 "state_is_tuple": self._state_is_tuple, 1065 "activation": activations.serialize(self._activation), 1066 "reuse": self._reuse, 1067 } 1068 base_config = super(LSTMCell, self).get_config() 1069 return dict(list(base_config.items()) + list(config.items())) 1070 1071 1072def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): 1073 ix = [0] 1074 def enumerated_fn(*inner_args, **inner_kwargs): 1075 r = map_fn(ix[0], *inner_args, **inner_kwargs) 1076 ix[0] += 1 1077 return r 1078 return nest.map_structure_up_to(shallow_structure, 1079 enumerated_fn, *args, **kwargs) 1080 1081 1082def _default_dropout_state_filter_visitor(substate): 1083 if isinstance(substate, LSTMStateTuple): 1084 # Do not perform dropout on the memory state. 1085 return LSTMStateTuple(c=False, h=True) 1086 elif isinstance(substate, tensor_array_ops.TensorArray): 1087 return False 1088 return True 1089 1090 1091class _RNNCellWrapperV1(RNNCell): 1092 """Base class for cells wrappers V1 compatibility. 1093 1094 This class along with `_RNNCellWrapperV2` allows to define cells wrappers that 1095 are compatible with V1 and V2, and defines helper methods for this purpose. 1096 """ 1097 1098 def __init__(self, cell): 1099 super(_RNNCellWrapperV1, self).__init__() 1100 self.cell = cell 1101 if isinstance(cell, trackable.Trackable): 1102 self._track_trackable(self.cell, name="cell") 1103 1104 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 1105 """Calls the wrapped cell and performs the wrapping logic. 1106 1107 This method is called from the wrapper's `call` or `__call__` methods. 1108 1109 Args: 1110 inputs: A tensor with wrapped cell's input. 1111 state: A tensor or tuple of tensors with wrapped cell's state. 1112 cell_call_fn: Wrapped cell's method to use for step computation (cell's 1113 `__call__` or 'call' method). 1114 **kwargs: Additional arguments. 1115 1116 Returns: 1117 A pair containing: 1118 - Output: A tensor with cell's output. 1119 - New state: A tensor or tuple of tensors with new wrapped cell's state. 1120 """ 1121 raise NotImplementedError 1122 1123 def __call__(self, inputs, state, scope=None): 1124 """Runs the RNN cell step computation. 1125 1126 We assume that the wrapped RNNCell is being built within its `__call__` 1127 method. We directly use the wrapped cell's `__call__` in the overridden 1128 wrapper `__call__` method. 1129 1130 This allows to use the wrapped cell and the non-wrapped cell equivalently 1131 when using `__call__`. 1132 1133 Args: 1134 inputs: A tensor with wrapped cell's input. 1135 state: A tensor or tuple of tensors with wrapped cell's state. 1136 scope: VariableScope for the subgraph created in the wrapped cells' 1137 `__call__`. 1138 1139 Returns: 1140 A pair containing: 1141 1142 - Output: A tensor with cell's output. 1143 - New state: A tensor or tuple of tensors with new wrapped cell's state. 1144 """ 1145 return self._call_wrapped_cell( 1146 inputs, state, cell_call_fn=self.cell.__call__, scope=scope) 1147 1148 1149class _RNNCellWrapperV2(keras_layer.AbstractRNNCell): 1150 """Base class for cells wrappers V2 compatibility. 1151 1152 This class along with `_RNNCellWrapperV1` allows to define cells wrappers that 1153 are compatible with V1 and V2, and defines helper methods for this purpose. 1154 """ 1155 1156 def __init__(self, cell, *args, **kwargs): 1157 super(_RNNCellWrapperV2, self).__init__(*args, **kwargs) 1158 self.cell = cell 1159 1160 def call(self, inputs, state, **kwargs): 1161 """Runs the RNN cell step computation. 1162 1163 When `call` is being used, we assume that the wrapper object has been built, 1164 and therefore the wrapped cells has been built via its `build` method and 1165 its `call` method can be used directly. 1166 1167 This allows to use the wrapped cell and the non-wrapped cell equivalently 1168 when using `call` and `build`. 1169 1170 Args: 1171 inputs: A tensor with wrapped cell's input. 1172 state: A tensor or tuple of tensors with wrapped cell's state. 1173 **kwargs: Additional arguments passed to the wrapped cell's `call`. 1174 1175 Returns: 1176 A pair containing: 1177 1178 - Output: A tensor with cell's output. 1179 - New state: A tensor or tuple of tensors with new wrapped cell's state. 1180 """ 1181 return self._call_wrapped_cell( 1182 inputs, state, cell_call_fn=self.cell.call, **kwargs) 1183 1184 def build(self, inputs_shape): 1185 """Builds the wrapped cell.""" 1186 self.cell.build(inputs_shape) 1187 self.built = True 1188 1189 1190class DropoutWrapperBase(object): 1191 """Operator adding dropout to inputs and outputs of the given cell.""" 1192 1193 def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, 1194 state_keep_prob=1.0, variational_recurrent=False, 1195 input_size=None, dtype=None, seed=None, 1196 dropout_state_filter_visitor=None): 1197 """Create a cell with added input, state, and/or output dropout. 1198 1199 If `variational_recurrent` is set to `True` (**NOT** the default behavior), 1200 then the same dropout mask is applied at every step, as described in: 1201 1202 Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in 1203 Recurrent Neural Networks". https://arxiv.org/abs/1512.05287 1204 1205 Otherwise a different dropout mask is applied at every time step. 1206 1207 Note, by default (unless a custom `dropout_state_filter` is provided), 1208 the memory state (`c` component of any `LSTMStateTuple`) passing through 1209 a `DropoutWrapper` is never modified. This behavior is described in the 1210 above article. 1211 1212 Args: 1213 cell: an RNNCell, a projection to output_size is added to it. 1214 input_keep_prob: unit Tensor or float between 0 and 1, input keep 1215 probability; if it is constant and 1, no input dropout will be added. 1216 output_keep_prob: unit Tensor or float between 0 and 1, output keep 1217 probability; if it is constant and 1, no output dropout will be added. 1218 state_keep_prob: unit Tensor or float between 0 and 1, output keep 1219 probability; if it is constant and 1, no output dropout will be added. 1220 State dropout is performed on the outgoing states of the cell. 1221 **Note** the state components to which dropout is applied when 1222 `state_keep_prob` is in `(0, 1)` are also determined by 1223 the argument `dropout_state_filter_visitor` (e.g. by default dropout 1224 is never applied to the `c` component of an `LSTMStateTuple`). 1225 variational_recurrent: Python bool. If `True`, then the same 1226 dropout pattern is applied across all time steps per run call. 1227 If this parameter is set, `input_size` **must** be provided. 1228 input_size: (optional) (possibly nested tuple of) `TensorShape` objects 1229 containing the depth(s) of the input tensors expected to be passed in to 1230 the `DropoutWrapper`. Required and used **iff** 1231 `variational_recurrent = True` and `input_keep_prob < 1`. 1232 dtype: (optional) The `dtype` of the input, state, and output tensors. 1233 Required and used **iff** `variational_recurrent = True`. 1234 seed: (optional) integer, the randomness seed. 1235 dropout_state_filter_visitor: (optional), default: (see below). Function 1236 that takes any hierarchical level of the state and returns 1237 a scalar or depth=1 structure of Python booleans describing 1238 which terms in the state should be dropped out. In addition, if the 1239 function returns `True`, dropout is applied across this sublevel. If 1240 the function returns `False`, dropout is not applied across this entire 1241 sublevel. 1242 Default behavior: perform dropout on all terms except the memory (`c`) 1243 state of `LSTMCellState` objects, and don't try to apply dropout to 1244 `TensorArray` objects: 1245 ``` 1246 def dropout_state_filter_visitor(s): 1247 if isinstance(s, LSTMCellState): 1248 # Never perform dropout on the c state. 1249 return LSTMCellState(c=False, h=True) 1250 elif isinstance(s, TensorArray): 1251 return False 1252 return True 1253 ``` 1254 1255 Raises: 1256 TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided 1257 but not `callable`. 1258 ValueError: if any of the keep_probs are not between 0 and 1. 1259 """ 1260 super(DropoutWrapperBase, self).__init__(cell) 1261 assert_like_rnncell("cell", cell) 1262 1263 if (dropout_state_filter_visitor is not None 1264 and not callable(dropout_state_filter_visitor)): 1265 raise TypeError("dropout_state_filter_visitor must be callable") 1266 self._dropout_state_filter = ( 1267 dropout_state_filter_visitor or _default_dropout_state_filter_visitor) 1268 with ops.name_scope("DropoutWrapperInit"): 1269 def tensor_and_const_value(v): 1270 tensor_value = ops.convert_to_tensor(v) 1271 const_value = tensor_util.constant_value(tensor_value) 1272 return (tensor_value, const_value) 1273 for prob, attr in [(input_keep_prob, "input_keep_prob"), 1274 (state_keep_prob, "state_keep_prob"), 1275 (output_keep_prob, "output_keep_prob")]: 1276 tensor_prob, const_prob = tensor_and_const_value(prob) 1277 if const_prob is not None: 1278 if const_prob < 0 or const_prob > 1: 1279 raise ValueError("Parameter %s must be between 0 and 1: %d" 1280 % (attr, const_prob)) 1281 setattr(self, "_%s" % attr, float(const_prob)) 1282 else: 1283 setattr(self, "_%s" % attr, tensor_prob) 1284 1285 # Set variational_recurrent, seed before running the code below 1286 self._variational_recurrent = variational_recurrent 1287 self._seed = seed 1288 1289 self._recurrent_input_noise = None 1290 self._recurrent_state_noise = None 1291 self._recurrent_output_noise = None 1292 1293 if variational_recurrent: 1294 if dtype is None: 1295 raise ValueError( 1296 "When variational_recurrent=True, dtype must be provided") 1297 1298 def convert_to_batch_shape(s): 1299 # Prepend a 1 for the batch dimension; for recurrent 1300 # variational dropout we use the same dropout mask for all 1301 # batch elements. 1302 return array_ops.concat( 1303 ([1], tensor_shape.TensorShape(s).as_list()), 0) 1304 1305 def batch_noise(s, inner_seed): 1306 shape = convert_to_batch_shape(s) 1307 return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) 1308 1309 if (not isinstance(self._input_keep_prob, numbers.Real) or 1310 self._input_keep_prob < 1.0): 1311 if input_size is None: 1312 raise ValueError( 1313 "When variational_recurrent=True and input_keep_prob < 1.0 or " 1314 "is unknown, input_size must be provided") 1315 self._recurrent_input_noise = _enumerated_map_structure_up_to( 1316 input_size, 1317 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), 1318 input_size) 1319 self._recurrent_state_noise = _enumerated_map_structure_up_to( 1320 cell.state_size, 1321 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), 1322 cell.state_size) 1323 self._recurrent_output_noise = _enumerated_map_structure_up_to( 1324 cell.output_size, 1325 lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), 1326 cell.output_size) 1327 1328 def _gen_seed(self, salt_prefix, index): 1329 if self._seed is None: 1330 return None 1331 salt = "%s_%d" % (salt_prefix, index) 1332 string = (str(self._seed) + salt).encode("utf-8") 1333 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 1334 1335 @property 1336 def wrapped_cell(self): 1337 return self.cell 1338 1339 @property 1340 def state_size(self): 1341 return self.cell.state_size 1342 1343 @property 1344 def output_size(self): 1345 return self.cell.output_size 1346 1347 def zero_state(self, batch_size, dtype): 1348 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 1349 return self.cell.zero_state(batch_size, dtype) 1350 1351 def _variational_recurrent_dropout_value( 1352 self, index, value, noise, keep_prob): 1353 """Performs dropout given the pre-calculated noise tensor.""" 1354 # uniform [keep_prob, 1.0 + keep_prob) 1355 random_tensor = keep_prob + noise 1356 1357 # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) 1358 binary_tensor = math_ops.floor(random_tensor) 1359 ret = math_ops.div(value, keep_prob) * binary_tensor 1360 ret.set_shape(value.get_shape()) 1361 return ret 1362 1363 def _dropout(self, values, salt_prefix, recurrent_noise, keep_prob, 1364 shallow_filtered_substructure=None): 1365 """Decides whether to perform standard dropout or recurrent dropout.""" 1366 1367 if shallow_filtered_substructure is None: 1368 # Put something so we traverse the entire structure; inside the 1369 # dropout function we check to see if leafs of this are bool or not. 1370 shallow_filtered_substructure = values 1371 1372 if not self._variational_recurrent: 1373 def dropout(i, do_dropout, v): 1374 if not isinstance(do_dropout, bool) or do_dropout: 1375 return nn_ops.dropout( 1376 v, keep_prob=keep_prob, seed=self._gen_seed(salt_prefix, i)) 1377 else: 1378 return v 1379 return _enumerated_map_structure_up_to( 1380 shallow_filtered_substructure, dropout, 1381 *[shallow_filtered_substructure, values]) 1382 else: 1383 def dropout(i, do_dropout, v, n): 1384 if not isinstance(do_dropout, bool) or do_dropout: 1385 return self._variational_recurrent_dropout_value(i, v, n, keep_prob) 1386 else: 1387 return v 1388 return _enumerated_map_structure_up_to( 1389 shallow_filtered_substructure, dropout, 1390 *[shallow_filtered_substructure, values, recurrent_noise]) 1391 1392 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 1393 """Runs the wrapped cell and applies dropout. 1394 1395 Args: 1396 inputs: A tensor with wrapped cell's input. 1397 state: A tensor or tuple of tensors with wrapped cell's state. 1398 cell_call_fn: Wrapped cell's method to use for step computation (cell's 1399 `__call__` or 'call' method). 1400 **kwargs: Additional arguments. 1401 1402 Returns: 1403 A pair containing: 1404 1405 - Output: A tensor with cell's output. 1406 - New state: A tensor or tuple of tensors with new wrapped cell's state. 1407 """ 1408 def _should_dropout(p): 1409 return (not isinstance(p, float)) or p < 1 1410 1411 if _should_dropout(self._input_keep_prob): 1412 inputs = self._dropout(inputs, "input", 1413 self._recurrent_input_noise, 1414 self._input_keep_prob) 1415 output, new_state = cell_call_fn(inputs, state, **kwargs) 1416 if _should_dropout(self._state_keep_prob): 1417 # Identify which subsets of the state to perform dropout on and 1418 # which ones to keep. 1419 shallow_filtered_substructure = nest.get_traverse_shallow_structure( 1420 self._dropout_state_filter, new_state) 1421 new_state = self._dropout(new_state, "state", 1422 self._recurrent_state_noise, 1423 self._state_keep_prob, 1424 shallow_filtered_substructure) 1425 if _should_dropout(self._output_keep_prob): 1426 output = self._dropout(output, "output", 1427 self._recurrent_output_noise, 1428 self._output_keep_prob) 1429 return output, new_state 1430 1431 1432@tf_export(v1=["nn.rnn_cell.DropoutWrapper"]) 1433class DropoutWrapper(DropoutWrapperBase, _RNNCellWrapperV1): 1434 """Operator adding dropout to inputs and outputs of the given cell.""" 1435 1436 def __init__(self, *args, **kwargs): 1437 super(DropoutWrapper, self).__init__(*args, **kwargs) 1438 1439 __init__.__doc__ = DropoutWrapperBase.__init__.__doc__ 1440 1441 1442@tf_export("nn.RNNCellDropoutWrapper", v1=[]) 1443class DropoutWrapperV2(DropoutWrapperBase, _RNNCellWrapperV2): 1444 """Operator adding dropout to inputs and outputs of the given cell.""" 1445 1446 def __init__(self, *args, **kwargs): 1447 super(DropoutWrapperV2, self).__init__(*args, **kwargs) 1448 1449 __init__.__doc__ = DropoutWrapperBase.__init__.__doc__ 1450 1451 1452class ResidualWrapperBase(object): 1453 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 1454 1455 def __init__(self, cell, residual_fn=None): 1456 """Constructs a `ResidualWrapper` for `cell`. 1457 1458 Args: 1459 cell: An instance of `RNNCell`. 1460 residual_fn: (Optional) The function to map raw cell inputs and raw cell 1461 outputs to the actual cell outputs of the residual network. 1462 Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs 1463 and outputs. 1464 """ 1465 super(ResidualWrapperBase, self).__init__(cell) 1466 self._residual_fn = residual_fn 1467 1468 @property 1469 def state_size(self): 1470 return self.cell.state_size 1471 1472 @property 1473 def output_size(self): 1474 return self.cell.output_size 1475 1476 def zero_state(self, batch_size, dtype): 1477 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 1478 return self.cell.zero_state(batch_size, dtype) 1479 1480 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 1481 """Run the cell and then apply the residual_fn on its inputs to its outputs. 1482 1483 Args: 1484 inputs: cell inputs. 1485 state: cell state. 1486 cell_call_fn: Wrapped cell's method to use for step computation (cell's 1487 `__call__` or 'call' method). 1488 **kwargs: Additional arguments passed to the wrapped cell's `call`. 1489 1490 Returns: 1491 Tuple of cell outputs and new state. 1492 1493 Raises: 1494 TypeError: If cell inputs and outputs have different structure (type). 1495 ValueError: If cell inputs and outputs have different structure (value). 1496 """ 1497 outputs, new_state = cell_call_fn(inputs, state, **kwargs) 1498 # Ensure shapes match 1499 def assert_shape_match(inp, out): 1500 inp.get_shape().assert_is_compatible_with(out.get_shape()) 1501 def default_residual_fn(inputs, outputs): 1502 nest.assert_same_structure(inputs, outputs) 1503 nest.map_structure(assert_shape_match, inputs, outputs) 1504 return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) 1505 res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) 1506 return (res_outputs, new_state) 1507 1508 1509@tf_export(v1=["nn.rnn_cell.ResidualWrapper"]) 1510class ResidualWrapper(ResidualWrapperBase, _RNNCellWrapperV1): 1511 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 1512 1513 def __init__(self, *args, **kwargs): 1514 super(ResidualWrapper, self).__init__(*args, **kwargs) 1515 1516 __init__.__doc__ = ResidualWrapperBase.__init__.__doc__ 1517 1518 1519@tf_export("nn.RNNCellResidualWrapper", v1=[]) 1520class ResidualWrapperV2(ResidualWrapperBase, _RNNCellWrapperV2): 1521 """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 1522 1523 def __init__(self, *args, **kwargs): 1524 super(ResidualWrapperV2, self).__init__(*args, **kwargs) 1525 1526 __init__.__doc__ = ResidualWrapperBase.__init__.__doc__ 1527 1528 1529class DeviceWrapperBase(object): 1530 """Operator that ensures an RNNCell runs on a particular device.""" 1531 1532 def __init__(self, cell, device): 1533 """Construct a `DeviceWrapper` for `cell` with device `device`. 1534 1535 Ensures the wrapped `cell` is called with `tf.device(device)`. 1536 1537 Args: 1538 cell: An instance of `RNNCell`. 1539 device: A device string or function, for passing to `tf.device`. 1540 """ 1541 super(DeviceWrapperBase, self).__init__(cell) 1542 self._device = device 1543 1544 @property 1545 def state_size(self): 1546 return self.cell.state_size 1547 1548 @property 1549 def output_size(self): 1550 return self.cell.output_size 1551 1552 def zero_state(self, batch_size, dtype): 1553 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 1554 with ops.device(self._device): 1555 return self.cell.zero_state(batch_size, dtype) 1556 1557 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs): 1558 """Run the cell on specified device.""" 1559 with ops.device(self._device): 1560 return cell_call_fn(inputs, state, **kwargs) 1561 1562 1563@tf_export(v1=["nn.rnn_cell.DeviceWrapper"]) 1564class DeviceWrapper(DeviceWrapperBase, _RNNCellWrapperV1): 1565 1566 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 1567 super(DeviceWrapper, self).__init__(*args, **kwargs) 1568 1569 __init__.__doc__ = DeviceWrapperBase.__init__.__doc__ 1570 1571 1572@tf_export("nn.RNNCellDeviceWrapper", v1=[]) 1573class DeviceWrapperV2(DeviceWrapperBase, _RNNCellWrapperV2): 1574 """Operator that ensures an RNNCell runs on a particular device.""" 1575 1576 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation 1577 super(DeviceWrapperV2, self).__init__(*args, **kwargs) 1578 1579 __init__.__doc__ = DeviceWrapperBase.__init__.__doc__ 1580 1581 1582@tf_export(v1=["nn.rnn_cell.MultiRNNCell"]) 1583class MultiRNNCell(RNNCell): 1584 """RNN cell composed sequentially of multiple simple cells. 1585 1586 Example: 1587 1588 ```python 1589 num_units = [128, 64] 1590 cells = [BasicLSTMCell(num_units=n) for n in num_units] 1591 stacked_rnn_cell = MultiRNNCell(cells) 1592 ``` 1593 """ 1594 1595 @deprecated(None, "This class is equivalent as " 1596 "tf.keras.layers.StackedRNNCells, and will be replaced by " 1597 "that in Tensorflow 2.0.") 1598 def __init__(self, cells, state_is_tuple=True): 1599 """Create a RNN cell composed sequentially of a number of RNNCells. 1600 1601 Args: 1602 cells: list of RNNCells that will be composed in this order. 1603 state_is_tuple: If True, accepted and returned states are n-tuples, where 1604 `n = len(cells)`. If False, the states are all 1605 concatenated along the column axis. This latter behavior will soon be 1606 deprecated. 1607 1608 Raises: 1609 ValueError: if cells is empty (not allowed), or at least one of the cells 1610 returns a state tuple but the flag `state_is_tuple` is `False`. 1611 """ 1612 super(MultiRNNCell, self).__init__() 1613 if not cells: 1614 raise ValueError("Must specify at least one cell for MultiRNNCell.") 1615 if not nest.is_sequence(cells): 1616 raise TypeError( 1617 "cells must be a list or tuple, but saw: %s." % cells) 1618 1619 if len(set([id(cell) for cell in cells])) < len(cells): 1620 logging.log_first_n(logging.WARN, 1621 "At least two cells provided to MultiRNNCell " 1622 "are the same object and will share weights.", 1) 1623 1624 self._cells = cells 1625 for cell_number, cell in enumerate(self._cells): 1626 # Add Trackable dependencies on these cells so their variables get 1627 # saved with this object when using object-based saving. 1628 if isinstance(cell, trackable.Trackable): 1629 # TODO(allenl): Track down non-Trackable callers. 1630 self._track_trackable(cell, name="cell-%d" % (cell_number,)) 1631 self._state_is_tuple = state_is_tuple 1632 if not state_is_tuple: 1633 if any(nest.is_sequence(c.state_size) for c in self._cells): 1634 raise ValueError("Some cells return tuples of states, but the flag " 1635 "state_is_tuple is not set. State sizes are: %s" 1636 % str([c.state_size for c in self._cells])) 1637 1638 @property 1639 def state_size(self): 1640 if self._state_is_tuple: 1641 return tuple(cell.state_size for cell in self._cells) 1642 else: 1643 return sum(cell.state_size for cell in self._cells) 1644 1645 @property 1646 def output_size(self): 1647 return self._cells[-1].output_size 1648 1649 def zero_state(self, batch_size, dtype): 1650 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 1651 if self._state_is_tuple: 1652 return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells) 1653 else: 1654 # We know here that state_size of each cell is not a tuple and 1655 # presumably does not contain TensorArrays or anything else fancy 1656 return super(MultiRNNCell, self).zero_state(batch_size, dtype) 1657 1658 @property 1659 def trainable_weights(self): 1660 if not self.trainable: 1661 return [] 1662 weights = [] 1663 for cell in self._cells: 1664 if isinstance(cell, base_layer.Layer): 1665 weights += cell.trainable_weights 1666 return weights 1667 1668 @property 1669 def non_trainable_weights(self): 1670 weights = [] 1671 for cell in self._cells: 1672 if isinstance(cell, base_layer.Layer): 1673 weights += cell.non_trainable_weights 1674 if not self.trainable: 1675 trainable_weights = [] 1676 for cell in self._cells: 1677 if isinstance(cell, base_layer.Layer): 1678 trainable_weights += cell.trainable_weights 1679 return trainable_weights + weights 1680 return weights 1681 1682 def call(self, inputs, state): 1683 """Run this multi-layer cell on inputs, starting from state.""" 1684 cur_state_pos = 0 1685 cur_inp = inputs 1686 new_states = [] 1687 for i, cell in enumerate(self._cells): 1688 with vs.variable_scope("cell_%d" % i): 1689 if self._state_is_tuple: 1690 if not nest.is_sequence(state): 1691 raise ValueError( 1692 "Expected state to be a tuple of length %d, but received: %s" % 1693 (len(self.state_size), state)) 1694 cur_state = state[i] 1695 else: 1696 cur_state = array_ops.slice(state, [0, cur_state_pos], 1697 [-1, cell.state_size]) 1698 cur_state_pos += cell.state_size 1699 cur_inp, new_state = cell(cur_inp, cur_state) 1700 new_states.append(new_state) 1701 1702 new_states = (tuple(new_states) if self._state_is_tuple else 1703 array_ops.concat(new_states, 1)) 1704 1705 return cur_inp, new_states 1706 1707 1708def _check_rnn_cell_input_dtypes(inputs): 1709 """Check whether the input tensors are with supported dtypes. 1710 1711 Default RNN cells only support floats and complex as its dtypes since the 1712 activation function (tanh and sigmoid) only allow those types. This function 1713 will throw a proper error message if the inputs is not in a supported type. 1714 1715 Args: 1716 inputs: tensor or nested structure of tensors that are feed to RNN cell as 1717 input or state. 1718 1719 Raises: 1720 ValueError: if any of the input tensor are not having dtypes of float or 1721 complex. 1722 """ 1723 for t in nest.flatten(inputs): 1724 _check_supported_dtypes(t.dtype) 1725 1726 1727def _check_supported_dtypes(dtype): 1728 if dtype is None: 1729 return 1730 dtype = dtypes.as_dtype(dtype) 1731 if not (dtype.is_floating or dtype.is_complex): 1732 raise ValueError("RNN cell only supports floating point inputs, " 1733 "but saw dtype: %s" % dtype) 1734