• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Recurrent computation.
16
17The main interface of this module is Recurrent().
18A recurrent computation describes an auto-regressive process, where outputs
19of one time step are fed to the output of the next time step.
20
21This module uses:
22  theta: the "weights" each RNN uses.
23  state0: the initial state of each RNN.
24  cell_fn: A python function describing RNN cell. It must has the following
25    signature:
26         cell_fn: (theta, state0, inputs) -> (state1, extras)
27    state1 is the next RNN state, extras are computed by cell_fn
28    and the library forwards extras to cell_fn's gradient function.
29  cell_grad: A python function describing the backprop gradient function
30    for the RNN cell. It must has the following signature:
31         cell_grad: (theta, state0, inputs, extras, dstate1) -> (
32                  dtheta, dstate0, dinputs)
33    dstate1 is what the backprop algorithm provides representing
34    gradients of state1 w.r.t. the final loss.
35
36In this module, we handle structures of tensors for theta, state0, inputs,
37and extras. The structure is an arbitrarily nested python structure, such
38as a dictionary of named tuples.
39
40Because the computation is a left-to-right chain, a single in-place accumulator
41can be used rather than a stack. Thus a special gradient was written to reduce
42unnecessary memory usage.
43"""
44
45from __future__ import absolute_import
46from __future__ import division
47from __future__ import print_function
48
49from tensorflow.python.framework import dtypes
50from tensorflow.python.framework import function
51from tensorflow.python.framework import ops
52from tensorflow.python.ops import array_ops
53from tensorflow.python.ops import functional_ops
54from tensorflow.python.ops import gradients_impl
55from tensorflow.python.ops import inplace_ops
56from tensorflow.python.ops import math_ops
57from tensorflow.python.ops.inplace_ops import alias_inplace_update
58from tensorflow.python.util import nest
59
60
61def _AssertIsCompatible(a, b):
62  """Checks that `a` and `b` are nested structures of the same type."""
63  # TODO(drpng): implement.
64  del a
65  del b
66
67
68def _Index(struct, index):
69  """Returns a structure with `x[index]` for each tensor `x` in the structure.
70
71  Args:
72    struct: A structure of tensors.
73    index: A scalar integer tensor. Performance is better if `index` is
74      on the host memory.
75
76  Returns:
77    A structure of tensors congruent to `struct`.
78    For each key in `ret`, `rets[key] = struct[key][index]`.
79  """
80  index = ops.convert_to_tensor(index)
81  index.get_shape().assert_has_rank(0)
82  return nest.map_structure(lambda x: array_ops.gather(x, index), struct)
83
84
85def _Update(struct_acc, struct_x, t):
86  """Updates t-th row in accumulators.
87
88  Args:
89    struct_acc: The accumulators. A structure of tensors.
90    struct_x: The new values. A structure of tensors congruent to `struct_acc`.
91    t: A scalar integer. Performance is better if `t` is on the device
92      memory.
93
94  Returns:
95    A structure of tensors. Say, ret is a returned dictionary. Then, for
96    each key, we have:
97      ret[key] = struct_acc[key];
98      ret[key][t, :] = struct_x[key]
99  """
100  to_skip_update = set()
101  acc_lst = nest.flatten(struct_acc)
102  x_lst = nest.flatten(struct_x)
103  t = math_ops.cast([t], dtypes.int32)  # tf.to_int32 casts on-device tensors.
104  lst = []
105  for acc, x in zip(acc_lst, x_lst):
106    if acc in to_skip_update:
107      # Until b/62105730 is fixed, we need to avoid inplace update for tensors
108      # of rank 1.  could reshape to handle it, but we don't really need the
109      # values applied to these, so just skip their modification.
110      lst += [acc]
111    else:
112      lst += [alias_inplace_update(acc, t, array_ops.expand_dims(x, 0))]
113  return nest.pack_sequence_as(struct_acc, lst)
114
115
116def _SeqLenDim(struct):
117  """Returns the 0-th dim size of tensors in a structure of tensors.
118
119  This is the max sequence length according to the shape of the inputs.
120
121  Args:
122    struct: A structure of tensors. Every tensor's 0-th dim has the same size.
123
124  Returns:
125    A scalar tensor which is the size of 0-th dim of every tensors in struct.
126  """
127  xs = nest.flatten(struct)
128  assert xs
129  dim0 = array_ops.shape(xs[0])[0]
130  return dim0
131
132
133def _Flatten(struct):
134  """Flattens a structure."""
135  return nest.flatten(struct)
136
137
138def _Pack(elements, struct_template):
139  """Packs the list of tensors according to the structure.
140
141  In the event that `elements` should be a scalar, `struct_template` must
142  contain exactly one non-trivial element (for instance, `[[], {'x':elt}]`).
143
144  Args:
145    elements: Elements to be packed. A list of tensor, or a single tensor.
146    struct_template: The container structure in which to pack them.
147  Returns:
148    A python structure of the same type as `struct_template`, containing
149    `elements` as its contained elements.
150  """
151  if not nest.is_sequence(elements):
152    return nest.pack_sequence_as(struct_template, [elements])
153  return nest.pack_sequence_as(struct_template, elements)
154
155
156def _EmptyAcc(slen, struct_template):
157  """Creates a set of accumulators for tensors in structure.
158
159  Args:
160    slen: The sequence length. A scalar tensor.
161    struct_template: A structure of tensors.
162
163  Returns:
164    A structure congruent to `struct_template`. Say ret is a returned
165    dictionary. Then, `ret.key`, a tensor, has the same dtype as
166    `struct_template.key`. The tensor's shape has 1 more dimension
167    than the tensor `struct_template.key`. The extra 0-th dimension is of size
168    `slen`. E.g., if `slen=10` and `struct_template.key`'s shape is `[3, 5]`,
169    then, `ret.key`'s shape is `[10, 3, 5]`.
170  """
171
172  def _EmptyAccForTensor(tensor):
173    return inplace_ops.empty(
174        array_ops.concat([[slen], array_ops.shape(tensor)], axis=0),
175        tensor.dtype,
176        init=True)
177
178  return nest.map_structure(_EmptyAccForTensor, struct_template)
179
180
181def _EmptyLike(struct):
182  """Creates a set of empty initialized tensors.
183
184  Args:
185    struct: A structure of tensors.
186
187  Returns:
188    A struct of tensors. Each tensor has the same shape and dtype as
189    its corresponding tensor in `struct`. And each tensor is initialized.
190  """
191  return nest.map_structure(
192      lambda x: inplace_ops.empty_like(x, init=True), struct)
193
194
195def _Add(struct_x, struct_y):
196  """Adds tensors in `struct_x` with respective tensors in `struct_y`.
197
198  Args:
199    struct_x: A struct of tensors.
200    struct_y: A struct of tensors congruent to `struct_x`.
201
202  Returns:
203    A struct of tensors. Each element of the returned value
204  equals `x + y`, with corresponding values in `struct_x` and `struct_y`.
205  """
206  list_x = nest.flatten(struct_x)
207  list_y = nest.flatten(struct_y)
208  z = []
209  for x, y in zip(list_x, list_y):
210    z += [math_ops.add(x, y)]
211  return nest.pack_sequence_as(struct_x, z)
212
213
214def _Dtypes(struct):
215  """Returns all tensors' data types in a list."""
216  return [x.dtype for x in nest.flatten(struct)]
217
218
219def _ConvertNoneGradientToZeros(xs, dxs):
220  """Sanitize dxs so that None becomes zeros appropriately.
221
222  Args:
223    xs: A list of tensors.
224    dxs: A list of tensors. dxs[i] corresponds to xs[i]'s gradient.
225
226  Returns:
227    A structure same as `dxs` with `None` replaced by a zero tensor.
228  """
229  list_xs = nest.flatten(xs)
230  list_dxs = nest.flatten(dxs)
231
232  # If x does not get any backprop-ed gradient, propagate zeros.
233  rets = []
234  for (x, dx) in zip(list_xs, list_dxs):
235    if dx is None:
236      rets.append(array_ops.zeros_like(x))
237    else:
238      rets.append(dx)
239
240  return nest.pack_sequence_as(dxs, rets)
241
242
243# All structures are flattened for use internally. This is for simplicity
244# and also to use the Defun construct.
245# In the forward pass (inference), the computation is structured as follows.
246# Forward: [gradient = _Recurrent.Grad]
247#   Flatten structures, create accumulators.
248#   for t = 0..max_input_length:
249#     Defun ForwardLoopBody:
250#       Defun Fwd: flatten/pack around cell_fn
251#       state1 = Fwd(inputs[t], state0)
252#       acc_state += [state1]
253#   Pack structures.
254# During the backward pass (backpropping the gradient from the last time
255# step to the first, through the structure), the computation is structured
256# as follows.
257# Grad:
258#   Flatten structures.
259#   Defun Backward:
260#     Create create accumulated derivatives: d_theta, d_inputs, d_acc_state.
261#     Regarding the note at the top of the file, there is only one accumulator
262#     for d_theta accumulated over the whole sequence.
263#     for t = max_input_length -1..0:
264#       Defun BackwardLoopBody:
265#         Retrieve acc_state[t] computed in the forward pass.
266#         Defun Bak: flatten/back around cell_fn_grad.
267#         d_state1 is d_state0 from previous step (ie next time).
268#         d_acc_state[dev_t] += d_state1
269#         d_theta_t, d_state0, d_inputs_t, = Bak()
270#         d_inputs[dev_t] += d_inputs
271#         d_theta += d_theta_t
272#         d_acc_state[t] += d_state1
273#   Pack structures and return.
274class _Recurrent(object):
275  """A helper class to construct a recurrent neural net."""
276
277  def __init__(self,
278               cell_fn,
279               cell_grad,
280               theta,
281               state0,
282               inputs,
283               max_input_length,
284               extras,
285               use_tpu,
286               aligned_end=False):
287    """RNN helper class.
288
289    Args:
290      cell_fn: A python function, which computes:
291         state1, extras = cell_fn(theta, state0, inputs[t, :])
292      cell_grad: A python function which computes:
293         dtheta, dstate0, dinputs[t, :] = cell_grad(
294           theta, state0, inputs[t, :], extras, dstate1)
295      theta: weights. A structure of tensors.
296      state0: initial state. A structure of tensors.
297      inputs: inputs. A structure of tensors.
298      max_input_length: None, or the maximum effective length of the input over
299        all batches. A scalar tensor.
300      extras: A structure of tensors. The 2nd return value of every
301        invocation of cell_fn is a structure of tensors with matching keys
302        and shapes of this `extras`.
303      use_tpu: A boolean indicating whether the computation is mean to
304        run on a TPU.
305      aligned_end: A boolean indicating whether the sequence is aligned at
306        the end.
307    """
308    self._theta = theta
309    self._state = state0
310    self._inputs = inputs
311    self._max_input_length = self._MaybeComputeMaxInputLength(
312        inputs, max_input_length)
313    self._cell_fn = cell_fn
314    self._cell_grad = cell_grad
315    self._extras = extras
316    self._aligned_end = aligned_end
317
318    # pylint: disable=unbalanced-tuple-unpacking
319
320    # NOTE: TF Function (Fwd, Bak, ForwardLoopBody, BackwardLoopBody,
321    # Forward and Backward defined below) simply takes a list of
322    # Tensors and returns a list of Tensors. When we pass in a
323    # structure (a list of structures of Tensors), we use _Flatten to
324    # convert the structure into a list of tensor. Conversely, the
325    # following code often uses _Pack to formulate a structure from a
326    # list of tensors based on a "template".
327
328    # Wraps cell_fn in a TF Function:
329    #    state1 = cell_fn(theta, state0, inputs)
330    fwd_sig = [self._theta, self._state, self._inputs]
331
332    compiled = use_tpu
333    noinline = not compiled
334    dev_t_type = dtypes.int32 if use_tpu else dtypes.int64
335
336    @function.Defun(*_Dtypes(fwd_sig))
337    def Fwd(*args):
338      (theta, state0, inputs) = _Pack(args, fwd_sig)
339      state1, extras = self._cell_fn(theta, state0, inputs)
340      assert not function.get_extra_args(), (
341          'cell_fn is not pure with extra args: %s.' %
342          (function.get_extra_args()))
343      _AssertIsCompatible(state1, self._state)
344      _AssertIsCompatible(extras, self._extras)
345      return _Flatten([state1, extras])
346
347    # Wraps cell_fn in a TF Function as a for-loop's body.
348    #
349    # The loop state is composed of:
350    #  t: The loop variable. Timestep id.
351    #  dev_t: The loop variable mirrored on the device.
352    #  theta: the recurrent net's weights.
353    #  state0: the previous recurrent state.
354    #  inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
355    #  acc_state: Each timestep's computed new state is also stashed into
356    #    acc_state.
357    #  acc_extras: Each timestep's computed extras is stashed into acc_extras
358    fwdloop_sig = [
359        self._theta, self._state, self._inputs, self._state, self._extras
360    ]
361
362    @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(fwdloop_sig))
363    def ForwardLoopBody(*args):
364      """The body of forward loop."""
365      t, dev_t = args[0], args[1]
366      (theta, state0, inputs, acc_state, acc_extras) = _Pack(
367          args[2:], fwdloop_sig)
368      inputs_t = _Index(inputs, t)  # external input at time step t.
369      fwd = Fwd(*_Flatten([theta, state0, inputs_t]))
370      state1, extras = _Pack(fwd, [self._state, self._extras])
371      # Saves state1 and extras in their accumulators.
372      acc_state = _Update(acc_state, state1, dev_t)
373      acc_extras = _Update(acc_extras, extras, dev_t)
374
375      return [math_ops.add(dev_t, 1)] + _Flatten(
376          [theta, state1, inputs, acc_state, acc_extras])
377
378    def Grad(op, *args):
379      """The python grad function for the Forward function."""
380
381      # NOTE: tf.gradient backprops None for int32/int64 while zeros
382      # for float32/float64. For consistency, we always backprop
383      # zeros.
384      args = list(args)
385      for i, dy in enumerate(args):
386        if dy is None:
387          args[i] = array_ops.zeros_like(op.outputs[i])
388      # TODO(drpng): getting the extra state here?
389      op_inputs = [x for x in op.inputs]
390      op_struct = [
391          self._theta, self._state, self._inputs, self._max_input_length,
392          self._extras
393      ]
394      (theta, state0, inputs, max_input_length, _) = _Pack(op_inputs, op_struct)
395      # acc_state and acc_extras are computed by the Forward pass and
396      # needed by the Backward pass.
397      acc_state, _, acc_extras = _Pack([x for x in op.outputs],
398                                       [self._state, self._state, self._extras])
399
400      # Forward computes acc_state, the final state and
401      # acc_extras. tf.gradients gives us their gradients w.r.t. the
402      # final loss. Because acc_extras are not exposed by Compute(),
403      # it has no gradients w.r.t. the final loss (i.e., by
404      # construction, it must be zeros).
405      d_acc_state, d_state1, _ = _Pack(args,
406                                       [self._state, self._state, self._extras])
407      return Backward(*_Flatten([
408          theta, state0, inputs, max_input_length, acc_state, acc_extras,
409          d_acc_state, d_state1
410      ]))
411
412    # Forward calls ForwardLoopBody n times. Each time computes one
413    # time step of the recurrent net.
414    forward_sig = [
415        self._theta, self._state, self._inputs, self._max_input_length,
416        self._extras
417    ]
418
419    @function.Defun(
420        *_Dtypes(forward_sig), python_grad_func=Grad, noinline=noinline)
421    def Forward(*args):
422      """Forward pass of the recurrent net."""
423      theta, state0, inputs, max_input_length, extras = _Pack(args, forward_sig)
424
425      slen_dim = _SeqLenDim(inputs)
426
427      # Creates accumulators for state0 and extras.
428      acc_state = _EmptyAcc(slen_dim, state0)
429      acc_extras = _EmptyAcc(slen_dim, extras)
430
431      t = slen_dim - max_input_length if self._aligned_end else 0
432      dev_t = math_ops.cast(t, dtypes.int32) if use_tpu else math_ops.cast(
433          t, dtypes.int64)
434      run = functional_ops.For(
435          start=t,
436          limit=slen_dim if self._aligned_end else max_input_length,
437          delta=1,
438          inputs=[dev_t] + _Flatten(
439              [theta, state0, inputs, acc_state, acc_extras]),
440          body=ForwardLoopBody,
441          rewrite_with_while=compiled)
442      _, state1, _, acc_state, acc_extras = _Pack(
443          run[1:],
444          [self._theta, self._state, self._inputs, self._state, self._extras])
445
446      return _Flatten([acc_state, state1, acc_extras])
447
448    # The per-step backward computes:
449    #    d_theta, d_state0, d_inputs = cell_grad(
450    #        theta, state0, inputs, extras, d_state1)
451    # where d_state1 is the backprop-ed gradient for state1, and
452    # extras is the computed by the forward step to facilitate the
453    # backward step.
454    bak_sig = [
455        self._theta, self._state, self._inputs, self._extras, self._state
456    ]
457
458    @function.Defun(*_Dtypes(bak_sig))
459    def Bak(*args):
460      """Backward step."""
461      (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig)
462      (dtheta, dstate0, dinputs) = self._cell_grad(theta, state0, inputs,
463                                                   extras, d_state1)
464      assert not function.get_extra_args(), (
465          'cell_grad is not pure with extra args: %s.' %
466          (function.get_extra_args()))
467      _AssertIsCompatible(dtheta, self._theta)
468      _AssertIsCompatible(dstate0, self._state)
469      _AssertIsCompatible(dinputs, self._inputs)
470      return _Flatten(
471          _ConvertNoneGradientToZeros([theta, state0, inputs],
472                                      [dtheta, dstate0, dinputs]))
473
474    # Define defuns used by a functional_ops.If in BackwardLoopBody.
475    state_if_sig = [self._state, self._state]
476
477    @function.Defun(*_Dtypes(state_if_sig))
478    def ReturnOrigState0(*args):
479      """Returns original state0 from inputs."""
480      (_, orig_state0) = _Pack(args, state_if_sig)
481      return nest.flatten(orig_state0)
482
483    @function.Defun(*_Dtypes(state_if_sig))
484    def ReturnAccState(*args):
485      """Returns acc_state[t-1] from inputs."""
486      (acc_state, _) = _Pack(args, state_if_sig)
487      return nest.flatten(acc_state)
488
489    # Wraps cell_grad gradient function in a TF Function as a
490    # for-loop's body for the Backward pass.
491    #
492    # The loop state is composed of:
493    #  t: The loop variable. Timestep id.
494    #  state0: the initial state for the entire backward loop.
495    #  dev_t: The loop variable mirrored on the device.
496    #  theta: the recurrent net's weights.
497    #  inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
498    #  acc_state: Each timestep's computed new state was stashed into
499    #    acc_state by the Forward pass.
500    #  acc_extras: Each timestep's computed extras was stashed into
501    #    acc_extras by the Forward pass.
502    #  d_theta: All timestep's gradient for theta is accumulated (added) into
503    #      d_theta.
504    #  d_state1: The backprop-ed gradient for the new stated computed by
505    #      timestep t.
506    #  d_inputs: d_inputs[t, :] is populated by the backward time step t.
507    #  d_acc_state: The backprop-ed gradient for acc_state.
508    bakloop_sig = [
509        self._theta, self._state, self._inputs, self._state, self._extras,
510        self._theta, self._state, self._inputs, self._state
511    ]
512
513    @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(bakloop_sig))
514    def BackwardLoopBody(*args):
515      """Backward loop body function."""
516      t, dev_t = args[0], args[1]
517      (theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state1,
518       d_inputs, d_acc_state) = _Pack(args[2:], bakloop_sig)
519
520      # The input recurrent state for time step t is previous time step's
521      # output, or the original state0 when on time step 0.
522      state_from_acc = _Index(acc_state, math_ops.maximum(0, t - 1))
523      state0 = functional_ops.If(
524          math_ops.equal(t, array_ops.constant(0, dtypes.int32)),
525          _Flatten([state_from_acc, orig_state0]), ReturnOrigState0,
526          ReturnAccState)
527      state0 = nest.pack_sequence_as(orig_state0, state0)
528
529      # The external inputs for time step t.
530      inputs_t = _Index(inputs, t)
531      # The extras for time step t.
532      extras_t = _Index(acc_extras, t)
533
534      d_state1 = _Add(_Index(d_acc_state, t), d_state1)
535      (d_theta_t, d_state0, d_inputs_t) = _Pack(
536          Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])),
537          [self._theta, self._state, self._inputs])
538      d_theta = _Add(d_theta, d_theta_t)
539      d_inputs = _Update(d_inputs, d_inputs_t, dev_t)
540      return [math_ops.subtract(dev_t, 1)] + _Flatten([
541          theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state0,
542          d_inputs, d_acc_state
543      ])
544
545    # Backward calls BackwardLoopBody n times.  Each time computes the backprop
546    # for one time step of the recurrent net.
547    backward_sig = [
548        self._theta, self._state, self._inputs, self._max_input_length,
549        self._state, self._extras, self._state, self._state
550    ]
551
552    @function.Defun(*_Dtypes(backward_sig), noinline=noinline)
553    def Backward(*args):
554      """Backward pass for the recurrent net."""
555      # theta, state0, inputs are Forward's inputs.
556      # acc_state is the accumulated 1st output of Forward.
557      # acc_extras is the accumulated 2nd output of Forward.
558      # d_acc_state is the gradient for acc_state.
559      # d_state1 is the gradient for the final state computed by Forward.
560      (theta, state0, inputs, max_input_length, acc_state, acc_extras,
561       d_acc_state, d_state1) = _Pack(args, backward_sig)
562
563      # Accumulators for gradients.
564      d_theta = _EmptyLike(theta)
565      d_inputs = _EmptyLike(inputs)
566
567      slen_dim = _SeqLenDim(inputs)
568
569      # Loop backwards. Note the loop's limit is open-ended, so goes through
570      # t=0.
571      t = slen_dim - 1 if self._aligned_end else max_input_length - 1
572      dev_t = math_ops.cast(t, dtypes.int32) if use_tpu else math_ops.cast(
573          t, dtypes.int64)
574      limit = slen_dim - max_input_length - 1 if self._aligned_end else -1
575      run = functional_ops.For(
576          start=t,
577          limit=limit,
578          delta=-1,
579          inputs=[dev_t] + _Flatten([
580              theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1,
581              d_inputs, d_acc_state
582          ]),
583          body=BackwardLoopBody,
584          rewrite_with_while=compiled)
585
586      (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0,
587       d_inputs, d_acc_state) = _Pack(run[1:], bakloop_sig)
588
589      d_max_input_length = array_ops.constant(0, dtype=max_input_length.dtype)
590      return _Flatten(
591          [d_theta, d_state0, d_inputs, d_max_input_length, acc_extras])
592
593    self._forward = Forward
594
595  def _MaybeComputeMaxInputLength(self, inputs, max_input_length):
596    if max_input_length is not None:
597      return max_input_length
598    return math_ops.reduce_max(array_ops.shape(nest.flatten(inputs)[0])[0])
599
600  def Compute(self):
601    return _Pack(
602        self._forward(*_Flatten([
603            self._theta, self._state, self._inputs, self._max_input_length,
604            self._extras
605        ])), [self._state, self._state, self._extras])[:2]
606
607
608def _GetCellGrad(cell_fn, cell_grad):
609  """Returns the gradient function for cell_fn.
610
611  Args:
612    cell_fn: The recurrent neural net's cell function.
613    cell_grad: If not None, cell_fn's gradient function.
614
615  Returns:
616    Returns cell_grad if not None. Otherwise, assume cell_fn is a python
617    function representing the recurrent neural net's cell function, i.e.,
618      cell_fn: (theta, state0, inputs) -> (state1, extra)
619    returns its default gradient python function, i.e.,
620      cell_grad: (theta, state0, inputs, extras, dstate1) -> (
621                  dtheta, dstate0, dinputs)
622  """
623
624  if cell_grad:
625    return cell_grad
626
627  def CellGrad(theta, state0, inputs, extras, dstate1):
628    """Default gradient function for cell_fn."""
629    # NOTE: The default grad function recomputes the forward
630    # function and does not take advantage of 'extras' returned by
631    # the forward function.
632    del extras
633    state1, extras = cell_fn(theta, state0, inputs)
634    ys = _Flatten([state1])
635    xs = _Flatten([theta, state0, inputs])
636    grad_ys = _Flatten([dstate1])
637    grads = gradients_impl.gradients(ys=ys, xs=xs, grad_ys=grad_ys)
638    return _ConvertNoneGradientToZeros([theta, state0, inputs],
639                                       _Pack(grads, [theta, state0, inputs]))
640
641  return CellGrad
642
643
644def _IsSingleTimeStep(inputs, max_input_length):
645  """Returns True only if the time dimension of inputs is 1."""
646  if not isinstance(max_input_length, ops.Tensor):
647    return max_input_length == 1
648  for x in nest.flatten(inputs):
649    if x.shape.dims is None or x.shape[0].value != 1:
650      return False
651  return True
652
653
654def Recurrent(theta,
655              state0,
656              inputs,
657              cell_fn,
658              cell_grad=None,
659              extras=None,
660              max_input_length=None,
661              use_tpu=False,
662              aligned_end=False):
663  """Compute a recurrent neural net.
664
665  Roughly, Recurrent() computes the following:
666    state = state0
667    for t in inputs' sequence length:
668      state = cell_fn(theta, state, inputs[t, :])
669      accumulate_state[t, :] = state
670    return accumulate_state, state
671
672  theta, state, inputs are all structures of tensors.
673
674  inputs[t, :] means taking a slice out from every tensor in the inputs.
675
676  accumulate_state[t, :] = state means that we stash every tensor in
677  'state' into a slice of the corresponding tensor in
678  accumulate_state.
679
680  cell_fn is a python callable computing (building up a TensorFlow
681  graph) the recurrent neural network's one forward step. Two calls of
682  cell_fn must describe two identical computations.
683
684  By construction, Recurrent()'s backward computation does not access
685  any intermediate values computed by cell_fn during forward
686  computation. We may extend Recurrent() to support that by taking a
687  customized backward function of cell_fn.
688
689  Args:
690    theta: weights. A structure of tensors.
691    state0: initial state. A structure of tensors.
692    inputs: inputs. A structure of tensors.
693    cell_fn: A python function, which computes:
694      state1, extras = cell_fn(theta, state0, inputs[t, :])
695    cell_grad: A python function which computes:
696      dtheta, dstate0, dinputs[t, :] = cell_grad(
697        theta, state0, inputs[t, :], extras, dstate1)
698    extras: A structure of tensors. The 2nd return value of every
699      invocation of cell_fn is a structure of tensors with matching keys
700      and shapes of  this `extras`.
701    max_input_length: maximum length of effective input. This is used to
702      truncate the computation if the inputs have been allocated to a
703      larger size. A scalar tensor.
704    use_tpu: whether or not we are on TPU.
705    aligned_end: A boolean indicating whether the sequence is aligned at
706      the end.
707
708  Returns:
709    accumulate_state and the final state.
710  """
711  if cell_grad is None and _IsSingleTimeStep(inputs, max_input_length):
712    # The seqlen length is staticly known as 1. Hence, we just need to
713    # call cell_fn once without putting it into a loop.
714    inputs = nest.map_structure(lambda x: array_ops.squeeze(x, axis=0), inputs)
715    state1, _ = cell_fn(theta, state0, inputs)
716    acc_state = nest.map_structure(lambda x: array_ops.expand_dims(x, axis=0),
717                                   state1)
718    return acc_state, state1
719
720  # If cell_grad is not given, derives the gradient function from
721  # cell_fn.
722  cell_grad = _GetCellGrad(cell_fn, cell_grad)
723
724  if extras is None:
725    # Derives 'extras' so that we can allocate extras' accumulator.
726    _, extras = cell_fn(theta, state0, _Index(inputs, 0))
727    extras = nest.map_structure(array_ops.zeros_like, extras)
728  else:
729    _, actual = cell_fn(theta, state0, _Index(inputs, 0))
730    _AssertIsCompatible(extras, actual)
731
732  return _Recurrent(
733      cell_fn=cell_fn,
734      cell_grad=cell_grad,
735      theta=theta,
736      state0=state0,
737      inputs=inputs,
738      max_input_length=max_input_length,
739      extras=extras,
740      use_tpu=use_tpu,
741      aligned_end=aligned_end).Compute()
742