1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utility functions for control flow. 17 18This file is necessary to avoid cyclic dependencies between ops.py and 19control_flow_ops.py. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import os 27import traceback 28 29from tensorflow.python import tf2 30from tensorflow.python.platform import tf_logging as logging 31 32ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and 33 os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or 34 os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or 35 os.getenv("TF_ENABLE_COND_V2", "0") != "0" or 36 os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or 37 os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0") 38 39 40# TODO(b/137793122): Remove this. 41def enable_control_flow_v2(): # pylint: disable=invalid-name 42 """Use control flow v2. 43 44 Do not use this symbol. This will be removed. 45 """ 46 global ENABLE_CONTROL_FLOW_V2 47 ENABLE_CONTROL_FLOW_V2 = True 48 49 50def EnableControlFlowV2(graph): 51 """Returns whether control flow v2 should be used in `graph`.""" 52 # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs). 53 # TODO(skyewm): do something better than hasattr without messing up imports. 54 return ENABLE_CONTROL_FLOW_V2 or ( 55 graph.building_function and not hasattr(graph, "_captured")) 56 57 58def IsInXLAContext(op): 59 try: 60 xla_compile = op.get_attr("_XlaCompile") 61 if xla_compile: return True 62 except ValueError: 63 pass 64 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 65 return GetContainingXLAContext(ctxt) is not None 66 67 68def InXlaContext(graph): 69 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access 70 return GetContainingXLAContext(ctxt) is not None 71 72 73def GraphOrParentsInXlaContext(graph): 74 while True: 75 if InXlaContext(graph): return True 76 try: 77 graph = graph.outer_graph 78 except AttributeError: 79 return False 80 81 82def IsInWhileLoop(op): 83 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 84 return GetContainingWhileContext(ctxt) is not None 85 86 87def IsInCond(op): 88 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 89 return GetContainingCondContext(ctxt) is not None 90 91 92def IsSwitch(op): 93 """Return true if `op` is a Switch.""" 94 return op.type == "Switch" or op.type == "RefSwitch" 95 96 97def IsMerge(op): 98 """Return true if `op` is a Merge.""" 99 return op.type == "Merge" or op.type == "RefMerge" 100 101 102def IsLoopEnter(op): 103 """Returns true if `op` is an Enter.""" 104 return op.type == "Enter" or op.type == "RefEnter" 105 106 107def IsLoopExit(op): 108 """Return true if `op` is an Exit.""" 109 return op.type == "Exit" or op.type == "RefExit" 110 111 112def IsCondSwitch(op): 113 """Return true if `op` is the Switch for a conditional.""" 114 if not IsSwitch(op): 115 return False 116 if not op.outputs: 117 return False 118 # Switch nodes are not part of the cond control flow context that they 119 # represent, so consider the consumers of its outputs to determine if it is 120 # cond switch or not. A switch is a cond switch iff all its consumers are in 121 # cond contexts. 122 is_cond_switch = True 123 for o in op.outputs: 124 for c in o.consumers(): 125 ctxt = c._get_control_flow_context() # pylint: disable=protected-access 126 if IsLoopEnter(c): 127 ctxt = ctxt.outer_context 128 is_cond_switch = is_cond_switch and (ctxt is not None and 129 ctxt.IsCondContext()) 130 return is_cond_switch 131 132 133def IsCondMerge(op): 134 """Return true if `op` is the Merge for a conditional.""" 135 if not IsMerge(op): 136 return False 137 if not op.inputs: 138 return False 139 # Merge nodes are not part of the cond control flow context that they 140 # represent, so consider the inputs to the merge of to determine if it is 141 # cond merge or not: A merge is a cond merge iff all its inputs are in 142 # cond contexts. 143 is_cond_merge = True 144 for i in op.inputs: 145 ctxt = GetOutputContext(i.op) 146 is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext() 147 return is_cond_merge 148 149 150def IsLoopSwitch(op): 151 """Return true if `op` is the Switch for a while loop.""" 152 if IsSwitch(op): 153 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 154 return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op) 155 return False 156 157 158def IsLoopMerge(op): 159 """Return true if `op` is the Merge for a while loop.""" 160 if IsMerge(op): 161 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 162 return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op) 163 return False 164 165 166def IsLoopConstantEnter(op): 167 """Return true iff op is a loop invariant.""" 168 return IsLoopEnter(op) and op.get_attr("is_constant") 169 170 171def GetLoopConstantEnter(value): 172 """Return the enter op if we can infer `value` to be a loop invariant.""" 173 id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"} 174 op = value.op 175 while op.type in id_ops: 176 op = op.inputs[0].op 177 return op if IsLoopConstantEnter(op) else None 178 179 180def GetOutputContext(op): 181 """Return the control flow context for the output of an op.""" 182 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 183 # Exit nodes usually have a control flow context, except in the case where the 184 # exit node was imported via import_graph_def (in which case no nodes have 185 # control flow contexts). 186 if ctxt is not None and IsLoopExit(op): 187 ctxt = ctxt.outer_context 188 return ctxt 189 190 191def GetContainingWhileContext(ctxt, stop_ctxt=None): 192 """Returns the first ancestor WhileContext of `ctxt`. 193 194 Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a 195 while loop. 196 197 Args: 198 ctxt: ControlFlowContext 199 stop_ctxt: ControlFlowContext, optional. If provided, the search will end 200 if it sees stop_ctxt. 201 202 Returns: 203 `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing 204 `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not 205 `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal. 206 """ 207 while ctxt: 208 if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt 209 ctxt = ctxt.outer_context 210 return None 211 212 213def GetContainingXLAContext(ctxt): 214 """Returns the first ancestor XLAContext of `ctxt`. 215 216 Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a 217 while loop. 218 219 Args: 220 ctxt: ControlFlowContext 221 222 Returns: 223 `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing 224 `ctxt`, or None if `ctxt` is not in a while loop. 225 """ 226 while ctxt: 227 if ctxt.IsXLAContext(): return ctxt 228 ctxt = ctxt.outer_context 229 return None 230 231 232def GetContainingCondContext(ctxt): 233 """Returns the first ancestor CondContext of `ctxt`. 234 235 Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond. 236 237 Args: 238 ctxt: ControlFlowContext 239 240 Returns: 241 `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing 242 `ctxt`, or None if `ctxt` is not in a cond. 243 """ 244 while ctxt: 245 if ctxt.IsCondContext(): return ctxt 246 ctxt = ctxt.outer_context 247 return None 248 249 250def IsContainingContext(ctxt, maybe_containing_ctxt): 251 """Returns true if `maybe_containing_ctxt` is or contains `ctxt`.""" 252 while ctxt is not maybe_containing_ctxt: 253 if ctxt is None: return False 254 ctxt = ctxt.outer_context 255 return True 256 257 258def OpInContext(op, ctxt): 259 return IsContainingContext(op._get_control_flow_context(), ctxt) # pylint: disable=protected-access 260 261 262def TensorInContext(tensor, ctxt): 263 return OpInContext(tensor.op, ctxt) 264 265 266def CheckInputFromValidContext(op, input_op): 267 """Returns whether `input_op` can be used from `op`s context. 268 269 Conceptually, only inputs from op's while context or any ancestor while 270 context (including outside of any context) are valid. In practice, there are 271 many other edge cases as well. 272 273 Args: 274 op: Operation 275 input_op: Operation 276 277 Raises: 278 ValueError: if input_op is from an invalid context. 279 """ 280 op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access 281 input_ctxt = GetOutputContext(input_op) 282 valid = False 283 284 if not input_ctxt: 285 # input_op isn't in a control flow context. 286 valid = True 287 elif op_ctxt is input_ctxt: 288 # input_op is in the same context as op. 289 valid = True 290 else: 291 while_ctxt = GetContainingWhileContext(op_ctxt) 292 input_while_ctxt = GetContainingWhileContext(input_ctxt) 293 294 if while_ctxt is None: 295 if input_while_ctxt is None: 296 # Neither op nor input_op is in a while loop, but one or both are in 297 # conds. We allow this, although execution will fail if the branch 298 # corresponding to input_op's cond context isn't taken. 299 valid = True 300 # Invalid if op isn't in a while loop and input_op is. Unless... 301 if IsLoopEnter(op): 302 # WhileContext._BuildLoop clears context for Enter nodes. 303 valid = True 304 if IsSwitch(op): 305 # CondContext.AddValue clears context for Switch nodes. 306 valid = True 307 elif IsContainingContext(while_ctxt, input_while_ctxt): 308 # input_op is in a while loop which contains op's while loop (or not in a 309 # while loop at all). 310 valid = True 311 elif (while_ctxt.grad_state and 312 IsContainingContext(while_ctxt.grad_state.forward_context, 313 input_while_ctxt)): 314 # op is in a gradient context and input_op is in the associated forward 315 # pass context or an ancestor thereof. This case is need to build while 316 # loop gradients. 317 # NOTE(skyewm): we theoretically also need this case for custom gradient 318 # functions that close over tensors from ancestor contexts, but I haven't 319 # verified this. 320 valid = True 321 elif (while_ctxt.grad_state and 322 while_ctxt.grad_state.forward_context is 323 input_while_ctxt._outer_context): # pylint: disable=protected-access 324 # op is in a gradient context and input_op is in a child of the associated 325 # forward pass context. This case is needed for the gradients of while 326 # loops with conds. 327 valid = True 328 elif (input_while_ctxt.grad_state and 329 input_while_ctxt.grad_state.forward_context is while_ctxt): 330 # input_op is in the gradient context of op's context. This case is needed 331 # when the gradient of a while loop gradient is requested (this will 332 # eventually fail unless there is a stop_gradient() or similar). 333 valid = True 334 elif (input_while_ctxt.grad_state and 335 input_ctxt.grad_state.forward_context.grad_state and 336 input_ctxt.grad_state.forward_context.grad_state.forward_context is 337 while_ctxt): 338 # input_op is in the grad grad context of op's context. This case is 339 # needed when the gradient of a while loop gradient is requested (this 340 # will eventually fail unless there is a stop_gradient() or similar). 341 valid = True 342 343 if not valid: 344 if while_ctxt: 345 error_msg = ( 346 "Cannot use '%s' as input to '%s' because they are in different while" 347 " loops." % (input_op.name, op.name)) 348 else: 349 error_msg = ( 350 "Cannot use '%s' as input to '%s' because '%s' is in a while loop." 351 % (input_op.name, op.name, input_op.name)) 352 353 # Log the error message plus the relevant stack traces. The stacks may be 354 # useful for debugging this error, but we don't want to raise an 355 # unreadable exception. 356 log_msg = error_msg 357 log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt) 358 log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt) 359 log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % ( 360 op.name, "".join(traceback.format_list(op.traceback)), 361 input_op.name, "".join(traceback.format_list(input_op.traceback))) 362 logging.info(log_msg) 363 raise ValueError(error_msg + " See info log for more details.") 364 365 366def GetWhileContext(op): 367 """Get the WhileContext to which this op belongs.""" 368 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 369 if ctxt: 370 ctxt = ctxt.GetWhileContext() 371 return ctxt 372