• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilties for V2 control flow."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.python.eager import context
24from tensorflow.python.eager import function
25from tensorflow.python.framework import ops
26from tensorflow.python.framework.func_graph import FuncGraph
27from tensorflow.python.ops import control_flow_util
28
29
30class CondBranchFuncGraph(FuncGraph):
31  """FuncGraph for branches of tf.cond().
32
33  This is used to distinguish cond branches from other functions.
34  """
35  pass
36
37
38class WhileCondFuncGraph(FuncGraph):
39  """FuncGraph for the condition of tf.while_loop().
40
41  This is used to distinguish while conditions from other functions.
42  """
43  pass
44
45
46class WhileBodyFuncGraph(FuncGraph):
47  """FuncGraph for the body of tf.while_loop().
48
49  This is used to distinguish while bodies from other functions.
50  """
51  pass
52
53
54def in_defun():
55  """Returns if the current graph is, or is nested in, a defun."""
56  if context.executing_eagerly(): return False
57
58  graph = ops.get_default_graph()
59  while (isinstance(graph, CondBranchFuncGraph) or
60         isinstance(graph, WhileBodyFuncGraph)):
61    graph = graph.outer_graph
62  return isinstance(graph, FuncGraph)
63
64
65def create_new_tf_function(func_graph):
66  """Converts func_graph to a TF_Function and adds it to the current graph.
67
68  Args:
69    func_graph: FuncGraph
70
71  Returns:
72    The name of the new TF_Function.
73  """
74  func = function._EagerDefinedFunction(  # pylint: disable=protected-access
75      func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {})
76  func.add_to_graph(func_graph.outer_graph)
77  return func_graph.name
78
79
80def unique_fn_name(scope, name):
81  """Returns a unique name to use for a control flow function.
82
83  Args:
84    scope: A name scope string.
85    name: An identifier for this function (e.g. "true", "body").
86
87  Returns:
88    A string, the name to use for the function.
89  """
90  return ("%s%s_%s" % (scope, name, ops.uid())).replace("/", "_")
91
92
93def unique_grad_fn_name(forward_name):
94  return "%s_grad_%s" % (forward_name, ops.uid())
95
96
97def maybe_set_lowering_attr(op):
98  """Sets the flag to enable lowering on `op` if necessary.
99
100  Lowering allows cond_v2 and while_v2 to avoid some of the limitations of
101  Functions, allowing users to specify devices & colocation inside of cond_v2
102  and while_v2 input functions, and enabling non-strict evaluation & partial
103  pruning. This brings v2 control flow closer to feature parity with v1 control
104  flow.
105
106  However, we do not lower in the following cases:
107    - When the `If` or `While` ops are in the XLA context. Because it is easier
108      for XLA to apply its own optimizations when dealing with un-lowered
109      control flow operators than with low-level control flow primitives.
110    - When the eager execution context specifies the executor of functions to
111      be the single threaded executor (see context.function_executor_type()).
112      Because the single threaded executor does not support v1 control flow ops.
113
114  Args:
115    op: An `If` or `While` Operation.
116  """
117  if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
118      context.context().function_call_options.executor_type !=
119      "SINGLE_THREADED_EXECUTOR"):
120    # pylint: disable=protected-access
121    op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
122    # pylint: enable=protected-access
123
124
125def maybe_propagate_compile_time_consts_in_xla(op):
126  """Tells XLA whether to propagate compile-time consts in the loop body.
127
128  This is needed to make compile time constants available to ops, for example
129  `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this
130  would always be turned on, but that doesn't work with legacy functionalized
131  while_loops.
132
133  Args:
134    op: A `While` Operation.
135  """
136  if control_flow_util.GraphOrParentsInXlaContext(op.graph):
137    # pylint: disable=protected-access
138    op._set_attr("_xla_propagate_compile_time_consts",
139                 attr_value_pb2.AttrValue(b=True))
140    # pylint: enable=protected-access
141