1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Upgrader for Python scripts from 1.x TensorFlow to 2.0 TensorFlow.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import argparse 23 24import six 25 26from tensorflow.tools.compatibility import ast_edits 27from tensorflow.tools.compatibility import ipynb 28from tensorflow.tools.compatibility import tf_upgrade_v2 29from tensorflow.tools.compatibility import tf_upgrade_v2_safety 30 31# Make straightforward changes to convert to 2.0. In harder cases, 32# use compat.v1. 33_DEFAULT_MODE = "DEFAULT" 34 35# Convert to use compat.v1. 36_SAFETY_MODE = "SAFETY" 37 38# Whether to rename to compat.v2 39_IMPORT_RENAME_DEFAULT = False 40 41 42def process_file(in_filename, out_filename, upgrader): 43 """Process a file of type `.py` or `.ipynb`.""" 44 45 if six.ensure_str(in_filename).endswith(".py"): 46 files_processed, report_text, errors = \ 47 upgrader.process_file(in_filename, out_filename) 48 elif six.ensure_str(in_filename).endswith(".ipynb"): 49 files_processed, report_text, errors = \ 50 ipynb.process_file(in_filename, out_filename, upgrader) 51 else: 52 raise NotImplementedError( 53 "Currently converter only supports python or ipynb") 54 55 return files_processed, report_text, errors 56 57 58def main(): 59 parser = argparse.ArgumentParser( 60 formatter_class=argparse.RawDescriptionHelpFormatter, 61 description="""Convert a TensorFlow Python file from 1.x to 2.0 62 63Simple usage: 64 tf_upgrade_v2.py --infile foo.py --outfile bar.py 65 tf_upgrade_v2.py --infile foo.ipynb --outfile bar.ipynb 66 tf_upgrade_v2.py --intree ~/code/old --outtree ~/code/new 67""") 68 parser.add_argument( 69 "--infile", 70 dest="input_file", 71 help="If converting a single file, the name of the file " 72 "to convert") 73 parser.add_argument( 74 "--outfile", 75 dest="output_file", 76 help="If converting a single file, the output filename.") 77 parser.add_argument( 78 "--intree", 79 dest="input_tree", 80 help="If converting a whole tree of files, the directory " 81 "to read from (relative or absolute).") 82 parser.add_argument( 83 "--outtree", 84 dest="output_tree", 85 help="If converting a whole tree of files, the output " 86 "directory (relative or absolute).") 87 parser.add_argument( 88 "--copyotherfiles", 89 dest="copy_other_files", 90 help=("If converting a whole tree of files, whether to " 91 "copy the other files."), 92 type=bool, 93 default=True) 94 parser.add_argument( 95 "--inplace", 96 dest="in_place", 97 help=("If converting a set of files, whether to " 98 "allow the conversion to be performed on the " 99 "input files."), 100 action="store_true") 101 parser.add_argument( 102 "--no_import_rename", 103 dest="no_import_rename", 104 help=("Not to rename import to compat.v2 explicitly."), 105 action="store_true") 106 parser.add_argument( 107 "--no_upgrade_compat_v1_import", 108 dest="no_upgrade_compat_v1_import", 109 help=("If specified, don't upgrade explicit imports of " 110 "`tensorflow.compat.v1 as tf` to the v2 APIs. Otherwise, " 111 "explicit imports of the form `tensorflow.compat.v1 as tf` will " 112 "be upgraded."), 113 action="store_true") 114 parser.add_argument( 115 "--reportfile", 116 dest="report_filename", 117 help=("The name of the file where the report log is " 118 "stored." 119 "(default: %(default)s)"), 120 default="report.txt") 121 parser.add_argument( 122 "--mode", 123 dest="mode", 124 choices=[_DEFAULT_MODE, _SAFETY_MODE], 125 help=("Upgrade script mode. Supported modes:\n" 126 "%s: Perform only straightforward conversions to upgrade to " 127 "2.0. In more difficult cases, switch to use compat.v1.\n" 128 "%s: Keep 1.* code intact and import compat.v1 " 129 "module." % 130 (_DEFAULT_MODE, _SAFETY_MODE)), 131 default=_DEFAULT_MODE) 132 parser.add_argument( 133 "--print_all", 134 dest="print_all", 135 help="Print full log to stdout instead of just printing errors", 136 action="store_true") 137 args = parser.parse_args() 138 139 if args.mode == _SAFETY_MODE: 140 change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec() 141 else: 142 if args.no_import_rename: 143 change_spec = tf_upgrade_v2.TFAPIChangeSpec( 144 import_rename=False, 145 upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import) 146 else: 147 change_spec = tf_upgrade_v2.TFAPIChangeSpec( 148 import_rename=_IMPORT_RENAME_DEFAULT, 149 upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import) 150 upgrade = ast_edits.ASTCodeUpgrader(change_spec) 151 152 report_text = None 153 report_filename = args.report_filename 154 files_processed = 0 155 if args.input_file: 156 if not args.in_place and not args.output_file: 157 raise ValueError( 158 "--outfile=<output file> argument is required when converting a " 159 "single file.") 160 if args.in_place and args.output_file: 161 raise ValueError("--outfile argument is invalid when converting in place") 162 output_file = args.input_file if args.in_place else args.output_file 163 files_processed, report_text, errors = process_file( 164 args.input_file, output_file, upgrade) 165 errors = {args.input_file: errors} 166 files_processed = 1 167 elif args.input_tree: 168 if not args.in_place and not args.output_tree: 169 raise ValueError( 170 "--outtree=<output directory> argument is required when converting a " 171 "file tree.") 172 if args.in_place and args.output_tree: 173 raise ValueError("--outtree argument is invalid when converting in place") 174 output_tree = args.input_tree if args.in_place else args.output_tree 175 files_processed, report_text, errors = upgrade.process_tree( 176 args.input_tree, output_tree, args.copy_other_files) 177 else: 178 parser.print_help() 179 if report_text: 180 num_errors = 0 181 report = [] 182 for f in errors: 183 if errors[f]: 184 num_errors += len(errors[f]) 185 report.append(six.ensure_str("-" * 80) + "\n") 186 report.append("File: %s\n" % f) 187 report.append(six.ensure_str("-" * 80) + "\n") 188 report.append("\n".join(errors[f]) + "\n") 189 190 report = ("TensorFlow 2.0 Upgrade Script\n" 191 "-----------------------------\n" 192 "Converted %d files\n" % files_processed + 193 "Detected %d issues that require attention" % num_errors + "\n" + 194 six.ensure_str("-" * 80) + "\n") + "".join(report) 195 detailed_report_header = six.ensure_str("=" * 80) + "\n" 196 detailed_report_header += "Detailed log follows:\n\n" 197 detailed_report_header += six.ensure_str("=" * 80) + "\n" 198 199 with open(report_filename, "w") as report_file: 200 report_file.write(report) 201 report_file.write(detailed_report_header) 202 report_file.write(six.ensure_str(report_text)) 203 204 if args.print_all: 205 print(report) 206 print(detailed_report_header) 207 print(report_text) 208 else: 209 print(report) 210 print("\nMake sure to read the detailed log %r\n" % report_filename) 211 212if __name__ == "__main__": 213 main() 214