• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Visitor restricting traversal to only the public tensorflow API."""
16
17import re
18
19from tensorflow.python.util import tf_inspect
20
21
22class PublicAPIVisitor:
23  """Visitor to use with `traverse` to visit exactly the public TF API."""
24
25  def __init__(self, visitor):
26    """Constructor.
27
28    `visitor` should be a callable suitable as a visitor for `traverse`. It will
29    be called only for members of the public TensorFlow API.
30
31    Args:
32      visitor: A visitor to call for the public API.
33    """
34    self._visitor = visitor
35    self._root_name = 'tf'
36
37    # Modules/classes we want to suppress entirely.
38    self._private_map = {
39        'tf': [
40            'compiler',
41            'core',
42            # TODO(scottzhu): See b/227410870 for more details. Currently
43            # dtensor API is exposed under tf.experimental.dtensor, but in the
44            # meantime, we have tensorflow/dtensor directory which will be treat
45            # as a python package. We want to avoid step into the
46            # tensorflow/dtensor directory when visit the API.
47            # When the tf.dtensor becomes the public API, it will actually pick
48            # up from tf.compat.v2.dtensor as priority and hide the
49            # tensorflow/dtensor package.
50            'dtensor',
51            'python',
52        ],
53        # Some implementations have this internal module that we shouldn't
54        # expose.
55        'tf.flags': ['cpp_flags'],
56    }
57
58    # Modules/classes we do not want to descend into if we hit them. Usually,
59    # system modules exposed through platforms for compatibility reasons.
60    # Each entry maps a module path to a name to ignore in traversal.
61    self._do_not_descend_map = {
62        'tf': [
63            'examples',
64            'flags',  # Don't add flags
65            # TODO(drpng): This can be removed once sealed off.
66            'platform',
67            # TODO(drpng): This can be removed once sealed.
68            'pywrap_tensorflow',
69            # TODO(drpng): This can be removed once sealed.
70            'user_ops',
71            'tools',
72            'tensorboard',
73        ],
74
75        ## Everything below here is legitimate.
76        # It'll stay, but it's not officially part of the API.
77        'tf.app': ['flags'],
78        # Imported for compatibility between py2/3.
79        'tf.test': ['mock'],
80    }
81
82  @property
83  def private_map(self):
84    """A map from parents to symbols that should not be included at all.
85
86    This map can be edited, but it should not be edited once traversal has
87    begun.
88
89    Returns:
90      The map marking symbols to not include.
91    """
92    return self._private_map
93
94  @property
95  def do_not_descend_map(self):
96    """A map from parents to symbols that should not be descended into.
97
98    This map can be edited, but it should not be edited once traversal has
99    begun.
100
101    Returns:
102      The map marking symbols to not explore.
103    """
104    return self._do_not_descend_map
105
106  def set_root_name(self, root_name):
107    """Override the default root name of 'tf'."""
108    self._root_name = root_name
109
110  def _is_private(self, path, name, obj=None):
111    """Return whether a name is private."""
112    # TODO(wicke): Find out what names to exclude.
113    del obj  # Unused.
114    return ((path in self._private_map and name in self._private_map[path]) or
115            (name.startswith('_') and not re.match('__.*__$', name) or
116             name in ['__base__', '__class__', '__next_in_mro__']))
117
118  def _do_not_descend(self, path, name):
119    """Safely queries if a specific fully qualified name should be excluded."""
120    return (path in self._do_not_descend_map and
121            name in self._do_not_descend_map[path])
122
123  def __call__(self, path, parent, children):
124    """Visitor interface, see `traverse` for details."""
125
126    # Avoid long waits in cases of pretty unambiguous failure.
127    if tf_inspect.ismodule(parent) and len(path.split('.')) > 10:
128      raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a '
129                         'problem with an accidental public import.' %
130                         (self._root_name, path))
131
132    # Includes self._root_name
133    full_path = '.'.join([self._root_name, path]) if path else self._root_name
134
135    # Remove things that are not visible.
136    for name, child in list(children):
137      if self._is_private(full_path, name, child):
138        children.remove((name, child))
139
140    self._visitor(path, parent, children)
141
142    # Remove things that are visible, but which should not be descended into.
143    for name, child in list(children):
144      if self._do_not_descend(full_path, name):
145        children.remove((name, child))
146