1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=line-too-long 16"""Script for updating tensorflow/tools/compatibility/renames_v2.py. 17 18To update renames_v2.py, run: 19 bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map 20 bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map 21 pyformat --in_place third_party/tensorflow/tools/compatibility/renames_v2.py 22""" 23# pylint: enable=line-too-long 24import sys 25 26from absl import app 27import tensorflow as tf 28 29from tensorflow import python as tf_python # pylint: disable=unused-import 30from tensorflow.python.lib.io import file_io 31from tensorflow.python.util import tf_decorator 32from tensorflow.python.util import tf_export 33from tensorflow.tools.common import public_api 34from tensorflow.tools.common import traverse 35from tensorflow.tools.compatibility import all_renames_v2 36 37# This import is needed so that TensorFlow python modules are in sys.modules. 38 39_OUTPUT_FILE_PATH = 'third_party/tensorflow/tools/compatibility/renames_v2.py' 40_FILE_HEADER = """# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 41# 42# Licensed under the Apache License, Version 2.0 (the "License"); 43# you may not use this file except in compliance with the License. 44# You may obtain a copy of the License at 45# 46# http://www.apache.org/licenses/LICENSE-2.0 47# 48# Unless required by applicable law or agreed to in writing, software 49# distributed under the License is distributed on an "AS IS" BASIS, 50# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 51# See the License for the specific language governing permissions and 52# limitations under the License. 53# ============================================================================== 54# pylint: disable=line-too-long 55\"\"\"List of renames to apply when converting from TF 1.0 to TF 2.0. 56 57THIS FILE IS AUTOGENERATED: To update, please run: 58 bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map 59 bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map 60pyformat --in_place third_party/tensorflow/tools/compatibility/renames_v2.py 61This file should be updated whenever endpoints are deprecated. 62\"\"\" 63""" 64 65 66def get_canonical_name(v2_names, v1_name): 67 if v2_names: 68 return v2_names[0] 69 return 'compat.v1.%s' % v1_name 70 71 72def get_all_v2_names(): 73 """Get a set of function/class names available in TensorFlow 2.0.""" 74 v2_names = set() # All op names in TensorFlow 2.0 75 76 def visit(unused_path, unused_parent, children): 77 """Visitor that collects TF 2.0 names.""" 78 for child in children: 79 _, attr = tf_decorator.unwrap(child[1]) 80 api_names_v2 = tf_export.get_v2_names(attr) 81 for name in api_names_v2: 82 v2_names.add(name) 83 84 visitor = public_api.PublicAPIVisitor(visit) 85 visitor.do_not_descend_map['tf'].append('contrib') 86 visitor.do_not_descend_map['tf.compat'] = ['v1'] 87 traverse.traverse(tf.compat.v2, visitor) 88 return v2_names 89 90 91def collect_constant_renames(): 92 """Looks for constants that need to be renamed in TF 2.0. 93 94 Returns: 95 Set of tuples of the form (current name, new name). 96 """ 97 renames = set() 98 for module in sys.modules.values(): 99 constants_v1_list = tf_export.get_v1_constants(module) 100 constants_v2_list = tf_export.get_v2_constants(module) 101 102 # _tf_api_constants attribute contains a list of tuples: 103 # (api_names_list, constant_name) 104 # We want to find API names that are in V1 but not in V2 for the same 105 # constant_names. 106 107 # First, we convert constants_v1_list and constants_v2_list to 108 # dictionaries for easier lookup. 109 constants_v1 = {constant_name: api_names 110 for api_names, constant_name in constants_v1_list} 111 constants_v2 = {constant_name: api_names 112 for api_names, constant_name in constants_v2_list} 113 # Second, we look for names that are in V1 but not in V2. 114 for constant_name, api_names_v1 in constants_v1.items(): 115 api_names_v2 = constants_v2[constant_name] 116 for name in api_names_v1: 117 if name not in api_names_v2: 118 renames.add((name, get_canonical_name(api_names_v2, name))) 119 return renames 120 121 122def collect_function_renames(): 123 """Looks for functions/classes that need to be renamed in TF 2.0. 124 125 Returns: 126 Set of tuples of the form (current name, new name). 127 """ 128 # Set of rename lines to write to output file in the form: 129 # 'tf.deprecated_name': 'tf.canonical_name' 130 renames = set() 131 132 def visit(unused_path, unused_parent, children): 133 """Visitor that collects rename strings to add to rename_line_set.""" 134 for child in children: 135 _, attr = tf_decorator.unwrap(child[1]) 136 api_names_v1 = tf_export.get_v1_names(attr) 137 api_names_v2 = tf_export.get_v2_names(attr) 138 deprecated_api_names = set(api_names_v1) - set(api_names_v2) 139 for name in deprecated_api_names: 140 renames.add((name, get_canonical_name(api_names_v2, name))) 141 142 visitor = public_api.PublicAPIVisitor(visit) 143 visitor.do_not_descend_map['tf'].append('contrib') 144 visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2'] 145 traverse.traverse(tf, visitor) 146 147 # It is possible that a different function is exported with the 148 # same name. For e.g. when creating a different function to 149 # rename arguments. Exclude it from renames in this case. 150 v2_names = get_all_v2_names() 151 renames = set((name, new_name) for name, new_name in renames 152 if name not in v2_names) 153 return renames 154 155 156def get_rename_line(name, canonical_name): 157 return ' \'tf.%s\': \'tf.%s\'' % (name, canonical_name) 158 159 160def update_renames_v2(output_file_path): 161 """Writes a Python dictionary mapping deprecated to canonical API names. 162 163 Args: 164 output_file_path: File path to write output to. Any existing contents 165 would be replaced. 166 """ 167 function_renames = collect_function_renames() 168 constant_renames = collect_constant_renames() 169 all_renames = function_renames.union(constant_renames) 170 manual_renames = set( 171 all_renames_v2.manual_symbol_renames.keys()) 172 173 # List of rename lines to write to output file in the form: 174 # 'tf.deprecated_name': 'tf.canonical_name' 175 rename_lines = [ 176 get_rename_line(name, canonical_name) 177 for name, canonical_name in all_renames 178 if 'tf.' + name not in manual_renames 179 ] 180 renames_file_text = '%srenames = {\n%s\n}\n' % ( 181 _FILE_HEADER, ',\n'.join(sorted(rename_lines))) 182 file_io.write_string_to_file(output_file_path, renames_file_text) 183 184 185def main(unused_argv): 186 update_renames_v2(_OUTPUT_FILE_PATH) 187 188 189if __name__ == '__main__': 190 app.run(main=main) 191