1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilty functions for control flow. 17 18This file is necessary to avoid cyclic dependencies between ops.py and 19control_flow_ops.py. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import os 27import traceback 28 29from tensorflow.python.platform import tf_logging as logging 30 31ENABLE_CONTROL_FLOW_V2 = (os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or 32 os.getenv("TF_ENABLE_COND_V2", "0") != "0" or 33 os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or 34 os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0") 35 36 37def EnableControlFlowV2(graph): 38 """Returns whether control flow v2 should be used in `graph`.""" 39 # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs). 40 # TODO(skyewm): do something better than hasattr without messing up imports. 41 return ENABLE_CONTROL_FLOW_V2 or ( 42 graph.building_function and not hasattr(graph, "_captured")) 43 44 45def IsInXLAContext(op): 46 try: 47 xla_compile = op.get_attr("_XlaCompile") 48 if xla_compile: return True 49 except ValueError: 50 pass 51 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 52 return GetContainingXLAContext(ctxt) is not None 53 54 55def InXlaContext(graph): 56 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access 57 return GetContainingXLAContext(ctxt) is not None 58 59 60def GraphOrParentsInXlaContext(graph): 61 while True: 62 if InXlaContext(graph): return True 63 try: 64 graph = graph.outer_graph 65 except AttributeError: 66 return False 67 68 69def IsInWhileLoop(op): 70 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 71 return GetContainingWhileContext(ctxt) is not None 72 73 74def IsInCond(op): 75 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 76 return GetContainingCondContext(ctxt) is not None 77 78 79def IsSwitch(op): 80 """Return true if `op` is a Switch.""" 81 return op.type == "Switch" or op.type == "RefSwitch" 82 83 84def IsMerge(op): 85 """Return true if `op` is a Merge.""" 86 return op.type == "Merge" or op.type == "RefMerge" 87 88 89def IsLoopEnter(op): 90 """Returns true if `op` is an Enter.""" 91 return op.type == "Enter" or op.type == "RefEnter" 92 93 94def IsLoopExit(op): 95 """Return true if `op` is an Exit.""" 96 return op.type == "Exit" or op.type == "RefExit" 97 98 99def IsCondSwitch(op): 100 """Return true if `op` is the Switch for a conditional.""" 101 if not IsSwitch(op): 102 return False 103 if not op.outputs: 104 return False 105 # Switch nodes are not part of the cond control flow context that they 106 # represent, so consider the consumers of its outputs to determine if it is 107 # cond switch or not. A switch is a cond switch iff all its consumers are in 108 # cond contexts. 109 is_cond_switch = True 110 for o in op.outputs: 111 for c in o.consumers(): 112 ctxt = c._get_control_flow_context() # pylint: disable=protected-access 113 if IsLoopEnter(c): 114 ctxt = ctxt.outer_context 115 is_cond_switch = is_cond_switch and (ctxt is not None and 116 ctxt.IsCondContext()) 117 return is_cond_switch 118 119 120def IsCondMerge(op): 121 """Return true if `op` is the Merge for a conditional.""" 122 if not IsMerge(op): 123 return False 124 if not op.inputs: 125 return False 126 # Merge nodes are not part of the cond control flow context that they 127 # represent, so consider the inputs to the merge of to determine if it is 128 # cond merge or not: A merge is a cond merge iff all its inputs are in 129 # cond contexts. 130 is_cond_merge = True 131 for i in op.inputs: 132 ctxt = GetOutputContext(i.op) 133 is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext() 134 return is_cond_merge 135 136 137def IsLoopSwitch(op): 138 """Return true if `op` is the Switch for a while loop.""" 139 if IsSwitch(op): 140 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 141 return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op) 142 return False 143 144 145def IsLoopMerge(op): 146 """Return true if `op` is the Merge for a while loop.""" 147 if IsMerge(op): 148 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 149 return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op) 150 return False 151 152 153def IsLoopConstantEnter(op): 154 """Return true iff op is a loop invariant.""" 155 return IsLoopEnter(op) and op.get_attr("is_constant") 156 157 158def GetLoopConstantEnter(value): 159 """Return the enter op if we can infer `value` to be a loop invariant.""" 160 id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"} 161 op = value.op 162 while op.type in id_ops: 163 op = op.inputs[0].op 164 return op if IsLoopConstantEnter(op) else None 165 166 167def GetOutputContext(op): 168 """Return the control flow context for the output of an op.""" 169 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 170 # Exit nodes usually have a control flow context, except in the case where the 171 # exit node was imported via import_graph_def (in which case no nodes have 172 # control flow contexts). 173 if ctxt is not None and IsLoopExit(op): 174 ctxt = ctxt.outer_context 175 return ctxt 176 177 178def GetContainingWhileContext(ctxt, stop_ctxt=None): 179 """Returns the first ancestor WhileContext of `ctxt`. 180 181 Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a 182 while loop. 183 184 Args: 185 ctxt: ControlFlowContext 186 stop_ctxt: ControlFlowContext, optional. If provided, the search will end 187 if it sees stop_ctxt. 188 189 Returns: 190 `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing 191 `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not 192 `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal. 193 """ 194 while ctxt: 195 if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt 196 ctxt = ctxt.outer_context 197 return None 198 199 200def GetContainingXLAContext(ctxt): 201 """Returns the first ancestor XLAContext of `ctxt`. 202 203 Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a 204 while loop. 205 206 Args: 207 ctxt: ControlFlowContext 208 209 Returns: 210 `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing 211 `ctxt`, or None if `ctxt` is not in a while loop. 212 """ 213 while ctxt: 214 if ctxt.IsXLAContext(): return ctxt 215 ctxt = ctxt.outer_context 216 return None 217 218 219def GetContainingCondContext(ctxt): 220 """Returns the first ancestor CondContext of `ctxt`. 221 222 Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond. 223 224 Args: 225 ctxt: ControlFlowContext 226 227 Returns: 228 `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing 229 `ctxt`, or None if `ctxt` is not in a cond. 230 """ 231 while ctxt: 232 if ctxt.IsCondContext(): return ctxt 233 ctxt = ctxt.outer_context 234 return None 235 236 237def IsContainingContext(ctxt, maybe_containing_ctxt): 238 """Returns true if `maybe_containing_ctxt` is or contains `ctxt`.""" 239 while ctxt is not maybe_containing_ctxt: 240 if ctxt is None: return False 241 ctxt = ctxt.outer_context 242 return True 243 244 245def OpInContext(op, ctxt): 246 return IsContainingContext(op._get_control_flow_context(), ctxt) # pylint: disable=protected-access 247 248 249def TensorInContext(tensor, ctxt): 250 return OpInContext(tensor.op, ctxt) 251 252 253def CheckInputFromValidContext(op, input_op): 254 """Returns whether `input_op` can be used from `op`s context. 255 256 Conceptually, only inputs from op's while context or any ancestor while 257 context (including outside of any context) are valid. In practice, there are 258 many other edge cases as well. 259 260 Args: 261 op: Operation 262 input_op: Operation 263 264 Raises: 265 ValueError: if input_op is from an invalid context. 266 """ 267 op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access 268 input_ctxt = GetOutputContext(input_op) 269 valid = False 270 271 if not input_ctxt: 272 # input_op isn't in a control flow context. 273 valid = True 274 elif op_ctxt is input_ctxt: 275 # input_op is in the same context as op. 276 valid = True 277 else: 278 while_ctxt = GetContainingWhileContext(op_ctxt) 279 input_while_ctxt = GetContainingWhileContext(input_ctxt) 280 281 if while_ctxt is None: 282 if input_while_ctxt is None: 283 # Neither op nor input_op is in a while loop, but one or both are in 284 # conds. We allow this, although execution will fail if the branch 285 # corresponding to input_op's cond context isn't taken. 286 valid = True 287 # Invalid if op isn't in a while loop and input_op is. Unless... 288 if IsLoopEnter(op): 289 # WhileContext._BuildLoop clears context for Enter nodes. 290 valid = True 291 if IsSwitch(op): 292 # CondContext.AddValue clears context for Switch nodes. 293 valid = True 294 elif IsContainingContext(while_ctxt, input_while_ctxt): 295 # input_op is in a while loop which contains op's while loop (or not in a 296 # while loop at all). 297 valid = True 298 elif (while_ctxt.grad_state and 299 IsContainingContext(while_ctxt.grad_state.forward_context, 300 input_while_ctxt)): 301 # op is in a gradient context and input_op is in the associated forward 302 # pass context or an ancestor thereof. This case is need to build while 303 # loop gradients. 304 # NOTE(skyewm): we theoretically also need this case for custom gradient 305 # functions that close over tensors from ancestor contexts, but I haven't 306 # verified this. 307 valid = True 308 elif (while_ctxt.grad_state and 309 while_ctxt.grad_state.forward_context is 310 input_while_ctxt._outer_context): # pylint: disable=protected-access 311 # op is in a gradient context and input_op is in a child of the associated 312 # forward pass context. This case is needed for the gradients of while 313 # loops with conds. 314 valid = True 315 elif (input_while_ctxt.grad_state and 316 input_while_ctxt.grad_state.forward_context is while_ctxt): 317 # input_op is in the gradient context of op's context. This case is needed 318 # when the gradient of a while loop gradient is requested (this will 319 # eventually fail unless there is a stop_gradient() or similar). 320 valid = True 321 elif (input_while_ctxt.grad_state and 322 input_ctxt.grad_state.forward_context.grad_state and 323 input_ctxt.grad_state.forward_context.grad_state.forward_context is 324 while_ctxt): 325 # input_op is in the grad grad context of op's context. This case is 326 # needed when the gradient of a while loop gradient is requested (this 327 # will eventually fail unless there is a stop_gradient() or similar). 328 valid = True 329 330 if not valid: 331 if while_ctxt: 332 error_msg = ( 333 "Cannot use '%s' as input to '%s' because they are in different while" 334 " loops." % (input_op.name, op.name)) 335 else: 336 error_msg = ( 337 "Cannot use '%s' as input to '%s' because '%s' is in a while loop." 338 % (input_op.name, op.name, input_op.name)) 339 340 # Log the error message plus the relevant stack traces. The stacks may be 341 # useful for debugging this error, but we don't want to raise an 342 # unreadable exception. 343 log_msg = error_msg 344 log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt) 345 log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt) 346 log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % ( 347 op.name, "".join(traceback.format_list(op.traceback)), 348 input_op.name, "".join(traceback.format_list(input_op.traceback))) 349 logging.info(log_msg) 350 raise ValueError(error_msg + " See info log for more details.") 351