1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilty functions for control flow. 17 18This file is necessary to avoid cyclic dependencies between ops.py and 19control_flow_ops.py. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import traceback 27 28from tensorflow.python.platform import tf_logging as logging 29 30 31def IsInXLAContext(op): 32 try: 33 xla_compile = op.get_attr("_XlaCompile") 34 if xla_compile: return True 35 except ValueError: 36 pass 37 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 38 return GetContainingXLAContext(ctxt) is not None 39 40 41def IsInWhileLoop(op): 42 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 43 return GetContainingWhileContext(ctxt) is not None 44 45 46def IsInCond(op): 47 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 48 return GetContainingCondContext(ctxt) is not None 49 50 51def IsSwitch(op): 52 """Return true if `op` is a Switch.""" 53 return op.type == "Switch" or op.type == "RefSwitch" 54 55 56def IsLoopEnter(op): 57 """Returns true if `op` is an Enter.""" 58 return op.type == "Enter" or op.type == "RefEnter" 59 60 61def IsLoopExit(op): 62 """Return true if `op` is an Exit.""" 63 return op.type == "Exit" or op.type == "RefExit" 64 65 66def IsLoopSwitch(op): 67 """Return true if `op` is the Switch for a while loop.""" 68 if IsSwitch(op): 69 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 70 return ctxt and ctxt.IsWhileContext() 71 return False 72 73 74def IsLoopConstantEnter(op): 75 """Return true iff op is a loop invariant.""" 76 return IsLoopEnter(op) and op.get_attr("is_constant") 77 78 79def GetLoopConstantEnter(value): 80 """Return the enter op if we can infer `value` to be a loop invariant.""" 81 id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"} 82 op = value.op 83 while op.type in id_ops: 84 op = op.inputs[0].op 85 return op if IsLoopConstantEnter(op) else None 86 87 88def GetOutputContext(op): 89 """Return the control flow context for the output of an op.""" 90 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 91 # Exit nodes usually have a control flow context, except in the case where the 92 # exit node was imported via import_graph_def (in which case no nodes have 93 # control flow contexts). 94 if ctxt is not None and IsLoopExit(op): 95 ctxt = ctxt.outer_context 96 return ctxt 97 98 99def GetContainingWhileContext(ctxt, stop_ctxt=None): 100 """Returns the first ancestor WhileContext of `ctxt`. 101 102 Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a 103 while loop. 104 105 Args: 106 ctxt: ControlFlowContext 107 stop_ctxt: ControlFlowContext, optional. If provided, the search will end 108 if it sees stop_ctxt. 109 110 Returns: 111 `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing 112 `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not 113 `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal. 114 """ 115 while ctxt: 116 if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt 117 ctxt = ctxt.outer_context 118 return None 119 120 121def GetContainingXLAContext(ctxt): 122 """Returns the first ancestor XLAContext of `ctxt`. 123 124 Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a 125 while loop. 126 127 Args: 128 ctxt: ControlFlowContext 129 130 Returns: 131 `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing 132 `ctxt`, or None if `ctxt` is not in a while loop. 133 """ 134 while ctxt: 135 if ctxt.IsXLAContext(): return ctxt 136 ctxt = ctxt.outer_context 137 return None 138 139 140def GetContainingCondContext(ctxt): 141 """Returns the first ancestor CondContext of `ctxt`. 142 143 Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond. 144 145 Args: 146 ctxt: ControlFlowContext 147 148 Returns: 149 `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing 150 `ctxt`, or None if `ctxt` is not in a cond. 151 """ 152 while ctxt: 153 if ctxt.IsCondContext(): return ctxt 154 ctxt = ctxt.outer_context 155 return None 156 157 158def IsContainingContext(ctxt, maybe_containing_ctxt): 159 """Returns true if `maybe_containing_ctxt` is or contains `ctxt`.""" 160 while ctxt is not maybe_containing_ctxt: 161 if ctxt is None: return False 162 ctxt = ctxt.outer_context 163 return True 164 165 166def CheckInputFromValidContext(op, input_op): 167 """Returns whether `input_op` can be used from `op`s context. 168 169 Conceptually, only inputs from op's while context or any ancestor while 170 context (including outside of any context) are valid. In practice, there are 171 many other edge cases as well. 172 173 Args: 174 op: Operation 175 input_op: Operation 176 177 Raises: 178 ValueError: if input_op is from an invalid context. 179 """ 180 op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access 181 input_ctxt = GetOutputContext(input_op) 182 valid = False 183 184 if not input_ctxt: 185 # input_op isn't in a control flow context. 186 valid = True 187 elif op_ctxt is input_ctxt: 188 # input_op is in the same context as op. 189 valid = True 190 else: 191 while_ctxt = GetContainingWhileContext(op_ctxt) 192 input_while_ctxt = GetContainingWhileContext(input_ctxt) 193 194 if while_ctxt is None: 195 if input_while_ctxt is None: 196 # Neither op nor input_op is in a while loop, but one or both are in 197 # conds. We allow this, although execution will fail if the branch 198 # corresponding to input_op's cond context isn't taken. 199 valid = True 200 # Invalid if op isn't in a while loop and input_op is. Unless... 201 if IsLoopEnter(op): 202 # WhileContext._BuildLoop clears context for Enter nodes. 203 valid = True 204 if IsSwitch(op): 205 # CondContext.AddValue clears context for Switch nodes. 206 valid = True 207 elif IsContainingContext(while_ctxt, input_while_ctxt): 208 # input_op is in a while loop which contains op's while loop (or not in a 209 # while loop at all). 210 valid = True 211 elif (while_ctxt.grad_state and 212 IsContainingContext(while_ctxt.grad_state.forward_context, 213 input_while_ctxt)): 214 # op is in a gradient context and input_op is in the associated forward 215 # pass context or an ancestor thereof. This case is need to build while 216 # loop gradients. 217 # NOTE(skyewm): we theoretically also need this case for custom gradient 218 # functions that close over tensors from ancestor contexts, but I haven't 219 # verified this. 220 valid = True 221 elif (while_ctxt.grad_state and 222 while_ctxt.grad_state.forward_context is 223 input_while_ctxt._outer_context): # pylint: disable=protected-access 224 # op is in a gradient context and input_op is in a child of the associated 225 # forward pass context. This case is needed for the gradients of while 226 # loops with conds. 227 valid = True 228 elif (input_while_ctxt.grad_state and 229 input_while_ctxt.grad_state.forward_context is while_ctxt): 230 # input_op is in the gradient context of op's context. This case is needed 231 # when the gradient of a while loop gradient is requested (this will 232 # eventually fail unless there is a stop_gradient() or similar). 233 valid = True 234 elif (input_while_ctxt.grad_state and 235 input_ctxt.grad_state.forward_context.grad_state and 236 input_ctxt.grad_state.forward_context.grad_state.forward_context is 237 while_ctxt): 238 # input_op is in the grad grad context of op's context. This case is 239 # needed when the gradient of a while loop gradient is requested (this 240 # will eventually fail unless there is a stop_gradient() or similar). 241 valid = True 242 243 if not valid: 244 if while_ctxt: 245 error_msg = ( 246 "Cannot use '%s' as input to '%s' because they are in different while" 247 " loops." % (op.name, input_op.name)) 248 else: 249 error_msg = ( 250 "Cannot use '%s' as input to '%s' because '%s' is in a while loop." 251 % (input_op.name, op.name, input_op.name)) 252 253 # Log the error message plus the relevant stack traces. The stacks may be 254 # useful for debugging this error, but we don't want to raise an 255 # unreadable exception. 256 log_msg = error_msg 257 log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt) 258 log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt) 259 log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % ( 260 op.name, "".join(traceback.format_list(op.traceback)), 261 input_op.name, "".join(traceback.format_list(input_op.traceback))) 262 logging.info(log_msg) 263 raise ValueError(error_msg + " See info log for more details.") 264