• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for wrapping converted functions bodies with auxiliary logic."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.autograph.core import ag_ctx
22from tensorflow.python.autograph.core import converter
23from tensorflow.python.autograph.operators import variables
24from tensorflow.python.framework import auto_control_deps
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.util import nest
28
29
30# TODO(mdan): Move this into operators - it represents a function definition.
31
32
33class FunctionScope(object):
34  """Context manager that wraps the body of a converted function.
35
36  This context manager handles various operations related to the scope of a
37  function:
38    * optional TF name scopes - these name scopes match the name of the
39        function, for easy visualization in tensorBoard;
40    * optional automatic control dependencies - this adds the same mechanism
41        for control dependencies that is used by `@tf.function`; it can be
42        optionally enabled when using `tf.autograph.to_graph`;
43    * tracking of autograph conversion state (whether it's enabled by the user,
44        conversion options;
45  """
46
47  def __init__(self, function_name, scope_name, options):
48    self.name = scope_name
49    self.options = options
50
51    if options.user_requested:
52      self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED,
53                                                   options)
54    self.callopts = options.call_options()
55
56    use_name_scope = options.uses(converter.Feature.NAME_SCOPES)
57    self.use_name_scope = use_name_scope
58    if use_name_scope:
59      self.name_scope = ops.name_scope(self._sanitize(function_name))
60
61    use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS)
62    self.use_auto_deps = use_auto_deps
63    if use_auto_deps:
64      self.autodeps_scope = auto_control_deps.AutomaticControlDependencies()
65      self._return_value_marked = False
66
67  def _sanitize(self, name):
68    """See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope."""
69    # TensorFlow doesn't like leading underscores at the top level.
70    if name and name.startswith('_'):
71      name = 'fn' + name
72    return name
73
74  def __enter__(self):
75    if self.options.user_requested:
76      self.autograph_ctx.__enter__()
77    if self.use_name_scope:
78      self.name_scope.__enter__()
79    if self.use_auto_deps:
80      self.autodeps_scope.__enter__()
81    return self
82
83  def __exit__(self, exc_type, exc_val, exc_tb):
84    if self.options.user_requested:
85      self.autograph_ctx.__exit__(exc_type, exc_val, exc_tb)
86    if self.use_name_scope:
87      self.name_scope.__exit__(exc_type, exc_val, exc_tb)
88    if self.use_auto_deps:
89      self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb)
90
91  def ret(self, value, did_return):
92    """Marks a value as returned from the function guarded by the scope."""
93    del did_return
94
95    if isinstance(value, variables.UndefinedReturnValue):
96      return None
97
98    if self.use_auto_deps:
99      self._return_value_marked = True
100      if value is None:
101        # We don't create dummy returns, to preserve Python semantics. The user
102        # is responsible for adding a return value to the top-level function.
103        return None
104
105      def _mark_return_if_tensor(t):
106        if tensor_util.is_tf_type(t):
107          return self.autodeps_scope.mark_as_return(t)
108        return t
109
110      value = nest.map_structure(_mark_return_if_tensor, value)
111    return value
112
113
114def with_function_scope(thunk, scope_name, options):
115  """Inline version of the FunctionScope context manager."""
116  with FunctionScope('lambda_', scope_name, options) as scope:
117    return thunk(scope)
118