1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Recurrent computation. 16 17The main interface of this module is Recurrent(). 18A recurrent computation describes an auto-regressive process, where outputs 19of one time step are fed to the output of the next time step. 20 21This module uses: 22 theta: the "weights" each RNN uses. 23 state0: the initial state of each RNN. 24 cell_fn: A python function describing RNN cell. It must has the following 25 signature: 26 cell_fn: (theta, state0, inputs) -> (state1, extras) 27 state1 is the next RNN state, extras are computed by cell_fn 28 and the library forwards extras to cell_fn's gradient function. 29 cell_grad: A python function describing the backprop gradient function 30 for the RNN cell. It must has the following signature: 31 cell_grad: (theta, state0, inputs, extras, dstate1) -> ( 32 dtheta, dstate0, dinputs) 33 dstate1 is what the backprop algorithm provides representing 34 gradients of state1 w.r.t. the final loss. 35 36In this module, we handle structures of tensors for theta, state0, inputs, 37and extras. The structure is an arbitrarily nested python structure, such 38as a dictionary of named tuples. 39 40Because the computation is a left-to-right chain, a single in-place accumulator 41can be used rather than a stack. Thus a special gradient was written to reduce 42unnecessary memory usage. 43""" 44 45from __future__ import absolute_import 46from __future__ import division 47from __future__ import print_function 48 49from tensorflow.python.framework import dtypes 50from tensorflow.python.framework import function 51from tensorflow.python.framework import ops 52from tensorflow.python.ops import array_ops 53from tensorflow.python.ops import functional_ops 54from tensorflow.python.ops import gradients_impl 55from tensorflow.python.ops import inplace_ops 56from tensorflow.python.ops import math_ops 57from tensorflow.python.ops.inplace_ops import alias_inplace_update 58from tensorflow.python.util import nest 59 60 61def _AssertIsCompatible(a, b): 62 """Checks that `a` and `b` are nested structures of the same type.""" 63 # TODO(drpng): implement. 64 del a 65 del b 66 67 68def _Index(struct, index): 69 """Returns a structure with `x[index]` for each tensor `x` in the structure. 70 71 Args: 72 struct: A structure of tensors. 73 index: A scalar integer tensor. Performance is better if `index` is 74 on the host memory. 75 76 Returns: 77 A structure of tensors congruent to `struct`. 78 For each key in `ret`, `rets[key] = struct[key][index]`. 79 """ 80 index = ops.convert_to_tensor(index) 81 index.get_shape().assert_has_rank(0) 82 return nest.map_structure(lambda x: array_ops.gather(x, index), struct) 83 84 85def _Update(struct_acc, struct_x, t): 86 """Updates t-th row in accumulators. 87 88 Args: 89 struct_acc: The accumulators. A structure of tensors. 90 struct_x: The new values. A structure of tensors congruent to `struct_acc`. 91 t: A scalar integer. Performance is better if `t` is on the device 92 memory. 93 94 Returns: 95 A structure of tensors. Say, ret is a returned dictionary. Then, for 96 each key, we have: 97 ret[key] = struct_acc[key]; 98 ret[key][t, :] = struct_x[key] 99 """ 100 to_skip_update = set() 101 acc_lst = nest.flatten(struct_acc) 102 x_lst = nest.flatten(struct_x) 103 t = math_ops.cast([t], dtypes.int32) # tf.to_int32 casts on-device tensors. 104 lst = [] 105 for acc, x in zip(acc_lst, x_lst): 106 if acc in to_skip_update: 107 # Until b/62105730 is fixed, we need to avoid inplace update for tensors 108 # of rank 1. could reshape to handle it, but we don't really need the 109 # values applied to these, so just skip their modification. 110 lst += [acc] 111 else: 112 lst += [alias_inplace_update(acc, t, array_ops.expand_dims(x, 0))] 113 return nest.pack_sequence_as(struct_acc, lst) 114 115 116def _SeqLenDim(struct): 117 """Returns the 0-th dim size of tensors in a structure of tensors. 118 119 This is the max sequence length according to the shape of the inputs. 120 121 Args: 122 struct: A structure of tensors. Every tensor's 0-th dim has the same size. 123 124 Returns: 125 A scalar tensor which is the size of 0-th dim of every tensors in struct. 126 """ 127 xs = nest.flatten(struct) 128 assert xs 129 dim0 = array_ops.shape(xs[0])[0] 130 return dim0 131 132 133def _Flatten(struct): 134 """Flattens a structure.""" 135 return nest.flatten(struct) 136 137 138def _Pack(elements, struct_template): 139 """Packs the list of tensors according to the structure. 140 141 In the event that `elements` should be a scalar, `struct_template` must 142 contain exactly one non-trivial element (for instance, `[[], {'x':elt}]`). 143 144 Args: 145 elements: Elements to be packed. A list of tensor, or a single tensor. 146 struct_template: The container structure in which to pack them. 147 Returns: 148 A python structure of the same type as `struct_template`, containing 149 `elements` as its contained elements. 150 """ 151 if not nest.is_sequence(elements): 152 return nest.pack_sequence_as(struct_template, [elements]) 153 return nest.pack_sequence_as(struct_template, elements) 154 155 156def _EmptyAcc(slen, struct_template): 157 """Creates a set of accumulators for tensors in structure. 158 159 Args: 160 slen: The sequence length. A scalar tensor. 161 struct_template: A structure of tensors. 162 163 Returns: 164 A structure congruent to `struct_template`. Say ret is a returned 165 dictionary. Then, `ret.key`, a tensor, has the same dtype as 166 `struct_template.key`. The tensor's shape has 1 more dimension 167 than the tensor `struct_template.key`. The extra 0-th dimension is of size 168 `slen`. E.g., if `slen=10` and `struct_template.key`'s shape is `[3, 5]`, 169 then, `ret.key`'s shape is `[10, 3, 5]`. 170 """ 171 172 def _EmptyAccForTensor(tensor): 173 return inplace_ops.empty( 174 array_ops.concat([[slen], array_ops.shape(tensor)], axis=0), 175 tensor.dtype, 176 init=True) 177 178 return nest.map_structure(_EmptyAccForTensor, struct_template) 179 180 181def _EmptyLike(struct): 182 """Creates a set of empty initialized tensors. 183 184 Args: 185 struct: A structure of tensors. 186 187 Returns: 188 A struct of tensors. Each tensor has the same shape and dtype as 189 its corresponding tensor in `struct`. And each tensor is initialized. 190 """ 191 return nest.map_structure( 192 lambda x: inplace_ops.empty_like(x, init=True), struct) 193 194 195def _Add(struct_x, struct_y): 196 """Adds tensors in `struct_x` with respective tensors in `struct_y`. 197 198 Args: 199 struct_x: A struct of tensors. 200 struct_y: A struct of tensors congruent to `struct_x`. 201 202 Returns: 203 A struct of tensors. Each element of the returned value 204 equals `x + y`, with corresponding values in `struct_x` and `struct_y`. 205 """ 206 list_x = nest.flatten(struct_x) 207 list_y = nest.flatten(struct_y) 208 z = [] 209 for x, y in zip(list_x, list_y): 210 z += [math_ops.add(x, y)] 211 return nest.pack_sequence_as(struct_x, z) 212 213 214def _Dtypes(struct): 215 """Returns all tensors' data types in a list.""" 216 return [x.dtype for x in nest.flatten(struct)] 217 218 219def _ConvertNoneGradientToZeros(xs, dxs): 220 """Sanitize dxs so that None becomes zeros appropriately. 221 222 Args: 223 xs: A list of tensors. 224 dxs: A list of tensors. dxs[i] corresponds to xs[i]'s gradient. 225 226 Returns: 227 A structure same as `dxs` with `None` replaced by a zero tensor. 228 """ 229 list_xs = nest.flatten(xs) 230 list_dxs = nest.flatten(dxs) 231 232 # If x does not get any backprop-ed gradient, propagate zeros. 233 rets = [] 234 for (x, dx) in zip(list_xs, list_dxs): 235 if dx is None: 236 rets.append(array_ops.zeros_like(x)) 237 else: 238 rets.append(dx) 239 240 return nest.pack_sequence_as(dxs, rets) 241 242 243# All structures are flattened for use internally. This is for simplicity 244# and also to use the Defun construct. 245# In the forward pass (inference), the computation is structured as follows. 246# Forward: [gradient = _Recurrent.Grad] 247# Flatten structures, create accumulators. 248# for t = 0..max_input_length: 249# Defun ForwardLoopBody: 250# Defun Fwd: flatten/pack around cell_fn 251# state1 = Fwd(inputs[t], state0) 252# acc_state += [state1] 253# Pack structures. 254# During the backward pass (backpropping the gradient from the last time 255# step to the first, through the structure), the computation is structured 256# as follows. 257# Grad: 258# Flatten structures. 259# Defun Backward: 260# Create create accumulated derivatives: d_theta, d_inputs, d_acc_state. 261# Regarding the note at the top of the file, there is only one accumulator 262# for d_theta accumulated over the whole sequence. 263# for t = max_input_length -1..0: 264# Defun BackwardLoopBody: 265# Retrieve acc_state[t] computed in the forward pass. 266# Defun Bak: flatten/back around cell_fn_grad. 267# d_state1 is d_state0 from previous step (ie next time). 268# d_acc_state[dev_t] += d_state1 269# d_theta_t, d_state0, d_inputs_t, = Bak() 270# d_inputs[dev_t] += d_inputs 271# d_theta += d_theta_t 272# d_acc_state[t] += d_state1 273# Pack structures and return. 274class _Recurrent(object): 275 """A helper class to construct a recurrent neural net.""" 276 277 def __init__(self, 278 cell_fn, 279 cell_grad, 280 theta, 281 state0, 282 inputs, 283 max_input_length, 284 extras, 285 use_tpu, 286 aligned_end=False): 287 """RNN helper class. 288 289 Args: 290 cell_fn: A python function, which computes: 291 state1, extras = cell_fn(theta, state0, inputs[t, :]) 292 cell_grad: A python function which computes: 293 dtheta, dstate0, dinputs[t, :] = cell_grad( 294 theta, state0, inputs[t, :], extras, dstate1) 295 theta: weights. A structure of tensors. 296 state0: initial state. A structure of tensors. 297 inputs: inputs. A structure of tensors. 298 max_input_length: None, or the maximum effective length of the input over 299 all batches. A scalar tensor. 300 extras: A structure of tensors. The 2nd return value of every 301 invocation of cell_fn is a structure of tensors with matching keys 302 and shapes of this `extras`. 303 use_tpu: A boolean indicating whether the computation is mean to 304 run on a TPU. 305 aligned_end: A boolean indicating whether the sequence is aligned at 306 the end. 307 """ 308 self._theta = theta 309 self._state = state0 310 self._inputs = inputs 311 self._max_input_length = self._MaybeComputeMaxInputLength( 312 inputs, max_input_length) 313 self._cell_fn = cell_fn 314 self._cell_grad = cell_grad 315 self._extras = extras 316 self._aligned_end = aligned_end 317 318 # pylint: disable=unbalanced-tuple-unpacking 319 320 # NOTE: TF Function (Fwd, Bak, ForwardLoopBody, BackwardLoopBody, 321 # Forward and Backward defined below) simply takes a list of 322 # Tensors and returns a list of Tensors. When we pass in a 323 # structure (a list of structures of Tensors), we use _Flatten to 324 # convert the structure into a list of tensor. Conversely, the 325 # following code often uses _Pack to formulate a structure from a 326 # list of tensors based on a "template". 327 328 # Wraps cell_fn in a TF Function: 329 # state1 = cell_fn(theta, state0, inputs) 330 fwd_sig = [self._theta, self._state, self._inputs] 331 332 compiled = use_tpu 333 noinline = not compiled 334 dev_t_type = dtypes.int32 if use_tpu else dtypes.int64 335 336 @function.Defun(*_Dtypes(fwd_sig)) 337 def Fwd(*args): 338 (theta, state0, inputs) = _Pack(args, fwd_sig) 339 state1, extras = self._cell_fn(theta, state0, inputs) 340 assert not function.get_extra_args(), ( 341 'cell_fn is not pure with extra args: %s.' % 342 (function.get_extra_args())) 343 _AssertIsCompatible(state1, self._state) 344 _AssertIsCompatible(extras, self._extras) 345 return _Flatten([state1, extras]) 346 347 # Wraps cell_fn in a TF Function as a for-loop's body. 348 # 349 # The loop state is composed of: 350 # t: The loop variable. Timestep id. 351 # dev_t: The loop variable mirrored on the device. 352 # theta: the recurrent net's weights. 353 # state0: the previous recurrent state. 354 # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t. 355 # acc_state: Each timestep's computed new state is also stashed into 356 # acc_state. 357 # acc_extras: Each timestep's computed extras is stashed into acc_extras 358 fwdloop_sig = [ 359 self._theta, self._state, self._inputs, self._state, self._extras 360 ] 361 362 @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(fwdloop_sig)) 363 def ForwardLoopBody(*args): 364 """The body of forward loop.""" 365 t, dev_t = args[0], args[1] 366 (theta, state0, inputs, acc_state, acc_extras) = _Pack( 367 args[2:], fwdloop_sig) 368 inputs_t = _Index(inputs, t) # external input at time step t. 369 fwd = Fwd(*_Flatten([theta, state0, inputs_t])) 370 state1, extras = _Pack(fwd, [self._state, self._extras]) 371 # Saves state1 and extras in their accumulators. 372 acc_state = _Update(acc_state, state1, dev_t) 373 acc_extras = _Update(acc_extras, extras, dev_t) 374 375 return [math_ops.add(dev_t, 1)] + _Flatten( 376 [theta, state1, inputs, acc_state, acc_extras]) 377 378 def Grad(op, *args): 379 """The python grad function for the Forward function.""" 380 381 # NOTE: tf.gradient backprops None for int32/int64 while zeros 382 # for float32/float64. For consistency, we always backprop 383 # zeros. 384 args = list(args) 385 for i, dy in enumerate(args): 386 if dy is None: 387 args[i] = array_ops.zeros_like(op.outputs[i]) 388 # TODO(drpng): getting the extra state here? 389 op_inputs = [x for x in op.inputs] 390 op_struct = [ 391 self._theta, self._state, self._inputs, self._max_input_length, 392 self._extras 393 ] 394 (theta, state0, inputs, max_input_length, _) = _Pack(op_inputs, op_struct) 395 # acc_state and acc_extras are computed by the Forward pass and 396 # needed by the Backward pass. 397 acc_state, _, acc_extras = _Pack([x for x in op.outputs], 398 [self._state, self._state, self._extras]) 399 400 # Forward computes acc_state, the final state and 401 # acc_extras. tf.gradients gives us their gradients w.r.t. the 402 # final loss. Because acc_extras are not exposed by Compute(), 403 # it has no gradients w.r.t. the final loss (i.e., by 404 # construction, it must be zeros). 405 d_acc_state, d_state1, _ = _Pack(args, 406 [self._state, self._state, self._extras]) 407 return Backward(*_Flatten([ 408 theta, state0, inputs, max_input_length, acc_state, acc_extras, 409 d_acc_state, d_state1 410 ])) 411 412 # Forward calls ForwardLoopBody n times. Each time computes one 413 # time step of the recurrent net. 414 forward_sig = [ 415 self._theta, self._state, self._inputs, self._max_input_length, 416 self._extras 417 ] 418 419 @function.Defun( 420 *_Dtypes(forward_sig), python_grad_func=Grad, noinline=noinline) 421 def Forward(*args): 422 """Forward pass of the recurrent net.""" 423 theta, state0, inputs, max_input_length, extras = _Pack(args, forward_sig) 424 425 slen_dim = _SeqLenDim(inputs) 426 427 # Creates accumulators for state0 and extras. 428 acc_state = _EmptyAcc(slen_dim, state0) 429 acc_extras = _EmptyAcc(slen_dim, extras) 430 431 t = slen_dim - max_input_length if self._aligned_end else 0 432 dev_t = math_ops.cast(t, dtypes.int32) if use_tpu else math_ops.cast( 433 t, dtypes.int64) 434 run = functional_ops.For( 435 start=t, 436 limit=slen_dim if self._aligned_end else max_input_length, 437 delta=1, 438 inputs=[dev_t] + _Flatten( 439 [theta, state0, inputs, acc_state, acc_extras]), 440 body=ForwardLoopBody, 441 rewrite_with_while=compiled) 442 _, state1, _, acc_state, acc_extras = _Pack( 443 run[1:], 444 [self._theta, self._state, self._inputs, self._state, self._extras]) 445 446 return _Flatten([acc_state, state1, acc_extras]) 447 448 # The per-step backward computes: 449 # d_theta, d_state0, d_inputs = cell_grad( 450 # theta, state0, inputs, extras, d_state1) 451 # where d_state1 is the backprop-ed gradient for state1, and 452 # extras is the computed by the forward step to facilitate the 453 # backward step. 454 bak_sig = [ 455 self._theta, self._state, self._inputs, self._extras, self._state 456 ] 457 458 @function.Defun(*_Dtypes(bak_sig)) 459 def Bak(*args): 460 """Backward step.""" 461 (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig) 462 (dtheta, dstate0, dinputs) = self._cell_grad(theta, state0, inputs, 463 extras, d_state1) 464 assert not function.get_extra_args(), ( 465 'cell_grad is not pure with extra args: %s.' % 466 (function.get_extra_args())) 467 _AssertIsCompatible(dtheta, self._theta) 468 _AssertIsCompatible(dstate0, self._state) 469 _AssertIsCompatible(dinputs, self._inputs) 470 return _Flatten( 471 _ConvertNoneGradientToZeros([theta, state0, inputs], 472 [dtheta, dstate0, dinputs])) 473 474 # Define defuns used by a functional_ops.If in BackwardLoopBody. 475 state_if_sig = [self._state, self._state] 476 477 @function.Defun(*_Dtypes(state_if_sig)) 478 def ReturnOrigState0(*args): 479 """Returns original state0 from inputs.""" 480 (_, orig_state0) = _Pack(args, state_if_sig) 481 return nest.flatten(orig_state0) 482 483 @function.Defun(*_Dtypes(state_if_sig)) 484 def ReturnAccState(*args): 485 """Returns acc_state[t-1] from inputs.""" 486 (acc_state, _) = _Pack(args, state_if_sig) 487 return nest.flatten(acc_state) 488 489 # Wraps cell_grad gradient function in a TF Function as a 490 # for-loop's body for the Backward pass. 491 # 492 # The loop state is composed of: 493 # t: The loop variable. Timestep id. 494 # state0: the initial state for the entire backward loop. 495 # dev_t: The loop variable mirrored on the device. 496 # theta: the recurrent net's weights. 497 # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t. 498 # acc_state: Each timestep's computed new state was stashed into 499 # acc_state by the Forward pass. 500 # acc_extras: Each timestep's computed extras was stashed into 501 # acc_extras by the Forward pass. 502 # d_theta: All timestep's gradient for theta is accumulated (added) into 503 # d_theta. 504 # d_state1: The backprop-ed gradient for the new stated computed by 505 # timestep t. 506 # d_inputs: d_inputs[t, :] is populated by the backward time step t. 507 # d_acc_state: The backprop-ed gradient for acc_state. 508 bakloop_sig = [ 509 self._theta, self._state, self._inputs, self._state, self._extras, 510 self._theta, self._state, self._inputs, self._state 511 ] 512 513 @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(bakloop_sig)) 514 def BackwardLoopBody(*args): 515 """Backward loop body function.""" 516 t, dev_t = args[0], args[1] 517 (theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state1, 518 d_inputs, d_acc_state) = _Pack(args[2:], bakloop_sig) 519 520 # The input recurrent state for time step t is previous time step's 521 # output, or the original state0 when on time step 0. 522 state_from_acc = _Index(acc_state, math_ops.maximum(0, t - 1)) 523 state0 = functional_ops.If( 524 math_ops.equal(t, array_ops.constant(0, dtypes.int32)), 525 _Flatten([state_from_acc, orig_state0]), ReturnOrigState0, 526 ReturnAccState) 527 state0 = nest.pack_sequence_as(orig_state0, state0) 528 529 # The external inputs for time step t. 530 inputs_t = _Index(inputs, t) 531 # The extras for time step t. 532 extras_t = _Index(acc_extras, t) 533 534 d_state1 = _Add(_Index(d_acc_state, t), d_state1) 535 (d_theta_t, d_state0, d_inputs_t) = _Pack( 536 Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])), 537 [self._theta, self._state, self._inputs]) 538 d_theta = _Add(d_theta, d_theta_t) 539 d_inputs = _Update(d_inputs, d_inputs_t, dev_t) 540 return [math_ops.subtract(dev_t, 1)] + _Flatten([ 541 theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state0, 542 d_inputs, d_acc_state 543 ]) 544 545 # Backward calls BackwardLoopBody n times. Each time computes the backprop 546 # for one time step of the recurrent net. 547 backward_sig = [ 548 self._theta, self._state, self._inputs, self._max_input_length, 549 self._state, self._extras, self._state, self._state 550 ] 551 552 @function.Defun(*_Dtypes(backward_sig), noinline=noinline) 553 def Backward(*args): 554 """Backward pass for the recurrent net.""" 555 # theta, state0, inputs are Forward's inputs. 556 # acc_state is the accumulated 1st output of Forward. 557 # acc_extras is the accumulated 2nd output of Forward. 558 # d_acc_state is the gradient for acc_state. 559 # d_state1 is the gradient for the final state computed by Forward. 560 (theta, state0, inputs, max_input_length, acc_state, acc_extras, 561 d_acc_state, d_state1) = _Pack(args, backward_sig) 562 563 # Accumulators for gradients. 564 d_theta = _EmptyLike(theta) 565 d_inputs = _EmptyLike(inputs) 566 567 slen_dim = _SeqLenDim(inputs) 568 569 # Loop backwards. Note the loop's limit is open-ended, so goes through 570 # t=0. 571 t = slen_dim - 1 if self._aligned_end else max_input_length - 1 572 dev_t = math_ops.cast(t, dtypes.int32) if use_tpu else math_ops.cast( 573 t, dtypes.int64) 574 limit = slen_dim - max_input_length - 1 if self._aligned_end else -1 575 run = functional_ops.For( 576 start=t, 577 limit=limit, 578 delta=-1, 579 inputs=[dev_t] + _Flatten([ 580 theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1, 581 d_inputs, d_acc_state 582 ]), 583 body=BackwardLoopBody, 584 rewrite_with_while=compiled) 585 586 (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0, 587 d_inputs, d_acc_state) = _Pack(run[1:], bakloop_sig) 588 589 d_max_input_length = array_ops.constant(0, dtype=max_input_length.dtype) 590 return _Flatten( 591 [d_theta, d_state0, d_inputs, d_max_input_length, acc_extras]) 592 593 self._forward = Forward 594 595 def _MaybeComputeMaxInputLength(self, inputs, max_input_length): 596 if max_input_length is not None: 597 return max_input_length 598 return math_ops.reduce_max(array_ops.shape(nest.flatten(inputs)[0])[0]) 599 600 def Compute(self): 601 return _Pack( 602 self._forward(*_Flatten([ 603 self._theta, self._state, self._inputs, self._max_input_length, 604 self._extras 605 ])), [self._state, self._state, self._extras])[:2] 606 607 608def _GetCellGrad(cell_fn, cell_grad): 609 """Returns the gradient function for cell_fn. 610 611 Args: 612 cell_fn: The recurrent neural net's cell function. 613 cell_grad: If not None, cell_fn's gradient function. 614 615 Returns: 616 Returns cell_grad if not None. Otherwise, assume cell_fn is a python 617 function representing the recurrent neural net's cell function, i.e., 618 cell_fn: (theta, state0, inputs) -> (state1, extra) 619 returns its default gradient python function, i.e., 620 cell_grad: (theta, state0, inputs, extras, dstate1) -> ( 621 dtheta, dstate0, dinputs) 622 """ 623 624 if cell_grad: 625 return cell_grad 626 627 def CellGrad(theta, state0, inputs, extras, dstate1): 628 """Default gradient function for cell_fn.""" 629 # NOTE: The default grad function recomputes the forward 630 # function and does not take advantage of 'extras' returned by 631 # the forward function. 632 del extras 633 state1, extras = cell_fn(theta, state0, inputs) 634 ys = _Flatten([state1]) 635 xs = _Flatten([theta, state0, inputs]) 636 grad_ys = _Flatten([dstate1]) 637 grads = gradients_impl.gradients(ys=ys, xs=xs, grad_ys=grad_ys) 638 return _ConvertNoneGradientToZeros([theta, state0, inputs], 639 _Pack(grads, [theta, state0, inputs])) 640 641 return CellGrad 642 643 644def _IsSingleTimeStep(inputs, max_input_length): 645 """Returns True only if the time dimension of inputs is 1.""" 646 if not isinstance(max_input_length, ops.Tensor): 647 return max_input_length == 1 648 for x in nest.flatten(inputs): 649 if x.shape.dims is None or x.shape[0].value != 1: 650 return False 651 return True 652 653 654def Recurrent(theta, 655 state0, 656 inputs, 657 cell_fn, 658 cell_grad=None, 659 extras=None, 660 max_input_length=None, 661 use_tpu=False, 662 aligned_end=False): 663 """Compute a recurrent neural net. 664 665 Roughly, Recurrent() computes the following: 666 state = state0 667 for t in inputs' sequence length: 668 state = cell_fn(theta, state, inputs[t, :]) 669 accumulate_state[t, :] = state 670 return accumulate_state, state 671 672 theta, state, inputs are all structures of tensors. 673 674 inputs[t, :] means taking a slice out from every tensor in the inputs. 675 676 accumulate_state[t, :] = state means that we stash every tensor in 677 'state' into a slice of the corresponding tensor in 678 accumulate_state. 679 680 cell_fn is a python callable computing (building up a TensorFlow 681 graph) the recurrent neural network's one forward step. Two calls of 682 cell_fn must describe two identical computations. 683 684 By construction, Recurrent()'s backward computation does not access 685 any intermediate values computed by cell_fn during forward 686 computation. We may extend Recurrent() to support that by taking a 687 customized backward function of cell_fn. 688 689 Args: 690 theta: weights. A structure of tensors. 691 state0: initial state. A structure of tensors. 692 inputs: inputs. A structure of tensors. 693 cell_fn: A python function, which computes: 694 state1, extras = cell_fn(theta, state0, inputs[t, :]) 695 cell_grad: A python function which computes: 696 dtheta, dstate0, dinputs[t, :] = cell_grad( 697 theta, state0, inputs[t, :], extras, dstate1) 698 extras: A structure of tensors. The 2nd return value of every 699 invocation of cell_fn is a structure of tensors with matching keys 700 and shapes of this `extras`. 701 max_input_length: maximum length of effective input. This is used to 702 truncate the computation if the inputs have been allocated to a 703 larger size. A scalar tensor. 704 use_tpu: whether or not we are on TPU. 705 aligned_end: A boolean indicating whether the sequence is aligned at 706 the end. 707 708 Returns: 709 accumulate_state and the final state. 710 """ 711 if cell_grad is None and _IsSingleTimeStep(inputs, max_input_length): 712 # The seqlen length is staticly known as 1. Hence, we just need to 713 # call cell_fn once without putting it into a loop. 714 inputs = nest.map_structure(lambda x: array_ops.squeeze(x, axis=0), inputs) 715 state1, _ = cell_fn(theta, state0, inputs) 716 acc_state = nest.map_structure(lambda x: array_ops.expand_dims(x, axis=0), 717 state1) 718 return acc_state, state1 719 720 # If cell_grad is not given, derives the gradient function from 721 # cell_fn. 722 cell_grad = _GetCellGrad(cell_fn, cell_grad) 723 724 if extras is None: 725 # Derives 'extras' so that we can allocate extras' accumulator. 726 _, extras = cell_fn(theta, state0, _Index(inputs, 0)) 727 extras = nest.map_structure(array_ops.zeros_like, extras) 728 else: 729 _, actual = cell_fn(theta, state0, _Index(inputs, 0)) 730 _AssertIsCompatible(extras, actual) 731 732 return _Recurrent( 733 cell_fn=cell_fn, 734 cell_grad=cell_grad, 735 theta=theta, 736 state0=state0, 737 inputs=inputs, 738 max_input_length=max_input_length, 739 extras=extras, 740 use_tpu=use_tpu, 741 aligned_end=aligned_end).Compute() 742