1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities for V2 control flow.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.core.framework import attr_value_pb2 23from tensorflow.python.distribute import distribution_strategy_context 24from tensorflow.python.eager import context 25from tensorflow.python.eager import function 26from tensorflow.python.framework import function_def_to_graph 27from tensorflow.python.framework import ops 28from tensorflow.python.framework.func_graph import FuncGraph 29from tensorflow.python.ops import control_flow_util 30from tensorflow.python.ops import control_flow_v2_func_graphs 31from tensorflow.python.ops import gradients_util 32from tensorflow.python.util import keras_deps 33from tensorflow.python.util import tf_contextlib 34 35 36_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None 37_DISABLE_LOWER_USING_SWITCH_MERGE = False 38 39 40CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph 41WhileCondFuncGraph = control_flow_v2_func_graphs.WhileCondFuncGraph 42WhileBodyFuncGraph = control_flow_v2_func_graphs.WhileBodyFuncGraph 43 44 45def in_defun(): 46 """Returns if the current graph is, or is nested in, a defun.""" 47 if context.executing_eagerly(): return False 48 49 graph = ops.get_default_graph() 50 while (isinstance(graph, CondBranchFuncGraph) or 51 isinstance(graph, WhileBodyFuncGraph) or 52 isinstance(graph, WhileCondFuncGraph)): 53 graph = graph.outer_graph 54 return isinstance(graph, FuncGraph) 55 56 57def in_while_loop_defun(graph): 58 """Returns if the graph is a while loop FuncGraph.""" 59 if context.executing_eagerly(): return False 60 return (isinstance(graph, WhileCondFuncGraph) or 61 isinstance(graph, WhileBodyFuncGraph)) 62 63 64def create_new_tf_function(func_graph): 65 """Converts func_graph to a TF_Function and adds it to the current graph. 66 67 Args: 68 func_graph: FuncGraph 69 70 Returns: 71 The name of the new TF_Function. 72 """ 73 func = function._EagerDefinedFunction( # pylint: disable=protected-access 74 func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {}) 75 func.add_to_graph(func_graph.outer_graph) 76 return func_graph.name 77 78 79def unique_fn_name(scope, name): 80 """Returns a unique name to use for a control flow function. 81 82 Args: 83 scope: A name scope string. 84 name: An identifier for this function (e.g. "true", "body"). 85 86 Returns: 87 A string, the name to use for the function. 88 """ 89 return ("%s%s_%s" % (scope, name, ops.uid())).replace("/", "_") 90 91 92def unique_grad_fn_name(forward_name): 93 return "%s_grad_%s" % (forward_name, ops.uid()) 94 95 96def maybe_set_lowering_attr(op, lower_using_switch_merge=None): 97 """Sets the flag to enable lowering on `op` if necessary. 98 99 Lowering allows cond_v2 and while_v2 to avoid some of the limitations of 100 Functions, allowing users to specify devices & colocation inside of cond_v2 101 and while_v2 input functions, and enabling non-strict evaluation & partial 102 pruning. This brings v2 control flow closer to feature parity with v1 control 103 flow. 104 105 However, we do not lower in the following cases: 106 - When the `If` or `While` ops are in the XLA context. Because it is easier 107 for XLA to apply its own optimizations when dealing with un-lowered 108 control flow operators than with low-level control flow primitives. 109 - When the eager execution context specifies the executor of functions to 110 be the single threaded executor (see context.function_executor_type()). 111 Because the single threaded executor does not support v1 control flow ops. 112 - When 'lower_using_switch_merge' is explicitly set to False. 113 114 Args: 115 op: An `If` or `While` Operation. 116 lower_using_switch_merge: Explicit value to lower or not (optional). 117 """ 118 if lower_using_switch_merge is not None: 119 # pylint: disable=protected-access 120 op._set_attr("_lower_using_switch_merge", 121 attr_value_pb2.AttrValue(b=lower_using_switch_merge)) 122 # pylint: enable=protected-access 123 elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and 124 not control_flow_util.GraphOrParentsInXlaContext(op.graph) and 125 context.context().function_call_options.executor_type != 126 "SINGLE_THREADED_EXECUTOR"): 127 # pylint: disable=protected-access 128 op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) 129 # pylint: enable=protected-access 130 131 132def maybe_propagate_compile_time_consts_in_xla(op): 133 """Tells XLA whether to propagate compile-time consts in the loop body. 134 135 This is needed to make compile time constants available to ops, for example 136 `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this 137 would always be turned on, but that doesn't work with legacy functionalized 138 while_loops. 139 140 Args: 141 op: A `While` Operation. 142 """ 143 if control_flow_util.GraphOrParentsInXlaContext(op.graph): 144 # pylint: disable=protected-access 145 op._set_attr("_xla_propagate_compile_time_consts", 146 attr_value_pb2.AttrValue(b=True)) 147 # pylint: enable=protected-access 148 149 150def resource_input_index(tensor_name, input_names, node_defs, functions): 151 """Returns the index of the input corresponding to `tensor_name`. 152 153 This method is used to find the corresponding index of an arbitrary resource 154 tensor in a function (the function could be a loop body). We assume that 155 resource handles are never created in functions, so that every resource 156 tensor can be traced back to a function input. 157 158 The awkward signature of this method is to make it work with both FuncGraphs 159 and FunctionDefs. This is so we can recurse on function call ops without 160 building the corresponding FuncGraph (note that even if a FuncGraph for a 161 FunctionDef already exists, the input/output/node names may have been 162 changed when the FuncGraph was serialized to the FunctionDef, which makes it 163 unusable with this algorithm). 164 165 Args: 166 tensor_name: the name of the resource tensor to be resolved to an input. 167 input_names: a list of the names of all inputs to the function. 168 node_defs: a dict mapping op name -> NodeDef for every op in the function. 169 functions: a dict mapping function name -> _EagerDefinedFunction. 170 171 Returns: 172 The index into input_names corresponding to `tensor_name`. 173 """ 174 while tensor_name not in input_names: 175 # FunctionDefs and graphs use different tensor naming conventions. 176 parts = tensor_name.split(":") 177 if len(parts) == 3: 178 op_name, _, output_idx = parts 179 elif len(parts) == 2: 180 op_name, output_idx = parts 181 else: 182 assert len(parts) == 1 183 op_name = parts[0] 184 output_idx = 0 185 tensor_name = "%s:%d" % (tensor_name, output_idx) 186 # Check again for cases where the tensor suffix (":0") is stripped out. 187 if tensor_name in input_names: 188 break 189 output_idx = int(output_idx) 190 node_def = node_defs[op_name] 191 192 def _extract_input_index(function_attribute_name): 193 func_name = node_def.attr[function_attribute_name].func.name 194 fdef = functions[func_name].definition 195 output_arg_name = fdef.signature.output_arg[output_idx].name 196 output_tensor_name = fdef.ret[output_arg_name] 197 return resource_input_index( 198 output_tensor_name, [arg.name for arg in fdef.signature.input_arg], 199 {ndef.name: ndef for ndef in fdef.node_def}, functions) 200 201 if node_def.op in ("Identity", "While"): 202 # Captured resources occur at the same index in the lists of inputs and 203 # outputs of a while or identity op. So we lookup the input of `tensor.op` 204 # at the same index as the index of `tensor` in the `tensor.op.outputs`. 205 tensor_name = node_def.input[output_idx] 206 elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"): 207 # Functions output any captured resource tensors used by their 208 # gradients. `tensor_name` is one of these outputs from a nested 209 # function call, so recursively find the corresponding input in the 210 # nested FunctionDef. 211 tensor_name = node_def.input[_extract_input_index("f")] 212 elif node_def.op in ("If", "StatelessIf"): 213 input_index = _extract_input_index("then_branch") 214 if input_index != _extract_input_index("else_branch"): 215 raise AssertionError( 216 ("Expected cond branches ({} op) to each have the same " 217 "input->output mapping of resources.").format(node_def.op)) 218 tensor_name = node_def.input[ 219 # Ignore the `cond` input; the function inputs come after. 220 input_index + 1] 221 else: 222 # We assume there are no other ops types that will "forward" resource 223 # handles like this, so all other handles must have been created by the 224 # op. (Note that cond_v2 wraps resource handle outputs in optionals, 225 # which we'll end up accumulating). 226 raise ValueError("Taking gradient of a while loop which creates " 227 "a resource in its body is not supported: %s (%s)" 228 % (op_name, node_def.op)) 229 230 return input_names.index(tensor_name) 231 232 233@tf_contextlib.contextmanager 234def clear_control_inputs(): 235 """Clears the control inputs but preserves the ControlFlowContext. 236 237 This is needed to preserve the XLAControlFlowControl when clearing 238 control inputs for the gradient accumulators in while_v2. 239 `ops.control_dependencies` does not allow that. 240 241 Yields: 242 A context manager in which the ops created will not have any control inputs 243 by default but the control flow context is the same. 244 """ 245 # pylint: disable=protected-access 246 control_flow_context = ops.get_default_graph()._get_control_flow_context() 247 with ops.control_dependencies(None): 248 ops.get_default_graph()._set_control_flow_context(control_flow_context) 249 yield 250 # pylint: enable=protected-access 251 252 253def _is_tpu_strategy(strategy): 254 return (strategy is not None and 255 strategy.__class__.__name__.startswith("TPUStrategy")) 256 257 258def _is_building_keras_layer(): 259 # TODO(srbs): Remove this function when we no long support session with Keras. 260 keras_call_context_function = keras_deps.get_call_context_function() 261 if keras_call_context_function: 262 return keras_call_context_function().layer is not None 263 else: 264 return False 265 266 267def output_all_intermediates(): 268 """Whether to output all intermediates of a functional control flow op. 269 270 The default behavior is to output intermediates only when building a Keras 271 Layer in graph mode and that too when certain other conditions are met: 272 1. We do not output intermediates if the functional control flow op 273 is being built inside a FuncGraph which is not a If/While graph. This 274 guards against outputting intermediates in eager mode since keras adds 275 tensors to a FuncGraph named "keras_graph" in that case. Also because we 276 do not output intermediates of tf.function (since this feature is only for 277 backwards compatibility) outputting intermediates of functional control 278 flow ops built inside tf.function is of no value. 279 2. We do not output intermediates when the compilation is using XLA or for a 280 TPU. 281 3. We do not output intermediates when a single threaded executor is used 282 since that does not perform inlining and pruning. 283 284 Returns: 285 A bool telling whether to output all intermediates. 286 """ 287 if _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE is not None: 288 return _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 289 if in_defun(): 290 return False 291 if (control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()) or 292 _is_tpu_strategy(distribution_strategy_context.get_strategy())): 293 return False 294 if (context.context().function_call_options.executor_type == 295 "SINGLE_THREADED_EXECUTOR"): 296 return False 297 return _is_building_keras_layer() 298 299 300def get_func_graph(op, input_shapes, func_name): 301 """Generates and returns a FuncGraph for the given op and input_shapes.""" 302 fdef = None 303 graph = op.graph 304 # Recursively search the func in graphs. 305 while graph is not None: 306 func = graph._get_function(func_name) # pylint: disable=protected-access 307 if func is not None: 308 fdef = func.definition 309 break 310 if hasattr(graph, "outer_graph"): 311 graph = graph.outer_graph 312 else: 313 break 314 315 if fdef is None: 316 raise KeyError("%s cannot be found in the graph" % func_name) 317 318 # `op.graph` may not be the same as `ops.get_default_graph()` e.g. 319 # in the case of nested if ops or when the gradient is being computed 320 # from inside a Defun. We build the `func_graph` with `op.graph` as its 321 # `outer_graph`. This resembles how the `FuncGraph` was built in the 322 # forward pass. We need this so that we can resolve references to tensors 323 # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. 324 with op.graph.as_default(): 325 func_graph = function_def_to_graph.function_def_to_graph( 326 fdef, input_shapes) 327 return func_graph 328 329 330def get_op_and_outputs(op_or_outputs): 331 if isinstance(op_or_outputs, ops.Operation): 332 return op_or_outputs, [] 333 elif not op_or_outputs: # Empty list. 334 return None, [] 335 else: 336 return op_or_outputs[0].op, op_or_outputs 337 338 339def graph_wrapped_for_higher_order_tape_gradients(graph): 340 """Check if `graph` is wrapped by `run_as_function_for_tape_gradients`.""" 341 while graph is not None: 342 if "cflow_gradient_wrapper" in getattr(graph, "name", ""): 343 return True 344 graph = getattr(graph, "outer_graph", None) 345 return False 346 347 348def run_as_function_for_tape_gradients(make_op, inputs): 349 """Fix higher-order tape gradients by wrapping `make_op` in a function. 350 351 Args: 352 make_op: A function that takes a list of inputs and returns a list of output 353 tensors. This function should set any handle data relevant to its outputs 354 before returning. 355 inputs: A list of tensors to check for tape gradients and pass to 356 `make_op`. These should include all tensors used in `make_op`. 357 358 Returns: 359 Tensors corresponding to `make_op`'s output. 360 """ 361 # GradientTapes created inside a function currently don't work well with 362 # un-wrapped control flow ops in that same function. Wrapping in an extra 363 # layer of intermediate function means we run extra logic in the function 364 # gradient code to record the correct intermediates on the tape. 365 # 366 # The function attribute inputs to control flow ops are not hashable, so we 367 # pass everything as a capture to bypass defun's caching. 368 if (gradients_util.PossibleTapeGradientTypes(inputs) 369 == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER 370 # We only need one function between the tape and the op; if we've already 371 # wrapped once, we stop wrapping to avoid infinite recursion. 372 and not (ops.get_default_graph().building_function 373 and "cflow_gradient_wrapper" in ops.get_default_graph().name)): 374 results = function.defun_with_attributes( 375 make_op, 376 autograph=False, 377 attributes=dict(func_name="cflow_gradient_wrapper"))(inputs) 378 return results 379 else: 380 return make_op(inputs) 381