• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=line-too-long
16"""Script for updating tensorflow/tools/compatibility/renames_v2.py.
17
18To update renames_v2.py, run:
19  bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map
20  bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map
21  pyformat --in_place third_party/tensorflow/tools/compatibility/renames_v2.py
22"""
23# pylint: enable=line-too-long
24import sys
25
26from absl import app
27import tensorflow as tf
28
29from tensorflow import python as tf_python  # pylint: disable=unused-import
30from tensorflow.python.lib.io import file_io
31from tensorflow.python.util import tf_decorator
32from tensorflow.python.util import tf_export
33from tensorflow.tools.common import public_api
34from tensorflow.tools.common import traverse
35from tensorflow.tools.compatibility import all_renames_v2
36
37# This import is needed so that TensorFlow python modules are in sys.modules.
38
39_OUTPUT_FILE_PATH = 'third_party/tensorflow/tools/compatibility/renames_v2.py'
40_FILE_HEADER = """# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
41#
42# Licensed under the Apache License, Version 2.0 (the "License");
43# you may not use this file except in compliance with the License.
44# You may obtain a copy of the License at
45#
46#     http://www.apache.org/licenses/LICENSE-2.0
47#
48# Unless required by applicable law or agreed to in writing, software
49# distributed under the License is distributed on an "AS IS" BASIS,
50# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51# See the License for the specific language governing permissions and
52# limitations under the License.
53# ==============================================================================
54# pylint: disable=line-too-long
55\"\"\"List of renames to apply when converting from TF 1.0 to TF 2.0.
56
57THIS FILE IS AUTOGENERATED: To update, please run:
58  bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map
59  bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map
60pyformat --in_place third_party/tensorflow/tools/compatibility/renames_v2.py
61This file should be updated whenever endpoints are deprecated.
62\"\"\"
63"""
64
65
66def get_canonical_name(v2_names, v1_name):
67  if v2_names:
68    return v2_names[0]
69  return 'compat.v1.%s' % v1_name
70
71
72def get_all_v2_names():
73  """Get a set of function/class names available in TensorFlow 2.0."""
74  v2_names = set()  # All op names in TensorFlow 2.0
75
76  def visit(unused_path, unused_parent, children):
77    """Visitor that collects TF 2.0 names."""
78    for child in children:
79      _, attr = tf_decorator.unwrap(child[1])
80      api_names_v2 = tf_export.get_v2_names(attr)
81      for name in api_names_v2:
82        v2_names.add(name)
83
84  visitor = public_api.PublicAPIVisitor(visit)
85  visitor.do_not_descend_map['tf'].append('contrib')
86  visitor.do_not_descend_map['tf.compat'] = ['v1']
87  traverse.traverse(tf.compat.v2, visitor)
88  return v2_names
89
90
91def collect_constant_renames():
92  """Looks for constants that need to be renamed in TF 2.0.
93
94  Returns:
95    Set of tuples of the form (current name, new name).
96  """
97  renames = set()
98  for module in sys.modules.values():
99    constants_v1_list = tf_export.get_v1_constants(module)
100    constants_v2_list = tf_export.get_v2_constants(module)
101
102    # _tf_api_constants attribute contains a list of tuples:
103    # (api_names_list, constant_name)
104    # We want to find API names that are in V1 but not in V2 for the same
105    # constant_names.
106
107    # First, we convert constants_v1_list and constants_v2_list to
108    # dictionaries for easier lookup.
109    constants_v1 = {constant_name: api_names
110                    for api_names, constant_name in constants_v1_list}
111    constants_v2 = {constant_name: api_names
112                    for api_names, constant_name in constants_v2_list}
113    # Second, we look for names that are in V1 but not in V2.
114    for constant_name, api_names_v1 in constants_v1.items():
115      api_names_v2 = constants_v2[constant_name]
116      for name in api_names_v1:
117        if name not in api_names_v2:
118          renames.add((name, get_canonical_name(api_names_v2, name)))
119  return renames
120
121
122def collect_function_renames():
123  """Looks for functions/classes that need to be renamed in TF 2.0.
124
125  Returns:
126    Set of tuples of the form (current name, new name).
127  """
128  # Set of rename lines to write to output file in the form:
129  #   'tf.deprecated_name': 'tf.canonical_name'
130  renames = set()
131
132  def visit(unused_path, unused_parent, children):
133    """Visitor that collects rename strings to add to rename_line_set."""
134    for child in children:
135      _, attr = tf_decorator.unwrap(child[1])
136      api_names_v1 = tf_export.get_v1_names(attr)
137      api_names_v2 = tf_export.get_v2_names(attr)
138      deprecated_api_names = set(api_names_v1) - set(api_names_v2)
139      for name in deprecated_api_names:
140        renames.add((name, get_canonical_name(api_names_v2, name)))
141
142  visitor = public_api.PublicAPIVisitor(visit)
143  visitor.do_not_descend_map['tf'].append('contrib')
144  visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
145  traverse.traverse(tf, visitor)
146
147  # It is possible that a different function is exported with the
148  # same name. For e.g. when creating a different function to
149  # rename arguments. Exclude it from renames in this case.
150  v2_names = get_all_v2_names()
151  renames = set((name, new_name) for name, new_name in renames
152                if name not in v2_names)
153  return renames
154
155
156def get_rename_line(name, canonical_name):
157  return '    \'tf.%s\': \'tf.%s\'' % (name, canonical_name)
158
159
160def update_renames_v2(output_file_path):
161  """Writes a Python dictionary mapping deprecated to canonical API names.
162
163  Args:
164    output_file_path: File path to write output to. Any existing contents
165      would be replaced.
166  """
167  function_renames = collect_function_renames()
168  constant_renames = collect_constant_renames()
169  all_renames = function_renames.union(constant_renames)
170  manual_renames = set(
171      all_renames_v2.manual_symbol_renames.keys())
172
173  # List of rename lines to write to output file in the form:
174  #   'tf.deprecated_name': 'tf.canonical_name'
175  rename_lines = [
176      get_rename_line(name, canonical_name)
177      for name, canonical_name in all_renames
178      if 'tf.' + name not in manual_renames
179  ]
180  renames_file_text = '%srenames = {\n%s\n}\n' % (
181      _FILE_HEADER, ',\n'.join(sorted(rename_lines)))
182  file_io.write_string_to_file(output_file_path, renames_file_text)
183
184
185def main(unused_argv):
186  update_renames_v2(_OUTPUT_FILE_PATH)
187
188
189if __name__ == '__main__':
190  app.run(main=main)
191