• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for managing state of v1 control flow for computing gradients."""
16
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_util
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import control_flow_util as util
24from tensorflow.python.ops import control_flow_v2_func_graphs
25from tensorflow.python.ops import default_gradient
26from tensorflow.python.ops import gen_data_flow_ops
27from tensorflow.python.ops import gen_resource_variable_ops
28from tensorflow.python.ops import resource_variable_ops
29
30# pylint: disable=protected-access
31
32
33def _GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
34  """Calculate a max_size for use by stack ops inside an XLA while_loop.
35
36  Args:
37    value: The value inside the while_loop forward context.  Used for printing
38      error messages.
39    while_ctxt: The forward context inside which value resides.  This does not
40      always match the value's immediate context, as `value` may be inside e.g.
41      a cond context inside the while_loop.
42
43  Returns:
44    A tensor containing the `max_size` to feed to a Stack initializer.
45
46  Raises:
47    ValueError: If `value` is nested inside a `while_loop` that either
48      lacks a `maximum_iterations` parameter, or the `maximum_iterations`
49      parameter:
50
51        - is inside a `while_loop` that is a parent of the calling context, and
52        - cannot be evaluated at graph build time to a constant.
53  """
54  value_name = value.name
55  # curr_ctxt is the context that tf.gradients was called in.
56  curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
57
58  curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else ""
59  max_size = constant_op.constant(1)
60
61  # Loop through all containing while contexts between value and the
62  # current context, multiplying together each context's
63  # max_iterations to get the maximum stack size.
64  while while_ctxt not in (None, curr_ctxt):
65    max_iter = while_ctxt.maximum_iterations
66    if max_iter is None:
67      raise ValueError(
68          "Cannot create a gradient accumulator for tensor '%s' inside "
69          "XLA while_loop because maximum_iterations was not passed to "
70          "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
71
72    # pylint: disable=protected-access
73    max_iter_ctxt = max_iter.op._get_control_flow_context()
74    # pylint: enable=protected-access
75
76    # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use.
77    if util.IsContainingContext(curr_ctxt, max_iter_ctxt):
78      max_size *= max_iter
79    else:
80      # We cannot use max_iter because it's defined in a nested while
81      # or cond context, so will fail if we try to use it as input to
82      # any ops in curr_ctxt (e.g. max_size or the final accumulator
83      # stack). Attempt to get a constant value out to use instead.
84      const_max_iter = tensor_util.constant_value(max_iter)
85      if const_max_iter is None:
86        raise ValueError(
87            "Cannot create a gradient accumulator for tensor '%s' inside XLA "
88            "while_loop. maximum_iterations tensor '%s' for while_loop context "
89            "'%s' must be statically known (e.g. a constant value or known "
90            "shape dimension), or be defined at or outside the while loop "
91            "context '%s' (currently defined in '%s')." %
92            (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
93             max_iter_ctxt.name))
94      max_size *= const_max_iter
95
96    # Find the next outer WhileContext (or stop if we reach the
97    # tf.gradient's context).
98    while_ctxt = util.GetContainingWhileContext(
99        while_ctxt.outer_context, stop_ctxt=curr_ctxt)
100
101  return max_size
102
103
104class _GradLoopState:
105  """The state used for constructing the gradient graph for a while loop.
106
107  We create a _GradLoopState for each while loop in forward and its
108  corresponding while loop in backprop. This gives us access to both
109  the forward and the backprop WhileContexts.
110
111  During the construction of gradient graph, any time when we detect
112  a forward value that is needed for backprop, we create a history
113  accumulator and add it to `history_map`. Any time when we backprop
114  a loop switch op (in _SwitchGrad), we add the grad merge op in
115  `switch_map`.
116  """
117
118  def __init__(self, forward_ctxt, outer_grad_state):
119    # The grad loop state for the outer while loop.
120    self._outer_grad_state = None
121
122    # The while loop context for forward.
123    self._forward_context = None
124
125    # The loop counter added by AddForwardLoopCounter. It is the value
126    # of the loop counter for the next iteration.
127    self._forward_index = None
128
129    # A sync op for forward.
130    self._forward_sync = None
131
132    # The while loop context for backprop.
133    self._grad_context = None
134
135    # The loop counter added by AddBackpropLoopCounter. It is the value
136    # of the loop counter for the current iteration.
137    self._grad_index = None
138
139    # A sync op for backprop.
140    self._grad_sync = None
141
142    # Information needed by backprop.
143    self._history_map = {}
144    self._switch_map = {}
145    self._unused_exits = []
146    self._deferred_exits = []
147    self._forward_loop_exits = list(forward_ctxt.loop_exits)
148    self._pending_exits_count = len(forward_ctxt.loop_exits)
149
150    self._outer_grad_state = outer_grad_state
151    if outer_grad_state:
152      outer_forward_ctxt = outer_grad_state.forward_context
153    else:
154      if not hasattr(forward_ctxt, "outer_context"):
155        raise ValueError("Failed to call gradients on a while loop without"
156                         "properly serializing graph via MetaGraphDef")
157      outer_forward_ctxt = forward_ctxt.outer_context
158
159    # Add the forward loop counter.
160    with forward_ctxt._graph.as_default():  # pylint: disable=protected-access
161      if outer_forward_ctxt:
162        outer_forward_ctxt.Enter()
163      cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
164      if outer_forward_ctxt:
165        outer_forward_ctxt.Exit()
166    self._forward_context = forward_ctxt
167    self._forward_index = forward_index
168
169    # Add the backprop WhileContext, and the backprop loop counter.
170    if outer_grad_state:
171      # This is a nested loop. Remember the iteration counts for each
172      # execution of this inner loop.
173      outer_forward_ctxt.AddName(cnt.name)
174      history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
175
176      outer_grad_ctxt = outer_grad_state.grad_context
177      outer_grad_ctxt.Enter()
178      self._grad_context = control_flow_ops.WhileContext(
179          maximum_iterations=forward_ctxt.maximum_iterations,
180          parallel_iterations=forward_ctxt.parallel_iterations,
181          back_prop=forward_ctxt.back_prop,
182          swap_memory=forward_ctxt.swap_memory,
183          name=forward_ctxt.name,
184          grad_state=self)
185      real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
186      self._grad_index = self._grad_context.AddBackpropLoopCounter(
187          real_cnt, outer_grad_state)
188      outer_grad_ctxt.Exit()
189    else:
190      if outer_forward_ctxt:
191        outer_forward_ctxt.Enter()
192      self._grad_context = control_flow_ops.WhileContext(
193          maximum_iterations=forward_ctxt.maximum_iterations,
194          parallel_iterations=forward_ctxt.parallel_iterations,
195          back_prop=forward_ctxt.back_prop,
196          swap_memory=forward_ctxt.swap_memory,
197          name=forward_ctxt.name,
198          grad_state=self)
199      self._grad_index = self._grad_context.AddBackpropLoopCounter(
200          cnt, outer_grad_state)
201      if outer_forward_ctxt:
202        outer_forward_ctxt.Exit()
203
204  @property
205  def outer_grad_state(self):
206    """The grad loop state for outer loop."""
207    return self._outer_grad_state
208
209  @property
210  def forward_context(self):
211    """The while loop context for forward."""
212    return self._forward_context
213
214  @property
215  def forward_index(self):
216    """The loop index of forward loop."""
217    return self._forward_index
218
219  @property
220  def forward_sync(self):
221    """A control trigger node for synchronization in the forward loop.
222
223    One main use is to keep the push ops of a stack executed in the
224    iteration order.
225    """
226    if self._forward_sync is None:
227      with ops.control_dependencies(None):
228        self._forward_sync = control_flow_ops.control_trigger(name="f_sync")
229      self._forward_sync._set_control_flow_context(self._forward_context)
230      self._forward_index.op._add_control_input(self._forward_sync)
231    return self._forward_sync
232
233  @property
234  def grad_context(self):
235    """The corresponding WhileContext for gradient."""
236    return self._grad_context
237
238  @property
239  def grad_index(self):
240    """The loop index of backprop loop."""
241    return self._grad_index
242
243  @property
244  def grad_sync(self):
245    """A control trigger node for synchronization in the grad loop.
246
247    One main use is to keep the pop ops of a stack executed in the
248    iteration order.
249    """
250    if self._grad_sync is None:
251      with ops.control_dependencies(None):
252        self._grad_sync = control_flow_ops.control_trigger(name="b_sync")
253      self._grad_sync._set_control_flow_context(self._grad_context)
254      self._grad_index.op._add_control_input(self._grad_sync)
255      if self._grad_context.outer_context:
256        self._grad_context.outer_context.AddInnerOp(self._grad_sync)
257    return self._grad_sync
258
259  @property
260  def history_map(self):
261    """The map that records all the tensors needed for backprop."""
262    return self._history_map
263
264  @property
265  def switch_map(self):
266    """The map that records all the Switch ops for the while loop."""
267    return self._switch_map
268
269  @property
270  def unused_exits(self):
271    """The list of "unused" exits."""
272    return self._unused_exits
273
274  @property
275  def deferred_exits(self):
276    """The list of "deferred" exits."""
277    return self._deferred_exits
278
279  @property
280  def forward_loop_exits(self):
281    """The list of exits of the forward loop."""
282    return self._forward_loop_exits
283
284  @property
285  def pending_exits_count(self):
286    """The number of exits we expect to see but haven't."""
287    return self._pending_exits_count
288
289  @pending_exits_count.setter
290  def pending_exits_count(self, cnt):
291    """Set the pending count to cnt."""
292    self._pending_exits_count = cnt
293
294  def AddForwardAccumulator(self, value, dead_branch=False):
295    """Add an accumulator for each forward tensor that is needed in backprop.
296
297    This is added to the forward loop at the first time when a tensor
298    in the forward loop is used by backprop gradient computation loop.
299    We create an accumulator that accumulates the value of tensor at each
300    iteration. Called in the control flow context where gradients() is called.
301
302    The pseudocode is:
303    ```
304      acc = stack();
305      while (_pivot) {
306        acc = stack_push(acc, value);
307      }
308    ```
309
310    We make sure that the stack push op in one iteration is executed before
311    next iteration. This is achieved by adding a control edge from
312    `forward_index.op.inputs[0].op` to the push op, and another control
313    edge from the push op to either `forward_index.op` or `forward_sync`.
314
315    Args:
316      value: The source tensor in forward that is to be accumulated.
317      dead_branch: True iff the tensor is on a dead branch of a cond.
318
319    Returns:
320      The stack that contains the accumulated history of the tensor.
321
322    Raises:
323      TypeError: For internal errors involving the value condition context.
324      ValueError: If `value` is inside a XLA scope and a valid max size
325        for the stack can't be found.
326    """
327    # curr_ctxt is the context that tf.gradients was called in.
328    with self._forward_index.graph.as_default():
329      curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
330      with ops.control_dependencies(None):
331        if curr_ctxt:
332          curr_ctxt.Enter()
333        with ops.colocate_with(value):
334          # We only need to pass maximum_iterations to the stack if
335          # we're inside an XLA context.
336          if not util.IsInXLAContext(value.op):
337            max_size = constant_op.constant(-1, dtypes.int32)
338          else:
339            max_size = _GetMaxSizeFromNestedMaximumIterations(
340                value, self.forward_context)
341          acc = gen_data_flow_ops.stack_v2(
342              max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
343        if curr_ctxt:
344          curr_ctxt.Exit()
345
346        # Make acc available in the forward context.
347        enter_acc = self.forward_context.AddValue(acc)
348
349        # Add the stack_push op in the context of value.op.
350        swap_enabled = self.forward_context.swap_memory
351        value_ctxt = util.GetOutputContext(value.op)
352        if value_ctxt == self.forward_context:
353          # value is not nested in the forward context.
354          self.forward_context.Enter()
355          push = gen_data_flow_ops.stack_push_v2(
356              enter_acc, value, swap_memory=swap_enabled)
357          self.forward_context.Exit()
358          # Protect stack push and order it before forward_index.
359          self.forward_index.op._add_control_input(push.op)
360        else:
361          # value is in a cond context within the forward context.
362          if not isinstance(value_ctxt, control_flow_ops.CondContext):
363            raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
364          if dead_branch:
365            # The special case for creating a zero tensor for a dead
366            # branch of a switch. See _ControlFlowState.ZerosLikeV1WhileLoop().
367            value_ctxt.outer_context.Enter()
368            push = gen_data_flow_ops.stack_push_v2(
369                enter_acc, value, swap_memory=swap_enabled)
370            value_ctxt.outer_context.Exit()
371            push.op._set_control_flow_context(value_ctxt)
372          else:
373            value_ctxt.Enter()
374            push = gen_data_flow_ops.stack_push_v2(
375                enter_acc, value, swap_memory=swap_enabled)
376            value_ctxt.Exit()
377          # Protect stack push and order it before forward_sync.
378          self.forward_sync._add_control_input(push.op)
379        # Order stack push after the successor of forward_index
380        add_op = self.forward_index.op.inputs[0].op
381        push.op._add_control_input(add_op)
382        return acc
383
384  def AddBackpropAccumulatedValue(self, history_value, value,
385                                  dead_branch=False):
386    """Add the getter for an accumulated value in the grad context.
387
388    This is added to the backprop loop. Called in the grad context to
389    get the value of an accumulated value. The stack pop op must be guarded
390    by the pred of the controlling cond.
391
392    Args:
393      history_value: The history (a stack) of a value.
394      value: The value that is pushed onto the stack.
395      dead_branch: True iff the tensor is on a dead branch of a cond.
396
397    Returns:
398      The current value (the top of the stack).
399    """
400    history_ctxt = history_value.op._get_control_flow_context()
401    # Find the cond context that controls history_value if any.
402    cond_ctxt = None
403    value_ctxt = value.op._get_control_flow_context()
404    while value_ctxt and value_ctxt != history_ctxt:
405      if isinstance(value_ctxt, control_flow_ops.CondContext):
406        cond_ctxt = value_ctxt
407        break
408      value_ctxt = value_ctxt.outer_context
409    with ops.control_dependencies(None):
410      self.grad_context.Enter()
411      if cond_ctxt:
412        # Guard stack pop with a switch if it is controlled by a cond.
413        grad_state = self
414        pred = None
415        while pred is None and grad_state:
416          pred = grad_state.history_map.get(cond_ctxt.pred.name)
417          grad_state = grad_state.outer_grad_state
418        if pred is None:
419          pred = cond_ctxt.pred
420        branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
421        history_value = control_flow_ops._SwitchRefOrTensor(
422            history_value, pred)[branch]
423      pop = gen_data_flow_ops.stack_pop_v2(history_value,
424                                           value.dtype.base_dtype)
425      pop.set_shape(value.get_shape())
426      self.grad_context.Exit()
427    parallel_iterations = self.grad_context.parallel_iterations
428    if parallel_iterations > 1:
429      # All pops are ordered after pivot_for_body and before grad_sync.
430      self.grad_sync._add_control_input(pop.op)
431    return pop
432
433  def GetRealValue(self, value):
434    """Get the real value of `value`.
435
436    If backprop "uses" a value produced by forward inference, an accumulator
437    is added in the forward loop to accumulate its values.  We use the
438    accumulated value. This method must be called in the grad loop context.
439    `value` must be in forward and needed for backprop.
440
441    Args:
442      value: A tensor to be captured.
443
444    Returns:
445      The same tensor obtained from the saved history.
446    """
447    assert value.op.type not in ["Variable", "VariableV2"]
448    real_value = self._history_map.get(value.name)
449    if real_value is None:
450      cur_value = value
451      cur_grad_state = self
452      while True:
453        enter_op = util.GetLoopConstantEnter(cur_value)
454        if enter_op:
455          # Special case: cur_value comes from a constant Enter node.
456          cur_value = enter_op.inputs[0]
457          cur_grad_state = cur_grad_state.outer_grad_state
458          if cur_grad_state is None:
459            # We are now outside all nested loops for this gradient(),
460            # so `value` is a loop invariant and there is no need to
461            # save the history of value. Just make cur_value to enter
462            # the right control flow context.
463            real_value = self._grad_context.AddValue(cur_value)
464            break
465        elif constant_op.is_constant(cur_value):
466          # If the value to be forwarded is a constant, clone the constant in
467          # the gradient loop rather than using a stack.
468          # TODO(phawkins): consider hoisting the constant out of the loop
469          # instead.
470          real_value = constant_op.constant(
471              tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
472          break
473        else:
474          # Record the history of this value in forward_ctxt.
475          self._grad_context.Exit()
476          history_value = cur_grad_state.AddForwardAccumulator(cur_value)
477          self._grad_context.Enter()
478          break
479
480      if real_value is None:
481        # Add the stack pop op in the grad context.
482        real_value = cur_grad_state.AddBackpropAccumulatedValue(
483            history_value, cur_value)
484        if cur_grad_state != self:
485          real_value = self._grad_context.AddValue(real_value)
486      self._history_map[value.name] = real_value
487    return real_value
488
489
490class _ControlFlowState:
491  """Maintain the mapping from the loops to their grad states."""
492
493  def __init__(self):
494    self._map = {}  # maps forward loop context to _GradLoopState
495
496  def GetGradState(self, op, before):
497    """Return the grad state for this op if it's in a forward loop context."""
498    if before and util.IsLoopExit(op):
499      forward_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
500      forward_ctxt = forward_ctxt.outer_context
501      if forward_ctxt:
502        forward_ctxt = forward_ctxt.GetWhileContext()
503    else:
504      forward_ctxt = util.GetWhileContext(op)
505    if forward_ctxt:
506      return self._map.get(forward_ctxt)
507    return None
508
509  def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
510    """Process all the "unused" loop exits.
511
512    The "unused" exits of the loops are added to `unused_exits`. An exit is
513    unused if its pending_count is 0. If there is an exit with real gradient,
514    all these deferred exits will enter the backprop loop with zero gradient.
515    Otherwise, they will enter the backprop loop with None. As an example,
516    people often write:
517
518    ```python
519    v1, _ = tf.while_loop(p, b, [x1, x2])
520    result = gradients(v1, x1)
521    ```
522
523    The exit node for x2 is not included by the betweenness analysis. But we
524    need to backprop x2 if x2 is involved in computing v1.
525
526    Args:
527      pending_count: The number of backprop inputs for every op.
528      to_ops_set: The set of ops for ys in gradients(ys, xs)
529
530    Returns:
531      The set of unused loop exits that we know at this point we need
532      to backprop.
533    """
534    loop_exits = []
535    for grad_state in self._map.values():
536      for y in grad_state.forward_loop_exits:
537        if pending_count[y.op] == 0:
538          grad_state.pending_exits_count -= 1
539          if y.op not in to_ops_set:
540            grad_state.unused_exits.append(y)
541          if grad_state.pending_exits_count == 0:
542            loop_exits.extend(grad_state.unused_exits)
543      # Need to include Enters in backprop for higher-order gradients.
544      for y in grad_state.forward_context.loop_enters:
545        if pending_count[y.op] == 0:
546          pending_count[y.op] = 1
547    return loop_exits
548
549  def EnterGradWhileContext(self, op, before):
550    """Enter the WhileContext for gradient computation."""
551    grad_state = self.GetGradState(op, before)
552    if grad_state:
553      grad_state.grad_context.Enter()
554
555  def ExitGradWhileContext(self, op, before):
556    """Exit the WhileContext for gradient computation."""
557    grad_state = self.GetGradState(op, before)
558    if grad_state:
559      grad_state.grad_context.Exit()
560
561  def AddWhileContext(self, op, between_op_list, between_ops):
562    """Add the grad state for the while loop that op belongs to.
563
564    Note that op is an Exit, and this method must be called in
565    the control flow context where gradients() is called.
566
567    Note that this method modifies `between_op_list` and `between_ops`.
568    """
569    forward_ctxt = util.GetWhileContext(op)
570    grad_state = self._map.get(forward_ctxt)
571    if grad_state is None:
572      # This is a new while loop so create a grad state for it.
573      outer_forward_ctxt = forward_ctxt.outer_context
574      if outer_forward_ctxt:
575        outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
576      outer_grad_state = None
577      if outer_forward_ctxt:
578        outer_grad_state = self._map.get(outer_forward_ctxt)
579      grad_state = _GradLoopState(forward_ctxt, outer_grad_state)
580      self._map[forward_ctxt] = grad_state
581
582      # We need to include all exits of a loop for backprop.
583      for loop_exit in grad_state.forward_loop_exits:
584        if loop_exit.op not in between_ops:
585          between_ops.add(loop_exit.op)
586          between_op_list.append(loop_exit.op)
587
588  def ZerosLikeForExit(self, val):
589    """Create zeros_like gradient for a loop exit.
590
591    If the result of a loop variable is not used but is involved in
592    computing the result of some needed loop variable, we create a
593    zero-valued tensor that is fed as gradient for the Exit node of that
594    loop variable. Note that val.op is an Exit, and this method must be
595    called in the control flow context where gradients() is called.
596
597    Args:
598      val: The output tensor of an Exit op.
599
600    Returns:
601      A zero tensor of the same shape of val.
602    """
603    val_shape = val.get_shape()
604    forward_ctxt = val.op._get_control_flow_context()
605    outer_forward_ctxt = forward_ctxt.outer_context
606    if outer_forward_ctxt:
607      outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
608    outer_grad_state = None
609    if outer_forward_ctxt:
610      outer_grad_state = self._map.get(outer_forward_ctxt)
611    if outer_grad_state:
612      # This is a nested loop.
613      if val_shape.is_fully_defined():
614        # If the shape is known statically, just create a zero tensor
615        # with the right shape in the right context.
616        outer_grad_state.grad_context.Enter()
617        result = array_ops.zeros(val_shape.dims, val.dtype)
618        outer_grad_state.grad_context.Exit()
619      else:
620        # Only the shape of value is needed for backprop.
621        forward_ctxt.outer_context.Enter()
622        shape = array_ops.shape_internal(val, optimize=False)
623        forward_ctxt.outer_context.Exit()
624        # Save the shape to a stack.
625        history_shape = outer_grad_state.AddForwardAccumulator(shape)
626        # Get the shape back from the stack.
627        outer_grad_ctxt = outer_grad_state.grad_context
628        outer_grad_ctxt.Enter()
629        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
630            history_shape, shape)
631        result = array_ops.zeros(real_shape, val.dtype)
632        outer_grad_ctxt.Exit()
633    else:
634      # This is not a nested loop.
635      if val_shape.is_fully_defined():
636        # If the shape is known statically, just create a zero tensor
637        # with the right shape.
638        result = array_ops.zeros(val_shape.dims, val.dtype)
639      else:
640        result = array_ops.zeros_like(val, optimize=False)
641    return result
642
643  def ZerosLikeV1WhileLoop(self, op, index):
644    """Create zeros_like for the specified output of an op.
645
646    If op is in a while loop that is part of gradients(), this method
647    must be called in its grad loop context.
648
649    Args:
650      op: A tensorflow operation.
651      index: the index for a specific output of the op.
652
653    Returns:
654      A zero tensor of the same shape of op.outputs[index].
655    """
656    if util.IsLoopSwitch(op):
657      return None
658    if op.graph.building_function:
659      # The optimization here is tricky to apply to functions
660      return array_ops.zeros_like(op.outputs[index])
661    dead_branch = util.IsSwitch(op)
662    forward_ctxt = util.GetWhileContext(op)
663    grad_state = self._map.get(forward_ctxt)
664    if grad_state is None:
665      # op is not in a while loop that is part of gradients().
666      return ZerosLike(op, index)
667    op_ctxt = op._get_control_flow_context()
668    val = ops.convert_to_tensor(op.outputs[index], name="tensor")
669    shape = val.get_shape()
670    if shape.is_fully_defined():
671      # If the shape is known statically, just create a zero tensor with
672      # the right shape in the grad loop context.
673      if val.dtype == dtypes.resource:
674        result = array_ops.zeros(
675            resource_variable_ops.variable_shape(val),
676            dtype=default_gradient.get_zeros_dtype(val))
677      else:
678        result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
679      if dead_branch:
680        # op is a cond switch. Guard the zero tensor with a switch.
681        pred = grad_state.history_map.get(op_ctxt.pred.name)
682        branch = op_ctxt.branch
683        result = control_flow_ops._SwitchRefOrTensor(result, pred)[1 - branch]
684    else:
685      # Unknown shape so keep a history of the shape at runtime.
686      if dead_branch:
687        # Need to add a special switch to guard the value.
688        pred = op_ctxt.pred
689        branch = op_ctxt.branch
690        op_ctxt.outer_context.Enter()
691        val = control_flow_ops._SwitchRefOrTensor(op.inputs[0],
692                                                  pred)[1 - branch]
693        zeros_shape = array_ops.shape_internal(val, optimize=False)
694        op_ctxt.outer_context.Exit()
695        val.op._set_control_flow_context(op_ctxt)
696        zeros_shape.op._set_control_flow_context(op_ctxt)
697      else:
698        op_ctxt.Enter()
699        zeros_shape = array_ops.shape_internal(val, optimize=False)
700        op_ctxt.Exit()
701
702      # Add forward accumulator for shape.
703      grad_state.grad_context.Exit()
704      history_zeros_shape = grad_state.AddForwardAccumulator(
705          zeros_shape, dead_branch=dead_branch)
706      grad_state.grad_context.Enter()
707
708      # Create a zero tensor with the right shape.
709      shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
710                                                     zeros_shape, dead_branch)
711      result = array_ops.zeros(shape, val.dtype)
712    return result
713
714  def PostProcessing(self):
715    """Perform postprocessing at the end of gradients().
716
717    We have created the gradient graph at this point. So this function
718    can be used to perform any postprocessing on the gradient graph.
719    We currently perform the following postprocessing:
720      1. Patch the gradient graph if the output of a loop variable
721         doesn't depend on its input.
722    """
723    for _, grad_state in self._map.items():
724      for _, b_merge in grad_state.switch_map.items():
725        if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
726          # The value of this loop variable at iteration i+1 doesn't
727          # depend on its value at iteration i. So use zeros as the
728          # gradients for all iterations > 0.
729          dtype = b_merge.op.inputs[0].dtype
730          shape = b_merge.op.inputs[0].get_shape()
731          # pylint: disable=protected-access
732          if shape.is_fully_defined():
733            grad_state.grad_context.Enter()
734            # Create a zeros and use it for iterations > 0.
735            grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
736            next_grad_val = control_flow_ops._NextIteration(grad_val)
737            grad_state.grad_context.Exit()
738          else:
739            # Create a zeros in the outer grad context.
740            outer_grad_ctxt = grad_state.grad_context.outer_context
741            if outer_grad_ctxt:
742              outer_grad_ctxt.Enter()
743            enter_grad_op = b_merge.op.inputs[0].op
744            enter_grad = enter_grad_op.inputs[0]
745            grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
746            grad_val = array_ops.zeros(grad_shape)
747            if outer_grad_ctxt:
748              outer_grad_ctxt.Exit()
749            # Use the zeros for iterations > 0.
750            grad_state.grad_context.Enter()
751            next_grad_val = control_flow_ops._NextIteration(grad_val)
752            grad_state.grad_context.Exit()
753          b_merge.op._update_input(1, next_grad_val)
754          # pylint: enable=protected-access
755
756
757def MaybeCreateControlFlowState(between_op_list, between_ops,
758                                colocate_gradients_with_ops):
759  """Create the state for all the while loops involved in one gradients().
760
761  We create a _ControlFlowState when there are while loops involved in
762  gradients(). In gradients(), control flow logic is only invoked when
763  the _ControlFlowState is not None.
764
765  Note that this method modifies `between_op_list` and `between_ops`.
766  """
767  loop_state = None
768  for op in between_op_list:
769    if util.IsLoopExit(op):
770      if loop_state is None:
771        loop_state = _ControlFlowState()
772      if colocate_gradients_with_ops:
773        with ops.colocate_with(op):
774          loop_state.AddWhileContext(op, between_op_list, between_ops)
775      else:
776        loop_state.AddWhileContext(op, between_op_list, between_ops)
777  return loop_state
778
779
780def _ZerosLikeV1(op, index):
781  """Branch of ZerosLike for TF1."""
782  val = op.outputs[index]
783  op_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
784  if op_ctxt:
785    # We are in a cond context. Use a switch to create zeros only when needed.
786    pred = op_ctxt.pred
787    branch = op_ctxt.branch
788    switch_val = control_flow_ops.switch(op.inputs[0], pred)[1 - branch]
789    # A op is created along the branch taken as control dependencies are on
790    # the whole op and not on the tensor output.
791    pivot = array_ops.identity(switch_val)
792    if val.dtype == dtypes.resource:
793      with ops.control_dependencies([pivot]):
794        return array_ops.zeros(
795            gen_resource_variable_ops.variable_shape(switch_val),
796            dtype=default_gradient.get_zeros_dtype(val))
797    zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
798    # Ensure ops created within array_ops.zeros are dominated by switch in
799    # cond context.
800    with ops.control_dependencies([pivot]):
801      return array_ops.zeros(zeros_shape, dtype=val.dtype)
802  else:
803    return array_ops.zeros_like(val, optimize=False)
804
805
806def _ZerosLikeV2(op, index):
807  """Branch of ZerosLike for TF2."""
808  val = op.outputs[index]
809  if val.dtype == dtypes.resource:
810    return array_ops.zeros(
811        gen_resource_variable_ops.variable_shape(val),
812        dtype=default_gradient.get_zeros_dtype(val))
813  if (isinstance(val.op.graph, control_flow_v2_func_graphs.WhileBodyFuncGraph)
814      and val.dtype != dtypes.variant):
815    # In while_v2 we do not want to add a `ZerosLike` op because that will
816    # trigger accumulation of `val`. Normally `ZerosLike` is preferred because
817    # it helps avoid creating extra nodes(possibly Consts) for the shape.
818    # For variants, we must use ZerosLike.
819    if val.shape.is_fully_defined():
820      return constant_op.constant(0, shape=val.shape.dims, dtype=val.dtype)
821    else:
822      # Note: Even though we add `Shape` in the default graph, while_v2 is smart
823      # enough to place it in the forward graph i.e. `val.graph`.
824      zeros_shape = array_ops.shape_internal(val, optimize=False)
825      return array_ops.zeros(zeros_shape, val.dtype)
826  else:
827    return array_ops.zeros_like(val, optimize=False)
828
829
830def ZerosLike(op, index):
831  """Create zeros_like for the specified output of an op."""
832  if not util.IsSwitch(op):
833    return _ZerosLikeV2(op, index)
834  else:
835    return _ZerosLikeV1(op, index)
836