1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utility functions for control flow. 17 18This file is necessary to avoid cyclic dependencies between ops.py and 19control_flow_ops.py. 20""" 21 22import os 23import traceback 24 25from tensorflow.python import tf2 26from tensorflow.python.platform import tf_logging as logging 27 28ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and 29 os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or 30 os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or 31 os.getenv("TF_ENABLE_COND_V2", "0") != "0" or 32 os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or 33 os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0") 34 35 36# TODO(b/137793122): Remove this. 37def enable_control_flow_v2(): # pylint: disable=invalid-name 38 """Use control flow v2. 39 40 Do not use this symbol. This will be removed. 41 """ 42 global ENABLE_CONTROL_FLOW_V2 43 ENABLE_CONTROL_FLOW_V2 = True 44 45 46def EnableControlFlowV2(graph): 47 """Returns whether control flow v2 should be used in `graph`.""" 48 # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs). 49 # TODO(skyewm): do something better than hasattr without messing up imports. 50 return ENABLE_CONTROL_FLOW_V2 or ( 51 graph.building_function and not hasattr(graph, "_captured")) 52 53 54def IsInXLAContext(op): 55 try: 56 xla_compile = op.get_attr("_XlaCompile") 57 if xla_compile: return True 58 except ValueError: 59 pass 60 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 61 return GetContainingXLAContext(ctxt) is not None 62 63 64def InXlaContext(graph): 65 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access 66 return GetContainingXLAContext(ctxt) is not None 67 68 69def GraphOrParentsInXlaContext(graph): 70 while True: 71 if InXlaContext(graph): return True 72 try: 73 graph = graph.outer_graph 74 except AttributeError: 75 return False 76 77 78def IsInWhileLoop(op): 79 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 80 return GetContainingWhileContext(ctxt) is not None 81 82 83def IsInCond(op): 84 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 85 return GetContainingCondContext(ctxt) is not None 86 87 88def IsSwitch(op): 89 """Return true if `op` is a Switch.""" 90 return op.type == "Switch" or op.type == "RefSwitch" 91 92 93def IsMerge(op): 94 """Return true if `op` is a Merge.""" 95 return op.type == "Merge" or op.type == "RefMerge" 96 97 98def IsLoopEnter(op): 99 """Returns true if `op` is an Enter.""" 100 return op.type == "Enter" or op.type == "RefEnter" 101 102 103def IsLoopExit(op): 104 """Return true if `op` is an Exit.""" 105 return op.type == "Exit" or op.type == "RefExit" 106 107 108def IsCondSwitch(op): 109 """Return true if `op` is the Switch for a conditional.""" 110 if not IsSwitch(op): 111 return False 112 if not op.outputs: 113 return False 114 # Switch nodes are not part of the cond control flow context that they 115 # represent, so consider the consumers of its outputs to determine if it is 116 # cond switch or not. A switch is a cond switch iff all its consumers are in 117 # cond contexts. 118 is_cond_switch = True 119 for o in op.outputs: 120 for c in o.consumers(): 121 ctxt = c._get_control_flow_context() # pylint: disable=protected-access 122 if IsLoopEnter(c): 123 ctxt = ctxt.outer_context 124 is_cond_switch = is_cond_switch and (ctxt is not None and 125 ctxt.IsCondContext()) 126 return is_cond_switch 127 128 129def IsCondMerge(op): 130 """Return true if `op` is the Merge for a conditional.""" 131 if not IsMerge(op): 132 return False 133 if not op.inputs: 134 return False 135 # Merge nodes are not part of the cond control flow context that they 136 # represent, so consider the inputs to the merge of to determine if it is 137 # cond merge or not: A merge is a cond merge iff all its inputs are in 138 # cond contexts. 139 is_cond_merge = True 140 for i in op.inputs: 141 ctxt = GetOutputContext(i.op) 142 is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext() 143 return is_cond_merge 144 145 146def IsLoopSwitch(op): 147 """Return true if `op` is the Switch for a while loop.""" 148 if IsSwitch(op): 149 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 150 return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op) 151 return False 152 153 154def IsLoopMerge(op): 155 """Return true if `op` is the Merge for a while loop.""" 156 if IsMerge(op): 157 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 158 return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op) 159 return False 160 161 162def IsLoopConstantEnter(op): 163 """Return true iff op is a loop invariant.""" 164 return IsLoopEnter(op) and op.get_attr("is_constant") 165 166 167def GetLoopConstantEnter(value): 168 """Return the enter op if we can infer `value` to be a loop invariant.""" 169 id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"} 170 op = value.op 171 while op.type in id_ops: 172 op = op.inputs[0].op 173 return op if IsLoopConstantEnter(op) else None 174 175 176def GetOutputContext(op): 177 """Return the control flow context for the output of an op.""" 178 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 179 # Exit nodes usually have a control flow context, except in the case where the 180 # exit node was imported via import_graph_def (in which case no nodes have 181 # control flow contexts). 182 if ctxt is not None and IsLoopExit(op): 183 ctxt = ctxt.outer_context 184 return ctxt 185 186 187def GetContainingWhileContext(ctxt, stop_ctxt=None): 188 """Returns the first ancestor WhileContext of `ctxt`. 189 190 Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a 191 while loop. 192 193 Args: 194 ctxt: ControlFlowContext 195 stop_ctxt: ControlFlowContext, optional. If provided, the search will end 196 if it sees stop_ctxt. 197 198 Returns: 199 `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing 200 `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not 201 `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal. 202 """ 203 while ctxt: 204 if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt 205 ctxt = ctxt.outer_context 206 return None 207 208 209def GetContainingXLAContext(ctxt): 210 """Returns the first ancestor XLAContext of `ctxt`. 211 212 Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a 213 while loop. 214 215 Args: 216 ctxt: ControlFlowContext 217 218 Returns: 219 `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing 220 `ctxt`, or None if `ctxt` is not in a while loop. 221 """ 222 while ctxt: 223 if ctxt.IsXLAContext(): return ctxt 224 ctxt = ctxt.outer_context 225 return None 226 227 228def GetContainingCondContext(ctxt): 229 """Returns the first ancestor CondContext of `ctxt`. 230 231 Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond. 232 233 Args: 234 ctxt: ControlFlowContext 235 236 Returns: 237 `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing 238 `ctxt`, or None if `ctxt` is not in a cond. 239 """ 240 while ctxt: 241 if ctxt.IsCondContext(): return ctxt 242 ctxt = ctxt.outer_context 243 return None 244 245 246def IsContainingContext(ctxt, maybe_containing_ctxt): 247 """Returns true if `maybe_containing_ctxt` is or contains `ctxt`.""" 248 while ctxt is not maybe_containing_ctxt: 249 if ctxt is None: return False 250 ctxt = ctxt.outer_context 251 return True 252 253 254def OpInContext(op, ctxt): 255 return IsContainingContext(op._get_control_flow_context(), ctxt) # pylint: disable=protected-access 256 257 258def TensorInContext(tensor, ctxt): 259 return OpInContext(tensor.op, ctxt) 260 261 262def CheckInputFromValidContext(op, input_op): 263 """Returns whether `input_op` can be used from `op`s context. 264 265 Conceptually, only inputs from op's while context or any ancestor while 266 context (including outside of any context) are valid. In practice, there are 267 many other edge cases as well. 268 269 Args: 270 op: Operation 271 input_op: Operation 272 273 Raises: 274 ValueError: if input_op is from an invalid context. 275 """ 276 op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access 277 input_ctxt = GetOutputContext(input_op) 278 valid = False 279 280 if not input_ctxt: 281 # input_op isn't in a control flow context. 282 valid = True 283 elif op_ctxt is input_ctxt: 284 # input_op is in the same context as op. 285 valid = True 286 else: 287 while_ctxt = GetContainingWhileContext(op_ctxt) 288 input_while_ctxt = GetContainingWhileContext(input_ctxt) 289 290 if while_ctxt is None: 291 if input_while_ctxt is None: 292 # Neither op nor input_op is in a while loop, but one or both are in 293 # conds. We allow this, although execution will fail if the branch 294 # corresponding to input_op's cond context isn't taken. 295 valid = True 296 # Invalid if op isn't in a while loop and input_op is. Unless... 297 if IsLoopEnter(op): 298 # WhileContext._BuildLoop clears context for Enter nodes. 299 valid = True 300 if IsSwitch(op): 301 # CondContext.AddValue clears context for Switch nodes. 302 valid = True 303 elif IsContainingContext(while_ctxt, input_while_ctxt): 304 # input_op is in a while loop which contains op's while loop (or not in a 305 # while loop at all). 306 valid = True 307 elif (while_ctxt.grad_state and 308 IsContainingContext(while_ctxt.grad_state.forward_context, 309 input_while_ctxt)): 310 # op is in a gradient context and input_op is in the associated forward 311 # pass context or an ancestor thereof. This case is need to build while 312 # loop gradients. 313 # NOTE(skyewm): we theoretically also need this case for custom gradient 314 # functions that close over tensors from ancestor contexts, but I haven't 315 # verified this. 316 valid = True 317 elif (while_ctxt.grad_state and 318 while_ctxt.grad_state.forward_context is 319 input_while_ctxt._outer_context): # pylint: disable=protected-access 320 # op is in a gradient context and input_op is in a child of the associated 321 # forward pass context. This case is needed for the gradients of while 322 # loops with conds. 323 valid = True 324 elif (input_while_ctxt.grad_state and 325 input_while_ctxt.grad_state.forward_context is while_ctxt): 326 # input_op is in the gradient context of op's context. This case is needed 327 # when the gradient of a while loop gradient is requested (this will 328 # eventually fail unless there is a stop_gradient() or similar). 329 valid = True 330 elif (input_while_ctxt.grad_state and 331 input_ctxt.grad_state.forward_context.grad_state and 332 input_ctxt.grad_state.forward_context.grad_state.forward_context is 333 while_ctxt): 334 # input_op is in the grad grad context of op's context. This case is 335 # needed when the gradient of a while loop gradient is requested (this 336 # will eventually fail unless there is a stop_gradient() or similar). 337 valid = True 338 339 if not valid: 340 if while_ctxt: 341 error_msg = ( 342 f"Cannot use '{input_op.name}' as input to '{op.name}' because they " 343 "are in different while loops.") 344 else: 345 error_msg = ( 346 f"Cannot use '{input_op.name}' as input to '{op.name}' because " 347 f"'{input_op.name}' is in a while loop.") 348 349 # Log the error message plus the relevant stack traces. The stacks may be 350 # useful for debugging this error, but we don't want to raise an 351 # unreadable exception. 352 log_msg = error_msg 353 log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt) 354 log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt) 355 log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % ( 356 op.name, "".join(traceback.format_list(op.traceback)), 357 input_op.name, "".join(traceback.format_list(input_op.traceback))) 358 logging.info(log_msg) 359 raise ValueError(error_msg + " See info log for more details.") 360 361 362def GetWhileContext(op): 363 """Get the WhileContext to which this op belongs.""" 364 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 365 if ctxt: 366 ctxt = ctxt.GetWhileContext() 367 return ctxt 368