1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A tf.nn.dynamic_rnn variant, built on the Recurrent class. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import copy 23 24from tensorflow.contrib.recurrent.python.ops import recurrent 25from tensorflow.python.framework import function 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import variable_scope 30from tensorflow.python.util import nest 31 32 33def _GetDTypesFromStructure(struct): 34 dtypes_list = [] 35 for x in nest.flatten(struct): 36 x = ops.convert_to_tensor(x) 37 dtypes_list.append(x.dtype) 38 return dtypes_list 39 40 41def _SetShapeFromTemplate(struct, struct_template): 42 as_list = nest.flatten(struct) 43 template_as_list = nest.flatten(struct_template) 44 for element, template in zip(as_list, template_as_list): 45 element.set_shape(template.shape) 46 47 48class _FunctionalRnnCell(object): 49 """Wrapper around RNNCell which separates state from computation. 50 51 This class accomplishes the following: 52 * Turn the cell's `__call__` function into a pure function. The global 53 side effects are separated as `theta`. They are the variables created 54 for the weights of the computation. 55 * Unless the output is aliased as part of the state, extend the state to 56 contain the output so that we store the history in `Recurrent`. 57 * Set static shapes as required. 58 """ 59 60 def __init__(self, rnn_cell, seq_inputs, initial_state): 61 assert initial_state is not None 62 63 # TODO(drpng): Dtype needs to be configurable. 64 input_dtypes = [seq_inputs.dtype] + _GetDTypesFromStructure(initial_state) 65 # See _index. 66 like_inputs_t = nest.map_structure( 67 lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs) 68 input_structure = (like_inputs_t, initial_state) 69 70 @function.Defun(*input_dtypes) 71 def FlatCellStep(*flat_inputs): 72 """The flattened version of `rnn_cell`.""" 73 inputs_t, state0 = nest.pack_sequence_as(input_structure, flat_inputs) 74 _SetShapeFromTemplate(state0, initial_state) 75 _SetShapeFromTemplate(inputs_t, like_inputs_t) 76 outputs_t, state1 = rnn_cell(inputs_t, state0) 77 state_list = nest.flatten(state1) 78 self._output_shape = outputs_t.shape 79 80 if outputs_t in state_list: 81 output_index_in_state = state_list.index(outputs_t) 82 else: 83 output_index_in_state = None 84 85 if output_index_in_state is None: 86 self._prepend_output = True 87 self._output_state_idx = 0 88 return [outputs_t] + state_list 89 else: 90 self._output_state_idx = output_index_in_state 91 self._prepend_output = False 92 # To save memory, we don't store return the output separately 93 # from the state list, since we know it's the same. 94 return state_list 95 96 def _ToPureFunction(func): 97 # NOTE: This forces the creating of the function. 98 if func.captured_inputs: 99 pure_func = copy.copy(func) 100 # pylint: disable=protected-access 101 pure_func._extra_inputs = [] 102 return pure_func 103 return func 104 105 pure_flat_cell_step = _ToPureFunction(FlatCellStep) 106 107 def CellStep(theta, extended_state0, inputs_t): 108 """Performs one time steps on structured inputs. 109 110 The purpose of this function is to turn the parameters into flattened 111 versions, and to resolve the parameter order difference between 112 `Recurrent` and `RNNCell`. 113 114 In the event the cell returns a transformed output that is not aliased 115 within its state, the `extended_state0` also contains the output as its 116 first element. 117 118 Args: 119 theta: Weights required for the computation. A structure of tensors. 120 extended_state0: the state0, and possibly the output at the previous 121 time step. A structure of tensors. 122 inputs_t: the inputs at time t. 123 124 Returns: 125 A pair of the next state (inclusive of the output), and an empty list 126 (unused `extras`). 127 The next state is congruent to state0. 128 """ 129 extended_state0_flat = nest.flatten(extended_state0) 130 state0_flat = self.MaybeRemoveOutputFromState(extended_state0_flat) 131 full_inputs = [inputs_t] + state0_flat + theta 132 # Note that the thetas are additional inputs appeneded as extra 133 # parameters. 134 cell_out = pure_flat_cell_step(*full_inputs) 135 return cell_out, [] 136 137 self._cell_step = CellStep 138 self._theta = FlatCellStep.captured_inputs 139 self._zero_state = rnn_cell.zero_state 140 self._state_template = initial_state 141 self._output_size = rnn_cell.output_size 142 143 @property 144 def extended_initial_state(self): 145 if self._prepend_output: 146 return [array_ops.zeros( 147 self._output_shape, 148 dtype=_GetDTypesFromStructure(self._state_template)[0]), 149 self._state_template] 150 else: 151 # The base case, where the output is just the hidden state. 152 return self._state_template 153 154 @property 155 def cell_step(self): 156 return self._cell_step 157 158 @property 159 def theta(self): 160 return self._theta 161 162 @property 163 def state_template(self): 164 return self._state_template 165 166 @property 167 def output_shape(self): 168 return self._output_shape 169 170 def GetOutputFromState(self, state): 171 return nest.flatten(state)[self._output_state_idx] 172 173 def MaybeRemoveOutputFromState(self, flat_state): 174 if self._prepend_output: 175 return flat_state[1:] 176 return flat_state 177 178 179def _ApplyLengthsToBatch(sequence_lengths, tf_output): 180 # TODO(drpng): just use Update so that we don't carry over the gradients? 181 """Sets the output to be zero at the end of the sequence.""" 182 # output is batch major. 183 shape = array_ops.shape(tf_output) 184 batch_size, max_time, vector_size = shape[0], shape[1], shape[2] 185 output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size]) 186 output_time = array_ops.reshape(output_time, [batch_size, max_time]) 187 lengths = array_ops.tile( 188 array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time]) 189 is_less = math_ops.cast( 190 math_ops.less(output_time, lengths), dtype=tf_output.dtype) 191 keep_mask = array_ops.tile( 192 array_ops.expand_dims(is_less, -1), 193 [1, 1, vector_size]) 194 final_output = keep_mask * tf_output 195 return final_output 196 197 198def _PickFinalStateFromHistory(acc_state, sequence_length): 199 """Implements acc_state[sequence_length - 1].""" 200 # This will work on all platforms, unlike the regular slice. 201 last_value = [] 202 for state_var in nest.flatten(acc_state): 203 # We compute the following with matrix operations: 204 # last_var = state_var[sequence_length - 1] 205 shape = array_ops.shape(state_var) 206 max_time, batch_size = shape[0], shape[1] 207 output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size]) 208 output_time = array_ops.reshape(output_time, [batch_size, max_time]) 209 lengths = array_ops.tile(array_ops.reshape(sequence_length, 210 [-1, 1]), [1, max_time]) 211 last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1), 212 dtype=state_var.dtype) 213 last_idx = array_ops.transpose(last_idx) 214 last_idx_for_bcast = array_ops.expand_dims(last_idx, -1) 215 sliced = math_ops.multiply(last_idx_for_bcast, state_var) 216 last_var = math_ops.reduce_sum(sliced, 0) 217 last_value += [last_var] 218 return nest.pack_sequence_as(acc_state, last_value) 219 220 221def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell, 222 total_time, inputs_lengths, is_reversed): 223 """Post-process output of recurrent. 224 225 This function takes the accumulated extended state and extracts the requested 226 state and output. 227 228 When `inputs_lengths` has been set, it extracts the output from the 229 accumulated state. It also sets outputs past. 230 231 When `is_reversed` is true, the output will be reversed in this function. 232 233 It also sets the static shape information. 234 235 Args: 236 extended_acc_state: A structure containing the accumulated state at each 237 time. It may contain the output at each time as well. 238 extended_final_state: A structure containing the final state. It may 239 contain the output at the final time. 240 func_cell: The functional wrapper around the cell. 241 total_time: A scalar integer tensor. 242 inputs_lengths: An integer tensor with one entry per input. 243 is_reversed: A boolean to indicate if the sequence is reversed. 244 245 Returns: 246 A tuple with the outputs at each time, and the final state. 247 """ 248 if inputs_lengths is None or is_reversed: 249 flat_final_state = func_cell.MaybeRemoveOutputFromState( 250 nest.flatten(extended_final_state)) 251 tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state) 252 else: 253 # The accumulated state is over the entire sequence, so we pick it 254 # out from the acc_state sequence. 255 flat_acc_state = func_cell.MaybeRemoveOutputFromState( 256 nest.flatten(extended_acc_state)) 257 acc_state = nest.pack_sequence_as( 258 func_cell.state_template, flat_acc_state) 259 tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths) 260 261 output_from_state = func_cell.GetOutputFromState(extended_acc_state) 262 if is_reversed: 263 output_from_state = array_ops.reverse(output_from_state, [0]) 264 tf_output = array_ops.transpose(output_from_state, [1, 0, 2]) 265 tf_output.set_shape( 266 [func_cell.output_shape[0], total_time, func_cell.output_shape[1]]) 267 if inputs_lengths is not None: 268 # Need set the outputs to zero. 269 tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output) 270 _SetShapeFromTemplate(tf_state, func_cell.state_template) 271 return tf_output, tf_state 272 273 274# pylint: disable=invalid-name 275def functional_rnn(cell, 276 inputs, 277 sequence_length=None, 278 initial_state=None, 279 dtype=None, 280 time_major=False, 281 scope=None, 282 use_tpu=False, 283 reverse=False): 284 """Same interface as `tf.nn.dynamic_rnn`.""" 285 with variable_scope.variable_scope(scope or 'rnn'): 286 if not time_major: 287 inputs = nest.map_structure( 288 lambda t: array_ops.transpose(t, [1, 0, 2]), inputs) 289 inputs_flat = nest.flatten(inputs) 290 batch_size = array_ops.shape(inputs_flat[0])[1] 291 if initial_state is None: 292 initial_state = cell.zero_state(batch_size, dtype) 293 func_cell = _FunctionalRnnCell(cell, inputs, initial_state) 294 if sequence_length is not None: 295 max_length = math_ops.reduce_max(sequence_length) 296 else: 297 max_length = None 298 if reverse: 299 inputs = array_ops.reverse(inputs, [0]) 300 extended_acc_state, extended_final_state = recurrent.Recurrent( 301 theta=func_cell.theta, 302 state0=func_cell.extended_initial_state, 303 inputs=inputs, 304 cell_fn=func_cell.cell_step, 305 max_input_length=max_length, 306 use_tpu=use_tpu, 307 aligned_end=reverse) 308 309 tf_output, tf_state = _PostProcessOutput( 310 extended_acc_state, 311 extended_final_state, 312 func_cell, 313 inputs_flat[0].shape[0], 314 sequence_length, 315 is_reversed=reverse) 316 317 if time_major: 318 tf_output = array_ops.transpose(tf_output, [1, 0, 2]) 319 return tf_output, tf_state 320 321 322def bidirectional_functional_rnn(cell_fw, 323 cell_bw, 324 inputs, 325 initial_state_fw=None, 326 initial_state_bw=None, 327 dtype=None, 328 sequence_length=None, 329 time_major=False, 330 use_tpu=False, 331 fast_reverse=False, 332 scope=None): 333 """Creates a bidirectional recurrent neural network. 334 335 Performs fully dynamic unrolling of inputs in both directions. Built to be API 336 compatible with `tf.nn.bidirectional_dynamic_rnn`, but implemented with 337 functional control flow for TPU compatibility. 338 339 Args: 340 cell_fw: An instance of `tf.contrib.rnn.RNNCell`. 341 cell_bw: An instance of `tf.contrib.rnn.RNNCell`. 342 inputs: The RNN inputs. If time_major == False (default), this must be a 343 Tensor (or hierarchical structure of Tensors) of shape 344 [batch_size, max_time, ...]. If time_major == True, this must be a Tensor 345 (or hierarchical structure of Tensors) of shape: 346 [max_time, batch_size, ...]. The first two dimensions must match across 347 all the inputs, but otherwise the ranks and other shape components may 348 differ. 349 initial_state_fw: An optional initial state for `cell_fw`. Should match 350 `cell_fw.zero_state` in structure and type. 351 initial_state_bw: An optional initial state for `cell_bw`. Should match 352 `cell_bw.zero_state` in structure and type. 353 dtype: (optional) The data type for the initial state and expected output. 354 Required if initial_states are not provided or RNN state has a 355 heterogeneous dtype. 356 sequence_length: An optional int32/int64 vector sized [batch_size]. Used to 357 copy-through state and zero-out outputs when past a batch element's 358 sequence length. So it's more for correctness than performance. 359 time_major: Whether the `inputs` tensor is in "time major" format. 360 use_tpu: Whether to enable TPU-compatible operation. If True, does not truly 361 reverse `inputs` in the backwards RNN. Once b/69305369 is fixed, we can 362 remove this flag. 363 fast_reverse: Whether to use fast tf.reverse to replace tf.reverse_sequence. 364 This is only possible when either all sequence lengths are the same inside 365 the batch, or when the cell function does not change the state on padded 366 input. 367 scope: An optional scope name for the dynamic RNN. 368 369 Returns: 370 outputs: A tuple of `(output_fw, output_bw)`. The output of the forward and 371 backward RNN. If time_major == False (default), these will 372 be Tensors shaped: [batch_size, max_time, cell.output_size]. If 373 time_major == True, these will be Tensors shaped: 374 [max_time, batch_size, cell.output_size]. Note, if cell.output_size is a 375 (possibly nested) tuple of integers or TensorShape objects, then the 376 output for that direction will be a tuple having the same structure as 377 cell.output_size, containing Tensors having shapes corresponding to the 378 shape data in cell.output_size. 379 final_states: A tuple of `(final_state_fw, final_state_bw)`. A Tensor or 380 hierarchical structure of Tensors indicating the final cell state in each 381 direction. Must have the same structure and shape as cell.zero_state. 382 383 Raises: 384 ValueError: If `initial_state_fw` is None or `initial_state_bw` is None and 385 `dtype` is not provided. 386 """ 387 # Keep this code in sync with tf.nn.dynamic_rnn for compatibility. 388 with variable_scope.variable_scope(scope or 'bidirectional_rnn'): 389 # Forward direction 390 with variable_scope.variable_scope('fw') as fw_scope: 391 output_fw, output_state_fw = functional_rnn( 392 cell=cell_fw, inputs=inputs, sequence_length=sequence_length, 393 initial_state=initial_state_fw, dtype=dtype, 394 time_major=time_major, scope=fw_scope, use_tpu=use_tpu) 395 # Backward direction 396 if not time_major: 397 time_dim = 1 398 batch_dim = 0 399 else: 400 time_dim = 0 401 batch_dim = 1 402 403 def _reverse(input_, seq_lengths, seq_dim, batch_dim): 404 if seq_lengths is not None: 405 return array_ops.reverse_sequence( 406 input=input_, seq_lengths=seq_lengths, 407 seq_dim=seq_dim, batch_dim=batch_dim) 408 else: 409 # See b/69305369. 410 assert not use_tpu, ( 411 'Bidirectional with variable sequence lengths unsupported on TPU') 412 return array_ops.reverse(input_, axis=[seq_dim]) 413 414 with variable_scope.variable_scope('bw') as bw_scope: 415 if not fast_reverse: 416 inputs = _reverse( 417 inputs, 418 seq_lengths=sequence_length, 419 seq_dim=time_dim, 420 batch_dim=batch_dim) 421 output_bw, output_state_bw = functional_rnn( 422 cell=cell_bw, 423 inputs=inputs, 424 sequence_length=sequence_length, 425 initial_state=initial_state_bw, 426 dtype=dtype, 427 time_major=time_major, 428 scope=bw_scope, 429 use_tpu=use_tpu, 430 reverse=fast_reverse) 431 432 if not fast_reverse: 433 output_bw = _reverse( 434 output_bw, 435 seq_lengths=sequence_length, 436 seq_dim=time_dim, 437 batch_dim=batch_dim) 438 439 outputs = (output_fw, output_bw) 440 output_states = (output_state_fw, output_state_bw) 441 442 return (outputs, output_states) 443# pylint: enable=invalid-name 444