1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Upgrader for Python scripts from pre-1.0 TensorFlow to 1.0 TensorFlow.""" 16 17import argparse 18 19from tensorflow.tools.compatibility import ast_edits 20 21 22class TFAPIChangeSpec(ast_edits.APIChangeSpec): 23 """List of maps that describe what changed in the API.""" 24 25 def __init__(self): 26 # Maps from a function name to a dictionary that describes how to 27 # map from an old argument keyword to the new argument keyword. 28 self.function_keyword_renames = { 29 "tf.batch_matmul": { 30 "adj_x": "adjoint_a", 31 "adj_y": "adjoint_b", 32 }, 33 "tf.count_nonzero": { 34 "reduction_indices": "axis" 35 }, 36 "tf.reduce_all": { 37 "reduction_indices": "axis" 38 }, 39 "tf.reduce_any": { 40 "reduction_indices": "axis" 41 }, 42 "tf.reduce_max": { 43 "reduction_indices": "axis" 44 }, 45 "tf.reduce_mean": { 46 "reduction_indices": "axis" 47 }, 48 "tf.reduce_min": { 49 "reduction_indices": "axis" 50 }, 51 "tf.reduce_prod": { 52 "reduction_indices": "axis" 53 }, 54 "tf.reduce_sum": { 55 "reduction_indices": "axis" 56 }, 57 "tf.reduce_logsumexp": { 58 "reduction_indices": "axis" 59 }, 60 "tf.expand_dims": { 61 "dim": "axis" 62 }, 63 "tf.argmax": { 64 "dimension": "axis" 65 }, 66 "tf.argmin": { 67 "dimension": "axis" 68 }, 69 "tf.reduce_join": { 70 "reduction_indices": "axis" 71 }, 72 "tf.sparse_concat": { 73 "concat_dim": "axis" 74 }, 75 "tf.sparse_split": { 76 "split_dim": "axis" 77 }, 78 "tf.sparse_reduce_sum": { 79 "reduction_axes": "axis" 80 }, 81 "tf.reverse_sequence": { 82 "seq_dim": "seq_axis", 83 "batch_dim": "batch_axis" 84 }, 85 "tf.sparse_reduce_sum_sparse": { 86 "reduction_axes": "axis" 87 }, 88 "tf.squeeze": { 89 "squeeze_dims": "axis" 90 }, 91 "tf.split": { 92 "split_dim": "axis", 93 "num_split": "num_or_size_splits" 94 }, 95 "tf.concat": { 96 "concat_dim": "axis" 97 }, 98 } 99 100 # Mapping from function to the new name of the function 101 self.symbol_renames = { 102 "tf.inv": "tf.reciprocal", 103 "tf.contrib.deprecated.scalar_summary": "tf.summary.scalar", 104 "tf.contrib.deprecated.histogram_summary": "tf.summary.histogram", 105 "tf.listdiff": "tf.setdiff1d", 106 "tf.list_diff": "tf.setdiff1d", 107 "tf.mul": "tf.multiply", 108 "tf.neg": "tf.negative", 109 "tf.sub": "tf.subtract", 110 "tf.train.SummaryWriter": "tf.summary.FileWriter", 111 "tf.scalar_summary": "tf.summary.scalar", 112 "tf.histogram_summary": "tf.summary.histogram", 113 "tf.audio_summary": "tf.summary.audio", 114 "tf.image_summary": "tf.summary.image", 115 "tf.merge_summary": "tf.summary.merge", 116 "tf.merge_all_summaries": "tf.summary.merge_all", 117 "tf.image.per_image_whitening": "tf.image.per_image_standardization", 118 "tf.all_variables": "tf.global_variables", 119 "tf.VARIABLES": "tf.GLOBAL_VARIABLES", 120 "tf.initialize_all_variables": "tf.global_variables_initializer", 121 "tf.initialize_variables": "tf.variables_initializer", 122 "tf.initialize_local_variables": "tf.local_variables_initializer", 123 "tf.batch_matrix_diag": "tf.matrix_diag", 124 "tf.batch_band_part": "tf.band_part", 125 "tf.batch_set_diag": "tf.set_diag", 126 "tf.batch_matrix_transpose": "tf.matrix_transpose", 127 "tf.batch_matrix_determinant": "tf.matrix_determinant", 128 "tf.batch_matrix_inverse": "tf.matrix_inverse", 129 "tf.batch_cholesky": "tf.cholesky", 130 "tf.batch_cholesky_solve": "tf.cholesky_solve", 131 "tf.batch_matrix_solve": "tf.matrix_solve", 132 "tf.batch_matrix_triangular_solve": "tf.matrix_triangular_solve", 133 "tf.batch_matrix_solve_ls": "tf.matrix_solve_ls", 134 "tf.batch_self_adjoint_eig": "tf.self_adjoint_eig", 135 "tf.batch_self_adjoint_eigvals": "tf.self_adjoint_eigvals", 136 "tf.batch_svd": "tf.svd", 137 "tf.batch_fft": "tf.fft", 138 "tf.batch_ifft": "tf.ifft", 139 "tf.batch_fft2d": "tf.fft2d", 140 "tf.batch_ifft2d": "tf.ifft2d", 141 "tf.batch_fft3d": "tf.fft3d", 142 "tf.batch_ifft3d": "tf.ifft3d", 143 "tf.select": "tf.where", 144 "tf.complex_abs": "tf.abs", 145 "tf.batch_matmul": "tf.matmul", 146 "tf.pack": "tf.stack", 147 "tf.unpack": "tf.unstack", 148 "tf.op_scope": "tf.name_scope", 149 } 150 151 self.change_to_function = { 152 "tf.ones_initializer", 153 "tf.zeros_initializer", 154 } 155 156 # Functions that were reordered should be changed to the new keyword args 157 # for safety, if positional arguments are used. If you have reversed the 158 # positional arguments yourself, this could do the wrong thing. 159 self.function_reorders = { 160 "tf.split": ["axis", "num_or_size_splits", "value", "name"], 161 "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"], 162 "tf.concat": ["concat_dim", "values", "name"], 163 "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], 164 "tf.nn.softmax_cross_entropy_with_logits": [ 165 "logits", "labels", "dim", "name" 166 ], 167 "tf.nn.sparse_softmax_cross_entropy_with_logits": [ 168 "logits", "labels", "name" 169 ], 170 "tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"], 171 "tf.op_scope": ["values", "name", "default_name"], 172 } 173 174 # Warnings that should be printed if corresponding functions are used. 175 self.function_warnings = { 176 "tf.reverse": ( 177 ast_edits.ERROR, 178 "tf.reverse has had its argument semantics changed " 179 "significantly. The converter cannot detect this reliably, so " 180 "you need to inspect this usage manually.\n"), 181 } 182 183 self.module_deprecations = {} 184 185 186if __name__ == "__main__": 187 parser = argparse.ArgumentParser( 188 formatter_class=argparse.RawDescriptionHelpFormatter, 189 description="""Convert a TensorFlow Python file to 1.0 190 191Simple usage: 192 tf_convert.py --infile foo.py --outfile bar.py 193 tf_convert.py --intree ~/code/old --outtree ~/code/new 194""") 195 parser.add_argument( 196 "--infile", 197 dest="input_file", 198 help="If converting a single file, the name of the file " 199 "to convert") 200 parser.add_argument( 201 "--outfile", 202 dest="output_file", 203 help="If converting a single file, the output filename.") 204 parser.add_argument( 205 "--intree", 206 dest="input_tree", 207 help="If converting a whole tree of files, the directory " 208 "to read from (relative or absolute).") 209 parser.add_argument( 210 "--outtree", 211 dest="output_tree", 212 help="If converting a whole tree of files, the output " 213 "directory (relative or absolute).") 214 parser.add_argument( 215 "--copyotherfiles", 216 dest="copy_other_files", 217 help=("If converting a whole tree of files, whether to " 218 "copy the other files."), 219 type=bool, 220 default=False) 221 parser.add_argument( 222 "--reportfile", 223 dest="report_filename", 224 help=("The name of the file where the report log is " 225 "stored." 226 "(default: %(default)s)"), 227 default="report.txt") 228 args = parser.parse_args() 229 230 upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec()) 231 report_text = None 232 report_filename = args.report_filename 233 files_processed = 0 234 if args.input_file: 235 files_processed, report_text, errors = upgrade.process_file( 236 args.input_file, args.output_file) 237 files_processed = 1 238 elif args.input_tree: 239 files_processed, report_text, errors = upgrade.process_tree( 240 args.input_tree, args.output_tree, args.copy_other_files) 241 else: 242 parser.print_help() 243 if report_text: 244 open(report_filename, "w").write(report_text) 245 print("TensorFlow 1.0 Upgrade Script") 246 print("-----------------------------") 247 print("Converted %d files\n" % files_processed) 248 print("Detected %d errors that require attention" % len(errors)) 249 print("-" * 80) 250 print("\n".join(errors)) 251 print("\nMake sure to read the detailed log %r\n" % report_filename) 252