• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Upgrader for Python scripts from 1.x TensorFlow to 2.0 TensorFlow."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import argparse
23
24import six
25
26from tensorflow.tools.compatibility import ast_edits
27from tensorflow.tools.compatibility import ipynb
28from tensorflow.tools.compatibility import tf_upgrade_v2
29from tensorflow.tools.compatibility import tf_upgrade_v2_safety
30
31# Make straightforward changes to convert to 2.0. In harder cases,
32# use compat.v1.
33_DEFAULT_MODE = "DEFAULT"
34
35# Convert to use compat.v1.
36_SAFETY_MODE = "SAFETY"
37
38# Whether to rename to compat.v2
39_IMPORT_RENAME_DEFAULT = False
40
41
42def process_file(in_filename, out_filename, upgrader):
43  """Process a file of type `.py` or `.ipynb`."""
44
45  if six.ensure_str(in_filename).endswith(".py"):
46    files_processed, report_text, errors = \
47      upgrader.process_file(in_filename, out_filename)
48  elif six.ensure_str(in_filename).endswith(".ipynb"):
49    files_processed, report_text, errors = \
50      ipynb.process_file(in_filename, out_filename, upgrader)
51  else:
52    raise NotImplementedError(
53        "Currently converter only supports python or ipynb")
54
55  return files_processed, report_text, errors
56
57
58def main():
59  parser = argparse.ArgumentParser(
60      formatter_class=argparse.RawDescriptionHelpFormatter,
61      description="""Convert a TensorFlow Python file from 1.x to 2.0
62
63Simple usage:
64  tf_upgrade_v2.py --infile foo.py --outfile bar.py
65  tf_upgrade_v2.py --infile foo.ipynb --outfile bar.ipynb
66  tf_upgrade_v2.py --intree ~/code/old --outtree ~/code/new
67""")
68  parser.add_argument(
69      "--infile",
70      dest="input_file",
71      help="If converting a single file, the name of the file "
72      "to convert")
73  parser.add_argument(
74      "--outfile",
75      dest="output_file",
76      help="If converting a single file, the output filename.")
77  parser.add_argument(
78      "--intree",
79      dest="input_tree",
80      help="If converting a whole tree of files, the directory "
81      "to read from (relative or absolute).")
82  parser.add_argument(
83      "--outtree",
84      dest="output_tree",
85      help="If converting a whole tree of files, the output "
86      "directory (relative or absolute).")
87  parser.add_argument(
88      "--copyotherfiles",
89      dest="copy_other_files",
90      help=("If converting a whole tree of files, whether to "
91            "copy the other files."),
92      type=bool,
93      default=True)
94  parser.add_argument(
95      "--inplace",
96      dest="in_place",
97      help=("If converting a set of files, whether to "
98            "allow the conversion to be performed on the "
99            "input files."),
100      action="store_true")
101  parser.add_argument(
102      "--no_import_rename",
103      dest="no_import_rename",
104      help=("Not to rename import to compat.v2 explicitly."),
105      action="store_true")
106  parser.add_argument(
107      "--no_upgrade_compat_v1_import",
108      dest="no_upgrade_compat_v1_import",
109      help=("If specified, don't upgrade explicit imports of "
110            "`tensorflow.compat.v1 as tf` to the v2 APIs. Otherwise, "
111            "explicit imports of  the form `tensorflow.compat.v1 as tf` will "
112            "be upgraded."),
113      action="store_true")
114  parser.add_argument(
115      "--reportfile",
116      dest="report_filename",
117      help=("The name of the file where the report log is "
118            "stored."
119            "(default: %(default)s)"),
120      default="report.txt")
121  parser.add_argument(
122      "--mode",
123      dest="mode",
124      choices=[_DEFAULT_MODE, _SAFETY_MODE],
125      help=("Upgrade script mode. Supported modes:\n"
126            "%s: Perform only straightforward conversions to upgrade to "
127            "2.0. In more difficult cases, switch to use compat.v1.\n"
128            "%s: Keep 1.* code intact and import compat.v1 "
129            "module." %
130            (_DEFAULT_MODE, _SAFETY_MODE)),
131      default=_DEFAULT_MODE)
132  parser.add_argument(
133      "--print_all",
134      dest="print_all",
135      help="Print full log to stdout instead of just printing errors",
136      action="store_true")
137  args = parser.parse_args()
138
139  if args.mode == _SAFETY_MODE:
140    change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec()
141  else:
142    if args.no_import_rename:
143      change_spec = tf_upgrade_v2.TFAPIChangeSpec(
144          import_rename=False,
145          upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
146    else:
147      change_spec = tf_upgrade_v2.TFAPIChangeSpec(
148          import_rename=_IMPORT_RENAME_DEFAULT,
149          upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
150  upgrade = ast_edits.ASTCodeUpgrader(change_spec)
151
152  report_text = None
153  report_filename = args.report_filename
154  files_processed = 0
155  if args.input_file:
156    if not args.in_place and not args.output_file:
157      raise ValueError(
158          "--outfile=<output file> argument is required when converting a "
159          "single file.")
160    if args.in_place and args.output_file:
161      raise ValueError("--outfile argument is invalid when converting in place")
162    output_file = args.input_file if args.in_place else args.output_file
163    files_processed, report_text, errors = process_file(
164        args.input_file, output_file, upgrade)
165    errors = {args.input_file: errors}
166    files_processed = 1
167  elif args.input_tree:
168    if not args.in_place and not args.output_tree:
169      raise ValueError(
170          "--outtree=<output directory> argument is required when converting a "
171          "file tree.")
172    if args.in_place and args.output_tree:
173      raise ValueError("--outtree argument is invalid when converting in place")
174    output_tree = args.input_tree if args.in_place else args.output_tree
175    files_processed, report_text, errors = upgrade.process_tree(
176        args.input_tree, output_tree, args.copy_other_files)
177  else:
178    parser.print_help()
179  if report_text:
180    num_errors = 0
181    report = []
182    for f in errors:
183      if errors[f]:
184        num_errors += len(errors[f])
185        report.append(six.ensure_str("-" * 80) + "\n")
186        report.append("File: %s\n" % f)
187        report.append(six.ensure_str("-" * 80) + "\n")
188        report.append("\n".join(errors[f]) + "\n")
189
190    report = ("TensorFlow 2.0 Upgrade Script\n"
191              "-----------------------------\n"
192              "Converted %d files\n" % files_processed +
193              "Detected %d issues that require attention" % num_errors + "\n" +
194              six.ensure_str("-" * 80) + "\n") + "".join(report)
195    detailed_report_header = six.ensure_str("=" * 80) + "\n"
196    detailed_report_header += "Detailed log follows:\n\n"
197    detailed_report_header += six.ensure_str("=" * 80) + "\n"
198
199    with open(report_filename, "w") as report_file:
200      report_file.write(report)
201      report_file.write(detailed_report_header)
202      report_file.write(six.ensure_str(report_text))
203
204    if args.print_all:
205      print(report)
206      print(detailed_report_header)
207      print(report_text)
208    else:
209      print(report)
210    print("\nMake sure to read the detailed log %r\n" % report_filename)
211
212if __name__ == "__main__":
213  main()
214