1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""RNN helpers for TensorFlow models.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.eager import context 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import control_flow_util 29from tensorflow.python.ops import control_flow_util_v2 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import rnn_cell_impl 32from tensorflow.python.ops import tensor_array_ops 33from tensorflow.python.ops import variable_scope as vs 34from tensorflow.python.util import deprecation 35from tensorflow.python.util import dispatch 36from tensorflow.python.util import nest 37from tensorflow.python.util.tf_export import tf_export 38 39# pylint: disable=protected-access 40_concat = rnn_cell_impl._concat 41# pylint: enable=protected-access 42 43 44def _transpose_batch_time(x): 45 """Transposes the batch and time dimensions of a Tensor. 46 47 If the input tensor has rank < 2 it returns the original tensor. Retains as 48 much of the static shape information as possible. 49 50 Args: 51 x: A Tensor. 52 53 Returns: 54 x transposed along the first two dimensions. 55 """ 56 x_static_shape = x.get_shape() 57 if x_static_shape.rank is not None and x_static_shape.rank < 2: 58 return x 59 60 x_rank = array_ops.rank(x) 61 x_t = array_ops.transpose( 62 x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0)) 63 x_t.set_shape( 64 tensor_shape.TensorShape( 65 [x_static_shape.dims[1].value, 66 x_static_shape.dims[0].value]).concatenate(x_static_shape[2:])) 67 return x_t 68 69 70def _best_effort_input_batch_size(flat_input): 71 """Get static input batch size if available, with fallback to the dynamic one. 72 73 Args: 74 flat_input: An iterable of time major input Tensors of shape `[max_time, 75 batch_size, ...]`. All inputs should have compatible batch sizes. 76 77 Returns: 78 The batch size in Python integer if available, or a scalar Tensor otherwise. 79 80 Raises: 81 ValueError: if there is any input with an invalid shape. 82 """ 83 for input_ in flat_input: 84 shape = input_.shape 85 if shape.rank is None: 86 continue 87 if shape.rank < 2: 88 raise ValueError("Expected input tensor %s to have rank at least 2" % 89 input_) 90 batch_size = shape.dims[1].value 91 if batch_size is not None: 92 return batch_size 93 # Fallback to the dynamic batch size of the first input. 94 return array_ops.shape(flat_input[0])[1] 95 96 97def _infer_state_dtype(explicit_dtype, state): 98 """Infer the dtype of an RNN state. 99 100 Args: 101 explicit_dtype: explicitly declared dtype or None. 102 state: RNN's hidden state. Must be a Tensor or a nested iterable containing 103 Tensors. 104 105 Returns: 106 dtype: inferred dtype of hidden state. 107 108 Raises: 109 ValueError: if `state` has heterogeneous dtypes or is empty. 110 """ 111 if explicit_dtype is not None: 112 return explicit_dtype 113 elif nest.is_sequence(state): 114 inferred_dtypes = [element.dtype for element in nest.flatten(state)] 115 if not inferred_dtypes: 116 raise ValueError("Unable to infer dtype from empty state.") 117 all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes) 118 if not all_same: 119 raise ValueError( 120 "State has tensors of different inferred_dtypes. Unable to infer a " 121 "single representative dtype.") 122 return inferred_dtypes[0] 123 else: 124 return state.dtype 125 126 127def _maybe_tensor_shape_from_tensor(shape): 128 if isinstance(shape, ops.Tensor): 129 return tensor_shape.as_shape(tensor_util.constant_value(shape)) 130 else: 131 return shape 132 133 134def _should_cache(): 135 """Returns True if a default caching device should be set, otherwise False.""" 136 if context.executing_eagerly(): 137 return False 138 # Don't set a caching device when running in a loop, since it is possible that 139 # train steps could be wrapped in a tf.while_loop. In that scenario caching 140 # prevents forward computations in loop iterations from re-reading the 141 # updated weights. 142 graph = ops.get_default_graph() 143 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access 144 in_v1_while_loop = ( 145 control_flow_util.GetContainingWhileContext(ctxt) is not None) 146 in_v2_while_loop = control_flow_util_v2.in_while_loop_defun(graph) 147 return not in_v1_while_loop and not in_v2_while_loop 148 149 150# pylint: disable=unused-argument 151def _rnn_step(time, 152 sequence_length, 153 min_sequence_length, 154 max_sequence_length, 155 zero_output, 156 state, 157 call_cell, 158 state_size, 159 skip_conditionals=False): 160 """Calculate one step of a dynamic RNN minibatch. 161 162 Returns an (output, state) pair conditioned on `sequence_length`. 163 When skip_conditionals=False, the pseudocode is something like: 164 165 if t >= max_sequence_length: 166 return (zero_output, state) 167 if t < min_sequence_length: 168 return call_cell() 169 170 # Selectively output zeros or output, old state or new state depending 171 # on whether we've finished calculating each row. 172 new_output, new_state = call_cell() 173 final_output = np.vstack([ 174 zero_output if time >= sequence_length[r] else new_output_r 175 for r, new_output_r in enumerate(new_output) 176 ]) 177 final_state = np.vstack([ 178 state[r] if time >= sequence_length[r] else new_state_r 179 for r, new_state_r in enumerate(new_state) 180 ]) 181 return (final_output, final_state) 182 183 Args: 184 time: int32 `Tensor` scalar. 185 sequence_length: int32 `Tensor` vector of size [batch_size]. 186 min_sequence_length: int32 `Tensor` scalar, min of sequence_length. 187 max_sequence_length: int32 `Tensor` scalar, max of sequence_length. 188 zero_output: `Tensor` vector of shape [output_size]. 189 state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, 190 or a list/tuple of such tensors. 191 call_cell: lambda returning tuple of (new_output, new_state) where 192 new_output is a `Tensor` matrix of shape `[batch_size, output_size]`. 193 new_state is a `Tensor` matrix of shape `[batch_size, state_size]`. 194 state_size: The `cell.state_size` associated with the state. 195 skip_conditionals: Python bool, whether to skip using the conditional 196 calculations. This is useful for `dynamic_rnn`, where the input tensor 197 matches `max_sequence_length`, and using conditionals just slows 198 everything down. 199 200 Returns: 201 A tuple of (`final_output`, `final_state`) as given by the pseudocode above: 202 final_output is a `Tensor` matrix of shape [batch_size, output_size] 203 final_state is either a single `Tensor` matrix, or a tuple of such 204 matrices (matching length and shapes of input `state`). 205 206 Raises: 207 ValueError: If the cell returns a state tuple whose length does not match 208 that returned by `state_size`. 209 """ 210 211 # Convert state to a list for ease of use 212 flat_state = nest.flatten(state) 213 flat_zero_output = nest.flatten(zero_output) 214 215 # Vector describing which batch entries are finished. 216 copy_cond = time >= sequence_length 217 218 def _copy_one_through(output, new_output): 219 # TensorArray and scalar get passed through. 220 if isinstance(output, tensor_array_ops.TensorArray): 221 return new_output 222 if output.shape.rank == 0: 223 return new_output 224 # Otherwise propagate the old or the new value. 225 with ops.colocate_with(new_output): 226 return array_ops.where(copy_cond, output, new_output) 227 228 def _copy_some_through(flat_new_output, flat_new_state): 229 # Use broadcasting select to determine which values should get 230 # the previous state & zero output, and which values should get 231 # a calculated state & output. 232 flat_new_output = [ 233 _copy_one_through(zero_output, new_output) 234 for zero_output, new_output in zip(flat_zero_output, flat_new_output) 235 ] 236 flat_new_state = [ 237 _copy_one_through(state, new_state) 238 for state, new_state in zip(flat_state, flat_new_state) 239 ] 240 return flat_new_output + flat_new_state 241 242 def _maybe_copy_some_through(): 243 """Run RNN step. Pass through either no or some past state.""" 244 new_output, new_state = call_cell() 245 246 nest.assert_same_structure(zero_output, new_output) 247 nest.assert_same_structure(state, new_state) 248 249 flat_new_state = nest.flatten(new_state) 250 flat_new_output = nest.flatten(new_output) 251 return control_flow_ops.cond( 252 # if t < min_seq_len: calculate and return everything 253 time < min_sequence_length, 254 lambda: flat_new_output + flat_new_state, 255 # else copy some of it through 256 lambda: _copy_some_through(flat_new_output, flat_new_state)) 257 258 # TODO(ebrevdo): skipping these conditionals may cause a slowdown, 259 # but benefits from removing cond() and its gradient. We should 260 # profile with and without this switch here. 261 if skip_conditionals: 262 # Instead of using conditionals, perform the selective copy at all time 263 # steps. This is faster when max_seq_len is equal to the number of unrolls 264 # (which is typical for dynamic_rnn). 265 new_output, new_state = call_cell() 266 nest.assert_same_structure(zero_output, new_output) 267 nest.assert_same_structure(state, new_state) 268 new_state = nest.flatten(new_state) 269 new_output = nest.flatten(new_output) 270 final_output_and_state = _copy_some_through(new_output, new_state) 271 else: 272 empty_update = lambda: flat_zero_output + flat_state 273 final_output_and_state = control_flow_ops.cond( 274 # if t >= max_seq_len: copy all state through, output zeros 275 time >= max_sequence_length, 276 empty_update, 277 # otherwise calculation is required: copy some or all of it through 278 _maybe_copy_some_through) 279 280 if len(final_output_and_state) != len(flat_zero_output) + len(flat_state): 281 raise ValueError("Internal error: state and output were not concatenated " 282 "correctly.") 283 final_output = final_output_and_state[:len(flat_zero_output)] 284 final_state = final_output_and_state[len(flat_zero_output):] 285 286 for output, flat_output in zip(final_output, flat_zero_output): 287 output.set_shape(flat_output.get_shape()) 288 for substate, flat_substate in zip(final_state, flat_state): 289 if not isinstance(substate, tensor_array_ops.TensorArray): 290 substate.set_shape(flat_substate.get_shape()) 291 292 final_output = nest.pack_sequence_as( 293 structure=zero_output, flat_sequence=final_output) 294 final_state = nest.pack_sequence_as( 295 structure=state, flat_sequence=final_state) 296 297 return final_output, final_state 298 299 300def _reverse_seq(input_seq, lengths): 301 """Reverse a list of Tensors up to specified lengths. 302 303 Args: 304 input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features) 305 or nested tuples of tensors. 306 lengths: A `Tensor` of dimension batch_size, containing lengths for each 307 sequence in the batch. If "None" is specified, simply reverses the list. 308 309 Returns: 310 time-reversed sequence 311 """ 312 if lengths is None: 313 return list(reversed(input_seq)) 314 315 flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq) 316 317 flat_results = [[] for _ in range(len(input_seq))] 318 for sequence in zip(*flat_input_seq): 319 input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank) 320 for input_ in sequence: 321 input_shape.assert_is_compatible_with(input_.get_shape()) 322 input_.set_shape(input_shape) 323 324 # Join into (time, batch_size, depth) 325 s_joined = array_ops.stack(sequence) 326 327 # Reverse along dimension 0 328 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) 329 # Split again into list 330 result = array_ops.unstack(s_reversed) 331 for r, flat_result in zip(result, flat_results): 332 r.set_shape(input_shape) 333 flat_result.append(r) 334 335 results = [ 336 nest.pack_sequence_as(structure=input_, flat_sequence=flat_result) 337 for input_, flat_result in zip(input_seq, flat_results) 338 ] 339 return results 340 341 342@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 343 "keras.layers.RNN(cell))`, which is equivalent to " 344 "this API") 345@tf_export(v1=["nn.bidirectional_dynamic_rnn"]) 346@dispatch.add_dispatch_support 347def bidirectional_dynamic_rnn(cell_fw, 348 cell_bw, 349 inputs, 350 sequence_length=None, 351 initial_state_fw=None, 352 initial_state_bw=None, 353 dtype=None, 354 parallel_iterations=None, 355 swap_memory=False, 356 time_major=False, 357 scope=None): 358 """Creates a dynamic version of bidirectional recurrent neural network. 359 360 Takes input and builds independent forward and backward RNNs. The input_size 361 of forward and backward cell must match. The initial state for both directions 362 is zero by default (but can be set optionally) and no intermediate states are 363 ever returned -- the network is fully unrolled for the given (passed in) 364 length(s) of the sequence(s) or completely unrolled if length(s) is not 365 given. 366 367 Args: 368 cell_fw: An instance of RNNCell, to be used for forward direction. 369 cell_bw: An instance of RNNCell, to be used for backward direction. 370 inputs: The RNN inputs. 371 If time_major == False (default), this must be a tensor of shape: 372 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 373 If time_major == True, this must be a tensor of shape: `[max_time, 374 batch_size, ...]`, or a nested tuple of such elements. 375 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 376 containing the actual lengths for each of the sequences in the batch. If 377 not provided, all batch entries are assumed to be full sequences; and time 378 reversal is applied from time `0` to `max_time` for each sequence. 379 initial_state_fw: (optional) An initial state for the forward RNN. This must 380 be a tensor of appropriate type and shape `[batch_size, 381 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 382 tuple of tensors having shapes `[batch_size, s] for s in 383 cell_fw.state_size`. 384 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 385 corresponding properties of `cell_bw`. 386 dtype: (optional) The data type for the initial states and expected output. 387 Required if initial_states are not provided or RNN states have a 388 heterogeneous dtype. 389 parallel_iterations: (Default: 32). The number of iterations to run in 390 parallel. Those operations which do not have any temporal dependency and 391 can be run in parallel, will be. This parameter trades off time for 392 space. Values >> 1 use more memory but take less time, while smaller 393 values use less memory but computations take longer. 394 swap_memory: Transparently swap the tensors produced in forward inference 395 but needed for back prop from GPU to CPU. This allows training RNNs which 396 would typically not fit on a single GPU, with very minimal (or no) 397 performance penalty. 398 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 399 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 400 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 401 `time_major = True` is a bit more efficient because it avoids transposes 402 at the beginning and end of the RNN calculation. However, most TensorFlow 403 data is batch-major, so by default this function accepts input and emits 404 output in batch-major form. 405 scope: VariableScope for the created subgraph; defaults to 406 "bidirectional_rnn" 407 408 Returns: 409 A tuple (outputs, output_states) where: 410 outputs: A tuple (output_fw, output_bw) containing the forward and 411 the backward rnn output `Tensor`. 412 If time_major == False (default), 413 output_fw will be a `Tensor` shaped: 414 `[batch_size, max_time, cell_fw.output_size]` 415 and output_bw will be a `Tensor` shaped: 416 `[batch_size, max_time, cell_bw.output_size]`. 417 If time_major == True, 418 output_fw will be a `Tensor` shaped: 419 `[max_time, batch_size, cell_fw.output_size]` 420 and output_bw will be a `Tensor` shaped: 421 `[max_time, batch_size, cell_bw.output_size]`. 422 It returns a tuple instead of a single concatenated `Tensor`, unlike 423 in the `bidirectional_rnn`. If the concatenated one is preferred, 424 the forward and backward outputs can be concatenated as 425 `tf.concat(outputs, 2)`. 426 output_states: A tuple (output_state_fw, output_state_bw) containing 427 the forward and the backward final states of bidirectional rnn. 428 429 Raises: 430 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 431 """ 432 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 433 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 434 435 with vs.variable_scope(scope or "bidirectional_rnn"): 436 # Forward direction 437 with vs.variable_scope("fw") as fw_scope: 438 output_fw, output_state_fw = dynamic_rnn( 439 cell=cell_fw, 440 inputs=inputs, 441 sequence_length=sequence_length, 442 initial_state=initial_state_fw, 443 dtype=dtype, 444 parallel_iterations=parallel_iterations, 445 swap_memory=swap_memory, 446 time_major=time_major, 447 scope=fw_scope) 448 449 # Backward direction 450 if not time_major: 451 time_axis = 1 452 batch_axis = 0 453 else: 454 time_axis = 0 455 batch_axis = 1 456 457 def _reverse(input_, seq_lengths, seq_axis, batch_axis): 458 if seq_lengths is not None: 459 return array_ops.reverse_sequence( 460 input=input_, 461 seq_lengths=seq_lengths, 462 seq_axis=seq_axis, 463 batch_axis=batch_axis) 464 else: 465 return array_ops.reverse(input_, axis=[seq_axis]) 466 467 with vs.variable_scope("bw") as bw_scope: 468 469 def _map_reverse(inp): 470 return _reverse( 471 inp, 472 seq_lengths=sequence_length, 473 seq_axis=time_axis, 474 batch_axis=batch_axis) 475 476 inputs_reverse = nest.map_structure(_map_reverse, inputs) 477 tmp, output_state_bw = dynamic_rnn( 478 cell=cell_bw, 479 inputs=inputs_reverse, 480 sequence_length=sequence_length, 481 initial_state=initial_state_bw, 482 dtype=dtype, 483 parallel_iterations=parallel_iterations, 484 swap_memory=swap_memory, 485 time_major=time_major, 486 scope=bw_scope) 487 488 output_bw = _reverse( 489 tmp, 490 seq_lengths=sequence_length, 491 seq_axis=time_axis, 492 batch_axis=batch_axis) 493 494 outputs = (output_fw, output_bw) 495 output_states = (output_state_fw, output_state_bw) 496 497 return (outputs, output_states) 498 499 500@deprecation.deprecated( 501 None, 502 "Please use `keras.layers.RNN(cell)`, which is equivalent to this API") 503@tf_export(v1=["nn.dynamic_rnn"]) 504@dispatch.add_dispatch_support 505def dynamic_rnn(cell, 506 inputs, 507 sequence_length=None, 508 initial_state=None, 509 dtype=None, 510 parallel_iterations=None, 511 swap_memory=False, 512 time_major=False, 513 scope=None): 514 """Creates a recurrent neural network specified by RNNCell `cell`. 515 516 Performs fully dynamic unrolling of `inputs`. 517 518 Example: 519 520 ```python 521 # create a BasicRNNCell 522 rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) 523 524 # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 525 526 # defining initial state 527 initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 528 529 # 'state' is a tensor of shape [batch_size, cell_state_size] 530 outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data, 531 initial_state=initial_state, 532 dtype=tf.float32) 533 ``` 534 535 ```python 536 # create 2 LSTMCells 537 rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 538 539 # create a RNN cell composed sequentially of a number of RNNCells 540 multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) 541 542 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 543 # 'state' is a N-tuple where N is the number of LSTMCells containing a 544 # tf.nn.rnn_cell.LSTMStateTuple for each cell 545 outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, 546 inputs=data, 547 dtype=tf.float32) 548 ``` 549 550 551 Args: 552 cell: An instance of RNNCell. 553 inputs: The RNN inputs. 554 If `time_major == False` (default), this must be a `Tensor` of shape: 555 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 556 If `time_major == True`, this must be a `Tensor` of shape: `[max_time, 557 batch_size, ...]`, or a nested tuple of such elements. This may also be 558 a (possibly nested) tuple of Tensors satisfying this property. The 559 first two dimensions must match across all the inputs, but otherwise the 560 ranks and other shape components may differ. In this case, input to 561 `cell` at each time-step will replicate the structure of these tuples, 562 except for the time dimension (from which the time is taken). The input 563 to `cell` at each time step will be a `Tensor` or (possibly nested) 564 tuple of Tensors each with dimensions `[batch_size, ...]`. 565 sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used 566 to copy-through state and zero-out outputs when past a batch element's 567 sequence length. This parameter enables users to extract the last valid 568 state and properly padded outputs, so it is provided for correctness. 569 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 570 is an integer, this must be a `Tensor` of appropriate type and shape 571 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 572 should be a tuple of tensors having shapes `[batch_size, s] for s in 573 cell.state_size`. 574 dtype: (optional) The data type for the initial state and expected output. 575 Required if initial_state is not provided or RNN state has a heterogeneous 576 dtype. 577 parallel_iterations: (Default: 32). The number of iterations to run in 578 parallel. Those operations which do not have any temporal dependency and 579 can be run in parallel, will be. This parameter trades off time for 580 space. Values >> 1 use more memory but take less time, while smaller 581 values use less memory but computations take longer. 582 swap_memory: Transparently swap the tensors produced in forward inference 583 but needed for back prop from GPU to CPU. This allows training RNNs which 584 would typically not fit on a single GPU, with very minimal (or no) 585 performance penalty. 586 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 587 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 588 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 589 `time_major = True` is a bit more efficient because it avoids transposes 590 at the beginning and end of the RNN calculation. However, most TensorFlow 591 data is batch-major, so by default this function accepts input and emits 592 output in batch-major form. 593 scope: VariableScope for the created subgraph; defaults to "rnn". 594 595 Returns: 596 A pair (outputs, state) where: 597 598 outputs: The RNN output `Tensor`. 599 600 If time_major == False (default), this will be a `Tensor` shaped: 601 `[batch_size, max_time, cell.output_size]`. 602 603 If time_major == True, this will be a `Tensor` shaped: 604 `[max_time, batch_size, cell.output_size]`. 605 606 Note, if `cell.output_size` is a (possibly nested) tuple of integers 607 or `TensorShape` objects, then `outputs` will be a tuple having the 608 same structure as `cell.output_size`, containing Tensors having shapes 609 corresponding to the shape data in `cell.output_size`. 610 611 state: The final state. If `cell.state_size` is an int, this 612 will be shaped `[batch_size, cell.state_size]`. If it is a 613 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 614 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 615 be a tuple having the corresponding shapes. If cells are `LSTMCells` 616 `state` will be a tuple containing a `LSTMStateTuple` for each cell. 617 618 Raises: 619 TypeError: If `cell` is not an instance of RNNCell. 620 ValueError: If inputs is None or an empty list. 621 622 @compatibility(TF2) 623 `tf.compat.v1.nn.dynamic_rnn` is not compatible with eager execution and 624 `tf.function`. Please use `tf.keras.layers.RNN` instead for TF2 migration. 625 Take LSTM as an example, you can instantiate a `tf.keras.layers.RNN` layer 626 with `tf.keras.layers.LSTMCell`, or directly via `tf.keras.layers.LSTM`. Once 627 the keras layer is created, you can get the output and states by calling 628 the layer with input and states. Please refer to [this 629 guide](https://www.tensorflow.org/guide/keras/rnn) for more details about 630 Keras RNN. You can also find more details about the difference and comparison 631 between Keras RNN and TF compat v1 rnn in [this 632 document](https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md) 633 634 #### Structural Mapping to Native TF2 635 636 Before: 637 638 ```python 639 # create 2 LSTMCells 640 rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 641 642 # create a RNN cell composed sequentially of a number of RNNCells 643 multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) 644 645 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 646 # 'state' is a N-tuple where N is the number of LSTMCells containing a 647 # tf.nn.rnn_cell.LSTMStateTuple for each cell 648 outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, 649 inputs=data, 650 dtype=tf.float32) 651 ``` 652 653 After: 654 655 ```python 656 # RNN layer can take a list of cells, which will then stack them together. 657 # By default, keras RNN will only return the last timestep output and will not 658 # return states. If you need whole time sequence output as well as the states, 659 # you can set `return_sequences` and `return_state` to True. 660 rnn_layer = tf.keras.layers.RNN([tf.keras.layers.LSTMCell(128), 661 tf.keras.layers.LSTMCell(256)], 662 return_sequences=True, 663 return_state=True) 664 outputs, output_states = rnn_layer(inputs, states) 665 ``` 666 667 #### How to Map Arguments 668 669 | TF1 Arg Name | TF2 Arg Name | Note | 670 | :-------------------- | :-------------- | :------------------------------- | 671 | `cell` | `cell` | In the RNN layer constructor | 672 | `inputs` | `inputs` | In the RNN layer `__call__` | 673 | `sequence_length` | Not used | Adding masking layer before RNN : 674 : : : to achieve the same result. : 675 | `initial_state` | `initial_state` | In the RNN layer `__call__` | 676 | `dtype` | `dtype` | In the RNN layer constructor | 677 | `parallel_iterations` | Not supported | | 678 | `swap_memory` | Not supported | | 679 | `time_major` | `time_major` | In the RNN layer constructor | 680 | `scope` | Not supported | | 681 @end_compatibility 682 """ 683 rnn_cell_impl.assert_like_rnncell("cell", cell) 684 685 with vs.variable_scope(scope or "rnn") as varscope: 686 # Create a new scope in which the caching device is either 687 # determined by the parent scope, or is set to place the cached 688 # Variable using the same placement as for the rest of the RNN. 689 if _should_cache(): 690 if varscope.caching_device is None: 691 varscope.set_caching_device(lambda op: op.device) 692 693 # By default, time_major==False and inputs are batch-major: shaped 694 # [batch, time, depth] 695 # For internal calculations, we transpose to [time, batch, depth] 696 flat_input = nest.flatten(inputs) 697 698 if not time_major: 699 # (B,T,D) => (T,B,D) 700 flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 701 flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 702 703 parallel_iterations = parallel_iterations or 32 704 if sequence_length is not None: 705 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 706 if sequence_length.get_shape().rank not in (None, 1): 707 raise ValueError( 708 "sequence_length must be a vector of length batch_size, " 709 "but saw shape: %s" % sequence_length.get_shape()) 710 sequence_length = array_ops.identity( # Just to find it in the graph. 711 sequence_length, 712 name="sequence_length") 713 714 batch_size = _best_effort_input_batch_size(flat_input) 715 716 if initial_state is not None: 717 state = initial_state 718 else: 719 if not dtype: 720 raise ValueError("If there is no initial_state, you must give a dtype.") 721 if getattr(cell, "get_initial_state", None) is not None: 722 state = cell.get_initial_state( 723 inputs=None, batch_size=batch_size, dtype=dtype) 724 else: 725 state = cell.zero_state(batch_size, dtype) 726 727 def _assert_has_shape(x, shape): 728 x_shape = array_ops.shape(x) 729 packed_shape = array_ops.stack(shape) 730 return control_flow_ops.Assert( 731 math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ 732 "Expected shape for Tensor %s is " % x.name, packed_shape, 733 " but saw shape: ", x_shape 734 ]) 735 736 if not context.executing_eagerly() and sequence_length is not None: 737 # Perform some shape validation 738 with ops.control_dependencies( 739 [_assert_has_shape(sequence_length, [batch_size])]): 740 sequence_length = array_ops.identity( 741 sequence_length, name="CheckSeqLen") 742 743 inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 744 745 (outputs, final_state) = _dynamic_rnn_loop( 746 cell, 747 inputs, 748 state, 749 parallel_iterations=parallel_iterations, 750 swap_memory=swap_memory, 751 sequence_length=sequence_length, 752 dtype=dtype) 753 754 # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 755 # If we are performing batch-major calculations, transpose output back 756 # to shape [batch, time, depth] 757 if not time_major: 758 # (T,B,D) => (B,T,D) 759 outputs = nest.map_structure(_transpose_batch_time, outputs) 760 761 return (outputs, final_state) 762 763 764def _dynamic_rnn_loop(cell, 765 inputs, 766 initial_state, 767 parallel_iterations, 768 swap_memory, 769 sequence_length=None, 770 dtype=None): 771 """Internal implementation of Dynamic RNN. 772 773 Args: 774 cell: An instance of RNNCell. 775 inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested 776 tuple of such elements. 777 initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if 778 `cell.state_size` is a tuple, then this should be a tuple of tensors 779 having shapes `[batch_size, s] for s in cell.state_size`. 780 parallel_iterations: Positive Python int. 781 swap_memory: A Python boolean 782 sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. 783 dtype: (optional) Expected dtype of output. If not specified, inferred from 784 initial_state. 785 786 Returns: 787 Tuple `(final_outputs, final_state)`. 788 final_outputs: 789 A `Tensor` of shape `[time, batch_size, cell.output_size]`. If 790 `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape` 791 objects, then this returns a (possibly nested) tuple of Tensors matching 792 the corresponding shapes. 793 final_state: 794 A `Tensor`, or possibly nested tuple of Tensors, matching in length 795 and shapes to `initial_state`. 796 797 Raises: 798 ValueError: If the input depth cannot be inferred via shape inference 799 from the inputs. 800 ValueError: If time_step is not the same for all the elements in the 801 inputs. 802 ValueError: If batch_size is not the same for all the elements in the 803 inputs. 804 """ 805 state = initial_state 806 assert isinstance(parallel_iterations, int), "parallel_iterations must be int" 807 808 state_size = cell.state_size 809 810 flat_input = nest.flatten(inputs) 811 flat_output_size = nest.flatten(cell.output_size) 812 813 # Construct an initial output 814 input_shape = array_ops.shape(flat_input[0]) 815 time_steps = input_shape[0] 816 batch_size = _best_effort_input_batch_size(flat_input) 817 818 inputs_got_shape = tuple( 819 input_.get_shape().with_rank_at_least(3) for input_ in flat_input) 820 821 const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2] 822 823 for shape in inputs_got_shape: 824 if not shape[2:].is_fully_defined(): 825 raise ValueError( 826 "Input size (depth of inputs) must be accessible via shape inference," 827 " but saw value None.") 828 got_time_steps = shape.dims[0].value 829 got_batch_size = shape.dims[1].value 830 if const_time_steps != got_time_steps: 831 raise ValueError( 832 "Time steps is not the same for all the elements in the input in a " 833 "batch.") 834 if const_batch_size != got_batch_size: 835 raise ValueError( 836 "Batch_size is not the same for all the elements in the input.") 837 838 # Prepare dynamic conditional copying of state & output 839 def _create_zero_arrays(size): 840 size = _concat(batch_size, size) 841 return array_ops.zeros( 842 array_ops.stack(size), _infer_state_dtype(dtype, state)) 843 844 flat_zero_output = tuple( 845 _create_zero_arrays(output) for output in flat_output_size) 846 zero_output = nest.pack_sequence_as( 847 structure=cell.output_size, flat_sequence=flat_zero_output) 848 849 if sequence_length is not None: 850 min_sequence_length = math_ops.reduce_min(sequence_length) 851 max_sequence_length = math_ops.reduce_max(sequence_length) 852 else: 853 max_sequence_length = time_steps 854 855 time = array_ops.constant(0, dtype=dtypes.int32, name="time") 856 857 with ops.name_scope("dynamic_rnn") as scope: 858 base_name = scope 859 860 def _create_ta(name, element_shape, dtype): 861 return tensor_array_ops.TensorArray( 862 dtype=dtype, 863 size=time_steps, 864 element_shape=element_shape, 865 tensor_array_name=base_name + name) 866 867 in_graph_mode = not context.executing_eagerly() 868 if in_graph_mode: 869 output_ta = tuple( 870 _create_ta( 871 "output_%d" % i, 872 element_shape=( 873 tensor_shape.TensorShape([const_batch_size]).concatenate( 874 _maybe_tensor_shape_from_tensor(out_size))), 875 dtype=_infer_state_dtype(dtype, state)) 876 for i, out_size in enumerate(flat_output_size)) 877 input_ta = tuple( 878 _create_ta( 879 "input_%d" % i, 880 element_shape=flat_input_i.shape[1:], 881 dtype=flat_input_i.dtype) 882 for i, flat_input_i in enumerate(flat_input)) 883 input_ta = tuple( 884 ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input)) 885 else: 886 output_ta = tuple([0 for _ in range(time_steps.numpy())] 887 for i in range(len(flat_output_size))) 888 input_ta = flat_input 889 890 def _time_step(time, output_ta_t, state): 891 """Take a time step of the dynamic RNN. 892 893 Args: 894 time: int32 scalar Tensor. 895 output_ta_t: List of `TensorArray`s that represent the output. 896 state: nested tuple of vector tensors that represent the state. 897 898 Returns: 899 The tuple (time + 1, output_ta_t with updated flow, new_state). 900 """ 901 902 if in_graph_mode: 903 input_t = tuple(ta.read(time) for ta in input_ta) 904 # Restore some shape information 905 for input_, shape in zip(input_t, inputs_got_shape): 906 input_.set_shape(shape[1:]) 907 else: 908 input_t = tuple(ta[time.numpy()] for ta in input_ta) 909 910 input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) 911 # Keras RNN cells only accept state as list, even if it's a single tensor. 912 call_cell = lambda: cell(input_t, state) 913 914 if sequence_length is not None: 915 (output, new_state) = _rnn_step( 916 time=time, 917 sequence_length=sequence_length, 918 min_sequence_length=min_sequence_length, 919 max_sequence_length=max_sequence_length, 920 zero_output=zero_output, 921 state=state, 922 call_cell=call_cell, 923 state_size=state_size, 924 skip_conditionals=True) 925 else: 926 (output, new_state) = call_cell() 927 928 # Pack state if using state tuples 929 output = nest.flatten(output) 930 931 if in_graph_mode: 932 output_ta_t = tuple( 933 ta.write(time, out) for ta, out in zip(output_ta_t, output)) 934 else: 935 for ta, out in zip(output_ta_t, output): 936 ta[time.numpy()] = out 937 938 return (time + 1, output_ta_t, new_state) 939 940 if in_graph_mode: 941 # Make sure that we run at least 1 step, if necessary, to ensure 942 # the TensorArrays pick up the dynamic shape. 943 loop_bound = math_ops.minimum(time_steps, 944 math_ops.maximum(1, max_sequence_length)) 945 else: 946 # Using max_sequence_length isn't currently supported in the Eager branch. 947 loop_bound = time_steps 948 949 _, output_final_ta, final_state = control_flow_ops.while_loop( 950 cond=lambda time, *_: time < loop_bound, 951 body=_time_step, 952 loop_vars=(time, output_ta, state), 953 parallel_iterations=parallel_iterations, 954 maximum_iterations=time_steps, 955 swap_memory=swap_memory) 956 957 # Unpack final output if not using output tuples. 958 if in_graph_mode: 959 final_outputs = tuple(ta.stack() for ta in output_final_ta) 960 # Restore some shape information 961 for output, output_size in zip(final_outputs, flat_output_size): 962 shape = _concat([const_time_steps, const_batch_size], 963 output_size, 964 static=True) 965 output.set_shape(shape) 966 else: 967 final_outputs = output_final_ta 968 969 final_outputs = nest.pack_sequence_as( 970 structure=cell.output_size, flat_sequence=final_outputs) 971 if not in_graph_mode: 972 final_outputs = nest.map_structure_up_to( 973 cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs) 974 975 return (final_outputs, final_state) 976 977 978@tf_export(v1=["nn.raw_rnn"]) 979@dispatch.add_dispatch_support 980def raw_rnn(cell, 981 loop_fn, 982 parallel_iterations=None, 983 swap_memory=False, 984 scope=None): 985 """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`. 986 987 **NOTE: This method is still in testing, and the API may change.** 988 989 This function is a more primitive version of `dynamic_rnn` that provides 990 more direct access to the inputs each iteration. It also provides more 991 control over when to start and finish reading the sequence, and 992 what to emit for the output. 993 994 For example, it can be used to implement the dynamic decoder of a seq2seq 995 model. 996 997 Instead of working with `Tensor` objects, most operations work with 998 `TensorArray` objects directly. 999 1000 The operation of `raw_rnn`, in pseudo-code, is basically the following: 1001 1002 ```python 1003 time = tf.constant(0, dtype=tf.int32) 1004 (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn( 1005 time=time, cell_output=None, cell_state=None, loop_state=None) 1006 emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype) 1007 state = initial_state 1008 while not all(finished): 1009 (output, cell_state) = cell(next_input, state) 1010 (next_finished, next_input, next_state, emit, loop_state) = loop_fn( 1011 time=time + 1, cell_output=output, cell_state=cell_state, 1012 loop_state=loop_state) 1013 # Emit zeros and copy forward state for minibatch entries that are finished. 1014 state = tf.where(finished, state, next_state) 1015 emit = tf.where(finished, tf.zeros_like(emit_structure), emit) 1016 emit_ta = emit_ta.write(time, emit) 1017 # If any new minibatch entries are marked as finished, mark these. 1018 finished = tf.logical_or(finished, next_finished) 1019 time += 1 1020 return (emit_ta, state, loop_state) 1021 ``` 1022 1023 with the additional properties that output and state may be (possibly nested) 1024 tuples, as determined by `cell.output_size` and `cell.state_size`, and 1025 as a result the final `state` and `emit_ta` may themselves be tuples. 1026 1027 A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this: 1028 1029 ```python 1030 inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth), 1031 dtype=tf.float32) 1032 sequence_length = tf.compat.v1.placeholder(shape=(batch_size,), 1033 dtype=tf.int32) 1034 inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 1035 inputs_ta = inputs_ta.unstack(inputs) 1036 1037 cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units) 1038 1039 def loop_fn(time, cell_output, cell_state, loop_state): 1040 emit_output = cell_output # == None for time == 0 1041 if cell_output is None: # time == 0 1042 next_cell_state = cell.zero_state(batch_size, tf.float32) 1043 else: 1044 next_cell_state = cell_state 1045 elements_finished = (time >= sequence_length) 1046 finished = tf.reduce_all(elements_finished) 1047 next_input = tf.cond( 1048 finished, 1049 lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32), 1050 lambda: inputs_ta.read(time)) 1051 next_loop_state = None 1052 return (elements_finished, next_input, next_cell_state, 1053 emit_output, next_loop_state) 1054 1055 outputs_ta, final_state, _ = raw_rnn(cell, loop_fn) 1056 outputs = outputs_ta.stack() 1057 ``` 1058 1059 Args: 1060 cell: An instance of RNNCell. 1061 loop_fn: A callable that takes inputs `(time, cell_output, cell_state, 1062 loop_state)` and returns the tuple `(finished, next_input, 1063 next_cell_state, emit_output, next_loop_state)`. Here `time` is an int32 1064 scalar `Tensor`, `cell_output` is a `Tensor` or (possibly nested) tuple of 1065 tensors as determined by `cell.output_size`, and `cell_state` is a 1066 `Tensor` or (possibly nested) tuple of tensors, as determined by the 1067 `loop_fn` on its first call (and should match `cell.state_size`). 1068 The outputs are: `finished`, a boolean `Tensor` of 1069 shape `[batch_size]`, `next_input`: the next input to feed to `cell`, 1070 `next_cell_state`: the next state to feed to `cell`, 1071 and `emit_output`: the output to store for this iteration. Note that 1072 `emit_output` should be a `Tensor` or (possibly nested) tuple of tensors 1073 which is aggregated in the `emit_ta` inside the `while_loop`. For the 1074 first call to `loop_fn`, the `emit_output` corresponds to the 1075 `emit_structure` which is then used to determine the size of the 1076 `zero_tensor` for the `emit_ta` (defaults to `cell.output_size`). For 1077 the subsequent calls to the `loop_fn`, the `emit_output` corresponds to 1078 the actual output tensor that is to be aggregated in the `emit_ta`. The 1079 parameter `cell_state` and output `next_cell_state` may be either a 1080 single or (possibly nested) tuple of tensors. The parameter 1081 `loop_state` and output `next_loop_state` may be either a single or 1082 (possibly nested) tuple of `Tensor` and `TensorArray` objects. This 1083 last parameter may be ignored by `loop_fn` and the return value may be 1084 `None`. If it is not `None`, then the `loop_state` will be propagated 1085 through the RNN loop, for use purely by `loop_fn` to keep track of its 1086 own state. The `next_loop_state` parameter returned may be `None`. The 1087 first call to `loop_fn` will be `time = 0`, `cell_output = None`, 1088 `cell_state = None`, and `loop_state = None`. For this call: The 1089 `next_cell_state` value should be the value with which to initialize the 1090 cell's state. It may be a final state from a previous RNN or it may be 1091 the output of `cell.zero_state()`. It should be a (possibly nested) 1092 tuple structure of tensors. If `cell.state_size` is an integer, this 1093 must be a `Tensor` of appropriate type and shape `[batch_size, 1094 cell.state_size]`. If `cell.state_size` is a `TensorShape`, this must be 1095 a `Tensor` of appropriate type and shape `[batch_size] + 1096 cell.state_size`. If `cell.state_size` is a (possibly nested) tuple of 1097 ints or `TensorShape`, this will be a tuple having the corresponding 1098 shapes. The `emit_output` value may be either `None` or a (possibly 1099 nested) tuple structure of tensors, e.g., `(tf.zeros(shape_0, 1100 dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. If this first 1101 `emit_output` return value is `None`, then the `emit_ta` result of 1102 `raw_rnn` will have the same structure and dtypes as `cell.output_size`. 1103 Otherwise `emit_ta` will have the same structure, shapes (prepended with 1104 a `batch_size` dimension), and dtypes as `emit_output`. The actual 1105 values returned for `emit_output` at this initializing call are ignored. 1106 Note, this emit structure must be consistent across all time steps. 1107 parallel_iterations: (Default: 32). The number of iterations to run in 1108 parallel. Those operations which do not have any temporal dependency and 1109 can be run in parallel, will be. This parameter trades off time for 1110 space. Values >> 1 use more memory but take less time, while smaller 1111 values use less memory but computations take longer. 1112 swap_memory: Transparently swap the tensors produced in forward inference 1113 but needed for back prop from GPU to CPU. This allows training RNNs which 1114 would typically not fit on a single GPU, with very minimal (or no) 1115 performance penalty. 1116 scope: VariableScope for the created subgraph; defaults to "rnn". 1117 1118 Returns: 1119 A tuple `(emit_ta, final_state, final_loop_state)` where: 1120 1121 `emit_ta`: The RNN output `TensorArray`. 1122 If `loop_fn` returns a (possibly nested) set of Tensors for 1123 `emit_output` during initialization, (inputs `time = 0`, 1124 `cell_output = None`, and `loop_state = None`), then `emit_ta` will 1125 have the same structure, dtypes, and shapes as `emit_output` instead. 1126 If `loop_fn` returns `emit_output = None` during this call, 1127 the structure of `cell.output_size` is used: 1128 If `cell.output_size` is a (possibly nested) tuple of integers 1129 or `TensorShape` objects, then `emit_ta` will be a tuple having the 1130 same structure as `cell.output_size`, containing TensorArrays whose 1131 elements' shapes correspond to the shape data in `cell.output_size`. 1132 1133 `final_state`: The final cell state. If `cell.state_size` is an int, this 1134 will be shaped `[batch_size, cell.state_size]`. If it is a 1135 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 1136 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 1137 be a tuple having the corresponding shapes. 1138 1139 `final_loop_state`: The final loop state as returned by `loop_fn`. 1140 1141 Raises: 1142 TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not 1143 a `callable`. 1144 """ 1145 rnn_cell_impl.assert_like_rnncell("cell", cell) 1146 1147 if not callable(loop_fn): 1148 raise TypeError("loop_fn must be a callable") 1149 1150 parallel_iterations = parallel_iterations or 32 1151 1152 # Create a new scope in which the caching device is either 1153 # determined by the parent scope, or is set to place the cached 1154 # Variable using the same placement as for the rest of the RNN. 1155 with vs.variable_scope(scope or "rnn") as varscope: 1156 if _should_cache(): 1157 if varscope.caching_device is None: 1158 varscope.set_caching_device(lambda op: op.device) 1159 1160 time = constant_op.constant(0, dtype=dtypes.int32) 1161 (elements_finished, next_input, 1162 initial_state, emit_structure, init_loop_state) = loop_fn( 1163 time, None, None, None) # time, cell_output, cell_state, loop_state 1164 flat_input = nest.flatten(next_input) 1165 1166 # Need a surrogate loop state for the while_loop if none is available. 1167 loop_state = ( 1168 init_loop_state if init_loop_state is not None else 1169 constant_op.constant(0, dtype=dtypes.int32)) 1170 1171 input_shape = [input_.get_shape() for input_ in flat_input] 1172 static_batch_size = tensor_shape.dimension_at_index(input_shape[0], 0) 1173 1174 for input_shape_i in input_shape: 1175 # Static verification that batch sizes all match 1176 static_batch_size.assert_is_compatible_with( 1177 tensor_shape.dimension_at_index(input_shape_i, 0)) 1178 1179 batch_size = tensor_shape.dimension_value(static_batch_size) 1180 const_batch_size = batch_size 1181 if batch_size is None: 1182 batch_size = array_ops.shape(flat_input[0])[0] 1183 1184 nest.assert_same_structure(initial_state, cell.state_size) 1185 state = initial_state 1186 flat_state = nest.flatten(state) 1187 flat_state = [ops.convert_to_tensor(s) for s in flat_state] 1188 state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state) 1189 1190 if emit_structure is not None: 1191 flat_emit_structure = nest.flatten(emit_structure) 1192 flat_emit_size = [ 1193 emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit) 1194 for emit in flat_emit_structure 1195 ] 1196 flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] 1197 else: 1198 emit_structure = cell.output_size 1199 flat_emit_size = nest.flatten(emit_structure) 1200 flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) 1201 1202 flat_emit_ta = [ 1203 tensor_array_ops.TensorArray( 1204 dtype=dtype_i, 1205 dynamic_size=True, 1206 element_shape=(tensor_shape.TensorShape([ 1207 const_batch_size 1208 ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))), 1209 size=0, 1210 name="rnn_output_%d" % i) 1211 for i, (dtype_i, 1212 size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size)) 1213 ] 1214 emit_ta = nest.pack_sequence_as( 1215 structure=emit_structure, flat_sequence=flat_emit_ta) 1216 flat_zero_emit = [ 1217 array_ops.zeros(_concat(batch_size, size_i), dtype_i) 1218 for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes) 1219 ] 1220 zero_emit = nest.pack_sequence_as( 1221 structure=emit_structure, flat_sequence=flat_zero_emit) 1222 1223 def condition(unused_time, elements_finished, *_): 1224 return math_ops.logical_not(math_ops.reduce_all(elements_finished)) 1225 1226 def body(time, elements_finished, current_input, emit_ta, state, 1227 loop_state): 1228 """Internal while loop body for raw_rnn. 1229 1230 Args: 1231 time: time scalar. 1232 elements_finished: batch-size vector. 1233 current_input: possibly nested tuple of input tensors. 1234 emit_ta: possibly nested tuple of output TensorArrays. 1235 state: possibly nested tuple of state tensors. 1236 loop_state: possibly nested tuple of loop state tensors. 1237 1238 Returns: 1239 Tuple having the same size as Args but with updated values. 1240 """ 1241 (next_output, cell_state) = cell(current_input, state) 1242 1243 nest.assert_same_structure(state, cell_state) 1244 nest.assert_same_structure(cell.output_size, next_output) 1245 1246 next_time = time + 1 1247 (next_finished, next_input, next_state, emit_output, 1248 next_loop_state) = loop_fn(next_time, next_output, cell_state, 1249 loop_state) 1250 1251 nest.assert_same_structure(state, next_state) 1252 nest.assert_same_structure(current_input, next_input) 1253 nest.assert_same_structure(emit_ta, emit_output) 1254 1255 # If loop_fn returns None for next_loop_state, just reuse the 1256 # previous one. 1257 loop_state = loop_state if next_loop_state is None else next_loop_state 1258 1259 def _copy_some_through(current, candidate): 1260 """Copy some tensors through via array_ops.where.""" 1261 1262 def copy_fn(cur_i, cand_i): 1263 # TensorArray and scalar get passed through. 1264 if isinstance(cur_i, tensor_array_ops.TensorArray): 1265 return cand_i 1266 if cur_i.shape.rank == 0: 1267 return cand_i 1268 # Otherwise propagate the old or the new value. 1269 with ops.colocate_with(cand_i): 1270 return array_ops.where(elements_finished, cur_i, cand_i) 1271 1272 return nest.map_structure(copy_fn, current, candidate) 1273 1274 emit_output = _copy_some_through(zero_emit, emit_output) 1275 next_state = _copy_some_through(state, next_state) 1276 1277 emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), 1278 emit_ta, emit_output) 1279 1280 elements_finished = math_ops.logical_or(elements_finished, next_finished) 1281 1282 return (next_time, elements_finished, next_input, emit_ta, next_state, 1283 loop_state) 1284 1285 returned = control_flow_ops.while_loop( 1286 condition, 1287 body, 1288 loop_vars=[ 1289 time, elements_finished, next_input, emit_ta, state, loop_state 1290 ], 1291 parallel_iterations=parallel_iterations, 1292 swap_memory=swap_memory) 1293 1294 (emit_ta, final_state, final_loop_state) = returned[-3:] 1295 1296 if init_loop_state is None: 1297 final_loop_state = None 1298 1299 return (emit_ta, final_state, final_loop_state) 1300 1301 1302@deprecation.deprecated(None, 1303 "Please use `keras.layers.RNN(cell, unroll=True)`, " 1304 "which is equivalent to this API") 1305@tf_export(v1=["nn.static_rnn"]) 1306@dispatch.add_dispatch_support 1307def static_rnn(cell, 1308 inputs, 1309 initial_state=None, 1310 dtype=None, 1311 sequence_length=None, 1312 scope=None): 1313 """Creates a recurrent neural network specified by RNNCell `cell`. 1314 1315 The simplest form of RNN network generated is: 1316 1317 ```python 1318 state = cell.zero_state(...) 1319 outputs = [] 1320 for input_ in inputs: 1321 output, state = cell(input_, state) 1322 outputs.append(output) 1323 return (outputs, state) 1324 ``` 1325 However, a few other options are available: 1326 1327 An initial state can be provided. 1328 If the sequence_length vector is provided, dynamic calculation is performed. 1329 This method of calculation does not compute the RNN steps past the maximum 1330 sequence length of the minibatch (thus saving computational time), 1331 and properly propagates the state at an example's sequence length 1332 to the final state output. 1333 1334 The dynamic calculation performed is, at time `t` for batch row `b`, 1335 1336 ```python 1337 (output, state)(b, t) = 1338 (t >= sequence_length(b)) 1339 ? (zeros(cell.output_size), states(b, sequence_length(b) - 1)) 1340 : cell(input(b, t), state(b, t - 1)) 1341 ``` 1342 1343 Args: 1344 cell: An instance of RNNCell. 1345 inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size, 1346 input_size]`, or a nested tuple of such elements. 1347 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 1348 is an integer, this must be a `Tensor` of appropriate type and shape 1349 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 1350 should be a tuple of tensors having shapes `[batch_size, s] for s in 1351 cell.state_size`. 1352 dtype: (optional) The data type for the initial state and expected output. 1353 Required if initial_state is not provided or RNN state has a heterogeneous 1354 dtype. 1355 sequence_length: Specifies the length of each sequence in inputs. An int32 1356 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`. 1357 scope: VariableScope for the created subgraph; defaults to "rnn". 1358 1359 Returns: 1360 A pair (outputs, state) where: 1361 1362 - outputs is a length T list of outputs (one for each input), or a nested 1363 tuple of such elements. 1364 - state is the final state 1365 1366 Raises: 1367 TypeError: If `cell` is not an instance of RNNCell. 1368 ValueError: If `inputs` is `None` or an empty list, or if the input depth 1369 (column size) cannot be inferred from inputs via shape inference. 1370 """ 1371 rnn_cell_impl.assert_like_rnncell("cell", cell) 1372 if not nest.is_sequence(inputs): 1373 raise TypeError("inputs must be a sequence") 1374 if not inputs: 1375 raise ValueError("inputs must not be empty") 1376 1377 outputs = [] 1378 # Create a new scope in which the caching device is either 1379 # determined by the parent scope, or is set to place the cached 1380 # Variable using the same placement as for the rest of the RNN. 1381 with vs.variable_scope(scope or "rnn") as varscope: 1382 if _should_cache(): 1383 if varscope.caching_device is None: 1384 varscope.set_caching_device(lambda op: op.device) 1385 1386 # Obtain the first sequence of the input 1387 first_input = inputs 1388 while nest.is_sequence(first_input): 1389 first_input = first_input[0] 1390 1391 # Temporarily avoid EmbeddingWrapper and seq2seq badness 1392 # TODO(lukaszkaiser): remove EmbeddingWrapper 1393 if first_input.get_shape().rank != 1: 1394 1395 input_shape = first_input.get_shape().with_rank_at_least(2) 1396 fixed_batch_size = input_shape.dims[0] 1397 1398 flat_inputs = nest.flatten(inputs) 1399 for flat_input in flat_inputs: 1400 input_shape = flat_input.get_shape().with_rank_at_least(2) 1401 batch_size, input_size = tensor_shape.dimension_at_index( 1402 input_shape, 0), input_shape[1:] 1403 fixed_batch_size.assert_is_compatible_with(batch_size) 1404 for i, size in enumerate(input_size.dims): 1405 if tensor_shape.dimension_value(size) is None: 1406 raise ValueError( 1407 "Input size (dimension %d of inputs) must be accessible via " 1408 "shape inference, but saw value None." % i) 1409 else: 1410 fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0] 1411 1412 if tensor_shape.dimension_value(fixed_batch_size): 1413 batch_size = tensor_shape.dimension_value(fixed_batch_size) 1414 else: 1415 batch_size = array_ops.shape(first_input)[0] 1416 if initial_state is not None: 1417 state = initial_state 1418 else: 1419 if not dtype: 1420 raise ValueError("If no initial_state is provided, " 1421 "dtype must be specified") 1422 if getattr(cell, "get_initial_state", None) is not None: 1423 state = cell.get_initial_state( 1424 inputs=None, batch_size=batch_size, dtype=dtype) 1425 else: 1426 state = cell.zero_state(batch_size, dtype) 1427 1428 if sequence_length is not None: # Prepare variables 1429 sequence_length = ops.convert_to_tensor( 1430 sequence_length, name="sequence_length") 1431 if sequence_length.get_shape().rank not in (None, 1): 1432 raise ValueError( 1433 "sequence_length must be a vector of length batch_size") 1434 1435 def _create_zero_output(output_size): 1436 # convert int to TensorShape if necessary 1437 size = _concat(batch_size, output_size) 1438 output = array_ops.zeros( 1439 array_ops.stack(size), _infer_state_dtype(dtype, state)) 1440 shape = _concat( 1441 tensor_shape.dimension_value(fixed_batch_size), 1442 output_size, 1443 static=True) 1444 output.set_shape(tensor_shape.TensorShape(shape)) 1445 return output 1446 1447 output_size = cell.output_size 1448 flat_output_size = nest.flatten(output_size) 1449 flat_zero_output = tuple( 1450 _create_zero_output(size) for size in flat_output_size) 1451 zero_output = nest.pack_sequence_as( 1452 structure=output_size, flat_sequence=flat_zero_output) 1453 1454 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 1455 min_sequence_length = math_ops.reduce_min(sequence_length) 1456 max_sequence_length = math_ops.reduce_max(sequence_length) 1457 1458 for time, input_ in enumerate(inputs): 1459 if time > 0: 1460 varscope.reuse_variables() 1461 # pylint: disable=cell-var-from-loop 1462 call_cell = lambda: cell(input_, state) 1463 # pylint: enable=cell-var-from-loop 1464 if sequence_length is not None: 1465 (output, state) = _rnn_step( 1466 time=time, 1467 sequence_length=sequence_length, 1468 min_sequence_length=min_sequence_length, 1469 max_sequence_length=max_sequence_length, 1470 zero_output=zero_output, 1471 state=state, 1472 call_cell=call_cell, 1473 state_size=cell.state_size) 1474 else: 1475 (output, state) = call_cell() 1476 outputs.append(output) 1477 1478 return (outputs, state) 1479 1480 1481@deprecation.deprecated(None, 1482 "Please use `keras.layers.RNN(cell, stateful=True)`, " 1483 "which is equivalent to this API") 1484@tf_export(v1=["nn.static_state_saving_rnn"]) 1485@dispatch.add_dispatch_support 1486def static_state_saving_rnn(cell, 1487 inputs, 1488 state_saver, 1489 state_name, 1490 sequence_length=None, 1491 scope=None): 1492 """RNN that accepts a state saver for time-truncated RNN calculation. 1493 1494 Args: 1495 cell: An instance of `RNNCell`. 1496 inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size, 1497 input_size]`. 1498 state_saver: A state saver object with methods `state` and `save_state`. 1499 state_name: Python string or tuple of strings. The name to use with the 1500 state_saver. If the cell returns tuples of states (i.e., `cell.state_size` 1501 is a tuple) then `state_name` should be a tuple of strings having the same 1502 length as `cell.state_size`. Otherwise it should be a single string. 1503 sequence_length: (optional) An int32/int64 vector size [batch_size]. See the 1504 documentation for rnn() for more details about sequence_length. 1505 scope: VariableScope for the created subgraph; defaults to "rnn". 1506 1507 Returns: 1508 A pair (outputs, state) where: 1509 outputs is a length T list of outputs (one for each input) 1510 states is the final state 1511 1512 Raises: 1513 TypeError: If `cell` is not an instance of RNNCell. 1514 ValueError: If `inputs` is `None` or an empty list, or if the arity and 1515 type of `state_name` does not match that of `cell.state_size`. 1516 """ 1517 state_size = cell.state_size 1518 state_is_tuple = nest.is_sequence(state_size) 1519 state_name_tuple = nest.is_sequence(state_name) 1520 1521 if state_is_tuple != state_name_tuple: 1522 raise ValueError("state_name should be the same type as cell.state_size. " 1523 "state_name: %s, cell.state_size: %s" % 1524 (str(state_name), str(state_size))) 1525 1526 if state_is_tuple: 1527 state_name_flat = nest.flatten(state_name) 1528 state_size_flat = nest.flatten(state_size) 1529 1530 if len(state_name_flat) != len(state_size_flat): 1531 raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" % 1532 (len(state_name_flat), len(state_size_flat))) 1533 1534 initial_state = nest.pack_sequence_as( 1535 structure=state_size, 1536 flat_sequence=[state_saver.state(s) for s in state_name_flat]) 1537 else: 1538 initial_state = state_saver.state(state_name) 1539 1540 (outputs, state) = static_rnn( 1541 cell, 1542 inputs, 1543 initial_state=initial_state, 1544 sequence_length=sequence_length, 1545 scope=scope) 1546 1547 if state_is_tuple: 1548 flat_state = nest.flatten(state) 1549 state_name = nest.flatten(state_name) 1550 save_state = [ 1551 state_saver.save_state(name, substate) 1552 for name, substate in zip(state_name, flat_state) 1553 ] 1554 else: 1555 save_state = [state_saver.save_state(state_name, state)] 1556 1557 with ops.control_dependencies(save_state): 1558 last_output = outputs[-1] 1559 flat_last_output = nest.flatten(last_output) 1560 flat_last_output = [ 1561 array_ops.identity(output) for output in flat_last_output 1562 ] 1563 outputs[-1] = nest.pack_sequence_as( 1564 structure=last_output, flat_sequence=flat_last_output) 1565 1566 if state_is_tuple: 1567 state = nest.pack_sequence_as( 1568 structure=state, 1569 flat_sequence=[array_ops.identity(s) for s in flat_state]) 1570 else: 1571 state = array_ops.identity(state) 1572 1573 return (outputs, state) 1574 1575 1576@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 1577 "keras.layers.RNN(cell, unroll=True))`, which is " 1578 "equivalent to this API") 1579@tf_export(v1=["nn.static_bidirectional_rnn"]) 1580@dispatch.add_dispatch_support 1581def static_bidirectional_rnn(cell_fw, 1582 cell_bw, 1583 inputs, 1584 initial_state_fw=None, 1585 initial_state_bw=None, 1586 dtype=None, 1587 sequence_length=None, 1588 scope=None): 1589 """Creates a bidirectional recurrent neural network. 1590 1591 Similar to the unidirectional case above (rnn) but takes input and builds 1592 independent forward and backward RNNs with the final forward and backward 1593 outputs depth-concatenated, such that the output will have the format 1594 [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of 1595 forward and backward cell must match. The initial state for both directions 1596 is zero by default (but can be set optionally) and no intermediate states are 1597 ever returned -- the network is fully unrolled for the given (passed in) 1598 length(s) of the sequence(s) or completely unrolled if length(s) is not given. 1599 1600 Args: 1601 cell_fw: An instance of RNNCell, to be used for forward direction. 1602 cell_bw: An instance of RNNCell, to be used for backward direction. 1603 inputs: A length T list of inputs, each a tensor of shape [batch_size, 1604 input_size], or a nested tuple of such elements. 1605 initial_state_fw: (optional) An initial state for the forward RNN. This must 1606 be a tensor of appropriate type and shape `[batch_size, 1607 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 1608 tuple of tensors having shapes `[batch_size, s] for s in 1609 cell_fw.state_size`. 1610 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 1611 corresponding properties of `cell_bw`. 1612 dtype: (optional) The data type for the initial state. Required if either 1613 of the initial states are not provided. 1614 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 1615 containing the actual lengths for each of the sequences. 1616 scope: VariableScope for the created subgraph; defaults to 1617 "bidirectional_rnn" 1618 1619 Returns: 1620 A tuple (outputs, output_state_fw, output_state_bw) where: 1621 outputs is a length `T` list of outputs (one for each input), which 1622 are depth-concatenated forward and backward outputs. 1623 output_state_fw is the final state of the forward rnn. 1624 output_state_bw is the final state of the backward rnn. 1625 1626 Raises: 1627 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 1628 ValueError: If inputs is None or an empty list. 1629 """ 1630 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 1631 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 1632 if not nest.is_sequence(inputs): 1633 raise TypeError("inputs must be a sequence") 1634 if not inputs: 1635 raise ValueError("inputs must not be empty") 1636 1637 with vs.variable_scope(scope or "bidirectional_rnn"): 1638 # Forward direction 1639 with vs.variable_scope("fw") as fw_scope: 1640 output_fw, output_state_fw = static_rnn( 1641 cell_fw, 1642 inputs, 1643 initial_state_fw, 1644 dtype, 1645 sequence_length, 1646 scope=fw_scope) 1647 1648 # Backward direction 1649 with vs.variable_scope("bw") as bw_scope: 1650 reversed_inputs = _reverse_seq(inputs, sequence_length) 1651 tmp, output_state_bw = static_rnn( 1652 cell_bw, 1653 reversed_inputs, 1654 initial_state_bw, 1655 dtype, 1656 sequence_length, 1657 scope=bw_scope) 1658 1659 output_bw = _reverse_seq(tmp, sequence_length) 1660 # Concat each of the forward/backward outputs 1661 flat_output_fw = nest.flatten(output_fw) 1662 flat_output_bw = nest.flatten(output_bw) 1663 1664 flat_outputs = tuple( 1665 array_ops.concat([fw, bw], 1) 1666 for fw, bw in zip(flat_output_fw, flat_output_bw)) 1667 1668 outputs = nest.pack_sequence_as( 1669 structure=output_fw, flat_sequence=flat_outputs) 1670 1671 return (outputs, output_state_fw, output_state_bw) 1672