1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilties for V2 control flow.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.core.framework import attr_value_pb2 23from tensorflow.python.eager import context 24from tensorflow.python.eager import function 25from tensorflow.python.framework import ops 26from tensorflow.python.framework.func_graph import FuncGraph 27from tensorflow.python.ops import control_flow_util 28 29 30class CondBranchFuncGraph(FuncGraph): 31 """FuncGraph for branches of tf.cond(). 32 33 This is used to distinguish cond branches from other functions. 34 """ 35 pass 36 37 38class WhileCondFuncGraph(FuncGraph): 39 """FuncGraph for the condition of tf.while_loop(). 40 41 This is used to distinguish while conditions from other functions. 42 """ 43 pass 44 45 46class WhileBodyFuncGraph(FuncGraph): 47 """FuncGraph for the body of tf.while_loop(). 48 49 This is used to distinguish while bodies from other functions. 50 """ 51 pass 52 53 54def in_defun(): 55 """Returns if the current graph is, or is nested in, a defun.""" 56 if context.executing_eagerly(): return False 57 58 graph = ops.get_default_graph() 59 while (isinstance(graph, CondBranchFuncGraph) or 60 isinstance(graph, WhileBodyFuncGraph)): 61 graph = graph.outer_graph 62 return isinstance(graph, FuncGraph) 63 64 65def create_new_tf_function(func_graph): 66 """Converts func_graph to a TF_Function and adds it to the current graph. 67 68 Args: 69 func_graph: FuncGraph 70 71 Returns: 72 The name of the new TF_Function. 73 """ 74 func = function._EagerDefinedFunction( # pylint: disable=protected-access 75 func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {}) 76 func.add_to_graph(func_graph.outer_graph) 77 return func_graph.name 78 79 80def unique_fn_name(scope, name): 81 """Returns a unique name to use for a control flow function. 82 83 Args: 84 scope: A name scope string. 85 name: An identifier for this function (e.g. "true", "body"). 86 87 Returns: 88 A string, the name to use for the function. 89 """ 90 return ("%s%s_%s" % (scope, name, ops.uid())).replace("/", "_") 91 92 93def unique_grad_fn_name(forward_name): 94 return "%s_grad_%s" % (forward_name, ops.uid()) 95 96 97def maybe_set_lowering_attr(op): 98 """Sets the flag to enable lowering on `op` if necessary. 99 100 Lowering allows cond_v2 and while_v2 to avoid some of the limitations of 101 Functions, allowing users to specify devices & colocation inside of cond_v2 102 and while_v2 input functions, and enabling non-strict evaluation & partial 103 pruning. This brings v2 control flow closer to feature parity with v1 control 104 flow. 105 106 However, we do not lower in the following cases: 107 - When the `If` or `While` ops are in the XLA context. Because it is easier 108 for XLA to apply its own optimizations when dealing with un-lowered 109 control flow operators than with low-level control flow primitives. 110 - When the eager execution context specifies the executor of functions to 111 be the single threaded executor (see context.function_executor_type()). 112 Because the single threaded executor does not support v1 control flow ops. 113 114 Args: 115 op: An `If` or `While` Operation. 116 """ 117 if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and 118 context.context().function_call_options.executor_type != 119 "SINGLE_THREADED_EXECUTOR"): 120 # pylint: disable=protected-access 121 op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) 122 # pylint: enable=protected-access 123 124 125def maybe_propagate_compile_time_consts_in_xla(op): 126 """Tells XLA whether to propagate compile-time consts in the loop body. 127 128 This is needed to make compile time constants available to ops, for example 129 `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this 130 would always be turned on, but that doesn't work with legacy functionalized 131 while_loops. 132 133 Args: 134 op: A `While` Operation. 135 """ 136 if control_flow_util.GraphOrParentsInXlaContext(op.graph): 137 # pylint: disable=protected-access 138 op._set_attr("_xla_propagate_compile_time_consts", 139 attr_value_pb2.AttrValue(b=True)) 140 # pylint: enable=protected-access 141