• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utility functions for control flow.
17
18This file is necessary to avoid cyclic dependencies between ops.py and
19control_flow_ops.py.
20"""
21
22import os
23import traceback
24
25from tensorflow.python import tf2
26from tensorflow.python.platform import tf_logging as logging
27
28ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and
29                           os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or
30                          os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
31                          os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
32                          os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
33                          os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
34
35
36# TODO(b/137793122): Remove this.
37def enable_control_flow_v2():  # pylint: disable=invalid-name
38  """Use control flow v2.
39
40  Do not use this symbol. This will be removed.
41  """
42  global ENABLE_CONTROL_FLOW_V2
43  ENABLE_CONTROL_FLOW_V2 = True
44
45
46def EnableControlFlowV2(graph):
47  """Returns whether control flow v2 should be used in `graph`."""
48  # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs).
49  # TODO(skyewm): do something better than hasattr without messing up imports.
50  return ENABLE_CONTROL_FLOW_V2 or (
51      graph.building_function and not hasattr(graph, "_captured"))
52
53
54def IsInXLAContext(op):
55  try:
56    xla_compile = op.get_attr("_XlaCompile")
57    if xla_compile: return True
58  except ValueError:
59    pass
60  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
61  return GetContainingXLAContext(ctxt) is not None
62
63
64def InXlaContext(graph):
65  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
66  return GetContainingXLAContext(ctxt) is not None
67
68
69def GraphOrParentsInXlaContext(graph):
70  while True:
71    if InXlaContext(graph): return True
72    try:
73      graph = graph.outer_graph
74    except AttributeError:
75      return False
76
77
78def IsInWhileLoop(op):
79  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
80  return GetContainingWhileContext(ctxt) is not None
81
82
83def IsInCond(op):
84  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
85  return GetContainingCondContext(ctxt) is not None
86
87
88def IsSwitch(op):
89  """Return true if `op` is a Switch."""
90  return op.type == "Switch" or op.type == "RefSwitch"
91
92
93def IsMerge(op):
94  """Return true if `op` is a Merge."""
95  return op.type == "Merge" or op.type == "RefMerge"
96
97
98def IsLoopEnter(op):
99  """Returns true if `op` is an Enter."""
100  return op.type == "Enter" or op.type == "RefEnter"
101
102
103def IsLoopExit(op):
104  """Return true if `op` is an Exit."""
105  return op.type == "Exit" or op.type == "RefExit"
106
107
108def IsCondSwitch(op):
109  """Return true if `op` is the Switch for a conditional."""
110  if not IsSwitch(op):
111    return False
112  if not op.outputs:
113    return False
114  # Switch nodes are not part of the cond control flow context that they
115  # represent, so consider the consumers of its outputs to determine if it is
116  # cond switch or not. A switch is a cond switch iff all its consumers are in
117  # cond contexts.
118  is_cond_switch = True
119  for o in op.outputs:
120    for c in o.consumers():
121      ctxt = c._get_control_flow_context()  # pylint: disable=protected-access
122      if IsLoopEnter(c):
123        ctxt = ctxt.outer_context
124      is_cond_switch = is_cond_switch and (ctxt is not None and
125                                           ctxt.IsCondContext())
126  return is_cond_switch
127
128
129def IsCondMerge(op):
130  """Return true if `op` is the Merge for a conditional."""
131  if not IsMerge(op):
132    return False
133  if not op.inputs:
134    return False
135  # Merge nodes are not part of the cond control flow context that they
136  # represent, so consider the inputs to the merge of to determine if it is
137  # cond merge or not: A merge is a cond merge iff all its inputs are in
138  # cond contexts.
139  is_cond_merge = True
140  for i in op.inputs:
141    ctxt = GetOutputContext(i.op)
142    is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext()
143  return is_cond_merge
144
145
146def IsLoopSwitch(op):
147  """Return true if `op` is the Switch for a while loop."""
148  if IsSwitch(op):
149    ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
150    return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op)
151  return False
152
153
154def IsLoopMerge(op):
155  """Return true if `op` is the Merge for a while loop."""
156  if IsMerge(op):
157    ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
158    return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op)
159  return False
160
161
162def IsLoopConstantEnter(op):
163  """Return true iff op is a loop invariant."""
164  return IsLoopEnter(op) and op.get_attr("is_constant")
165
166
167def GetLoopConstantEnter(value):
168  """Return the enter op if we can infer `value` to be a loop invariant."""
169  id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"}
170  op = value.op
171  while op.type in id_ops:
172    op = op.inputs[0].op
173  return op if IsLoopConstantEnter(op) else None
174
175
176def GetOutputContext(op):
177  """Return the control flow context for the output of an op."""
178  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
179  # Exit nodes usually have a control flow context, except in the case where the
180  # exit node was imported via import_graph_def (in which case no nodes have
181  # control flow contexts).
182  if ctxt is not None and IsLoopExit(op):
183    ctxt = ctxt.outer_context
184  return ctxt
185
186
187def GetContainingWhileContext(ctxt, stop_ctxt=None):
188  """Returns the first ancestor WhileContext of `ctxt`.
189
190  Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
191  while loop.
192
193  Args:
194    ctxt: ControlFlowContext
195    stop_ctxt: ControlFlowContext, optional. If provided, the search will end
196      if it sees stop_ctxt.
197
198  Returns:
199    `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
200    `ctxt`, or None if `ctxt` is not in a while loop.  If `stop_ctxt` is not
201    `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
202  """
203  while ctxt:
204    if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
205    ctxt = ctxt.outer_context
206  return None
207
208
209def GetContainingXLAContext(ctxt):
210  """Returns the first ancestor XLAContext of `ctxt`.
211
212  Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
213  while loop.
214
215  Args:
216    ctxt: ControlFlowContext
217
218  Returns:
219    `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
220    `ctxt`, or None if `ctxt` is not in a while loop.
221  """
222  while ctxt:
223    if ctxt.IsXLAContext(): return ctxt
224    ctxt = ctxt.outer_context
225  return None
226
227
228def GetContainingCondContext(ctxt):
229  """Returns the first ancestor CondContext of `ctxt`.
230
231  Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond.
232
233  Args:
234    ctxt: ControlFlowContext
235
236  Returns:
237    `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing
238    `ctxt`, or None if `ctxt` is not in a cond.
239  """
240  while ctxt:
241    if ctxt.IsCondContext(): return ctxt
242    ctxt = ctxt.outer_context
243  return None
244
245
246def IsContainingContext(ctxt, maybe_containing_ctxt):
247  """Returns true if `maybe_containing_ctxt` is or contains `ctxt`."""
248  while ctxt is not maybe_containing_ctxt:
249    if ctxt is None: return False
250    ctxt = ctxt.outer_context
251  return True
252
253
254def OpInContext(op, ctxt):
255  return IsContainingContext(op._get_control_flow_context(), ctxt)  # pylint: disable=protected-access
256
257
258def TensorInContext(tensor, ctxt):
259  return OpInContext(tensor.op, ctxt)
260
261
262def CheckInputFromValidContext(op, input_op):
263  """Returns whether `input_op` can be used from `op`s context.
264
265  Conceptually, only inputs from op's while context or any ancestor while
266  context (including outside of any context) are valid. In practice, there are
267  many other edge cases as well.
268
269  Args:
270    op: Operation
271    input_op: Operation
272
273  Raises:
274    ValueError: if input_op is from an invalid context.
275  """
276  op_ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
277  input_ctxt = GetOutputContext(input_op)
278  valid = False
279
280  if not input_ctxt:
281    # input_op isn't in a control flow context.
282    valid = True
283  elif op_ctxt is input_ctxt:
284    # input_op is in the same context as op.
285    valid = True
286  else:
287    while_ctxt = GetContainingWhileContext(op_ctxt)
288    input_while_ctxt = GetContainingWhileContext(input_ctxt)
289
290    if while_ctxt is None:
291      if input_while_ctxt is None:
292        # Neither op nor input_op is in a while loop, but one or both are in
293        # conds. We allow this, although execution will fail if the branch
294        # corresponding to input_op's cond context isn't taken.
295        valid = True
296      # Invalid if op isn't in a while loop and input_op is. Unless...
297      if IsLoopEnter(op):
298        # WhileContext._BuildLoop clears context for Enter nodes.
299        valid = True
300      if IsSwitch(op):
301        # CondContext.AddValue clears context for Switch nodes.
302        valid = True
303    elif IsContainingContext(while_ctxt, input_while_ctxt):
304      # input_op is in a while loop which contains op's while loop (or not in a
305      # while loop at all).
306      valid = True
307    elif (while_ctxt.grad_state and
308          IsContainingContext(while_ctxt.grad_state.forward_context,
309                              input_while_ctxt)):
310      # op is in a gradient context and input_op is in the associated forward
311      # pass context or an ancestor thereof. This case is need to build while
312      # loop gradients.
313      # NOTE(skyewm): we theoretically also need this case for custom gradient
314      # functions that close over tensors from ancestor contexts, but I haven't
315      # verified this.
316      valid = True
317    elif (while_ctxt.grad_state and
318          while_ctxt.grad_state.forward_context is
319          input_while_ctxt._outer_context):  # pylint: disable=protected-access
320      # op is in a gradient context and input_op is in a child of the associated
321      # forward pass context. This case is needed for the gradients of while
322      # loops with conds.
323      valid = True
324    elif (input_while_ctxt.grad_state and
325          input_while_ctxt.grad_state.forward_context is while_ctxt):
326      # input_op is in the gradient context of op's context. This case is needed
327      # when the gradient of a while loop gradient is requested (this will
328      # eventually fail unless there is a stop_gradient() or similar).
329      valid = True
330    elif (input_while_ctxt.grad_state and
331          input_ctxt.grad_state.forward_context.grad_state and
332          input_ctxt.grad_state.forward_context.grad_state.forward_context is
333          while_ctxt):
334      # input_op is in the grad grad context of op's context. This case is
335      # needed when the gradient of a while loop gradient is requested (this
336      # will eventually fail unless there is a stop_gradient() or similar).
337      valid = True
338
339  if not valid:
340    if while_ctxt:
341      error_msg = (
342          f"Cannot use '{input_op.name}' as input to '{op.name}' because they "
343          "are in different while loops.")
344    else:
345      error_msg = (
346          f"Cannot use '{input_op.name}' as input to '{op.name}' because "
347          f"'{input_op.name}' is in a while loop.")
348
349    # Log the error message plus the relevant stack traces. The stacks may be
350    # useful for debugging this error, but we don't want to raise an
351    # unreadable exception.
352    log_msg = error_msg
353    log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt)
354    log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt)
355    log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % (
356        op.name, "".join(traceback.format_list(op.traceback)),
357        input_op.name, "".join(traceback.format_list(input_op.traceback)))
358    logging.info(log_msg)
359    raise ValueError(error_msg + " See info log for more details.")
360
361
362def GetWhileContext(op):
363  """Get the WhileContext to which this op belongs."""
364  ctxt = op._get_control_flow_context()  # pylint: disable=protected-access
365  if ctxt:
366    ctxt = ctxt.GetWhileContext()
367  return ctxt
368