1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""RNN helpers for TensorFlow models.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.keras.engine import base_layer 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import control_flow_util 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import rnn_cell_impl 33from tensorflow.python.ops import tensor_array_ops 34from tensorflow.python.ops import variable_scope as vs 35from tensorflow.python.util import deprecation 36from tensorflow.python.util import nest 37from tensorflow.python.util.tf_export import tf_export 38 39 40# pylint: disable=protected-access 41_concat = rnn_cell_impl._concat 42# pylint: enable=protected-access 43 44 45def _transpose_batch_time(x): 46 """Transposes the batch and time dimensions of a Tensor. 47 48 If the input tensor has rank < 2 it returns the original tensor. Retains as 49 much of the static shape information as possible. 50 51 Args: 52 x: A Tensor. 53 54 Returns: 55 x transposed along the first two dimensions. 56 """ 57 x_static_shape = x.get_shape() 58 if x_static_shape.rank is not None and x_static_shape.rank < 2: 59 return x 60 61 x_rank = array_ops.rank(x) 62 x_t = array_ops.transpose( 63 x, array_ops.concat( 64 ([1, 0], math_ops.range(2, x_rank)), axis=0)) 65 x_t.set_shape( 66 tensor_shape.TensorShape([ 67 x_static_shape.dims[1].value, x_static_shape.dims[0].value 68 ]).concatenate(x_static_shape[2:])) 69 return x_t 70 71 72def _best_effort_input_batch_size(flat_input): 73 """Get static input batch size if available, with fallback to the dynamic one. 74 75 Args: 76 flat_input: An iterable of time major input Tensors of shape 77 `[max_time, batch_size, ...]`. 78 All inputs should have compatible batch sizes. 79 80 Returns: 81 The batch size in Python integer if available, or a scalar Tensor otherwise. 82 83 Raises: 84 ValueError: if there is any input with an invalid shape. 85 """ 86 for input_ in flat_input: 87 shape = input_.shape 88 if shape.rank is None: 89 continue 90 if shape.rank < 2: 91 raise ValueError( 92 "Expected input tensor %s to have rank at least 2" % input_) 93 batch_size = shape.dims[1].value 94 if batch_size is not None: 95 return batch_size 96 # Fallback to the dynamic batch size of the first input. 97 return array_ops.shape(flat_input[0])[1] 98 99 100def _infer_state_dtype(explicit_dtype, state): 101 """Infer the dtype of an RNN state. 102 103 Args: 104 explicit_dtype: explicitly declared dtype or None. 105 state: RNN's hidden state. Must be a Tensor or a nested iterable containing 106 Tensors. 107 108 Returns: 109 dtype: inferred dtype of hidden state. 110 111 Raises: 112 ValueError: if `state` has heterogeneous dtypes or is empty. 113 """ 114 if explicit_dtype is not None: 115 return explicit_dtype 116 elif nest.is_sequence(state): 117 inferred_dtypes = [element.dtype for element in nest.flatten(state)] 118 if not inferred_dtypes: 119 raise ValueError("Unable to infer dtype from empty state.") 120 all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes) 121 if not all_same: 122 raise ValueError( 123 "State has tensors of different inferred_dtypes. Unable to infer a " 124 "single representative dtype.") 125 return inferred_dtypes[0] 126 else: 127 return state.dtype 128 129 130def _maybe_tensor_shape_from_tensor(shape): 131 if isinstance(shape, ops.Tensor): 132 return tensor_shape.as_shape(tensor_util.constant_value(shape)) 133 else: 134 return shape 135 136 137def _should_cache(): 138 """Returns True if a default caching device should be set, otherwise False.""" 139 if context.executing_eagerly(): 140 return False 141 # Don't set a caching device when running in a loop, since it is possible that 142 # train steps could be wrapped in a tf.while_loop. In that scenario caching 143 # prevents forward computations in loop iterations from re-reading the 144 # updated weights. 145 ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 146 return control_flow_util.GetContainingWhileContext(ctxt) is None 147 148 149def _is_keras_rnn_cell(rnn_cell): 150 """Check whether the cell is a Keras RNN cell. 151 152 The Keras RNN cell accept the state as a list even the state is a single 153 tensor, whereas the TF RNN cell does not wrap single state tensor in list. 154 This behavior difference should be unified in future version. 155 156 Args: 157 rnn_cell: An RNN cell instance that either follow the Keras interface or TF 158 RNN interface. 159 Returns: 160 Boolean, whether the cell is an Keras RNN cell. 161 """ 162 # Cell type check is not strict enough since there are cells created by other 163 # library like Deepmind that didn't inherit tf.nn.rnn_cell.RNNCell. 164 # Keras cells never had zero_state method, which was from the original 165 # interface from TF RNN cell. 166 return (not isinstance(rnn_cell, rnn_cell_impl.RNNCell) 167 and isinstance(rnn_cell, base_layer.Layer) 168 and getattr(rnn_cell, "zero_state", None) is None) 169 170 171# pylint: disable=unused-argument 172def _rnn_step( 173 time, sequence_length, min_sequence_length, max_sequence_length, 174 zero_output, state, call_cell, state_size, skip_conditionals=False): 175 """Calculate one step of a dynamic RNN minibatch. 176 177 Returns an (output, state) pair conditioned on `sequence_length`. 178 When skip_conditionals=False, the pseudocode is something like: 179 180 if t >= max_sequence_length: 181 return (zero_output, state) 182 if t < min_sequence_length: 183 return call_cell() 184 185 # Selectively output zeros or output, old state or new state depending 186 # on whether we've finished calculating each row. 187 new_output, new_state = call_cell() 188 final_output = np.vstack([ 189 zero_output if time >= sequence_length[r] else new_output_r 190 for r, new_output_r in enumerate(new_output) 191 ]) 192 final_state = np.vstack([ 193 state[r] if time >= sequence_length[r] else new_state_r 194 for r, new_state_r in enumerate(new_state) 195 ]) 196 return (final_output, final_state) 197 198 Args: 199 time: int32 `Tensor` scalar. 200 sequence_length: int32 `Tensor` vector of size [batch_size]. 201 min_sequence_length: int32 `Tensor` scalar, min of sequence_length. 202 max_sequence_length: int32 `Tensor` scalar, max of sequence_length. 203 zero_output: `Tensor` vector of shape [output_size]. 204 state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, 205 or a list/tuple of such tensors. 206 call_cell: lambda returning tuple of (new_output, new_state) where 207 new_output is a `Tensor` matrix of shape `[batch_size, output_size]`. 208 new_state is a `Tensor` matrix of shape `[batch_size, state_size]`. 209 state_size: The `cell.state_size` associated with the state. 210 skip_conditionals: Python bool, whether to skip using the conditional 211 calculations. This is useful for `dynamic_rnn`, where the input tensor 212 matches `max_sequence_length`, and using conditionals just slows 213 everything down. 214 215 Returns: 216 A tuple of (`final_output`, `final_state`) as given by the pseudocode above: 217 final_output is a `Tensor` matrix of shape [batch_size, output_size] 218 final_state is either a single `Tensor` matrix, or a tuple of such 219 matrices (matching length and shapes of input `state`). 220 221 Raises: 222 ValueError: If the cell returns a state tuple whose length does not match 223 that returned by `state_size`. 224 """ 225 226 # Convert state to a list for ease of use 227 flat_state = nest.flatten(state) 228 flat_zero_output = nest.flatten(zero_output) 229 230 # Vector describing which batch entries are finished. 231 copy_cond = time >= sequence_length 232 233 def _copy_one_through(output, new_output): 234 # TensorArray and scalar get passed through. 235 if isinstance(output, tensor_array_ops.TensorArray): 236 return new_output 237 if output.shape.rank == 0: 238 return new_output 239 # Otherwise propagate the old or the new value. 240 with ops.colocate_with(new_output): 241 return array_ops.where(copy_cond, output, new_output) 242 243 def _copy_some_through(flat_new_output, flat_new_state): 244 # Use broadcasting select to determine which values should get 245 # the previous state & zero output, and which values should get 246 # a calculated state & output. 247 flat_new_output = [ 248 _copy_one_through(zero_output, new_output) 249 for zero_output, new_output in zip(flat_zero_output, flat_new_output)] 250 flat_new_state = [ 251 _copy_one_through(state, new_state) 252 for state, new_state in zip(flat_state, flat_new_state)] 253 return flat_new_output + flat_new_state 254 255 def _maybe_copy_some_through(): 256 """Run RNN step. Pass through either no or some past state.""" 257 new_output, new_state = call_cell() 258 259 nest.assert_same_structure(state, new_state) 260 261 flat_new_state = nest.flatten(new_state) 262 flat_new_output = nest.flatten(new_output) 263 return control_flow_ops.cond( 264 # if t < min_seq_len: calculate and return everything 265 time < min_sequence_length, lambda: flat_new_output + flat_new_state, 266 # else copy some of it through 267 lambda: _copy_some_through(flat_new_output, flat_new_state)) 268 269 # TODO(ebrevdo): skipping these conditionals may cause a slowdown, 270 # but benefits from removing cond() and its gradient. We should 271 # profile with and without this switch here. 272 if skip_conditionals: 273 # Instead of using conditionals, perform the selective copy at all time 274 # steps. This is faster when max_seq_len is equal to the number of unrolls 275 # (which is typical for dynamic_rnn). 276 new_output, new_state = call_cell() 277 nest.assert_same_structure(state, new_state) 278 new_state = nest.flatten(new_state) 279 new_output = nest.flatten(new_output) 280 final_output_and_state = _copy_some_through(new_output, new_state) 281 else: 282 empty_update = lambda: flat_zero_output + flat_state 283 final_output_and_state = control_flow_ops.cond( 284 # if t >= max_seq_len: copy all state through, output zeros 285 time >= max_sequence_length, empty_update, 286 # otherwise calculation is required: copy some or all of it through 287 _maybe_copy_some_through) 288 289 if len(final_output_and_state) != len(flat_zero_output) + len(flat_state): 290 raise ValueError("Internal error: state and output were not concatenated " 291 "correctly.") 292 final_output = final_output_and_state[:len(flat_zero_output)] 293 final_state = final_output_and_state[len(flat_zero_output):] 294 295 for output, flat_output in zip(final_output, flat_zero_output): 296 output.set_shape(flat_output.get_shape()) 297 for substate, flat_substate in zip(final_state, flat_state): 298 if not isinstance(substate, tensor_array_ops.TensorArray): 299 substate.set_shape(flat_substate.get_shape()) 300 301 final_output = nest.pack_sequence_as( 302 structure=zero_output, flat_sequence=final_output) 303 final_state = nest.pack_sequence_as( 304 structure=state, flat_sequence=final_state) 305 306 return final_output, final_state 307 308 309def _reverse_seq(input_seq, lengths): 310 """Reverse a list of Tensors up to specified lengths. 311 312 Args: 313 input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features) 314 or nested tuples of tensors. 315 lengths: A `Tensor` of dimension batch_size, containing lengths for each 316 sequence in the batch. If "None" is specified, simply reverses 317 the list. 318 319 Returns: 320 time-reversed sequence 321 """ 322 if lengths is None: 323 return list(reversed(input_seq)) 324 325 flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq) 326 327 flat_results = [[] for _ in range(len(input_seq))] 328 for sequence in zip(*flat_input_seq): 329 input_shape = tensor_shape.unknown_shape( 330 rank=sequence[0].get_shape().rank) 331 for input_ in sequence: 332 input_shape.merge_with(input_.get_shape()) 333 input_.set_shape(input_shape) 334 335 # Join into (time, batch_size, depth) 336 s_joined = array_ops.stack(sequence) 337 338 # Reverse along dimension 0 339 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) 340 # Split again into list 341 result = array_ops.unstack(s_reversed) 342 for r, flat_result in zip(result, flat_results): 343 r.set_shape(input_shape) 344 flat_result.append(r) 345 346 results = [nest.pack_sequence_as(structure=input_, flat_sequence=flat_result) 347 for input_, flat_result in zip(input_seq, flat_results)] 348 return results 349 350 351@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 352 "keras.layers.RNN(cell))`, which is equivalent to " 353 "this API") 354@tf_export(v1=["nn.bidirectional_dynamic_rnn"]) 355def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, 356 initial_state_fw=None, initial_state_bw=None, 357 dtype=None, parallel_iterations=None, 358 swap_memory=False, time_major=False, scope=None): 359 """Creates a dynamic version of bidirectional recurrent neural network. 360 361 Takes input and builds independent forward and backward RNNs. The input_size 362 of forward and backward cell must match. The initial state for both directions 363 is zero by default (but can be set optionally) and no intermediate states are 364 ever returned -- the network is fully unrolled for the given (passed in) 365 length(s) of the sequence(s) or completely unrolled if length(s) is not 366 given. 367 368 Args: 369 cell_fw: An instance of RNNCell, to be used for forward direction. 370 cell_bw: An instance of RNNCell, to be used for backward direction. 371 inputs: The RNN inputs. 372 If time_major == False (default), this must be a tensor of shape: 373 `[batch_size, max_time, ...]`, or a nested tuple of such elements. 374 If time_major == True, this must be a tensor of shape: 375 `[max_time, batch_size, ...]`, or a nested tuple of such elements. 376 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 377 containing the actual lengths for each of the sequences in the batch. 378 If not provided, all batch entries are assumed to be full sequences; and 379 time reversal is applied from time `0` to `max_time` for each sequence. 380 initial_state_fw: (optional) An initial state for the forward RNN. 381 This must be a tensor of appropriate type and shape 382 `[batch_size, cell_fw.state_size]`. 383 If `cell_fw.state_size` is a tuple, this should be a tuple of 384 tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. 385 initial_state_bw: (optional) Same as for `initial_state_fw`, but using 386 the corresponding properties of `cell_bw`. 387 dtype: (optional) The data type for the initial states and expected output. 388 Required if initial_states are not provided or RNN states have a 389 heterogeneous dtype. 390 parallel_iterations: (Default: 32). The number of iterations to run in 391 parallel. Those operations which do not have any temporal dependency 392 and can be run in parallel, will be. This parameter trades off 393 time for space. Values >> 1 use more memory but take less time, 394 while smaller values use less memory but computations take longer. 395 swap_memory: Transparently swap the tensors produced in forward inference 396 but needed for back prop from GPU to CPU. This allows training RNNs 397 which would typically not fit on a single GPU, with very minimal (or no) 398 performance penalty. 399 time_major: The shape format of the `inputs` and `outputs` Tensors. 400 If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 401 If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 402 Using `time_major = True` is a bit more efficient because it avoids 403 transposes at the beginning and end of the RNN calculation. However, 404 most TensorFlow data is batch-major, so by default this function 405 accepts input and emits output in batch-major form. 406 scope: VariableScope for the created subgraph; defaults to 407 "bidirectional_rnn" 408 409 Returns: 410 A tuple (outputs, output_states) where: 411 outputs: A tuple (output_fw, output_bw) containing the forward and 412 the backward rnn output `Tensor`. 413 If time_major == False (default), 414 output_fw will be a `Tensor` shaped: 415 `[batch_size, max_time, cell_fw.output_size]` 416 and output_bw will be a `Tensor` shaped: 417 `[batch_size, max_time, cell_bw.output_size]`. 418 If time_major == True, 419 output_fw will be a `Tensor` shaped: 420 `[max_time, batch_size, cell_fw.output_size]` 421 and output_bw will be a `Tensor` shaped: 422 `[max_time, batch_size, cell_bw.output_size]`. 423 It returns a tuple instead of a single concatenated `Tensor`, unlike 424 in the `bidirectional_rnn`. If the concatenated one is preferred, 425 the forward and backward outputs can be concatenated as 426 `tf.concat(outputs, 2)`. 427 output_states: A tuple (output_state_fw, output_state_bw) containing 428 the forward and the backward final states of bidirectional rnn. 429 430 Raises: 431 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 432 """ 433 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 434 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 435 436 with vs.variable_scope(scope or "bidirectional_rnn"): 437 # Forward direction 438 with vs.variable_scope("fw") as fw_scope: 439 output_fw, output_state_fw = dynamic_rnn( 440 cell=cell_fw, inputs=inputs, sequence_length=sequence_length, 441 initial_state=initial_state_fw, dtype=dtype, 442 parallel_iterations=parallel_iterations, swap_memory=swap_memory, 443 time_major=time_major, scope=fw_scope) 444 445 # Backward direction 446 if not time_major: 447 time_axis = 1 448 batch_axis = 0 449 else: 450 time_axis = 0 451 batch_axis = 1 452 453 def _reverse(input_, seq_lengths, seq_axis, batch_axis): 454 if seq_lengths is not None: 455 return array_ops.reverse_sequence( 456 input=input_, seq_lengths=seq_lengths, 457 seq_axis=seq_axis, batch_axis=batch_axis) 458 else: 459 return array_ops.reverse(input_, axis=[seq_axis]) 460 461 with vs.variable_scope("bw") as bw_scope: 462 463 def _map_reverse(inp): 464 return _reverse( 465 inp, 466 seq_lengths=sequence_length, 467 seq_axis=time_axis, 468 batch_axis=batch_axis) 469 470 inputs_reverse = nest.map_structure(_map_reverse, inputs) 471 tmp, output_state_bw = dynamic_rnn( 472 cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, 473 initial_state=initial_state_bw, dtype=dtype, 474 parallel_iterations=parallel_iterations, swap_memory=swap_memory, 475 time_major=time_major, scope=bw_scope) 476 477 output_bw = _reverse( 478 tmp, seq_lengths=sequence_length, 479 seq_axis=time_axis, batch_axis=batch_axis) 480 481 outputs = (output_fw, output_bw) 482 output_states = (output_state_fw, output_state_bw) 483 484 return (outputs, output_states) 485 486 487@deprecation.deprecated( 488 None, 489 "Please use `keras.layers.RNN(cell)`, which is equivalent to this API") 490@tf_export(v1=["nn.dynamic_rnn"]) 491def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, 492 dtype=None, parallel_iterations=None, swap_memory=False, 493 time_major=False, scope=None): 494 """Creates a recurrent neural network specified by RNNCell `cell`. 495 496 Performs fully dynamic unrolling of `inputs`. 497 498 Example: 499 500 ```python 501 # create a BasicRNNCell 502 rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) 503 504 # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 505 506 # defining initial state 507 initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 508 509 # 'state' is a tensor of shape [batch_size, cell_state_size] 510 outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data, 511 initial_state=initial_state, 512 dtype=tf.float32) 513 ``` 514 515 ```python 516 # create 2 LSTMCells 517 rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 518 519 # create a RNN cell composed sequentially of a number of RNNCells 520 multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) 521 522 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 523 # 'state' is a N-tuple where N is the number of LSTMCells containing a 524 # tf.contrib.rnn.LSTMStateTuple for each cell 525 outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, 526 inputs=data, 527 dtype=tf.float32) 528 ``` 529 530 531 Args: 532 cell: An instance of RNNCell. 533 inputs: The RNN inputs. 534 If `time_major == False` (default), this must be a `Tensor` of shape: 535 `[batch_size, max_time, ...]`, or a nested tuple of such 536 elements. 537 If `time_major == True`, this must be a `Tensor` of shape: 538 `[max_time, batch_size, ...]`, or a nested tuple of such 539 elements. 540 This may also be a (possibly nested) tuple of Tensors satisfying 541 this property. The first two dimensions must match across all the inputs, 542 but otherwise the ranks and other shape components may differ. 543 In this case, input to `cell` at each time-step will replicate the 544 structure of these tuples, except for the time dimension (from which the 545 time is taken). 546 The input to `cell` at each time step will be a `Tensor` or (possibly 547 nested) tuple of Tensors each with dimensions `[batch_size, ...]`. 548 sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. 549 Used to copy-through state and zero-out outputs when past a batch 550 element's sequence length. So it's more for performance than correctness. 551 initial_state: (optional) An initial state for the RNN. 552 If `cell.state_size` is an integer, this must be 553 a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 554 If `cell.state_size` is a tuple, this should be a tuple of 555 tensors having shapes `[batch_size, s] for s in cell.state_size`. 556 dtype: (optional) The data type for the initial state and expected output. 557 Required if initial_state is not provided or RNN state has a heterogeneous 558 dtype. 559 parallel_iterations: (Default: 32). The number of iterations to run in 560 parallel. Those operations which do not have any temporal dependency 561 and can be run in parallel, will be. This parameter trades off 562 time for space. Values >> 1 use more memory but take less time, 563 while smaller values use less memory but computations take longer. 564 swap_memory: Transparently swap the tensors produced in forward inference 565 but needed for back prop from GPU to CPU. This allows training RNNs 566 which would typically not fit on a single GPU, with very minimal (or no) 567 performance penalty. 568 time_major: The shape format of the `inputs` and `outputs` Tensors. 569 If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 570 If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 571 Using `time_major = True` is a bit more efficient because it avoids 572 transposes at the beginning and end of the RNN calculation. However, 573 most TensorFlow data is batch-major, so by default this function 574 accepts input and emits output in batch-major form. 575 scope: VariableScope for the created subgraph; defaults to "rnn". 576 577 Returns: 578 A pair (outputs, state) where: 579 580 outputs: The RNN output `Tensor`. 581 582 If time_major == False (default), this will be a `Tensor` shaped: 583 `[batch_size, max_time, cell.output_size]`. 584 585 If time_major == True, this will be a `Tensor` shaped: 586 `[max_time, batch_size, cell.output_size]`. 587 588 Note, if `cell.output_size` is a (possibly nested) tuple of integers 589 or `TensorShape` objects, then `outputs` will be a tuple having the 590 same structure as `cell.output_size`, containing Tensors having shapes 591 corresponding to the shape data in `cell.output_size`. 592 593 state: The final state. If `cell.state_size` is an int, this 594 will be shaped `[batch_size, cell.state_size]`. If it is a 595 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 596 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 597 be a tuple having the corresponding shapes. If cells are `LSTMCells` 598 `state` will be a tuple containing a `LSTMStateTuple` for each cell. 599 600 Raises: 601 TypeError: If `cell` is not an instance of RNNCell. 602 ValueError: If inputs is None or an empty list. 603 """ 604 rnn_cell_impl.assert_like_rnncell("cell", cell) 605 606 with vs.variable_scope(scope or "rnn") as varscope: 607 # Create a new scope in which the caching device is either 608 # determined by the parent scope, or is set to place the cached 609 # Variable using the same placement as for the rest of the RNN. 610 if _should_cache(): 611 if varscope.caching_device is None: 612 varscope.set_caching_device(lambda op: op.device) 613 614 # By default, time_major==False and inputs are batch-major: shaped 615 # [batch, time, depth] 616 # For internal calculations, we transpose to [time, batch, depth] 617 flat_input = nest.flatten(inputs) 618 619 if not time_major: 620 # (B,T,D) => (T,B,D) 621 flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 622 flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 623 624 parallel_iterations = parallel_iterations or 32 625 if sequence_length is not None: 626 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 627 if sequence_length.get_shape().rank not in (None, 1): 628 raise ValueError( 629 "sequence_length must be a vector of length batch_size, " 630 "but saw shape: %s" % sequence_length.get_shape()) 631 sequence_length = array_ops.identity( # Just to find it in the graph. 632 sequence_length, name="sequence_length") 633 634 batch_size = _best_effort_input_batch_size(flat_input) 635 636 if initial_state is not None: 637 state = initial_state 638 else: 639 if not dtype: 640 raise ValueError("If there is no initial_state, you must give a dtype.") 641 if getattr(cell, "get_initial_state", None) is not None: 642 state = cell.get_initial_state( 643 inputs=None, batch_size=batch_size, dtype=dtype) 644 else: 645 state = cell.zero_state(batch_size, dtype) 646 647 def _assert_has_shape(x, shape): 648 x_shape = array_ops.shape(x) 649 packed_shape = array_ops.stack(shape) 650 return control_flow_ops.Assert( 651 math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), 652 ["Expected shape for Tensor %s is " % x.name, 653 packed_shape, " but saw shape: ", x_shape]) 654 655 if not context.executing_eagerly() and sequence_length is not None: 656 # Perform some shape validation 657 with ops.control_dependencies( 658 [_assert_has_shape(sequence_length, [batch_size])]): 659 sequence_length = array_ops.identity( 660 sequence_length, name="CheckSeqLen") 661 662 inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 663 664 (outputs, final_state) = _dynamic_rnn_loop( 665 cell, 666 inputs, 667 state, 668 parallel_iterations=parallel_iterations, 669 swap_memory=swap_memory, 670 sequence_length=sequence_length, 671 dtype=dtype) 672 673 # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 674 # If we are performing batch-major calculations, transpose output back 675 # to shape [batch, time, depth] 676 if not time_major: 677 # (T,B,D) => (B,T,D) 678 outputs = nest.map_structure(_transpose_batch_time, outputs) 679 680 return (outputs, final_state) 681 682 683def _dynamic_rnn_loop(cell, 684 inputs, 685 initial_state, 686 parallel_iterations, 687 swap_memory, 688 sequence_length=None, 689 dtype=None): 690 """Internal implementation of Dynamic RNN. 691 692 Args: 693 cell: An instance of RNNCell. 694 inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested 695 tuple of such elements. 696 initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if 697 `cell.state_size` is a tuple, then this should be a tuple of 698 tensors having shapes `[batch_size, s] for s in cell.state_size`. 699 parallel_iterations: Positive Python int. 700 swap_memory: A Python boolean 701 sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. 702 dtype: (optional) Expected dtype of output. If not specified, inferred from 703 initial_state. 704 705 Returns: 706 Tuple `(final_outputs, final_state)`. 707 final_outputs: 708 A `Tensor` of shape `[time, batch_size, cell.output_size]`. If 709 `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape` 710 objects, then this returns a (possibly nested) tuple of Tensors matching 711 the corresponding shapes. 712 final_state: 713 A `Tensor`, or possibly nested tuple of Tensors, matching in length 714 and shapes to `initial_state`. 715 716 Raises: 717 ValueError: If the input depth cannot be inferred via shape inference 718 from the inputs. 719 ValueError: If time_step is not the same for all the elements in the 720 inputs. 721 ValueError: If batch_size is not the same for all the elements in the 722 inputs. 723 """ 724 state = initial_state 725 assert isinstance(parallel_iterations, int), "parallel_iterations must be int" 726 727 state_size = cell.state_size 728 729 flat_input = nest.flatten(inputs) 730 flat_output_size = nest.flatten(cell.output_size) 731 732 # Construct an initial output 733 input_shape = array_ops.shape(flat_input[0]) 734 time_steps = input_shape[0] 735 batch_size = _best_effort_input_batch_size(flat_input) 736 737 inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3) 738 for input_ in flat_input) 739 740 const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2] 741 742 for shape in inputs_got_shape: 743 if not shape[2:].is_fully_defined(): 744 raise ValueError( 745 "Input size (depth of inputs) must be accessible via shape inference," 746 " but saw value None.") 747 got_time_steps = shape.dims[0].value 748 got_batch_size = shape.dims[1].value 749 if const_time_steps != got_time_steps: 750 raise ValueError( 751 "Time steps is not the same for all the elements in the input in a " 752 "batch.") 753 if const_batch_size != got_batch_size: 754 raise ValueError( 755 "Batch_size is not the same for all the elements in the input.") 756 757 # Prepare dynamic conditional copying of state & output 758 def _create_zero_arrays(size): 759 size = _concat(batch_size, size) 760 return array_ops.zeros( 761 array_ops.stack(size), _infer_state_dtype(dtype, state)) 762 763 flat_zero_output = tuple(_create_zero_arrays(output) 764 for output in flat_output_size) 765 zero_output = nest.pack_sequence_as(structure=cell.output_size, 766 flat_sequence=flat_zero_output) 767 768 if sequence_length is not None: 769 min_sequence_length = math_ops.reduce_min(sequence_length) 770 max_sequence_length = math_ops.reduce_max(sequence_length) 771 else: 772 max_sequence_length = time_steps 773 774 time = array_ops.constant(0, dtype=dtypes.int32, name="time") 775 776 with ops.name_scope("dynamic_rnn") as scope: 777 base_name = scope 778 779 def _create_ta(name, element_shape, dtype): 780 return tensor_array_ops.TensorArray(dtype=dtype, 781 size=time_steps, 782 element_shape=element_shape, 783 tensor_array_name=base_name + name) 784 785 in_graph_mode = not context.executing_eagerly() 786 if in_graph_mode: 787 output_ta = tuple( 788 _create_ta( 789 "output_%d" % i, 790 element_shape=(tensor_shape.TensorShape([const_batch_size]) 791 .concatenate( 792 _maybe_tensor_shape_from_tensor(out_size))), 793 dtype=_infer_state_dtype(dtype, state)) 794 for i, out_size in enumerate(flat_output_size)) 795 input_ta = tuple( 796 _create_ta( 797 "input_%d" % i, 798 element_shape=flat_input_i.shape[1:], 799 dtype=flat_input_i.dtype) 800 for i, flat_input_i in enumerate(flat_input)) 801 input_ta = tuple(ta.unstack(input_) 802 for ta, input_ in zip(input_ta, flat_input)) 803 else: 804 output_ta = tuple([0 for _ in range(time_steps.numpy())] 805 for i in range(len(flat_output_size))) 806 input_ta = flat_input 807 808 def _time_step(time, output_ta_t, state): 809 """Take a time step of the dynamic RNN. 810 811 Args: 812 time: int32 scalar Tensor. 813 output_ta_t: List of `TensorArray`s that represent the output. 814 state: nested tuple of vector tensors that represent the state. 815 816 Returns: 817 The tuple (time + 1, output_ta_t with updated flow, new_state). 818 """ 819 820 if in_graph_mode: 821 input_t = tuple(ta.read(time) for ta in input_ta) 822 # Restore some shape information 823 for input_, shape in zip(input_t, inputs_got_shape): 824 input_.set_shape(shape[1:]) 825 else: 826 input_t = tuple(ta[time.numpy()] for ta in input_ta) 827 828 input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) 829 # Keras RNN cells only accept state as list, even if it's a single tensor. 830 is_keras_rnn_cell = _is_keras_rnn_cell(cell) 831 if is_keras_rnn_cell and not nest.is_sequence(state): 832 state = [state] 833 call_cell = lambda: cell(input_t, state) 834 835 if sequence_length is not None: 836 (output, new_state) = _rnn_step( 837 time=time, 838 sequence_length=sequence_length, 839 min_sequence_length=min_sequence_length, 840 max_sequence_length=max_sequence_length, 841 zero_output=zero_output, 842 state=state, 843 call_cell=call_cell, 844 state_size=state_size, 845 skip_conditionals=True) 846 else: 847 (output, new_state) = call_cell() 848 849 # Keras cells always wrap state as list, even if it's a single tensor. 850 if is_keras_rnn_cell and len(new_state) == 1: 851 new_state = new_state[0] 852 # Pack state if using state tuples 853 output = nest.flatten(output) 854 855 if in_graph_mode: 856 output_ta_t = tuple( 857 ta.write(time, out) for ta, out in zip(output_ta_t, output)) 858 else: 859 for ta, out in zip(output_ta_t, output): 860 ta[time.numpy()] = out 861 862 return (time + 1, output_ta_t, new_state) 863 864 if in_graph_mode: 865 # Make sure that we run at least 1 step, if necessary, to ensure 866 # the TensorArrays pick up the dynamic shape. 867 loop_bound = math_ops.minimum( 868 time_steps, math_ops.maximum(1, max_sequence_length)) 869 else: 870 # Using max_sequence_length isn't currently supported in the Eager branch. 871 loop_bound = time_steps 872 873 _, output_final_ta, final_state = control_flow_ops.while_loop( 874 cond=lambda time, *_: time < loop_bound, 875 body=_time_step, 876 loop_vars=(time, output_ta, state), 877 parallel_iterations=parallel_iterations, 878 maximum_iterations=time_steps, 879 swap_memory=swap_memory) 880 881 # Unpack final output if not using output tuples. 882 if in_graph_mode: 883 final_outputs = tuple(ta.stack() for ta in output_final_ta) 884 # Restore some shape information 885 for output, output_size in zip(final_outputs, flat_output_size): 886 shape = _concat( 887 [const_time_steps, const_batch_size], output_size, static=True) 888 output.set_shape(shape) 889 else: 890 final_outputs = output_final_ta 891 892 final_outputs = nest.pack_sequence_as( 893 structure=cell.output_size, flat_sequence=final_outputs) 894 if not in_graph_mode: 895 final_outputs = nest.map_structure_up_to( 896 cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs) 897 898 return (final_outputs, final_state) 899 900 901@tf_export(v1=["nn.raw_rnn"]) 902def raw_rnn(cell, loop_fn, 903 parallel_iterations=None, swap_memory=False, scope=None): 904 """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`. 905 906 **NOTE: This method is still in testing, and the API may change.** 907 908 This function is a more primitive version of `dynamic_rnn` that provides 909 more direct access to the inputs each iteration. It also provides more 910 control over when to start and finish reading the sequence, and 911 what to emit for the output. 912 913 For example, it can be used to implement the dynamic decoder of a seq2seq 914 model. 915 916 Instead of working with `Tensor` objects, most operations work with 917 `TensorArray` objects directly. 918 919 The operation of `raw_rnn`, in pseudo-code, is basically the following: 920 921 ```python 922 time = tf.constant(0, dtype=tf.int32) 923 (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn( 924 time=time, cell_output=None, cell_state=None, loop_state=None) 925 emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype) 926 state = initial_state 927 while not all(finished): 928 (output, cell_state) = cell(next_input, state) 929 (next_finished, next_input, next_state, emit, loop_state) = loop_fn( 930 time=time + 1, cell_output=output, cell_state=cell_state, 931 loop_state=loop_state) 932 # Emit zeros and copy forward state for minibatch entries that are finished. 933 state = tf.where(finished, state, next_state) 934 emit = tf.where(finished, tf.zeros_like(emit_structure), emit) 935 emit_ta = emit_ta.write(time, emit) 936 # If any new minibatch entries are marked as finished, mark these. 937 finished = tf.logical_or(finished, next_finished) 938 time += 1 939 return (emit_ta, state, loop_state) 940 ``` 941 942 with the additional properties that output and state may be (possibly nested) 943 tuples, as determined by `cell.output_size` and `cell.state_size`, and 944 as a result the final `state` and `emit_ta` may themselves be tuples. 945 946 A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this: 947 948 ```python 949 inputs = tf.placeholder(shape=(max_time, batch_size, input_depth), 950 dtype=tf.float32) 951 sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32) 952 inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 953 inputs_ta = inputs_ta.unstack(inputs) 954 955 cell = tf.contrib.rnn.LSTMCell(num_units) 956 957 def loop_fn(time, cell_output, cell_state, loop_state): 958 emit_output = cell_output # == None for time == 0 959 if cell_output is None: # time == 0 960 next_cell_state = cell.zero_state(batch_size, tf.float32) 961 else: 962 next_cell_state = cell_state 963 elements_finished = (time >= sequence_length) 964 finished = tf.reduce_all(elements_finished) 965 next_input = tf.cond( 966 finished, 967 lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32), 968 lambda: inputs_ta.read(time)) 969 next_loop_state = None 970 return (elements_finished, next_input, next_cell_state, 971 emit_output, next_loop_state) 972 973 outputs_ta, final_state, _ = raw_rnn(cell, loop_fn) 974 outputs = outputs_ta.stack() 975 ``` 976 977 Args: 978 cell: An instance of RNNCell. 979 loop_fn: A callable that takes inputs 980 `(time, cell_output, cell_state, loop_state)` 981 and returns the tuple 982 `(finished, next_input, next_cell_state, emit_output, next_loop_state)`. 983 Here `time` is an int32 scalar `Tensor`, `cell_output` is a 984 `Tensor` or (possibly nested) tuple of tensors as determined by 985 `cell.output_size`, and `cell_state` is a `Tensor` 986 or (possibly nested) tuple of tensors, as determined by the `loop_fn` 987 on its first call (and should match `cell.state_size`). 988 The outputs are: `finished`, a boolean `Tensor` of 989 shape `[batch_size]`, `next_input`: the next input to feed to `cell`, 990 `next_cell_state`: the next state to feed to `cell`, 991 and `emit_output`: the output to store for this iteration. 992 993 Note that `emit_output` should be a `Tensor` or (possibly nested) 994 tuple of tensors which is aggregated in the `emit_ta` inside the 995 `while_loop`. For the first call to `loop_fn`, the `emit_output` 996 corresponds to the `emit_structure` which is then used to determine the 997 size of the `zero_tensor` for the `emit_ta` (defaults to 998 `cell.output_size`). For the subsequent calls to the `loop_fn`, the 999 `emit_output` corresponds to the actual output tensor 1000 that is to be aggregated in the `emit_ta`. The parameter `cell_state` 1001 and output `next_cell_state` may be either a single or (possibly nested) 1002 tuple of tensors. The parameter `loop_state` and 1003 output `next_loop_state` may be either a single or (possibly nested) tuple 1004 of `Tensor` and `TensorArray` objects. This last parameter 1005 may be ignored by `loop_fn` and the return value may be `None`. If it 1006 is not `None`, then the `loop_state` will be propagated through the RNN 1007 loop, for use purely by `loop_fn` to keep track of its own state. 1008 The `next_loop_state` parameter returned may be `None`. 1009 1010 The first call to `loop_fn` will be `time = 0`, `cell_output = None`, 1011 `cell_state = None`, and `loop_state = None`. For this call: 1012 The `next_cell_state` value should be the value with which to initialize 1013 the cell's state. It may be a final state from a previous RNN or it 1014 may be the output of `cell.zero_state()`. It should be a 1015 (possibly nested) tuple structure of tensors. 1016 If `cell.state_size` is an integer, this must be 1017 a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 1018 If `cell.state_size` is a `TensorShape`, this must be a `Tensor` of 1019 appropriate type and shape `[batch_size] + cell.state_size`. 1020 If `cell.state_size` is a (possibly nested) tuple of ints or 1021 `TensorShape`, this will be a tuple having the corresponding shapes. 1022 The `emit_output` value may be either `None` or a (possibly nested) 1023 tuple structure of tensors, e.g., 1024 `(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. 1025 If this first `emit_output` return value is `None`, 1026 then the `emit_ta` result of `raw_rnn` will have the same structure and 1027 dtypes as `cell.output_size`. Otherwise `emit_ta` will have the same 1028 structure, shapes (prepended with a `batch_size` dimension), and dtypes 1029 as `emit_output`. The actual values returned for `emit_output` at this 1030 initializing call are ignored. Note, this emit structure must be 1031 consistent across all time steps. 1032 1033 parallel_iterations: (Default: 32). The number of iterations to run in 1034 parallel. Those operations which do not have any temporal dependency 1035 and can be run in parallel, will be. This parameter trades off 1036 time for space. Values >> 1 use more memory but take less time, 1037 while smaller values use less memory but computations take longer. 1038 swap_memory: Transparently swap the tensors produced in forward inference 1039 but needed for back prop from GPU to CPU. This allows training RNNs 1040 which would typically not fit on a single GPU, with very minimal (or no) 1041 performance penalty. 1042 scope: VariableScope for the created subgraph; defaults to "rnn". 1043 1044 Returns: 1045 A tuple `(emit_ta, final_state, final_loop_state)` where: 1046 1047 `emit_ta`: The RNN output `TensorArray`. 1048 If `loop_fn` returns a (possibly nested) set of Tensors for 1049 `emit_output` during initialization, (inputs `time = 0`, 1050 `cell_output = None`, and `loop_state = None`), then `emit_ta` will 1051 have the same structure, dtypes, and shapes as `emit_output` instead. 1052 If `loop_fn` returns `emit_output = None` during this call, 1053 the structure of `cell.output_size` is used: 1054 If `cell.output_size` is a (possibly nested) tuple of integers 1055 or `TensorShape` objects, then `emit_ta` will be a tuple having the 1056 same structure as `cell.output_size`, containing TensorArrays whose 1057 elements' shapes correspond to the shape data in `cell.output_size`. 1058 1059 `final_state`: The final cell state. If `cell.state_size` is an int, this 1060 will be shaped `[batch_size, cell.state_size]`. If it is a 1061 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 1062 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 1063 be a tuple having the corresponding shapes. 1064 1065 `final_loop_state`: The final loop state as returned by `loop_fn`. 1066 1067 Raises: 1068 TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not 1069 a `callable`. 1070 """ 1071 rnn_cell_impl.assert_like_rnncell("cell", cell) 1072 1073 if not callable(loop_fn): 1074 raise TypeError("loop_fn must be a callable") 1075 1076 parallel_iterations = parallel_iterations or 32 1077 1078 # Create a new scope in which the caching device is either 1079 # determined by the parent scope, or is set to place the cached 1080 # Variable using the same placement as for the rest of the RNN. 1081 with vs.variable_scope(scope or "rnn") as varscope: 1082 if _should_cache(): 1083 if varscope.caching_device is None: 1084 varscope.set_caching_device(lambda op: op.device) 1085 1086 time = constant_op.constant(0, dtype=dtypes.int32) 1087 (elements_finished, next_input, initial_state, emit_structure, 1088 init_loop_state) = loop_fn( 1089 time, None, None, None) # time, cell_output, cell_state, loop_state 1090 flat_input = nest.flatten(next_input) 1091 1092 # Need a surrogate loop state for the while_loop if none is available. 1093 loop_state = (init_loop_state if init_loop_state is not None 1094 else constant_op.constant(0, dtype=dtypes.int32)) 1095 1096 input_shape = [input_.get_shape() for input_ in flat_input] 1097 static_batch_size = tensor_shape.dimension_at_index(input_shape[0], 0) 1098 1099 for input_shape_i in input_shape: 1100 # Static verification that batch sizes all match 1101 static_batch_size.merge_with( 1102 tensor_shape.dimension_at_index(input_shape_i, 0)) 1103 1104 batch_size = tensor_shape.dimension_value(static_batch_size) 1105 const_batch_size = batch_size 1106 if batch_size is None: 1107 batch_size = array_ops.shape(flat_input[0])[0] 1108 1109 nest.assert_same_structure(initial_state, cell.state_size) 1110 state = initial_state 1111 flat_state = nest.flatten(state) 1112 flat_state = [ops.convert_to_tensor(s) for s in flat_state] 1113 state = nest.pack_sequence_as(structure=state, 1114 flat_sequence=flat_state) 1115 1116 if emit_structure is not None: 1117 flat_emit_structure = nest.flatten(emit_structure) 1118 flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else 1119 array_ops.shape(emit) for emit in flat_emit_structure] 1120 flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] 1121 else: 1122 emit_structure = cell.output_size 1123 flat_emit_size = nest.flatten(emit_structure) 1124 flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) 1125 1126 flat_emit_ta = [ 1127 tensor_array_ops.TensorArray( 1128 dtype=dtype_i, 1129 dynamic_size=True, 1130 element_shape=(tensor_shape.TensorShape([const_batch_size]) 1131 .concatenate( 1132 _maybe_tensor_shape_from_tensor(size_i))), 1133 size=0, 1134 name="rnn_output_%d" % i) 1135 for i, (dtype_i, size_i) 1136 in enumerate(zip(flat_emit_dtypes, flat_emit_size))] 1137 emit_ta = nest.pack_sequence_as(structure=emit_structure, 1138 flat_sequence=flat_emit_ta) 1139 flat_zero_emit = [ 1140 array_ops.zeros(_concat(batch_size, size_i), dtype_i) 1141 for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)] 1142 zero_emit = nest.pack_sequence_as(structure=emit_structure, 1143 flat_sequence=flat_zero_emit) 1144 1145 def condition(unused_time, elements_finished, *_): 1146 return math_ops.logical_not(math_ops.reduce_all(elements_finished)) 1147 1148 def body(time, elements_finished, current_input, 1149 emit_ta, state, loop_state): 1150 """Internal while loop body for raw_rnn. 1151 1152 Args: 1153 time: time scalar. 1154 elements_finished: batch-size vector. 1155 current_input: possibly nested tuple of input tensors. 1156 emit_ta: possibly nested tuple of output TensorArrays. 1157 state: possibly nested tuple of state tensors. 1158 loop_state: possibly nested tuple of loop state tensors. 1159 1160 Returns: 1161 Tuple having the same size as Args but with updated values. 1162 """ 1163 (next_output, cell_state) = cell(current_input, state) 1164 1165 nest.assert_same_structure(state, cell_state) 1166 nest.assert_same_structure(cell.output_size, next_output) 1167 1168 next_time = time + 1 1169 (next_finished, next_input, next_state, emit_output, 1170 next_loop_state) = loop_fn( 1171 next_time, next_output, cell_state, loop_state) 1172 1173 nest.assert_same_structure(state, next_state) 1174 nest.assert_same_structure(current_input, next_input) 1175 nest.assert_same_structure(emit_ta, emit_output) 1176 1177 # If loop_fn returns None for next_loop_state, just reuse the 1178 # previous one. 1179 loop_state = loop_state if next_loop_state is None else next_loop_state 1180 1181 def _copy_some_through(current, candidate): 1182 """Copy some tensors through via array_ops.where.""" 1183 def copy_fn(cur_i, cand_i): 1184 # TensorArray and scalar get passed through. 1185 if isinstance(cur_i, tensor_array_ops.TensorArray): 1186 return cand_i 1187 if cur_i.shape.rank == 0: 1188 return cand_i 1189 # Otherwise propagate the old or the new value. 1190 with ops.colocate_with(cand_i): 1191 return array_ops.where(elements_finished, cur_i, cand_i) 1192 return nest.map_structure(copy_fn, current, candidate) 1193 1194 emit_output = _copy_some_through(zero_emit, emit_output) 1195 next_state = _copy_some_through(state, next_state) 1196 1197 emit_ta = nest.map_structure( 1198 lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) 1199 1200 elements_finished = math_ops.logical_or(elements_finished, next_finished) 1201 1202 return (next_time, elements_finished, next_input, 1203 emit_ta, next_state, loop_state) 1204 1205 returned = control_flow_ops.while_loop( 1206 condition, body, loop_vars=[ 1207 time, elements_finished, next_input, 1208 emit_ta, state, loop_state], 1209 parallel_iterations=parallel_iterations, 1210 swap_memory=swap_memory) 1211 1212 (emit_ta, final_state, final_loop_state) = returned[-3:] 1213 1214 if init_loop_state is None: 1215 final_loop_state = None 1216 1217 return (emit_ta, final_state, final_loop_state) 1218 1219 1220@deprecation.deprecated( 1221 None, "Please use `keras.layers.RNN(cell, unroll=True)`, " 1222 "which is equivalent to this API") 1223@tf_export(v1=["nn.static_rnn"]) 1224def static_rnn(cell, 1225 inputs, 1226 initial_state=None, 1227 dtype=None, 1228 sequence_length=None, 1229 scope=None): 1230 """Creates a recurrent neural network specified by RNNCell `cell`. 1231 1232 The simplest form of RNN network generated is: 1233 1234 ```python 1235 state = cell.zero_state(...) 1236 outputs = [] 1237 for input_ in inputs: 1238 output, state = cell(input_, state) 1239 outputs.append(output) 1240 return (outputs, state) 1241 ``` 1242 However, a few other options are available: 1243 1244 An initial state can be provided. 1245 If the sequence_length vector is provided, dynamic calculation is performed. 1246 This method of calculation does not compute the RNN steps past the maximum 1247 sequence length of the minibatch (thus saving computational time), 1248 and properly propagates the state at an example's sequence length 1249 to the final state output. 1250 1251 The dynamic calculation performed is, at time `t` for batch row `b`, 1252 1253 ```python 1254 (output, state)(b, t) = 1255 (t >= sequence_length(b)) 1256 ? (zeros(cell.output_size), states(b, sequence_length(b) - 1)) 1257 : cell(input(b, t), state(b, t - 1)) 1258 ``` 1259 1260 Args: 1261 cell: An instance of RNNCell. 1262 inputs: A length T list of inputs, each a `Tensor` of shape 1263 `[batch_size, input_size]`, or a nested tuple of such elements. 1264 initial_state: (optional) An initial state for the RNN. 1265 If `cell.state_size` is an integer, this must be 1266 a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 1267 If `cell.state_size` is a tuple, this should be a tuple of 1268 tensors having shapes `[batch_size, s] for s in cell.state_size`. 1269 dtype: (optional) The data type for the initial state and expected output. 1270 Required if initial_state is not provided or RNN state has a heterogeneous 1271 dtype. 1272 sequence_length: Specifies the length of each sequence in inputs. 1273 An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`. 1274 scope: VariableScope for the created subgraph; defaults to "rnn". 1275 1276 Returns: 1277 A pair (outputs, state) where: 1278 1279 - outputs is a length T list of outputs (one for each input), or a nested 1280 tuple of such elements. 1281 - state is the final state 1282 1283 Raises: 1284 TypeError: If `cell` is not an instance of RNNCell. 1285 ValueError: If `inputs` is `None` or an empty list, or if the input depth 1286 (column size) cannot be inferred from inputs via shape inference. 1287 """ 1288 rnn_cell_impl.assert_like_rnncell("cell", cell) 1289 if not nest.is_sequence(inputs): 1290 raise TypeError("inputs must be a sequence") 1291 if not inputs: 1292 raise ValueError("inputs must not be empty") 1293 1294 outputs = [] 1295 # Create a new scope in which the caching device is either 1296 # determined by the parent scope, or is set to place the cached 1297 # Variable using the same placement as for the rest of the RNN. 1298 with vs.variable_scope(scope or "rnn") as varscope: 1299 if _should_cache(): 1300 if varscope.caching_device is None: 1301 varscope.set_caching_device(lambda op: op.device) 1302 1303 # Obtain the first sequence of the input 1304 first_input = inputs 1305 while nest.is_sequence(first_input): 1306 first_input = first_input[0] 1307 1308 # Temporarily avoid EmbeddingWrapper and seq2seq badness 1309 # TODO(lukaszkaiser): remove EmbeddingWrapper 1310 if first_input.get_shape().rank != 1: 1311 1312 input_shape = first_input.get_shape().with_rank_at_least(2) 1313 fixed_batch_size = input_shape.dims[0] 1314 1315 flat_inputs = nest.flatten(inputs) 1316 for flat_input in flat_inputs: 1317 input_shape = flat_input.get_shape().with_rank_at_least(2) 1318 batch_size, input_size = tensor_shape.dimension_at_index( 1319 input_shape, 0), input_shape[1:] 1320 fixed_batch_size.merge_with(batch_size) 1321 for i, size in enumerate(input_size.dims): 1322 if tensor_shape.dimension_value(size) is None: 1323 raise ValueError( 1324 "Input size (dimension %d of inputs) must be accessible via " 1325 "shape inference, but saw value None." % i) 1326 else: 1327 fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0] 1328 1329 if tensor_shape.dimension_value(fixed_batch_size): 1330 batch_size = tensor_shape.dimension_value(fixed_batch_size) 1331 else: 1332 batch_size = array_ops.shape(first_input)[0] 1333 if initial_state is not None: 1334 state = initial_state 1335 else: 1336 if not dtype: 1337 raise ValueError("If no initial_state is provided, " 1338 "dtype must be specified") 1339 if getattr(cell, "get_initial_state", None) is not None: 1340 state = cell.get_initial_state( 1341 inputs=None, batch_size=batch_size, dtype=dtype) 1342 else: 1343 state = cell.zero_state(batch_size, dtype) 1344 1345 if sequence_length is not None: # Prepare variables 1346 sequence_length = ops.convert_to_tensor( 1347 sequence_length, name="sequence_length") 1348 if sequence_length.get_shape().rank not in (None, 1): 1349 raise ValueError( 1350 "sequence_length must be a vector of length batch_size") 1351 1352 def _create_zero_output(output_size): 1353 # convert int to TensorShape if necessary 1354 size = _concat(batch_size, output_size) 1355 output = array_ops.zeros( 1356 array_ops.stack(size), _infer_state_dtype(dtype, state)) 1357 shape = _concat(tensor_shape.dimension_value(fixed_batch_size), 1358 output_size, 1359 static=True) 1360 output.set_shape(tensor_shape.TensorShape(shape)) 1361 return output 1362 1363 output_size = cell.output_size 1364 flat_output_size = nest.flatten(output_size) 1365 flat_zero_output = tuple( 1366 _create_zero_output(size) for size in flat_output_size) 1367 zero_output = nest.pack_sequence_as( 1368 structure=output_size, flat_sequence=flat_zero_output) 1369 1370 sequence_length = math_ops.cast(sequence_length, dtypes.int32) 1371 min_sequence_length = math_ops.reduce_min(sequence_length) 1372 max_sequence_length = math_ops.reduce_max(sequence_length) 1373 1374 # Keras RNN cells only accept state as list, even if it's a single tensor. 1375 is_keras_rnn_cell = _is_keras_rnn_cell(cell) 1376 if is_keras_rnn_cell and not nest.is_sequence(state): 1377 state = [state] 1378 for time, input_ in enumerate(inputs): 1379 if time > 0: 1380 varscope.reuse_variables() 1381 # pylint: disable=cell-var-from-loop 1382 call_cell = lambda: cell(input_, state) 1383 # pylint: enable=cell-var-from-loop 1384 if sequence_length is not None: 1385 (output, state) = _rnn_step( 1386 time=time, 1387 sequence_length=sequence_length, 1388 min_sequence_length=min_sequence_length, 1389 max_sequence_length=max_sequence_length, 1390 zero_output=zero_output, 1391 state=state, 1392 call_cell=call_cell, 1393 state_size=cell.state_size) 1394 else: 1395 (output, state) = call_cell() 1396 outputs.append(output) 1397 # Keras RNN cells only return state as list, even if it's a single tensor. 1398 if is_keras_rnn_cell and len(state) == 1: 1399 state = state[0] 1400 1401 return (outputs, state) 1402 1403 1404@tf_export("nn.static_state_saving_rnn") 1405def static_state_saving_rnn(cell, 1406 inputs, 1407 state_saver, 1408 state_name, 1409 sequence_length=None, 1410 scope=None): 1411 """RNN that accepts a state saver for time-truncated RNN calculation. 1412 1413 Args: 1414 cell: An instance of `RNNCell`. 1415 inputs: A length T list of inputs, each a `Tensor` of shape 1416 `[batch_size, input_size]`. 1417 state_saver: A state saver object with methods `state` and `save_state`. 1418 state_name: Python string or tuple of strings. The name to use with the 1419 state_saver. If the cell returns tuples of states (i.e., 1420 `cell.state_size` is a tuple) then `state_name` should be a tuple of 1421 strings having the same length as `cell.state_size`. Otherwise it should 1422 be a single string. 1423 sequence_length: (optional) An int32/int64 vector size [batch_size]. 1424 See the documentation for rnn() for more details about sequence_length. 1425 scope: VariableScope for the created subgraph; defaults to "rnn". 1426 1427 Returns: 1428 A pair (outputs, state) where: 1429 outputs is a length T list of outputs (one for each input) 1430 states is the final state 1431 1432 Raises: 1433 TypeError: If `cell` is not an instance of RNNCell. 1434 ValueError: If `inputs` is `None` or an empty list, or if the arity and 1435 type of `state_name` does not match that of `cell.state_size`. 1436 """ 1437 state_size = cell.state_size 1438 state_is_tuple = nest.is_sequence(state_size) 1439 state_name_tuple = nest.is_sequence(state_name) 1440 1441 if state_is_tuple != state_name_tuple: 1442 raise ValueError("state_name should be the same type as cell.state_size. " 1443 "state_name: %s, cell.state_size: %s" % (str(state_name), 1444 str(state_size))) 1445 1446 if state_is_tuple: 1447 state_name_flat = nest.flatten(state_name) 1448 state_size_flat = nest.flatten(state_size) 1449 1450 if len(state_name_flat) != len(state_size_flat): 1451 raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" % 1452 (len(state_name_flat), len(state_size_flat))) 1453 1454 initial_state = nest.pack_sequence_as( 1455 structure=state_size, 1456 flat_sequence=[state_saver.state(s) for s in state_name_flat]) 1457 else: 1458 initial_state = state_saver.state(state_name) 1459 1460 (outputs, state) = static_rnn( 1461 cell, 1462 inputs, 1463 initial_state=initial_state, 1464 sequence_length=sequence_length, 1465 scope=scope) 1466 1467 if state_is_tuple: 1468 flat_state = nest.flatten(state) 1469 state_name = nest.flatten(state_name) 1470 save_state = [ 1471 state_saver.save_state(name, substate) 1472 for name, substate in zip(state_name, flat_state) 1473 ] 1474 else: 1475 save_state = [state_saver.save_state(state_name, state)] 1476 1477 with ops.control_dependencies(save_state): 1478 last_output = outputs[-1] 1479 flat_last_output = nest.flatten(last_output) 1480 flat_last_output = [ 1481 array_ops.identity(output) for output in flat_last_output 1482 ] 1483 outputs[-1] = nest.pack_sequence_as( 1484 structure=last_output, flat_sequence=flat_last_output) 1485 1486 if state_is_tuple: 1487 state = nest.pack_sequence_as( 1488 structure=state, 1489 flat_sequence=[array_ops.identity(s) for s in flat_state]) 1490 else: 1491 state = array_ops.identity(state) 1492 1493 return (outputs, state) 1494 1495 1496@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional(" 1497 "keras.layers.RNN(cell, unroll=True))`, which is " 1498 "equivalent to this API") 1499@tf_export(v1=["nn.static_bidirectional_rnn"]) 1500def static_bidirectional_rnn(cell_fw, 1501 cell_bw, 1502 inputs, 1503 initial_state_fw=None, 1504 initial_state_bw=None, 1505 dtype=None, 1506 sequence_length=None, 1507 scope=None): 1508 """Creates a bidirectional recurrent neural network. 1509 1510 Similar to the unidirectional case above (rnn) but takes input and builds 1511 independent forward and backward RNNs with the final forward and backward 1512 outputs depth-concatenated, such that the output will have the format 1513 [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of 1514 forward and backward cell must match. The initial state for both directions 1515 is zero by default (but can be set optionally) and no intermediate states are 1516 ever returned -- the network is fully unrolled for the given (passed in) 1517 length(s) of the sequence(s) or completely unrolled if length(s) is not given. 1518 1519 Args: 1520 cell_fw: An instance of RNNCell, to be used for forward direction. 1521 cell_bw: An instance of RNNCell, to be used for backward direction. 1522 inputs: A length T list of inputs, each a tensor of shape 1523 [batch_size, input_size], or a nested tuple of such elements. 1524 initial_state_fw: (optional) An initial state for the forward RNN. 1525 This must be a tensor of appropriate type and shape 1526 `[batch_size, cell_fw.state_size]`. 1527 If `cell_fw.state_size` is a tuple, this should be a tuple of 1528 tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. 1529 initial_state_bw: (optional) Same as for `initial_state_fw`, but using 1530 the corresponding properties of `cell_bw`. 1531 dtype: (optional) The data type for the initial state. Required if 1532 either of the initial states are not provided. 1533 sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 1534 containing the actual lengths for each of the sequences. 1535 scope: VariableScope for the created subgraph; defaults to 1536 "bidirectional_rnn" 1537 1538 Returns: 1539 A tuple (outputs, output_state_fw, output_state_bw) where: 1540 outputs is a length `T` list of outputs (one for each input), which 1541 are depth-concatenated forward and backward outputs. 1542 output_state_fw is the final state of the forward rnn. 1543 output_state_bw is the final state of the backward rnn. 1544 1545 Raises: 1546 TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 1547 ValueError: If inputs is None or an empty list. 1548 """ 1549 rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) 1550 rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) 1551 if not nest.is_sequence(inputs): 1552 raise TypeError("inputs must be a sequence") 1553 if not inputs: 1554 raise ValueError("inputs must not be empty") 1555 1556 with vs.variable_scope(scope or "bidirectional_rnn"): 1557 # Forward direction 1558 with vs.variable_scope("fw") as fw_scope: 1559 output_fw, output_state_fw = static_rnn( 1560 cell_fw, 1561 inputs, 1562 initial_state_fw, 1563 dtype, 1564 sequence_length, 1565 scope=fw_scope) 1566 1567 # Backward direction 1568 with vs.variable_scope("bw") as bw_scope: 1569 reversed_inputs = _reverse_seq(inputs, sequence_length) 1570 tmp, output_state_bw = static_rnn( 1571 cell_bw, 1572 reversed_inputs, 1573 initial_state_bw, 1574 dtype, 1575 sequence_length, 1576 scope=bw_scope) 1577 1578 output_bw = _reverse_seq(tmp, sequence_length) 1579 # Concat each of the forward/backward outputs 1580 flat_output_fw = nest.flatten(output_fw) 1581 flat_output_bw = nest.flatten(output_bw) 1582 1583 flat_outputs = tuple( 1584 array_ops.concat([fw, bw], 1) 1585 for fw, bw in zip(flat_output_fw, flat_output_bw)) 1586 1587 outputs = nest.pack_sequence_as( 1588 structure=output_fw, flat_sequence=flat_outputs) 1589 1590 return (outputs, output_state_fw, output_state_bw) 1591