• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities for V2 control flow."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.python.distribute import distribution_strategy_context
24from tensorflow.python.eager import context
25from tensorflow.python.eager import function
26from tensorflow.python.framework import function_def_to_graph
27from tensorflow.python.framework import ops
28from tensorflow.python.framework.func_graph import FuncGraph
29from tensorflow.python.ops import control_flow_util
30from tensorflow.python.ops import control_flow_v2_func_graphs
31from tensorflow.python.ops import gradients_util
32from tensorflow.python.util import keras_deps
33from tensorflow.python.util import tf_contextlib
34
35
36_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None
37_DISABLE_LOWER_USING_SWITCH_MERGE = False
38
39
40CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph
41WhileCondFuncGraph = control_flow_v2_func_graphs.WhileCondFuncGraph
42WhileBodyFuncGraph = control_flow_v2_func_graphs.WhileBodyFuncGraph
43
44
45def in_defun():
46  """Returns if the current graph is, or is nested in, a defun."""
47  if context.executing_eagerly(): return False
48
49  graph = ops.get_default_graph()
50  while (isinstance(graph, CondBranchFuncGraph) or
51         isinstance(graph, WhileBodyFuncGraph) or
52         isinstance(graph, WhileCondFuncGraph)):
53    graph = graph.outer_graph
54  return isinstance(graph, FuncGraph)
55
56
57def in_while_loop_defun(graph):
58  """Returns if the graph is a while loop FuncGraph."""
59  if context.executing_eagerly(): return False
60  return (isinstance(graph, WhileCondFuncGraph) or
61          isinstance(graph, WhileBodyFuncGraph))
62
63
64def create_new_tf_function(func_graph):
65  """Converts func_graph to a TF_Function and adds it to the current graph.
66
67  Args:
68    func_graph: FuncGraph
69
70  Returns:
71    The name of the new TF_Function.
72  """
73  func = function._EagerDefinedFunction(  # pylint: disable=protected-access
74      func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {})
75  func.add_to_graph(func_graph.outer_graph)
76  return func_graph.name
77
78
79def unique_fn_name(scope, name):
80  """Returns a unique name to use for a control flow function.
81
82  Args:
83    scope: A name scope string.
84    name: An identifier for this function (e.g. "true", "body").
85
86  Returns:
87    A string, the name to use for the function.
88  """
89  return ("%s%s_%s" % (scope, name, ops.uid())).replace("/", "_")
90
91
92def unique_grad_fn_name(forward_name):
93  return "%s_grad_%s" % (forward_name, ops.uid())
94
95
96def maybe_set_lowering_attr(op, lower_using_switch_merge=None):
97  """Sets the flag to enable lowering on `op` if necessary.
98
99  Lowering allows cond_v2 and while_v2 to avoid some of the limitations of
100  Functions, allowing users to specify devices & colocation inside of cond_v2
101  and while_v2 input functions, and enabling non-strict evaluation & partial
102  pruning. This brings v2 control flow closer to feature parity with v1 control
103  flow.
104
105  However, we do not lower in the following cases:
106    - When the `If` or `While` ops are in the XLA context. Because it is easier
107      for XLA to apply its own optimizations when dealing with un-lowered
108      control flow operators than with low-level control flow primitives.
109    - When the eager execution context specifies the executor of functions to
110      be the single threaded executor (see context.function_executor_type()).
111      Because the single threaded executor does not support v1 control flow ops.
112    - When 'lower_using_switch_merge' is explicitly set to False.
113
114  Args:
115    op: An `If` or `While` Operation.
116    lower_using_switch_merge: Explicit value to lower or not (optional).
117  """
118  if lower_using_switch_merge is not None:
119    # pylint: disable=protected-access
120    op._set_attr("_lower_using_switch_merge",
121                 attr_value_pb2.AttrValue(b=lower_using_switch_merge))
122    # pylint: enable=protected-access
123  elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and
124        not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
125        context.context().function_call_options.executor_type !=
126        "SINGLE_THREADED_EXECUTOR"):
127    # pylint: disable=protected-access
128    op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
129    # pylint: enable=protected-access
130
131
132def maybe_propagate_compile_time_consts_in_xla(op):
133  """Tells XLA whether to propagate compile-time consts in the loop body.
134
135  This is needed to make compile time constants available to ops, for example
136  `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this
137  would always be turned on, but that doesn't work with legacy functionalized
138  while_loops.
139
140  Args:
141    op: A `While` Operation.
142  """
143  if control_flow_util.GraphOrParentsInXlaContext(op.graph):
144    # pylint: disable=protected-access
145    op._set_attr("_xla_propagate_compile_time_consts",
146                 attr_value_pb2.AttrValue(b=True))
147    # pylint: enable=protected-access
148
149
150def resource_input_index(tensor_name, input_names, node_defs, functions):
151  """Returns the index of the input corresponding to `tensor_name`.
152
153  This method is used to find the corresponding index of an arbitrary resource
154  tensor in a function (the function could be a loop body). We assume that
155  resource handles are never created in functions, so that every resource
156  tensor can be traced back to a function input.
157
158  The awkward signature of this method is to make it work with both FuncGraphs
159  and FunctionDefs. This is so we can recurse on function call ops without
160  building the corresponding FuncGraph (note that even if a FuncGraph for a
161  FunctionDef already exists, the input/output/node names may have been
162  changed when the FuncGraph was serialized to the FunctionDef, which makes it
163  unusable with this algorithm).
164
165  Args:
166    tensor_name: the name of the resource tensor to be resolved to an input.
167    input_names: a list of the names of all inputs to the function.
168    node_defs: a dict mapping op name -> NodeDef for every op in the function.
169    functions: a dict mapping function name -> _EagerDefinedFunction.
170
171  Returns:
172    The index into input_names corresponding to `tensor_name`.
173  """
174  while tensor_name not in input_names:
175    # FunctionDefs and graphs use different tensor naming conventions.
176    parts = tensor_name.split(":")
177    if len(parts) == 3:
178      op_name, _, output_idx = parts
179    elif len(parts) == 2:
180      op_name, output_idx = parts
181    else:
182      assert len(parts) == 1
183      op_name = parts[0]
184      output_idx = 0
185      tensor_name = "%s:%d" % (tensor_name, output_idx)
186      # Check again for cases where the tensor suffix (":0") is stripped out.
187      if tensor_name in input_names:
188        break
189    output_idx = int(output_idx)
190    node_def = node_defs[op_name]
191
192    def _extract_input_index(function_attribute_name):
193      func_name = node_def.attr[function_attribute_name].func.name
194      fdef = functions[func_name].definition
195      output_arg_name = fdef.signature.output_arg[output_idx].name
196      output_tensor_name = fdef.ret[output_arg_name]
197      return resource_input_index(
198          output_tensor_name, [arg.name for arg in fdef.signature.input_arg],
199          {ndef.name: ndef for ndef in fdef.node_def}, functions)
200
201    if node_def.op in ("Identity", "While"):
202      # Captured resources occur at the same index in the lists of inputs and
203      # outputs of a while or identity op. So we lookup the input of `tensor.op`
204      # at the same index as the index of `tensor` in the `tensor.op.outputs`.
205      tensor_name = node_def.input[output_idx]
206    elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
207      # Functions output any captured resource tensors used by their
208      # gradients.  `tensor_name` is one of these outputs from a nested
209      # function call, so recursively find the corresponding input in the
210      # nested FunctionDef.
211      tensor_name = node_def.input[_extract_input_index("f")]
212    elif node_def.op in ("If", "StatelessIf"):
213      input_index = _extract_input_index("then_branch")
214      if input_index != _extract_input_index("else_branch"):
215        raise AssertionError(
216            ("Expected cond branches ({} op) to each have the same "
217             "input->output mapping of resources.").format(node_def.op))
218      tensor_name = node_def.input[
219          # Ignore the `cond` input; the function inputs come after.
220          input_index + 1]
221    else:
222      # We assume there are no other ops types that will "forward" resource
223      # handles like this, so all other handles must have been created by the
224      # op. (Note that cond_v2 wraps resource handle outputs in optionals,
225      # which we'll end up accumulating).
226      raise ValueError("Taking gradient of a while loop which creates "
227                       "a resource in its body is not supported: %s (%s)"
228                       % (op_name, node_def.op))
229
230  return input_names.index(tensor_name)
231
232
233@tf_contextlib.contextmanager
234def clear_control_inputs():
235  """Clears the control inputs but preserves the ControlFlowContext.
236
237  This is needed to preserve the XLAControlFlowControl when clearing
238  control inputs for the gradient accumulators in while_v2.
239  `ops.control_dependencies` does not allow that.
240
241  Yields:
242    A context manager in which the ops created will not have any control inputs
243    by default but the control flow context is the same.
244  """
245  # pylint: disable=protected-access
246  control_flow_context = ops.get_default_graph()._get_control_flow_context()
247  with ops.control_dependencies(None):
248    ops.get_default_graph()._set_control_flow_context(control_flow_context)
249    yield
250  # pylint: enable=protected-access
251
252
253def _is_tpu_strategy(strategy):
254  return (strategy is not None and
255          strategy.__class__.__name__.startswith("TPUStrategy"))
256
257
258def _is_building_keras_layer():
259  # TODO(srbs): Remove this function when we no long support session with Keras.
260  keras_call_context_function = keras_deps.get_call_context_function()
261  if keras_call_context_function:
262    return keras_call_context_function().layer is not None
263  else:
264    return False
265
266
267def output_all_intermediates():
268  """Whether to output all intermediates of a functional control flow op.
269
270  The default behavior is to output intermediates only when building a Keras
271  Layer in graph mode and that too when certain other conditions are met:
272  1. We do not output intermediates if the functional control flow op
273     is being built inside a FuncGraph which is not a If/While graph. This
274     guards against outputting intermediates in eager mode since keras adds
275     tensors to a FuncGraph named "keras_graph" in that case. Also because we
276     do not output intermediates of tf.function (since this feature is only for
277     backwards compatibility) outputting intermediates of functional control
278     flow ops built inside tf.function is of no value.
279  2. We do not output intermediates when the compilation is using XLA or for a
280     TPU.
281  3. We do not output intermediates when a single threaded executor is used
282     since that does not perform inlining and pruning.
283
284  Returns:
285    A bool telling whether to output all intermediates.
286  """
287  if _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE is not None:
288    return _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
289  if in_defun():
290    return False
291  if (control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()) or
292      _is_tpu_strategy(distribution_strategy_context.get_strategy())):
293    return False
294  if (context.context().function_call_options.executor_type ==
295      "SINGLE_THREADED_EXECUTOR"):
296    return False
297  return _is_building_keras_layer()
298
299
300def get_func_graph(op, input_shapes, func_name):
301  """Generates and returns a FuncGraph for the given op and input_shapes."""
302  fdef = None
303  graph = op.graph
304  # Recursively search the func in graphs.
305  while graph is not None:
306    func = graph._get_function(func_name)  # pylint: disable=protected-access
307    if func is not None:
308      fdef = func.definition
309      break
310    if hasattr(graph, "outer_graph"):
311      graph = graph.outer_graph
312    else:
313      break
314
315  if fdef is None:
316    raise KeyError("%s cannot be found in the graph" % func_name)
317
318  # `op.graph` may not be the same as `ops.get_default_graph()` e.g.
319  # in the case of nested if ops or when the gradient is being computed
320  # from inside a Defun. We build the `func_graph` with `op.graph` as its
321  # `outer_graph`. This resembles how the `FuncGraph` was built in the
322  # forward pass. We need this so that we can resolve references to tensors
323  # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
324  with op.graph.as_default():
325    func_graph = function_def_to_graph.function_def_to_graph(
326        fdef, input_shapes)
327  return func_graph
328
329
330def get_op_and_outputs(op_or_outputs):
331  if isinstance(op_or_outputs, ops.Operation):
332    return op_or_outputs, []
333  elif not op_or_outputs:  # Empty list.
334    return None, []
335  else:
336    return op_or_outputs[0].op, op_or_outputs
337
338
339def graph_wrapped_for_higher_order_tape_gradients(graph):
340  """Check if `graph` is wrapped by `run_as_function_for_tape_gradients`."""
341  while graph is not None:
342    if "cflow_gradient_wrapper" in getattr(graph, "name", ""):
343      return True
344    graph = getattr(graph, "outer_graph", None)
345  return False
346
347
348def run_as_function_for_tape_gradients(make_op, inputs):
349  """Fix higher-order tape gradients by wrapping `make_op` in a function.
350
351  Args:
352    make_op: A function that takes a list of inputs and returns a list of output
353      tensors. This function should set any handle data relevant to its outputs
354      before returning.
355    inputs: A list of tensors to check for tape gradients and pass to
356      `make_op`. These should include all tensors used in `make_op`.
357
358  Returns:
359    Tensors corresponding to `make_op`'s output.
360  """
361  # GradientTapes created inside a function currently don't work well with
362  # un-wrapped control flow ops in that same function. Wrapping in an extra
363  # layer of intermediate function means we run extra logic in the function
364  # gradient code to record the correct intermediates on the tape.
365  #
366  # The function attribute inputs to control flow ops are not hashable, so we
367  # pass everything as a capture to bypass defun's caching.
368  if (gradients_util.PossibleTapeGradientTypes(inputs)
369      == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
370      # We only need one function between the tape and the op; if we've already
371      # wrapped once, we stop wrapping to avoid infinite recursion.
372      and not (ops.get_default_graph().building_function
373               and "cflow_gradient_wrapper" in ops.get_default_graph().name)):
374    results = function.defun_with_attributes(
375        make_op,
376        autograph=False,
377        attributes=dict(func_name="cflow_gradient_wrapper"))(inputs)
378    return results
379  else:
380    return make_op(inputs)
381