1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""RNN helpers for TensorFlow models.""" 16from tensorflow.python.eager import context 17from tensorflow.python.framework import constant_op 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.framework import tensor_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import control_flow_util 25from tensorflow.python.ops import control_flow_util_v2 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import rnn_cell_impl 28from tensorflow.python.ops import tensor_array_ops 29from tensorflow.python.ops import variable_scope as vs 30from tensorflow.python.util import deprecation 31from tensorflow.python.util import dispatch 32from tensorflow.python.util import nest 33from tensorflow.python.util.tf_export import tf_export 34 35# pylint: disable=protected-access 36_concat = rnn_cell_impl._concat 37# pylint: enable=protected-access 38 39 40def _transpose_batch_time(x): 41 """Transposes the batch and time dimensions of a Tensor. 42 43 If the input tensor has rank < 2 it returns the original tensor. Retains as 44 much of the static shape information as possible. 45 46 Args: 47 x: A Tensor. 48 49 Returns: 50 x transposed along the first two dimensions. 51 """ 52 x_static_shape = x.get_shape() 53 if x_static_shape.rank is not None and x_static_shape.rank < 2: 54 return x 55 56 x_rank = array_ops.rank(x) 57 x_t = array_ops.transpose( 58 x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0)) 59 x_t.set_shape( 60 tensor_shape.TensorShape( 61 [x_static_shape.dims[1].value, 62 x_static_shape.dims[0].value]).concatenate(x_static_shape[2:])) 63 return x_t 64 65 66def _best_effort_input_batch_size(flat_input): 67 """Get static input batch size if available, with fallback to the dynamic one. 68 69 Args: 70 flat_input: An iterable of time major input Tensors of shape `[max_time, 71 batch_size, ...]`. All inputs should have compatible batch sizes. 72 73 Returns: 74 The batch size in Python integer if available, or a scalar Tensor otherwise. 75 76 Raises: 77 ValueError: if there is any input with an invalid shape. 78 """ 79 for input_ in flat_input: 80 shape = input_.shape 81 if shape.rank is None: 82 continue 83 if shape.rank < 2: 84 raise ValueError("Input tensor should have rank >= 2. Received input=" 85 f"{input_} of rank {shape.rank}") 86 batch_size = shape.dims[1].value 87 if batch_size is not None: 88 return batch_size 89 # Fallback to the dynamic batch size of the first input. 90 return array_ops.shape(flat_input[0])[1] 91 92 93def _infer_state_dtype(explicit_dtype, state): 94 """Infer the dtype of an RNN state. 95 96 Args: 97 explicit_dtype: explicitly declared dtype or None. 98 state: RNN's hidden state. Must be a Tensor or a nested iterable containing 99 Tensors. 100 101 Returns: 102 dtype: inferred dtype of hidden state. 103 104 Raises: 105 ValueError: if `state` has heterogeneous dtypes or is empty. 106 """ 107 if explicit_dtype is not None: 108 return explicit_dtype 109 elif nest.is_nested(state): 110 inferred_dtypes = [element.dtype for element in nest.flatten(state)] 111 if not inferred_dtypes: 112 raise ValueError(f"Unable to infer dtype from argument state={state}.") 113 all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes) 114 if not all_same: 115 raise ValueError( 116 f"Argument state={state} has tensors of different inferred dtypes. " 117 "Unable to infer a single representative dtype. Dtypes received: " 118 f"{inferred_dtypes}") 119 return inferred_dtypes[0] 120 else: 121 return state.dtype 122 123 124def _maybe_tensor_shape_from_tensor(shape): 125 if isinstance(shape, ops.Tensor): 126 return tensor_shape.as_shape(tensor_util.constant_value(shape)) 127 else: 128 return shape 129 130 131def _should_cache(): 132 """Returns True if a default caching device should be set, otherwise False.""" 133 if context.executing_eagerly(): 134 return False 135 # Don't set a caching device when running in a loop, since it is possible that 136 # train steps could be wrapped in a tf.while_loop. In that scenario caching 137 # prevents forward computations in loop iterations from re-reading the 138 # updated weights. 139 graph = ops.get_default_graph() 140 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access 141 in_v1_while_loop = ( 142 control_flow_util.GetContainingWhileContext(ctxt) is not None) 143 in_v2_while_loop = control_flow_util_v2.in_while_loop_defun(graph) 144 return not in_v1_while_loop and not in_v2_while_loop 145 146 147# pylint: disable=unused-argument 148def _rnn_step(time, 149 sequence_length, 150 min_sequence_length, 151 max_sequence_length, 152 zero_output, 153 state, 154 call_cell, 155 state_size, 156 skip_conditionals=False): 157 """Calculate one step of a dynamic RNN minibatch. 158 159 Returns an (output, state) pair conditioned on `sequence_length`. 160 When skip_conditionals=False, the pseudocode is something like: 161 162 if t >= max_sequence_length: 163 return (zero_output, state) 164 if t < min_sequence_length: 165 return call_cell() 166 167 # Selectively output zeros or output, old state or new state depending 168 # on whether we've finished calculating each row. 169 new_output, new_state = call_cell() 170 final_output = np.vstack([ 171 zero_output if time >= sequence_length[r] else new_output_r 172 for r, new_output_r in enumerate(new_output) 173 ]) 174 final_state = np.vstack([ 175 state[r] if time >= sequence_length[r] else new_state_r 176 for r, new_state_r in enumerate(new_state) 177 ]) 178 return (final_output, final_state) 179 180 Args: 181 time: int32 `Tensor` scalar. 182 sequence_length: int32 `Tensor` vector of size [batch_size]. 183 min_sequence_length: int32 `Tensor` scalar, min of sequence_length. 184 max_sequence_length: int32 `Tensor` scalar, max of sequence_length. 185 zero_output: `Tensor` vector of shape [output_size]. 186 state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, 187 or a list/tuple of such tensors. 188 call_cell: lambda returning tuple of (new_output, new_state) where 189 new_output is a `Tensor` matrix of shape `[batch_size, output_size]`. 190 new_state is a `Tensor` matrix of shape `[batch_size, state_size]`. 191 state_size: The `cell.state_size` associated with the state. 192 skip_conditionals: Python bool, whether to skip using the conditional 193 calculations. This is useful for `dynamic_rnn`, where the input tensor 194 matches `max_sequence_length`, and using conditionals just slows 195 everything down. 196 197 Returns: 198 A tuple of (`final_output`, `final_state`) as given by the pseudocode above: 199 final_output is a `Tensor` matrix of shape [batch_size, output_size] 200 final_state is either a single `Tensor` matrix, or a tuple of such 201 matrices (matching length and shapes of input `state`). 202 203 Raises: 204 ValueError: If the cell returns a state tuple whose length does not match 205 that returned by `state_size`. 206 """ 207 208 # Convert state to a list for ease of use 209 flat_state = nest.flatten(state) 210 flat_zero_output = nest.flatten(zero_output) 211 212 # Vector describing which batch entries are finished. 213 copy_cond = time >= sequence_length 214 215 def _copy_one_through(output, new_output): 216 # TensorArray and scalar get passed through. 217 if isinstance(output, tensor_array_ops.TensorArray): 218 return new_output 219 if output.shape.rank == 0: 220 return new_output 221 # Otherwise propagate the old or the new value. 222 with ops.colocate_with(new_output): 223 return array_ops.where(copy_cond, output, new_output) 224 225 def _copy_some_through(flat_new_output, flat_new_state): 226 # Use broadcasting select to determine which values should get 227 # the previous state & zero output, and which values should get 228 # a calculated state & output. 229 flat_new_output = [ 230 _copy_one_through(zero_output, new_output) 231 for zero_output, new_output in zip(flat_zero_output, flat_new_output) 232 ] 233 flat_new_state = [ 234 _copy_one_through(state, new_state) 235 for state, new_state in zip(flat_state, flat_new_state) 236 ] 237 return flat_new_output + flat_new_state 238 239 def _maybe_copy_some_through(): 240 """Run RNN step. Pass through either no or some past state.""" 241 new_output, new_state = call_cell() 242 243 nest.assert_same_structure(zero_output, new_output) 244 nest.assert_same_structure(state, new_state) 245 246 flat_new_state = nest.flatten(new_state) 247 flat_new_output = nest.flatten(new_output) 248 return control_flow_ops.cond( 249 # if t < min_seq_len: calculate and return everything 250 time < min_sequence_length, 251 lambda: flat_new_output + flat_new_state, 252 # else copy some of it through 253 lambda: _copy_some_through(flat_new_output, flat_new_state)) 254 255 # TODO(ebrevdo): skipping these conditionals may cause a slowdown, 256 # but benefits from removing cond() and its gradient. We should 257 # profile with and without this switch here. 258 if skip_conditionals: 259 # Instead of using conditionals, perform the selective copy at all time 260 # steps. This is faster when max_seq_len is equal to the number of unrolls 261 # (which is typical for dynamic_rnn). 262 new_output, new_state = call_cell() 263 nest.assert_same_structure(zero_output, new_output) 264 nest.assert_same_structure(state, new_state) 265 new_state = nest.flatten(new_state) 266 new_output = nest.flatten(new_output) 267 final_output_and_state = _copy_some_through(new_output, new_state) 268 else: 269 empty_update = lambda: flat_zero_output + flat_state 270 final_output_and_state = control_flow_ops.cond( 271 # if t >= max_seq_len: copy all state through, output zeros 272 time >= max_sequence_length, 273 empty_update, 274 # otherwise calculation is required: copy some or all of it through 275 _maybe_copy_some_through) 276 277 if len(final_output_and_state) != len(flat_zero_output) + len(flat_state): 278 raise ValueError("Internal error: state and output were not concatenated " 279 f"correctly. Received state length: {len(flat_state)}, " 280 f"output length: {len(flat_zero_output)}. Expected " 281 f"contatenated length: {len(final_output_and_state)}.") 282 final_output = final_output_and_state[:len(flat_zero_output)] 283 final_state = final_output_and_state[len(flat_zero_output):] 284 285 for output, flat_output in zip(final_output, flat_zero_output): 286 output.set_shape(flat_output.get_shape()) 287 for substate, flat_substate in zip(final_state, flat_state): 288 if not isinstance(substate, tensor_array_ops.TensorArray): 289 substate.set_shape(flat_substate.get_shape()) 290 291 final_output = nest.pack_sequence_as( 292 structure=zero_output, flat_sequence=final_output) 293 final_state = nest.pack_sequence_as( 294 structure=state, flat_sequence=final_state) 295 296 return final_output, final_state 297 298 299def _reverse_seq(input_seq, lengths): 300 """Reverse a list of Tensors up to specified lengths. 301 302 Args: 303 input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features) 304 or nested tuples of tensors. 305 lengths: A `Tensor` of dimension batch_size, containing lengths for each 306 sequence in the batch. If "None" is specified, simply reverses the list. 307 308 Returns: 309 time-reversed sequence 310 """ 311 if lengths is None: 312 return list(reversed(input_seq)) 313 314 flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq) 315 316 flat_results = [[] for _ in range(len(input_seq))] 317 for sequence in zip(*flat_input_seq): 318 input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank) 319 for input_ in sequence: 320 input_shape.assert_is_compatible_with(input_.get_shape()) 321 input_.set_shape(input_shape) 322 323 # Join into (time, batch_size, depth) 324 s_joined = array_ops.stack(sequence) 325 326 # Reverse along dimension 0 327 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) 328 # Split again into list 329 result = array_ops.unstack(s_reversed) 330 for r, flat_result in zip(result, flat_results): 331 r.set_shape(input_shape) 332 flat_result.append(r) 333 334 results = [ 335 nest.pack_sequence_as(structure=input_, flat_sequence=flat_result) 336 for input_, flat_result in zip(input_seq, flat_results) 337 ] 338 return results 339 340 341@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 342 "keras.layers.RNN(cell))`, which is equivalent to " 343 "this API") 344@tf_export(v1=["nn.bidirectional_dynamic_rnn"]) 345@dispatch.add_dispatch_support 346def bidirectional_dynamic_rnn(cell_fw, 347 cell_bw, 348 inputs, 349 sequence_length=None, 350 initial_state_fw=None, 351 initial_state_bw=None, 352 dtype=None, 353 parallel_iterations=None, 354 swap_memory=False, 355 time_major=False, 356 scope=None): 357 """Creates a dynamic version of bidirectional recurrent neural network. 358 359 Takes input and builds independent forward and backward RNNs. The input_size 360 of forward and backward cell must match. The initial state for both directions 361 is zero by default (but can be set optionally) and no intermediate states are 362 ever returned -- the network is fully unrolled for the given (passed in) 363 length(s) of the sequence(s) or completely unrolled if length(s) is not 364 given. 365 366 Args: 367 cell_fw: An instance of RNNCell, to be used for forward direction. 368 cell_bw: An instance of RNNCell, to be used for backward direction. 369 inputs: The RNN inputs. 370 If time_major == False (default), this must be a tensor of shape: 371 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 372 If time_major == True, this must be a tensor of shape: `[max_time, 373 batch_size, ...]`, or a nested tuple of such elements. 374 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 375 containing the actual lengths for each of the sequences in the batch. If 376 not provided, all batch entries are assumed to be full sequences; and time 377 reversal is applied from time `0` to `max_time` for each sequence. 378 initial_state_fw: (optional) An initial state for the forward RNN. This must 379 be a tensor of appropriate type and shape `[batch_size, 380 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 381 tuple of tensors having shapes `[batch_size, s] for s in 382 cell_fw.state_size`. 383 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 384 corresponding properties of `cell_bw`. 385 dtype: (optional) The data type for the initial states and expected output. 386 Required if initial_states are not provided or RNN states have a 387 heterogeneous dtype. 388 parallel_iterations: (Default: 32). The number of iterations to run in 389 parallel. Those operations which do not have any temporal dependency and 390 can be run in parallel, will be. This parameter trades off time for 391 space. Values >> 1 use more memory but take less time, while smaller 392 values use less memory but computations take longer. 393 swap_memory: Transparently swap the tensors produced in forward inference 394 but needed for back prop from GPU to CPU. This allows training RNNs which 395 would typically not fit on a single GPU, with very minimal (or no) 396 performance penalty. 397 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 398 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 399 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 400 `time_major = True` is a bit more efficient because it avoids transposes 401 at the beginning and end of the RNN calculation. However, most TensorFlow 402 data is batch-major, so by default this function accepts input and emits 403 output in batch-major form. 404 scope: VariableScope for the created subgraph; defaults to 405 "bidirectional_rnn" 406 407 Returns: 408 A tuple (outputs, output_states) where: 409 outputs: A tuple (output_fw, output_bw) containing the forward and 410 the backward rnn output `Tensor`. 411 If time_major == False (default), 412 output_fw will be a `Tensor` shaped: 413 `[batch_size, max_time, cell_fw.output_size]` 414 and output_bw will be a `Tensor` shaped: 415 `[batch_size, max_time, cell_bw.output_size]`. 416 If time_major == True, 417 output_fw will be a `Tensor` shaped: 418 `[max_time, batch_size, cell_fw.output_size]` 419 and output_bw will be a `Tensor` shaped: 420 `[max_time, batch_size, cell_bw.output_size]`. 421 It returns a tuple instead of a single concatenated `Tensor`, unlike 422 in the `bidirectional_rnn`. If the concatenated one is preferred, 423 the forward and backward outputs can be concatenated as 424 `tf.concat(outputs, 2)`. 425 output_states: A tuple (output_state_fw, output_state_bw) containing 426 the forward and the backward final states of bidirectional rnn. 427 428 Raises: 429 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 430 """ 431 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 432 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 433 434 with vs.variable_scope(scope or "bidirectional_rnn"): 435 # Forward direction 436 with vs.variable_scope("fw") as fw_scope: 437 output_fw, output_state_fw = dynamic_rnn( 438 cell=cell_fw, 439 inputs=inputs, 440 sequence_length=sequence_length, 441 initial_state=initial_state_fw, 442 dtype=dtype, 443 parallel_iterations=parallel_iterations, 444 swap_memory=swap_memory, 445 time_major=time_major, 446 scope=fw_scope) 447 448 # Backward direction 449 if not time_major: 450 time_axis = 1 451 batch_axis = 0 452 else: 453 time_axis = 0 454 batch_axis = 1 455 456 def _reverse(input_, seq_lengths, seq_axis, batch_axis): 457 if seq_lengths is not None: 458 return array_ops.reverse_sequence( 459 input=input_, 460 seq_lengths=seq_lengths, 461 seq_axis=seq_axis, 462 batch_axis=batch_axis) 463 else: 464 return array_ops.reverse(input_, axis=[seq_axis]) 465 466 with vs.variable_scope("bw") as bw_scope: 467 468 def _map_reverse(inp): 469 return _reverse( 470 inp, 471 seq_lengths=sequence_length, 472 seq_axis=time_axis, 473 batch_axis=batch_axis) 474 475 inputs_reverse = nest.map_structure(_map_reverse, inputs) 476 tmp, output_state_bw = dynamic_rnn( 477 cell=cell_bw, 478 inputs=inputs_reverse, 479 sequence_length=sequence_length, 480 initial_state=initial_state_bw, 481 dtype=dtype, 482 parallel_iterations=parallel_iterations, 483 swap_memory=swap_memory, 484 time_major=time_major, 485 scope=bw_scope) 486 487 output_bw = _reverse( 488 tmp, 489 seq_lengths=sequence_length, 490 seq_axis=time_axis, 491 batch_axis=batch_axis) 492 493 outputs = (output_fw, output_bw) 494 output_states = (output_state_fw, output_state_bw) 495 496 return (outputs, output_states) 497 498 499@deprecation.deprecated( 500 None, 501 "Please use `keras.layers.RNN(cell)`, which is equivalent to this API") 502@tf_export(v1=["nn.dynamic_rnn"]) 503@dispatch.add_dispatch_support 504def dynamic_rnn(cell, 505 inputs, 506 sequence_length=None, 507 initial_state=None, 508 dtype=None, 509 parallel_iterations=None, 510 swap_memory=False, 511 time_major=False, 512 scope=None): 513 """Creates a recurrent neural network specified by RNNCell `cell`. 514 515 Performs fully dynamic unrolling of `inputs`. 516 517 Example: 518 519 ```python 520 # create a BasicRNNCell 521 rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) 522 523 # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 524 525 # defining initial state 526 initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 527 528 # 'state' is a tensor of shape [batch_size, cell_state_size] 529 outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data, 530 initial_state=initial_state, 531 dtype=tf.float32) 532 ``` 533 534 ```python 535 # create 2 LSTMCells 536 rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 537 538 # create a RNN cell composed sequentially of a number of RNNCells 539 multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) 540 541 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 542 # 'state' is a N-tuple where N is the number of LSTMCells containing a 543 # tf.nn.rnn_cell.LSTMStateTuple for each cell 544 outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, 545 inputs=data, 546 dtype=tf.float32) 547 ``` 548 549 550 Args: 551 cell: An instance of RNNCell. 552 inputs: The RNN inputs. 553 If `time_major == False` (default), this must be a `Tensor` of shape: 554 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 555 If `time_major == True`, this must be a `Tensor` of shape: `[max_time, 556 batch_size, ...]`, or a nested tuple of such elements. This may also be 557 a (possibly nested) tuple of Tensors satisfying this property. The 558 first two dimensions must match across all the inputs, but otherwise the 559 ranks and other shape components may differ. In this case, input to 560 `cell` at each time-step will replicate the structure of these tuples, 561 except for the time dimension (from which the time is taken). The input 562 to `cell` at each time step will be a `Tensor` or (possibly nested) 563 tuple of Tensors each with dimensions `[batch_size, ...]`. 564 sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used 565 to copy-through state and zero-out outputs when past a batch element's 566 sequence length. This parameter enables users to extract the last valid 567 state and properly padded outputs, so it is provided for correctness. 568 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 569 is an integer, this must be a `Tensor` of appropriate type and shape 570 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 571 should be a tuple of tensors having shapes `[batch_size, s] for s in 572 cell.state_size`. 573 dtype: (optional) The data type for the initial state and expected output. 574 Required if initial_state is not provided or RNN state has a heterogeneous 575 dtype. 576 parallel_iterations: (Default: 32). The number of iterations to run in 577 parallel. Those operations which do not have any temporal dependency and 578 can be run in parallel, will be. This parameter trades off time for 579 space. Values >> 1 use more memory but take less time, while smaller 580 values use less memory but computations take longer. 581 swap_memory: Transparently swap the tensors produced in forward inference 582 but needed for back prop from GPU to CPU. This allows training RNNs which 583 would typically not fit on a single GPU, with very minimal (or no) 584 performance penalty. 585 time_major: The shape format of the `inputs` and `outputs` Tensors. If true, 586 these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, 587 these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using 588 `time_major = True` is a bit more efficient because it avoids transposes 589 at the beginning and end of the RNN calculation. However, most TensorFlow 590 data is batch-major, so by default this function accepts input and emits 591 output in batch-major form. 592 scope: VariableScope for the created subgraph; defaults to "rnn". 593 594 Returns: 595 A pair (outputs, state) where: 596 597 outputs: The RNN output `Tensor`. 598 599 If time_major == False (default), this will be a `Tensor` shaped: 600 `[batch_size, max_time, cell.output_size]`. 601 602 If time_major == True, this will be a `Tensor` shaped: 603 `[max_time, batch_size, cell.output_size]`. 604 605 Note, if `cell.output_size` is a (possibly nested) tuple of integers 606 or `TensorShape` objects, then `outputs` will be a tuple having the 607 same structure as `cell.output_size`, containing Tensors having shapes 608 corresponding to the shape data in `cell.output_size`. 609 610 state: The final state. If `cell.state_size` is an int, this 611 will be shaped `[batch_size, cell.state_size]`. If it is a 612 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 613 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 614 be a tuple having the corresponding shapes. If cells are `LSTMCells` 615 `state` will be a tuple containing a `LSTMStateTuple` for each cell. 616 617 Raises: 618 TypeError: If `cell` is not an instance of RNNCell. 619 ValueError: If inputs is None or an empty list. 620 621 @compatibility(TF2) 622 `tf.compat.v1.nn.dynamic_rnn` is not compatible with eager execution and 623 `tf.function`. Please use `tf.keras.layers.RNN` instead for TF2 migration. 624 Take LSTM as an example, you can instantiate a `tf.keras.layers.RNN` layer 625 with `tf.keras.layers.LSTMCell`, or directly via `tf.keras.layers.LSTM`. Once 626 the keras layer is created, you can get the output and states by calling 627 the layer with input and states. Please refer to [this 628 guide](https://www.tensorflow.org/guide/keras/rnn) for more details about 629 Keras RNN. You can also find more details about the difference and comparison 630 between Keras RNN and TF compat v1 rnn in [this 631 document](https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md) 632 633 #### Structural Mapping to Native TF2 634 635 Before: 636 637 ```python 638 # create 2 LSTMCells 639 rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 640 641 # create a RNN cell composed sequentially of a number of RNNCells 642 multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) 643 644 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 645 # 'state' is a N-tuple where N is the number of LSTMCells containing a 646 # tf.nn.rnn_cell.LSTMStateTuple for each cell 647 outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, 648 inputs=data, 649 dtype=tf.float32) 650 ``` 651 652 After: 653 654 ```python 655 # RNN layer can take a list of cells, which will then stack them together. 656 # By default, keras RNN will only return the last timestep output and will not 657 # return states. If you need whole time sequence output as well as the states, 658 # you can set `return_sequences` and `return_state` to True. 659 rnn_layer = tf.keras.layers.RNN([tf.keras.layers.LSTMCell(128), 660 tf.keras.layers.LSTMCell(256)], 661 return_sequences=True, 662 return_state=True) 663 outputs, output_states = rnn_layer(inputs, states) 664 ``` 665 666 #### How to Map Arguments 667 668 | TF1 Arg Name | TF2 Arg Name | Note | 669 | :-------------------- | :-------------- | :------------------------------- | 670 | `cell` | `cell` | In the RNN layer constructor | 671 | `inputs` | `inputs` | In the RNN layer `__call__` | 672 | `sequence_length` | Not used | Adding masking layer before RNN : 673 : : : to achieve the same result. : 674 | `initial_state` | `initial_state` | In the RNN layer `__call__` | 675 | `dtype` | `dtype` | In the RNN layer constructor | 676 | `parallel_iterations` | Not supported | | 677 | `swap_memory` | Not supported | | 678 | `time_major` | `time_major` | In the RNN layer constructor | 679 | `scope` | Not supported | | 680 @end_compatibility 681 """ 682 rnn_cell_impl.assert_like_rnncell("cell", cell) 683 684 with vs.variable_scope(scope or "rnn") as varscope: 685 # Create a new scope in which the caching device is either 686 # determined by the parent scope, or is set to place the cached 687 # Variable using the same placement as for the rest of the RNN. 688 if _should_cache(): 689 if varscope.caching_device is None: 690 varscope.set_caching_device(lambda op: op.device) 691 692 # By default, time_major==False and inputs are batch-major: shaped 693 # [batch, time, depth] 694 # For internal calculations, we transpose to [time, batch, depth] 695 flat_input = nest.flatten(inputs) 696 697 if not time_major: 698 # (B,T,D) => (T,B,D) 699 flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 700 flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 701 702 parallel_iterations = parallel_iterations or 32 703 if sequence_length is not None: 704 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 705 if sequence_length.get_shape().rank not in (None, 1): 706 raise ValueError( 707 f"Argument sequence_length must be a vector of length batch_size." 708 f" Received sequence_length={sequence_length} of shape: " 709 f"{sequence_length.get_shape()}") 710 sequence_length = array_ops.identity( # Just to find it in the graph. 711 sequence_length, 712 name="sequence_length") 713 714 batch_size = _best_effort_input_batch_size(flat_input) 715 716 if initial_state is not None: 717 state = initial_state 718 else: 719 if not dtype: 720 raise ValueError("If no initial_state is provided, argument `dtype` " 721 "must be specified") 722 if getattr(cell, "get_initial_state", None) is not None: 723 state = cell.get_initial_state( 724 inputs=None, batch_size=batch_size, dtype=dtype) 725 else: 726 state = cell.zero_state(batch_size, dtype) 727 728 def _assert_has_shape(x, shape): 729 x_shape = array_ops.shape(x) 730 packed_shape = array_ops.stack(shape) 731 return control_flow_ops.Assert( 732 math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ 733 "Expected shape for Tensor %s is " % x.name, packed_shape, 734 " but saw shape: ", x_shape 735 ]) 736 737 if not context.executing_eagerly() and sequence_length is not None: 738 # Perform some shape validation 739 with ops.control_dependencies( 740 [_assert_has_shape(sequence_length, [batch_size])]): 741 sequence_length = array_ops.identity( 742 sequence_length, name="CheckSeqLen") 743 744 inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 745 746 (outputs, final_state) = _dynamic_rnn_loop( 747 cell, 748 inputs, 749 state, 750 parallel_iterations=parallel_iterations, 751 swap_memory=swap_memory, 752 sequence_length=sequence_length, 753 dtype=dtype) 754 755 # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 756 # If we are performing batch-major calculations, transpose output back 757 # to shape [batch, time, depth] 758 if not time_major: 759 # (T,B,D) => (B,T,D) 760 outputs = nest.map_structure(_transpose_batch_time, outputs) 761 762 return (outputs, final_state) 763 764 765def _dynamic_rnn_loop(cell, 766 inputs, 767 initial_state, 768 parallel_iterations, 769 swap_memory, 770 sequence_length=None, 771 dtype=None): 772 """Internal implementation of Dynamic RNN. 773 774 Args: 775 cell: An instance of RNNCell. 776 inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested 777 tuple of such elements. 778 initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if 779 `cell.state_size` is a tuple, then this should be a tuple of tensors 780 having shapes `[batch_size, s] for s in cell.state_size`. 781 parallel_iterations: Positive Python int. 782 swap_memory: A Python boolean 783 sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. 784 dtype: (optional) Expected dtype of output. If not specified, inferred from 785 initial_state. 786 787 Returns: 788 Tuple `(final_outputs, final_state)`. 789 final_outputs: 790 A `Tensor` of shape `[time, batch_size, cell.output_size]`. If 791 `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape` 792 objects, then this returns a (possibly nested) tuple of Tensors matching 793 the corresponding shapes. 794 final_state: 795 A `Tensor`, or possibly nested tuple of Tensors, matching in length 796 and shapes to `initial_state`. 797 798 Raises: 799 ValueError: If the input depth cannot be inferred via shape inference 800 from the inputs. 801 ValueError: If time_step is not the same for all the elements in the 802 inputs. 803 ValueError: If batch_size is not the same for all the elements in the 804 inputs. 805 """ 806 state = initial_state 807 assert isinstance(parallel_iterations, int), "parallel_iterations must be int" 808 809 state_size = cell.state_size 810 811 flat_input = nest.flatten(inputs) 812 flat_output_size = nest.flatten(cell.output_size) 813 814 # Construct an initial output 815 input_shape = array_ops.shape(flat_input[0]) 816 time_steps = input_shape[0] 817 batch_size = _best_effort_input_batch_size(flat_input) 818 819 inputs_got_shape = tuple( 820 input_.get_shape().with_rank_at_least(3) for input_ in flat_input) 821 822 const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2] 823 824 for i, shape in enumerate(inputs_got_shape): 825 if not shape[2:].is_fully_defined(): 826 raise ValueError( 827 "Input size (depth of inputs) must be accessible via shape inference," 828 f" but saw value None for input={flat_input[i]}.") 829 got_time_steps = shape.dims[0].value 830 got_batch_size = shape.dims[1].value 831 if const_time_steps != got_time_steps: 832 raise ValueError( 833 "Time steps is not the same for all the elements in the input in a " 834 f"batch. Received time steps={got_time_steps} for input=" 835 f"{flat_input[i]}.") 836 if const_batch_size != got_batch_size: 837 raise ValueError( 838 "Batch_size is not the same for all the elements in the input. " 839 f"Received batch size={got_batch_size} for input={flat_input[i]}.") 840 841 # Prepare dynamic conditional copying of state & output 842 def _create_zero_arrays(size): 843 size = _concat(batch_size, size) 844 return array_ops.zeros( 845 array_ops.stack(size), _infer_state_dtype(dtype, state)) 846 847 flat_zero_output = tuple( 848 _create_zero_arrays(output) for output in flat_output_size) 849 zero_output = nest.pack_sequence_as( 850 structure=cell.output_size, flat_sequence=flat_zero_output) 851 852 if sequence_length is not None: 853 min_sequence_length = math_ops.reduce_min(sequence_length) 854 max_sequence_length = math_ops.reduce_max(sequence_length) 855 else: 856 max_sequence_length = time_steps 857 858 time = array_ops.constant(0, dtype=dtypes.int32, name="time") 859 860 with ops.name_scope("dynamic_rnn") as scope: 861 base_name = scope 862 863 def _create_ta(name, element_shape, dtype): 864 return tensor_array_ops.TensorArray( 865 dtype=dtype, 866 size=time_steps, 867 element_shape=element_shape, 868 tensor_array_name=base_name + name) 869 870 in_graph_mode = not context.executing_eagerly() 871 if in_graph_mode: 872 output_ta = tuple( 873 _create_ta( 874 "output_%d" % i, 875 element_shape=( 876 tensor_shape.TensorShape([const_batch_size]).concatenate( 877 _maybe_tensor_shape_from_tensor(out_size))), 878 dtype=_infer_state_dtype(dtype, state)) 879 for i, out_size in enumerate(flat_output_size)) 880 input_ta = tuple( 881 _create_ta( 882 "input_%d" % i, 883 element_shape=flat_input_i.shape[1:], 884 dtype=flat_input_i.dtype) 885 for i, flat_input_i in enumerate(flat_input)) 886 input_ta = tuple( 887 ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input)) 888 else: 889 output_ta = tuple([0 for _ in range(time_steps.numpy())] 890 for i in range(len(flat_output_size))) 891 input_ta = flat_input 892 893 def _time_step(time, output_ta_t, state): 894 """Take a time step of the dynamic RNN. 895 896 Args: 897 time: int32 scalar Tensor. 898 output_ta_t: List of `TensorArray`s that represent the output. 899 state: nested tuple of vector tensors that represent the state. 900 901 Returns: 902 The tuple (time + 1, output_ta_t with updated flow, new_state). 903 """ 904 905 if in_graph_mode: 906 input_t = tuple(ta.read(time) for ta in input_ta) 907 # Restore some shape information 908 for input_, shape in zip(input_t, inputs_got_shape): 909 input_.set_shape(shape[1:]) 910 else: 911 input_t = tuple(ta[time.numpy()] for ta in input_ta) 912 913 input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) 914 # Keras RNN cells only accept state as list, even if it's a single tensor. 915 call_cell = lambda: cell(input_t, state) 916 917 if sequence_length is not None: 918 (output, new_state) = _rnn_step( 919 time=time, 920 sequence_length=sequence_length, 921 min_sequence_length=min_sequence_length, 922 max_sequence_length=max_sequence_length, 923 zero_output=zero_output, 924 state=state, 925 call_cell=call_cell, 926 state_size=state_size, 927 skip_conditionals=True) 928 else: 929 (output, new_state) = call_cell() 930 931 # Pack state if using state tuples 932 output = nest.flatten(output) 933 934 if in_graph_mode: 935 output_ta_t = tuple( 936 ta.write(time, out) for ta, out in zip(output_ta_t, output)) 937 else: 938 for ta, out in zip(output_ta_t, output): 939 ta[time.numpy()] = out 940 941 return (time + 1, output_ta_t, new_state) 942 943 if in_graph_mode: 944 # Make sure that we run at least 1 step, if necessary, to ensure 945 # the TensorArrays pick up the dynamic shape. 946 loop_bound = math_ops.minimum(time_steps, 947 math_ops.maximum(1, max_sequence_length)) 948 else: 949 # Using max_sequence_length isn't currently supported in the Eager branch. 950 loop_bound = time_steps 951 952 _, output_final_ta, final_state = control_flow_ops.while_loop( 953 cond=lambda time, *_: time < loop_bound, 954 body=_time_step, 955 loop_vars=(time, output_ta, state), 956 parallel_iterations=parallel_iterations, 957 maximum_iterations=time_steps, 958 swap_memory=swap_memory) 959 960 # Unpack final output if not using output tuples. 961 if in_graph_mode: 962 final_outputs = tuple(ta.stack() for ta in output_final_ta) 963 # Restore some shape information 964 for output, output_size in zip(final_outputs, flat_output_size): 965 shape = _concat([const_time_steps, const_batch_size], 966 output_size, 967 static=True) 968 output.set_shape(shape) 969 else: 970 final_outputs = output_final_ta 971 972 final_outputs = nest.pack_sequence_as( 973 structure=cell.output_size, flat_sequence=final_outputs) 974 if not in_graph_mode: 975 final_outputs = nest.map_structure_up_to( 976 cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs) 977 978 return (final_outputs, final_state) 979 980 981@tf_export(v1=["nn.raw_rnn"]) 982@dispatch.add_dispatch_support 983def raw_rnn(cell, 984 loop_fn, 985 parallel_iterations=None, 986 swap_memory=False, 987 scope=None): 988 """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`. 989 990 **NOTE: This method is still in testing, and the API may change.** 991 992 This function is a more primitive version of `dynamic_rnn` that provides 993 more direct access to the inputs each iteration. It also provides more 994 control over when to start and finish reading the sequence, and 995 what to emit for the output. 996 997 For example, it can be used to implement the dynamic decoder of a seq2seq 998 model. 999 1000 Instead of working with `Tensor` objects, most operations work with 1001 `TensorArray` objects directly. 1002 1003 The operation of `raw_rnn`, in pseudo-code, is basically the following: 1004 1005 ```python 1006 time = tf.constant(0, dtype=tf.int32) 1007 (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn( 1008 time=time, cell_output=None, cell_state=None, loop_state=None) 1009 emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype) 1010 state = initial_state 1011 while not all(finished): 1012 (output, cell_state) = cell(next_input, state) 1013 (next_finished, next_input, next_state, emit, loop_state) = loop_fn( 1014 time=time + 1, cell_output=output, cell_state=cell_state, 1015 loop_state=loop_state) 1016 # Emit zeros and copy forward state for minibatch entries that are finished. 1017 state = tf.where(finished, state, next_state) 1018 emit = tf.where(finished, tf.zeros_like(emit_structure), emit) 1019 emit_ta = emit_ta.write(time, emit) 1020 # If any new minibatch entries are marked as finished, mark these. 1021 finished = tf.logical_or(finished, next_finished) 1022 time += 1 1023 return (emit_ta, state, loop_state) 1024 ``` 1025 1026 with the additional properties that output and state may be (possibly nested) 1027 tuples, as determined by `cell.output_size` and `cell.state_size`, and 1028 as a result the final `state` and `emit_ta` may themselves be tuples. 1029 1030 A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this: 1031 1032 ```python 1033 inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth), 1034 dtype=tf.float32) 1035 sequence_length = tf.compat.v1.placeholder(shape=(batch_size,), 1036 dtype=tf.int32) 1037 inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 1038 inputs_ta = inputs_ta.unstack(inputs) 1039 1040 cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units) 1041 1042 def loop_fn(time, cell_output, cell_state, loop_state): 1043 emit_output = cell_output # == None for time == 0 1044 if cell_output is None: # time == 0 1045 next_cell_state = cell.zero_state(batch_size, tf.float32) 1046 else: 1047 next_cell_state = cell_state 1048 elements_finished = (time >= sequence_length) 1049 finished = tf.reduce_all(elements_finished) 1050 next_input = tf.cond( 1051 finished, 1052 lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32), 1053 lambda: inputs_ta.read(time)) 1054 next_loop_state = None 1055 return (elements_finished, next_input, next_cell_state, 1056 emit_output, next_loop_state) 1057 1058 outputs_ta, final_state, _ = raw_rnn(cell, loop_fn) 1059 outputs = outputs_ta.stack() 1060 ``` 1061 1062 Args: 1063 cell: An instance of RNNCell. 1064 loop_fn: A callable that takes inputs `(time, cell_output, cell_state, 1065 loop_state)` and returns the tuple `(finished, next_input, 1066 next_cell_state, emit_output, next_loop_state)`. Here `time` is an int32 1067 scalar `Tensor`, `cell_output` is a `Tensor` or (possibly nested) tuple of 1068 tensors as determined by `cell.output_size`, and `cell_state` is a 1069 `Tensor` or (possibly nested) tuple of tensors, as determined by the 1070 `loop_fn` on its first call (and should match `cell.state_size`). 1071 The outputs are: `finished`, a boolean `Tensor` of 1072 shape `[batch_size]`, `next_input`: the next input to feed to `cell`, 1073 `next_cell_state`: the next state to feed to `cell`, 1074 and `emit_output`: the output to store for this iteration. Note that 1075 `emit_output` should be a `Tensor` or (possibly nested) tuple of tensors 1076 which is aggregated in the `emit_ta` inside the `while_loop`. For the 1077 first call to `loop_fn`, the `emit_output` corresponds to the 1078 `emit_structure` which is then used to determine the size of the 1079 `zero_tensor` for the `emit_ta` (defaults to `cell.output_size`). For 1080 the subsequent calls to the `loop_fn`, the `emit_output` corresponds to 1081 the actual output tensor that is to be aggregated in the `emit_ta`. The 1082 parameter `cell_state` and output `next_cell_state` may be either a 1083 single or (possibly nested) tuple of tensors. The parameter 1084 `loop_state` and output `next_loop_state` may be either a single or 1085 (possibly nested) tuple of `Tensor` and `TensorArray` objects. This 1086 last parameter may be ignored by `loop_fn` and the return value may be 1087 `None`. If it is not `None`, then the `loop_state` will be propagated 1088 through the RNN loop, for use purely by `loop_fn` to keep track of its 1089 own state. The `next_loop_state` parameter returned may be `None`. The 1090 first call to `loop_fn` will be `time = 0`, `cell_output = None`, 1091 `cell_state = None`, and `loop_state = None`. For this call: The 1092 `next_cell_state` value should be the value with which to initialize the 1093 cell's state. It may be a final state from a previous RNN or it may be 1094 the output of `cell.zero_state()`. It should be a (possibly nested) 1095 tuple structure of tensors. If `cell.state_size` is an integer, this 1096 must be a `Tensor` of appropriate type and shape `[batch_size, 1097 cell.state_size]`. If `cell.state_size` is a `TensorShape`, this must be 1098 a `Tensor` of appropriate type and shape `[batch_size] + 1099 cell.state_size`. If `cell.state_size` is a (possibly nested) tuple of 1100 ints or `TensorShape`, this will be a tuple having the corresponding 1101 shapes. The `emit_output` value may be either `None` or a (possibly 1102 nested) tuple structure of tensors, e.g., `(tf.zeros(shape_0, 1103 dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. If this first 1104 `emit_output` return value is `None`, then the `emit_ta` result of 1105 `raw_rnn` will have the same structure and dtypes as `cell.output_size`. 1106 Otherwise `emit_ta` will have the same structure, shapes (prepended with 1107 a `batch_size` dimension), and dtypes as `emit_output`. The actual 1108 values returned for `emit_output` at this initializing call are ignored. 1109 Note, this emit structure must be consistent across all time steps. 1110 parallel_iterations: (Default: 32). The number of iterations to run in 1111 parallel. Those operations which do not have any temporal dependency and 1112 can be run in parallel, will be. This parameter trades off time for 1113 space. Values >> 1 use more memory but take less time, while smaller 1114 values use less memory but computations take longer. 1115 swap_memory: Transparently swap the tensors produced in forward inference 1116 but needed for back prop from GPU to CPU. This allows training RNNs which 1117 would typically not fit on a single GPU, with very minimal (or no) 1118 performance penalty. 1119 scope: VariableScope for the created subgraph; defaults to "rnn". 1120 1121 Returns: 1122 A tuple `(emit_ta, final_state, final_loop_state)` where: 1123 1124 `emit_ta`: The RNN output `TensorArray`. 1125 If `loop_fn` returns a (possibly nested) set of Tensors for 1126 `emit_output` during initialization, (inputs `time = 0`, 1127 `cell_output = None`, and `loop_state = None`), then `emit_ta` will 1128 have the same structure, dtypes, and shapes as `emit_output` instead. 1129 If `loop_fn` returns `emit_output = None` during this call, 1130 the structure of `cell.output_size` is used: 1131 If `cell.output_size` is a (possibly nested) tuple of integers 1132 or `TensorShape` objects, then `emit_ta` will be a tuple having the 1133 same structure as `cell.output_size`, containing TensorArrays whose 1134 elements' shapes correspond to the shape data in `cell.output_size`. 1135 1136 `final_state`: The final cell state. If `cell.state_size` is an int, this 1137 will be shaped `[batch_size, cell.state_size]`. If it is a 1138 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 1139 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 1140 be a tuple having the corresponding shapes. 1141 1142 `final_loop_state`: The final loop state as returned by `loop_fn`. 1143 1144 Raises: 1145 TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not 1146 a `callable`. 1147 """ 1148 rnn_cell_impl.assert_like_rnncell("cell", cell) 1149 1150 if not callable(loop_fn): 1151 raise TypeError("Argument `loop_fn` must be a callable. Received: " 1152 f"{loop_fn}.") 1153 1154 parallel_iterations = parallel_iterations or 32 1155 1156 # Create a new scope in which the caching device is either 1157 # determined by the parent scope, or is set to place the cached 1158 # Variable using the same placement as for the rest of the RNN. 1159 with vs.variable_scope(scope or "rnn") as varscope: 1160 if _should_cache(): 1161 if varscope.caching_device is None: 1162 varscope.set_caching_device(lambda op: op.device) 1163 1164 time = constant_op.constant(0, dtype=dtypes.int32) 1165 (elements_finished, next_input, 1166 initial_state, emit_structure, init_loop_state) = loop_fn( 1167 time, None, None, None) # time, cell_output, cell_state, loop_state 1168 flat_input = nest.flatten(next_input) 1169 1170 # Need a surrogate loop state for the while_loop if none is available. 1171 loop_state = ( 1172 init_loop_state if init_loop_state is not None else 1173 constant_op.constant(0, dtype=dtypes.int32)) 1174 1175 input_shape = [input_.get_shape() for input_ in flat_input] 1176 static_batch_size = tensor_shape.dimension_at_index(input_shape[0], 0) 1177 1178 for input_shape_i in input_shape: 1179 # Static verification that batch sizes all match 1180 static_batch_size.assert_is_compatible_with( 1181 tensor_shape.dimension_at_index(input_shape_i, 0)) 1182 1183 batch_size = tensor_shape.dimension_value(static_batch_size) 1184 const_batch_size = batch_size 1185 if batch_size is None: 1186 batch_size = array_ops.shape(flat_input[0])[0] 1187 1188 nest.assert_same_structure(initial_state, cell.state_size) 1189 state = initial_state 1190 flat_state = nest.flatten(state) 1191 flat_state = [ops.convert_to_tensor(s) for s in flat_state] 1192 state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state) 1193 1194 if emit_structure is not None: 1195 flat_emit_structure = nest.flatten(emit_structure) 1196 flat_emit_size = [ 1197 emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit) 1198 for emit in flat_emit_structure 1199 ] 1200 flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] 1201 else: 1202 emit_structure = cell.output_size 1203 flat_emit_size = nest.flatten(emit_structure) 1204 flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) 1205 1206 flat_emit_ta = [ 1207 tensor_array_ops.TensorArray( 1208 dtype=dtype_i, 1209 dynamic_size=True, 1210 element_shape=(tensor_shape.TensorShape([ 1211 const_batch_size 1212 ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))), 1213 size=0, 1214 name="rnn_output_%d" % i) 1215 for i, (dtype_i, 1216 size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size)) 1217 ] 1218 emit_ta = nest.pack_sequence_as( 1219 structure=emit_structure, flat_sequence=flat_emit_ta) 1220 flat_zero_emit = [ 1221 array_ops.zeros(_concat(batch_size, size_i), dtype_i) 1222 for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes) 1223 ] 1224 zero_emit = nest.pack_sequence_as( 1225 structure=emit_structure, flat_sequence=flat_zero_emit) 1226 1227 def condition(unused_time, elements_finished, *_): 1228 return math_ops.logical_not(math_ops.reduce_all(elements_finished)) 1229 1230 def body(time, elements_finished, current_input, emit_ta, state, 1231 loop_state): 1232 """Internal while loop body for raw_rnn. 1233 1234 Args: 1235 time: time scalar. 1236 elements_finished: batch-size vector. 1237 current_input: possibly nested tuple of input tensors. 1238 emit_ta: possibly nested tuple of output TensorArrays. 1239 state: possibly nested tuple of state tensors. 1240 loop_state: possibly nested tuple of loop state tensors. 1241 1242 Returns: 1243 Tuple having the same size as Args but with updated values. 1244 """ 1245 (next_output, cell_state) = cell(current_input, state) 1246 1247 nest.assert_same_structure(state, cell_state) 1248 nest.assert_same_structure(cell.output_size, next_output) 1249 1250 next_time = time + 1 1251 (next_finished, next_input, next_state, emit_output, 1252 next_loop_state) = loop_fn(next_time, next_output, cell_state, 1253 loop_state) 1254 1255 nest.assert_same_structure(state, next_state) 1256 nest.assert_same_structure(current_input, next_input) 1257 nest.assert_same_structure(emit_ta, emit_output) 1258 1259 # If loop_fn returns None for next_loop_state, just reuse the 1260 # previous one. 1261 loop_state = loop_state if next_loop_state is None else next_loop_state 1262 1263 def _copy_some_through(current, candidate): 1264 """Copy some tensors through via array_ops.where.""" 1265 1266 def copy_fn(cur_i, cand_i): 1267 # TensorArray and scalar get passed through. 1268 if isinstance(cur_i, tensor_array_ops.TensorArray): 1269 return cand_i 1270 if cur_i.shape.rank == 0: 1271 return cand_i 1272 # Otherwise propagate the old or the new value. 1273 with ops.colocate_with(cand_i): 1274 return array_ops.where(elements_finished, cur_i, cand_i) 1275 1276 return nest.map_structure(copy_fn, current, candidate) 1277 1278 emit_output = _copy_some_through(zero_emit, emit_output) 1279 next_state = _copy_some_through(state, next_state) 1280 1281 emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), 1282 emit_ta, emit_output) 1283 1284 elements_finished = math_ops.logical_or(elements_finished, next_finished) 1285 1286 return (next_time, elements_finished, next_input, emit_ta, next_state, 1287 loop_state) 1288 1289 returned = control_flow_ops.while_loop( 1290 condition, 1291 body, 1292 loop_vars=[ 1293 time, elements_finished, next_input, emit_ta, state, loop_state 1294 ], 1295 parallel_iterations=parallel_iterations, 1296 swap_memory=swap_memory) 1297 1298 (emit_ta, final_state, final_loop_state) = returned[-3:] 1299 1300 if init_loop_state is None: 1301 final_loop_state = None 1302 1303 return (emit_ta, final_state, final_loop_state) 1304 1305 1306@deprecation.deprecated(None, 1307 "Please use `keras.layers.RNN(cell, unroll=True)`, " 1308 "which is equivalent to this API") 1309@tf_export(v1=["nn.static_rnn"]) 1310@dispatch.add_dispatch_support 1311def static_rnn(cell, 1312 inputs, 1313 initial_state=None, 1314 dtype=None, 1315 sequence_length=None, 1316 scope=None): 1317 """Creates a recurrent neural network specified by RNNCell `cell`. 1318 1319 The simplest form of RNN network generated is: 1320 1321 ```python 1322 state = cell.zero_state(...) 1323 outputs = [] 1324 for input_ in inputs: 1325 output, state = cell(input_, state) 1326 outputs.append(output) 1327 return (outputs, state) 1328 ``` 1329 However, a few other options are available: 1330 1331 An initial state can be provided. 1332 If the sequence_length vector is provided, dynamic calculation is performed. 1333 This method of calculation does not compute the RNN steps past the maximum 1334 sequence length of the minibatch (thus saving computational time), 1335 and properly propagates the state at an example's sequence length 1336 to the final state output. 1337 1338 The dynamic calculation performed is, at time `t` for batch row `b`, 1339 1340 ```python 1341 (output, state)(b, t) = 1342 (t >= sequence_length(b)) 1343 ? (zeros(cell.output_size), states(b, sequence_length(b) - 1)) 1344 : cell(input(b, t), state(b, t - 1)) 1345 ``` 1346 1347 Args: 1348 cell: An instance of RNNCell. 1349 inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size, 1350 input_size]`, or a nested tuple of such elements. 1351 initial_state: (optional) An initial state for the RNN. If `cell.state_size` 1352 is an integer, this must be a `Tensor` of appropriate type and shape 1353 `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this 1354 should be a tuple of tensors having shapes `[batch_size, s] for s in 1355 cell.state_size`. 1356 dtype: (optional) The data type for the initial state and expected output. 1357 Required if initial_state is not provided or RNN state has a heterogeneous 1358 dtype. 1359 sequence_length: Specifies the length of each sequence in inputs. An int32 1360 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`. 1361 scope: VariableScope for the created subgraph; defaults to "rnn". 1362 1363 Returns: 1364 A pair (outputs, state) where: 1365 1366 - outputs is a length T list of outputs (one for each input), or a nested 1367 tuple of such elements. 1368 - state is the final state 1369 1370 Raises: 1371 TypeError: If `cell` is not an instance of RNNCell. 1372 ValueError: If `inputs` is `None` or an empty list, or if the input depth 1373 (column size) cannot be inferred from inputs via shape inference. 1374 """ 1375 rnn_cell_impl.assert_like_rnncell("cell", cell) 1376 if not nest.is_nested(inputs): 1377 raise TypeError(f"Argument `inputs` must be a sequence. Received: {inputs}") 1378 if not inputs: 1379 raise ValueError("Argument `inputs` must not be empty.") 1380 1381 outputs = [] 1382 # Create a new scope in which the caching device is either 1383 # determined by the parent scope, or is set to place the cached 1384 # Variable using the same placement as for the rest of the RNN. 1385 with vs.variable_scope(scope or "rnn") as varscope: 1386 if _should_cache(): 1387 if varscope.caching_device is None: 1388 varscope.set_caching_device(lambda op: op.device) 1389 1390 # Obtain the first sequence of the input 1391 first_input = inputs 1392 while nest.is_nested(first_input): 1393 first_input = first_input[0] 1394 1395 # Temporarily avoid EmbeddingWrapper and seq2seq badness 1396 # TODO(lukaszkaiser): remove EmbeddingWrapper 1397 if first_input.get_shape().rank != 1: 1398 1399 input_shape = first_input.get_shape().with_rank_at_least(2) 1400 fixed_batch_size = input_shape.dims[0] 1401 1402 flat_inputs = nest.flatten(inputs) 1403 for flat_input in flat_inputs: 1404 input_shape = flat_input.get_shape().with_rank_at_least(2) 1405 batch_size, input_size = tensor_shape.dimension_at_index( 1406 input_shape, 0), input_shape[1:] 1407 fixed_batch_size.assert_is_compatible_with(batch_size) 1408 for i, size in enumerate(input_size.dims): 1409 if tensor_shape.dimension_value(size) is None: 1410 raise ValueError( 1411 f"Input size (dimension {i} of input {flat_input}) must be " 1412 "accessible via shape inference, but saw value None.") 1413 else: 1414 fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0] 1415 1416 if tensor_shape.dimension_value(fixed_batch_size): 1417 batch_size = tensor_shape.dimension_value(fixed_batch_size) 1418 else: 1419 batch_size = array_ops.shape(first_input)[0] 1420 if initial_state is not None: 1421 state = initial_state 1422 else: 1423 if not dtype: 1424 raise ValueError("If no initial_state is provided, argument `dtype` " 1425 "must be specified") 1426 if getattr(cell, "get_initial_state", None) is not None: 1427 state = cell.get_initial_state( 1428 inputs=None, batch_size=batch_size, dtype=dtype) 1429 else: 1430 state = cell.zero_state(batch_size, dtype) 1431 1432 if sequence_length is not None: # Prepare variables 1433 sequence_length = ops.convert_to_tensor( 1434 sequence_length, name="sequence_length") 1435 if sequence_length.get_shape().rank not in (None, 1): 1436 raise ValueError( 1437 "Argument `sequence_length` must be a vector of length " 1438 f"{batch_size}. Received sequence_length={sequence_length}.") 1439 1440 def _create_zero_output(output_size): 1441 # convert int to TensorShape if necessary 1442 size = _concat(batch_size, output_size) 1443 output = array_ops.zeros( 1444 array_ops.stack(size), _infer_state_dtype(dtype, state)) 1445 shape = _concat( 1446 tensor_shape.dimension_value(fixed_batch_size), 1447 output_size, 1448 static=True) 1449 output.set_shape(tensor_shape.TensorShape(shape)) 1450 return output 1451 1452 output_size = cell.output_size 1453 flat_output_size = nest.flatten(output_size) 1454 flat_zero_output = tuple( 1455 _create_zero_output(size) for size in flat_output_size) 1456 zero_output = nest.pack_sequence_as( 1457 structure=output_size, flat_sequence=flat_zero_output) 1458 1459 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 1460 min_sequence_length = math_ops.reduce_min(sequence_length) 1461 max_sequence_length = math_ops.reduce_max(sequence_length) 1462 1463 for time, input_ in enumerate(inputs): 1464 if time > 0: 1465 varscope.reuse_variables() 1466 # pylint: disable=cell-var-from-loop 1467 call_cell = lambda: cell(input_, state) 1468 # pylint: enable=cell-var-from-loop 1469 if sequence_length is not None: 1470 (output, state) = _rnn_step( 1471 time=time, 1472 sequence_length=sequence_length, 1473 min_sequence_length=min_sequence_length, 1474 max_sequence_length=max_sequence_length, 1475 zero_output=zero_output, 1476 state=state, 1477 call_cell=call_cell, 1478 state_size=cell.state_size) 1479 else: 1480 (output, state) = call_cell() 1481 outputs.append(output) 1482 1483 return (outputs, state) 1484 1485 1486@deprecation.deprecated(None, 1487 "Please use `keras.layers.RNN(cell, stateful=True)`, " 1488 "which is equivalent to this API") 1489@tf_export(v1=["nn.static_state_saving_rnn"]) 1490@dispatch.add_dispatch_support 1491def static_state_saving_rnn(cell, 1492 inputs, 1493 state_saver, 1494 state_name, 1495 sequence_length=None, 1496 scope=None): 1497 """RNN that accepts a state saver for time-truncated RNN calculation. 1498 1499 Args: 1500 cell: An instance of `RNNCell`. 1501 inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size, 1502 input_size]`. 1503 state_saver: A state saver object with methods `state` and `save_state`. 1504 state_name: Python string or tuple of strings. The name to use with the 1505 state_saver. If the cell returns tuples of states (i.e., `cell.state_size` 1506 is a tuple) then `state_name` should be a tuple of strings having the same 1507 length as `cell.state_size`. Otherwise it should be a single string. 1508 sequence_length: (optional) An int32/int64 vector size [batch_size]. See the 1509 documentation for rnn() for more details about sequence_length. 1510 scope: VariableScope for the created subgraph; defaults to "rnn". 1511 1512 Returns: 1513 A pair (outputs, state) where: 1514 outputs is a length T list of outputs (one for each input) 1515 states is the final state 1516 1517 Raises: 1518 TypeError: If `cell` is not an instance of RNNCell. 1519 ValueError: If `inputs` is `None` or an empty list, or if the arity and 1520 type of `state_name` does not match that of `cell.state_size`. 1521 """ 1522 state_size = cell.state_size 1523 state_is_tuple = nest.is_nested(state_size) 1524 state_name_tuple = nest.is_nested(state_name) 1525 1526 if state_is_tuple != state_name_tuple: 1527 raise ValueError("Argument `state_name` should be the same type as " 1528 f"`cell.state_size`. Received: state_name={state_name!s}, " 1529 f"cell.state_size={state_size!s}.") 1530 1531 if state_is_tuple: 1532 state_name_flat = nest.flatten(state_name) 1533 state_size_flat = nest.flatten(state_size) 1534 1535 if len(state_name_flat) != len(state_size_flat): 1536 raise ValueError("Number of elements in argument `state_name` and " 1537 "`cell.state_size` are mismatched. Received " 1538 f"state_name={state_name} with {len(state_name_flat)} " 1539 f"elements and cell.state_size={cell.state_size} with " 1540 f"{len(state_size_flat)} elements.") 1541 1542 initial_state = nest.pack_sequence_as( 1543 structure=state_size, 1544 flat_sequence=[state_saver.state(s) for s in state_name_flat]) 1545 else: 1546 initial_state = state_saver.state(state_name) 1547 1548 (outputs, state) = static_rnn( 1549 cell, 1550 inputs, 1551 initial_state=initial_state, 1552 sequence_length=sequence_length, 1553 scope=scope) 1554 1555 if state_is_tuple: 1556 flat_state = nest.flatten(state) 1557 state_name = nest.flatten(state_name) 1558 save_state = [ 1559 state_saver.save_state(name, substate) 1560 for name, substate in zip(state_name, flat_state) 1561 ] 1562 else: 1563 save_state = [state_saver.save_state(state_name, state)] 1564 1565 with ops.control_dependencies(save_state): 1566 last_output = outputs[-1] 1567 flat_last_output = nest.flatten(last_output) 1568 flat_last_output = [ 1569 array_ops.identity(output) for output in flat_last_output 1570 ] 1571 outputs[-1] = nest.pack_sequence_as( 1572 structure=last_output, flat_sequence=flat_last_output) 1573 1574 if state_is_tuple: 1575 state = nest.pack_sequence_as( 1576 structure=state, 1577 flat_sequence=[array_ops.identity(s) for s in flat_state]) 1578 else: 1579 state = array_ops.identity(state) 1580 1581 return (outputs, state) 1582 1583 1584@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 1585 "keras.layers.RNN(cell, unroll=True))`, which is " 1586 "equivalent to this API") 1587@tf_export(v1=["nn.static_bidirectional_rnn"]) 1588@dispatch.add_dispatch_support 1589def static_bidirectional_rnn(cell_fw, 1590 cell_bw, 1591 inputs, 1592 initial_state_fw=None, 1593 initial_state_bw=None, 1594 dtype=None, 1595 sequence_length=None, 1596 scope=None): 1597 """Creates a bidirectional recurrent neural network. 1598 1599 Similar to the unidirectional case above (rnn) but takes input and builds 1600 independent forward and backward RNNs with the final forward and backward 1601 outputs depth-concatenated, such that the output will have the format 1602 [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of 1603 forward and backward cell must match. The initial state for both directions 1604 is zero by default (but can be set optionally) and no intermediate states are 1605 ever returned -- the network is fully unrolled for the given (passed in) 1606 length(s) of the sequence(s) or completely unrolled if length(s) is not given. 1607 1608 Args: 1609 cell_fw: An instance of RNNCell, to be used for forward direction. 1610 cell_bw: An instance of RNNCell, to be used for backward direction. 1611 inputs: A length T list of inputs, each a tensor of shape [batch_size, 1612 input_size], or a nested tuple of such elements. 1613 initial_state_fw: (optional) An initial state for the forward RNN. This must 1614 be a tensor of appropriate type and shape `[batch_size, 1615 cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a 1616 tuple of tensors having shapes `[batch_size, s] for s in 1617 cell_fw.state_size`. 1618 initial_state_bw: (optional) Same as for `initial_state_fw`, but using the 1619 corresponding properties of `cell_bw`. 1620 dtype: (optional) The data type for the initial state. Required if either 1621 of the initial states are not provided. 1622 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 1623 containing the actual lengths for each of the sequences. 1624 scope: VariableScope for the created subgraph; defaults to 1625 "bidirectional_rnn" 1626 1627 Returns: 1628 A tuple (outputs, output_state_fw, output_state_bw) where: 1629 outputs is a length `T` list of outputs (one for each input), which 1630 are depth-concatenated forward and backward outputs. 1631 output_state_fw is the final state of the forward rnn. 1632 output_state_bw is the final state of the backward rnn. 1633 1634 Raises: 1635 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 1636 ValueError: If inputs is None or an empty list. 1637 """ 1638 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 1639 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 1640 if not nest.is_nested(inputs): 1641 raise TypeError(f"Argument `inputs` must be a sequence. Received: {inputs}") 1642 if not inputs: 1643 raise ValueError("Argument `inputs` must not be empty.") 1644 1645 with vs.variable_scope(scope or "bidirectional_rnn"): 1646 # Forward direction 1647 with vs.variable_scope("fw") as fw_scope: 1648 output_fw, output_state_fw = static_rnn( 1649 cell_fw, 1650 inputs, 1651 initial_state_fw, 1652 dtype, 1653 sequence_length, 1654 scope=fw_scope) 1655 1656 # Backward direction 1657 with vs.variable_scope("bw") as bw_scope: 1658 reversed_inputs = _reverse_seq(inputs, sequence_length) 1659 tmp, output_state_bw = static_rnn( 1660 cell_bw, 1661 reversed_inputs, 1662 initial_state_bw, 1663 dtype, 1664 sequence_length, 1665 scope=bw_scope) 1666 1667 output_bw = _reverse_seq(tmp, sequence_length) 1668 # Concat each of the forward/backward outputs 1669 flat_output_fw = nest.flatten(output_fw) 1670 flat_output_bw = nest.flatten(output_bw) 1671 1672 flat_outputs = tuple( 1673 array_ops.concat([fw, bw], 1) 1674 for fw, bw in zip(flat_output_fw, flat_output_bw)) 1675 1676 outputs = nest.pack_sequence_as( 1677 structure=output_fw, flat_sequence=flat_outputs) 1678 1679 return (outputs, output_state_fw, output_state_bw) 1680