• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Upgrader for Python scripts from 1.* TensorFlow to 2.0 TensorFlow."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import ast
23import copy
24import functools
25import sys
26
27import pasta
28import six
29
30from tensorflow.tools.compatibility import all_renames_v2
31from tensorflow.tools.compatibility import ast_edits
32from tensorflow.tools.compatibility import module_deprecations_v2
33from tensorflow.tools.compatibility import reorders_v2
34
35# These pylint warnings are a mistake.
36# pylint: disable=g-explicit-bool-comparison,g-bool-id-comparison
37
38
39class UnaliasedTFImport(ast_edits.AnalysisResult):
40
41  def __init__(self):
42    self.log_level = ast_edits.ERROR
43    self.log_message = ("The tf_upgrade_v2 script detected an unaliased "
44                        "`import tensorflow`. The script can only run when "
45                        "importing with `import tensorflow as tf`.")
46
47
48class VersionedTFImport(ast_edits.AnalysisResult):
49
50  def __init__(self, version):
51    self.log_level = ast_edits.INFO
52    self.log_message = ("Not upgrading symbols because `tensorflow." +
53                        six.ensure_str(version) +
54                        "` was directly imported as `tf`.")
55
56
57compat_v1_import = VersionedTFImport("compat.v1")
58compat_v2_import = VersionedTFImport("compat.v2")
59
60
61class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
62
63  def __init__(self):
64    self.symbols_to_detect = {}
65    self.imports_to_detect = {
66        ("tensorflow", None): UnaliasedTFImport(),
67        ("tensorflow.compat.v1", "tf"): compat_v1_import,
68        ("tensorflow.compat.v2", "tf"): compat_v2_import,
69    }
70
71
72class CompatV1ImportReplacer(ast.NodeVisitor):
73  """AST Visitor that replaces `import tensorflow.compat.v1 as tf`.
74
75  Converts `import tensorflow.compat.v1 as tf` to `import tensorflow as tf`
76  """
77
78  def visit_Import(self, node):  # pylint: disable=invalid-name
79    """Handle visiting an import node in the AST.
80
81    Args:
82      node: Current Node
83    """
84    for import_alias in node.names:
85      # Detect based on full import name and alias
86      if (import_alias.name == "tensorflow.compat.v1" and
87          import_alias.asname == "tf"):
88        import_alias.name = "tensorflow"
89    self.generic_visit(node)
90
91
92class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
93  """List of maps that describe what changed in the API."""
94
95  def __init__(self, import_rename=False, upgrade_compat_v1_import=False):
96    self.upgrade_compat_v1_import = upgrade_compat_v1_import
97
98    # Maps from a function name to a dictionary that describes how to
99    # map from an old argument keyword to the new argument keyword.
100    # If the new argument is None, it will be removed.
101    # Only keyword args are handled, so make sure to also put any function in
102    # function_reorders to ensure that all args are made into keywords first.
103    self.function_keyword_renames = {
104        # TODO(b/129398290)
105        # "tf.string_split": {
106        #     "delimiter": "sep",
107        # },
108        "tf.test.assert_equal_graph_def": {
109            "checkpoint_v2": None,
110            "hash_table_shared_name": None,
111        },
112        "tf.autograph.to_code": {
113            "arg_types": None,
114            "arg_values": None,
115            "indentation": None,
116        },
117        "tf.autograph.to_graph": {
118            "arg_types": None,
119            "arg_values": None,
120        },
121        "tf.nn.embedding_lookup": {
122            "validate_indices": None,
123        },
124        "tf.image.sample_distorted_bounding_box": {
125            "seed2": None,
126        },
127        "tf.gradients": {
128            "colocate_gradients_with_ops": None,
129        },
130        "tf.hessians": {
131            "colocate_gradients_with_ops": None,
132        },
133        "*.minimize": {
134            "colocate_gradients_with_ops": None,
135        },
136        "*.compute_gradients": {
137            "colocate_gradients_with_ops": None,
138        },
139        "tf.cond": {
140            "strict": None,
141            "fn1": "true_fn",
142            "fn2": "false_fn"
143        },
144        "tf.argmin": {
145            "dimension": "axis",
146        },
147        "tf.argmax": {
148            "dimension": "axis",
149        },
150        "tf.arg_min": {
151            "dimension": "axis",
152        },
153        "tf.arg_max": {
154            "dimension": "axis",
155        },
156        "tf.math.argmin": {
157            "dimension": "axis",
158        },
159        "tf.math.argmax": {
160            "dimension": "axis",
161        },
162        "tf.image.crop_and_resize": {
163            "box_ind": "box_indices",
164        },
165        "tf.extract_image_patches": {
166            "ksizes": "sizes",
167        },
168        "tf.image.extract_image_patches": {
169            "ksizes": "sizes",
170        },
171        "tf.image.resize": {
172            "align_corners": None,
173        },
174        "tf.image.resize_images": {
175            "align_corners": None,
176        },
177        "tf.expand_dims": {
178            "dim": "axis",
179        },
180        "tf.batch_to_space": {
181            "block_size": "block_shape",
182        },
183        "tf.space_to_batch": {
184            "block_size": "block_shape",
185        },
186        "tf.nn.space_to_batch": {
187            "block_size": "block_shape",
188        },
189        "tf.constant": {
190            "verify_shape": "verify_shape_is_now_always_true",
191        },
192        "tf.convert_to_tensor": {
193            "preferred_dtype": "dtype_hint"
194        },
195        "tf.nn.softmax_cross_entropy_with_logits": {
196            "dim": "axis",
197            "_sentinel": None,
198        },
199        "tf.nn.softmax_cross_entropy_with_logits_v2": {
200            "dim": "axis"
201        },
202        "tf.linalg.l2_normalize": {
203            "dim": "axis",
204        },
205        "tf.linalg.norm": {
206            "keep_dims": "keepdims",
207        },
208        "tf.norm": {
209            "keep_dims": "keepdims",
210        },
211        "tf.load_file_system_library": {
212            "library_filename": "library_location",
213        },
214        "tf.count_nonzero": {
215            "input_tensor": "input",
216            "keep_dims": "keepdims",
217            "reduction_indices": "axis",
218        },
219        "tf.math.count_nonzero": {
220            "input_tensor": "input",
221            "keep_dims": "keepdims",
222            "reduction_indices": "axis",
223        },
224        "tf.nn.erosion2d": {
225            "kernel": "filters",
226            "rates": "dilations",
227        },
228        "tf.math.l2_normalize": {
229            "dim": "axis",
230        },
231        "tf.math.log_softmax": {
232            "dim": "axis",
233        },
234        "tf.math.softmax": {
235            "dim": "axis"
236        },
237        "tf.nn.l2_normalize": {
238            "dim": "axis",
239        },
240        "tf.nn.log_softmax": {
241            "dim": "axis",
242        },
243        "tf.nn.moments": {
244            "keep_dims": "keepdims",
245        },
246        "tf.nn.pool": {
247            "dilation_rate": "dilations"
248        },
249        "tf.nn.separable_conv2d": {
250            "rate": "dilations"
251        },
252        "tf.nn.depthwise_conv2d": {
253            "rate": "dilations"
254        },
255        "tf.nn.softmax": {
256            "dim": "axis"
257        },
258        "tf.nn.sufficient_statistics": {
259            "keep_dims": "keepdims"
260        },
261        "tf.debugging.assert_all_finite": {
262            "t": "x",
263            "msg": "message",
264        },
265        "tf.sparse.add": {
266            "thresh": "threshold",
267        },
268        "tf.sparse_add": {
269            "thresh": "threshold",
270        },
271        "tf.sparse.concat": {
272            "concat_dim": "axis",
273            "expand_nonconcat_dim": "expand_nonconcat_dims",
274        },
275        "tf.sparse_concat": {
276            "concat_dim": "axis",
277            "expand_nonconcat_dim": "expand_nonconcat_dims",
278        },
279        "tf.sparse.split": {
280            "split_dim": "axis",
281        },
282        "tf.sparse_split": {
283            "split_dim": "axis",
284        },
285        "tf.sparse.reduce_max": {
286            "reduction_axes": "axis",
287            "keep_dims": "keepdims",
288        },
289        "tf.sparse_reduce_max": {
290            "reduction_axes": "axis",
291            "keep_dims": "keepdims",
292        },
293        "tf.sparse.reduce_sum": {
294            "reduction_axes": "axis",
295            "keep_dims": "keepdims",
296        },
297        "tf.sparse_reduce_sum": {
298            "reduction_axes": "axis",
299            "keep_dims": "keepdims",
300        },
301        "tf.nn.max_pool_with_argmax": {
302            "Targmax": "output_dtype",
303        },
304        "tf.nn.max_pool": {
305            "value": "input"
306        },
307        "tf.nn.avg_pool": {
308            "value": "input"
309        },
310        "tf.nn.avg_pool2d": {
311            "value": "input"
312        },
313        "tf.multinomial": {
314            "output_dtype": "dtype",
315        },
316        "tf.random.multinomial": {
317            "output_dtype": "dtype",
318        },
319        "tf.reverse_sequence": {
320            "seq_dim": "seq_axis",
321            "batch_dim": "batch_axis",
322        },
323        "tf.nn.batch_norm_with_global_normalization": {
324            "t": "input",
325            "m": "mean",
326            "v": "variance",
327        },
328        "tf.nn.dilation2d": {
329            "filter": "filters",
330            "rates": "dilations",
331        },
332        "tf.nn.conv3d": {
333            "filter": "filters"
334        },
335        "tf.zeros_like": {
336            "tensor": "input",
337        },
338        "tf.ones_like": {
339            "tensor": "input",
340        },
341        "tf.nn.conv2d_transpose": {
342            "value": "input",
343            "filter": "filters",
344        },
345        "tf.nn.conv3d_transpose": {
346            "value": "input",
347            "filter": "filters",
348        },
349        "tf.nn.convolution": {
350            "filter": "filters",
351            "dilation_rate": "dilations",
352        },
353        "tf.gfile.Exists": {
354            "filename": "path",
355        },
356        "tf.gfile.Remove": {
357            "filename": "path",
358        },
359        "tf.gfile.Stat": {
360            "filename": "path",
361        },
362        "tf.gfile.Glob": {
363            "filename": "pattern",
364        },
365        "tf.gfile.MkDir": {
366            "dirname": "path",
367        },
368        "tf.gfile.MakeDirs": {
369            "dirname": "path",
370        },
371        "tf.gfile.DeleteRecursively": {
372            "dirname": "path",
373        },
374        "tf.gfile.IsDirectory": {
375            "dirname": "path",
376        },
377        "tf.gfile.ListDirectory": {
378            "dirname": "path",
379        },
380        "tf.gfile.Copy": {
381            "oldpath": "src",
382            "newpath": "dst",
383        },
384        "tf.gfile.Rename": {
385            "oldname": "src",
386            "newname": "dst",
387        },
388        "tf.gfile.Walk": {
389            "in_order": "topdown",
390        },
391        "tf.random.stateless_multinomial": {
392            "output_dtype": "dtype",
393        },
394        "tf.string_to_number": {
395            "string_tensor": "input",
396        },
397        "tf.strings.to_number": {
398            "string_tensor": "input",
399        },
400        "tf.string_to_hash_bucket": {
401            "string_tensor": "input",
402        },
403        "tf.strings.to_hash_bucket": {
404            "string_tensor": "input",
405        },
406        "tf.reduce_all": {
407            "reduction_indices": "axis",
408            "keep_dims": "keepdims",
409        },
410        "tf.math.reduce_all": {
411            "reduction_indices": "axis",
412            "keep_dims": "keepdims",
413        },
414        "tf.reduce_any": {
415            "reduction_indices": "axis",
416            "keep_dims": "keepdims",
417        },
418        "tf.math.reduce_any": {
419            "reduction_indices": "axis",
420            "keep_dims": "keepdims",
421        },
422        "tf.reduce_min": {
423            "reduction_indices": "axis",
424            "keep_dims": "keepdims",
425        },
426        "tf.math.reduce_min": {
427            "reduction_indices": "axis",
428            "keep_dims": "keepdims",
429        },
430        "tf.reduce_max": {
431            "reduction_indices": "axis",
432            "keep_dims": "keepdims",
433        },
434        "tf.math.reduce_max": {
435            "reduction_indices": "axis",
436            "keep_dims": "keepdims",
437        },
438        "tf.reduce_sum": {
439            "reduction_indices": "axis",
440            "keep_dims": "keepdims",
441        },
442        "tf.math.reduce_sum": {
443            "reduction_indices": "axis",
444            "keep_dims": "keepdims",
445        },
446        "tf.reduce_mean": {
447            "reduction_indices": "axis",
448            "keep_dims": "keepdims",
449        },
450        "tf.math.reduce_mean": {
451            "reduction_indices": "axis",
452            "keep_dims": "keepdims",
453        },
454        "tf.reduce_prod": {
455            "reduction_indices": "axis",
456            "keep_dims": "keepdims",
457        },
458        "tf.math.reduce_prod": {
459            "reduction_indices": "axis",
460            "keep_dims": "keepdims",
461        },
462        "tf.reduce_logsumexp": {
463            "reduction_indices": "axis",
464            "keep_dims": "keepdims",
465        },
466        "tf.math.reduce_logsumexp": {
467            "reduction_indices": "axis",
468            "keep_dims": "keepdims",
469        },
470        "tf.reduce_join": {
471            "keep_dims": "keepdims",
472            "reduction_indices": "axis"
473        },
474        "tf.strings.reduce_join": {
475            "keep_dims": "keepdims",
476            "reduction_indices": "axis"
477        },
478        "tf.squeeze": {
479            "squeeze_dims": "axis",
480        },
481        "tf.nn.weighted_moments": {
482            "keep_dims": "keepdims"
483        },
484        "tf.nn.conv1d": {
485            "value": "input",
486            "use_cudnn_on_gpu": None,
487        },
488        "tf.nn.conv2d": {
489            "filter": "filters",
490            "use_cudnn_on_gpu": None,
491        },
492        "tf.nn.conv2d_backprop_input": {
493            "use_cudnn_on_gpu": None,
494            "input_sizes": "output_shape",
495            "out_backprop": "input",
496            "filter": "filters",
497        },
498        "tf.contrib.summary.audio": {
499            "tensor": "data",
500            "family": None,
501        },
502        "tf.contrib.summary.create_file_writer": {
503            "name": None,
504        },
505        "tf.contrib.summary.generic": {
506            "name": "tag",
507            "tensor": "data",
508            "family": None,
509        },
510        "tf.contrib.summary.histogram": {
511            "tensor": "data",
512            "family": None,
513        },
514        "tf.contrib.summary.image": {
515            "tensor": "data",
516            "bad_color": None,
517            "max_images": "max_outputs",
518            "family": None,
519        },
520        "tf.contrib.summary.scalar": {
521            "tensor": "data",
522            "family": None,
523        },
524        "tf.nn.weighted_cross_entropy_with_logits": {
525            "targets": "labels",
526        },
527        "tf.decode_raw": {
528            "bytes": "input_bytes",
529        },
530        "tf.io.decode_raw": {
531            "bytes": "input_bytes",
532        },
533        "tf.contrib.framework.load_variable": {
534            "checkpoint_dir": "ckpt_dir_or_file",
535        }
536    }
537    all_renames_v2.add_contrib_direct_import_support(
538        self.function_keyword_renames)
539
540    # Mapping from function to the new name of the function
541    # Add additional renames not in renames_v2.py to all_renames_v2.py.
542    self.symbol_renames = all_renames_v2.symbol_renames
543    self.import_rename = import_rename
544    if self.import_rename:
545      self.import_renames = {
546          "tensorflow":
547              ast_edits.ImportRename(
548                  "tensorflow.compat.v2",
549                  excluded_prefixes=[
550                      "tensorflow.contrib", "tensorflow.flags",
551                      "tensorflow.compat.v1", "tensorflow.compat.v2",
552                      "tensorflow.google"
553                  ],
554              )
555      }
556    else:
557      self.import_renames = {}
558
559    # Variables that should be changed to functions.
560    self.change_to_function = {}
561
562    # pylint: disable=line-too-long
563    # This list should just contain names of functions that had
564    # their arguments reordered. After adding a function name to the list
565    # run the following to update reorders_v2.py:
566    # bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
567    # bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
568    # pylint: enable=line-too-long
569    self.reordered_function_names = {
570        "tf.io.serialize_sparse",
571        "tf.io.serialize_many_sparse",
572        "tf.argmax",
573        "tf.argmin",
574        "tf.batch_to_space",
575        "tf.cond",
576        "tf.nn.space_to_batch",
577        "tf.boolean_mask",
578        "tf.convert_to_tensor",
579        "tf.nn.conv1d",
580        "tf.nn.conv2d",
581        "tf.nn.conv2d_backprop_input",
582        "tf.nn.ctc_beam_search_decoder",
583        "tf.nn.moments",
584        "tf.nn.convolution",
585        "tf.nn.crelu",
586        "tf.nn.weighted_moments",
587        "tf.nn.pool",
588        "tf.nn.separable_conv2d",
589        "tf.nn.depthwise_conv2d",
590        "tf.multinomial",
591        "tf.random.multinomial",
592        "tf.pad",
593        "tf.quantize_v2",
594        "tf.feature_column.categorical_column_with_vocabulary_file",
595        "tf.shape",
596        "tf.size",
597        # TODO(b/129398290)
598        # "tf.string_split",
599        "tf.random.poisson",
600        "tf.sparse.add",
601        "tf.sparse_add",
602        "tf.sparse.concat",
603        "tf.sparse_concat",
604        "tf.sparse.segment_mean",
605        "tf.sparse.segment_sqrt_n",
606        "tf.sparse.segment_sum",
607        "tf.sparse_matmul",
608        "tf.sparse.reduce_max",
609        "tf.sparse_reduce_max",
610        "tf.io.decode_csv",
611        "tf.strings.length",
612        "tf.strings.reduce_join",
613        "tf.strings.substr",
614        "tf.substr",
615        "tf.transpose",
616        "tf.tuple",
617        "tf.parse_example",
618        "tf.parse_single_example",
619        "tf.io.parse_example",
620        "tf.io.parse_single_example",
621        "tf.while_loop",
622        "tf.reduce_all",
623        "tf.math.reduce_all",
624        "tf.reduce_any",
625        "tf.math.reduce_any",
626        "tf.reduce_min",
627        "tf.math.reduce_min",
628        "tf.reduce_max",
629        "tf.math.reduce_max",
630        "tf.reduce_sum",
631        "tf.math.reduce_sum",
632        "tf.reduce_mean",
633        "tf.math.reduce_mean",
634        "tf.reduce_prod",
635        "tf.math.reduce_prod",
636        "tf.reduce_logsumexp",
637        "tf.math.reduce_logsumexp",
638        "tf.reduce_join",
639        "tf.confusion_matrix",
640        "tf.math.confusion_matrix",
641        "tf.math.in_top_k",
642        "tf.nn.depth_to_space",
643        "tf.nn.embedding_lookup",
644        "tf.nn.embedding_lookup_sparse",
645        "tf.nn.in_top_k",
646        "tf.nn.space_to_depth",
647        "tf.test.assert_equal_graph_def",
648        "tf.linalg.norm",
649        "tf.norm",
650        "tf.reverse_sequence",
651        "tf.sparse_split",
652        # tf.nn.softmax_cross_entropy_with_logits *must* be called with
653        # keyword arguments. Add keyword arguments in rare case when they
654        # are not specified.
655        "tf.nn.softmax_cross_entropy_with_logits",
656        "tf.nn.fractional_avg_pool",
657        "tf.nn.fractional_max_pool",
658        "tf.image.sample_distorted_bounding_box",
659        "tf.gradients",
660        "tf.hessians",
661        "tf.nn.max_pool",
662        "tf.nn.avg_pool",
663        "tf.estimator.LinearClassifier",
664        "tf.estimator.LinearRegressor",
665        "tf.estimator.DNNLinearCombinedClassifier",
666        "tf.estimator.DNNLinearCombinedRegressor",
667        "tf.estimator.DNNRegressor",
668        "tf.estimator.DNNClassifier",
669        "tf.estimator.BaselineClassifier",
670        "tf.estimator.BaselineRegressor",
671        "tf.initializers.uniform_unit_scaling",
672        "tf.uniform_unit_scaling_initializer",
673        "tf.train.sdca_fprint",
674        "tf.train.sdca_optimizer",
675        "tf.train.sdca_shrink_l1",
676        "tf.data.experimental.TensorStructure",
677        "tf.data.experimental.SparseTensorStructure",
678        "tf.data.experimental.RaggedTensorStructure",
679        "tf.data.experimental.TensorArrayStructure",
680    }
681
682    # Manual mapping of function names to be reordered to their list of argument
683    # names, in order. Only use this if argument names cannot be autodetected,
684    # e.g. if the functions are in contrib.
685    self.manual_function_reorders = {
686        "tf.contrib.summary.audio": [
687            "name", "tensor", "sample_rate", "max_outputs", "family", "step"],
688        "tf.contrib.summary.create_file_writer": [
689            "logdir", "max_queue", "flush_millis", "filename_suffix", "name"],
690        "tf.contrib.summary.generic": [
691            "name", "tensor", "metadata", "family", "step"],
692        "tf.contrib.summary.histogram": [
693            "name", "tensor", "family", "step"],
694        "tf.contrib.summary.image": [
695            "name", "tensor", "bad_color", "max_images", "family", "step"],
696        "tf.contrib.summary.scalar": [
697            "name", "tensor", "family", "step"],
698    }
699    # Functions that were reordered should be changed to the new keyword args
700    # for safety, if positional arguments are used. If you have reversed the
701    # positional arguments yourself, this could do the wrong thing.
702    self.function_reorders = dict(reorders_v2.reorders)
703    self.function_reorders.update(self.manual_function_reorders)
704
705    decay_function_comment = (
706        ast_edits.INFO,
707        "To use learning rate decay schedules with TensorFlow 2.0, switch to "
708        "the schedules in `tf.keras.optimizers.schedules`.\n"
709    )
710
711    assert_return_type_comment = (
712        ast_edits.INFO,
713        "<function name> has been changed to return None, the "
714        "data argument has been removed, and arguments have been reordered."
715        "\nThe calls have been converted to compat.v1 for safety (even though "
716        " they may already have been correct)."
717    )
718
719    assert_rank_comment = (
720        ast_edits.INFO,
721        "<function name> has been changed to return None, and"
722        " the data and summarize arguments have been removed."
723        "\nThe calls have been converted to compat.v1 for safety (even though "
724        " they may already have been correct)."
725    )
726
727    contrib_layers_layer_norm_comment = (
728        ast_edits.WARNING,
729        "(Manual edit required) `tf.contrib.layers.layer_norm` has been "
730        "deprecated, and its implementation has been integrated with "
731        "`tf.keras.layers.LayerNormalization` in TensorFlow 2.0. "
732        "Note that, the default value of `epsilon` is changed to `1e-3` in the "
733        "new API from `1e-12`, and this may introduce numerical differences. "
734        "Please check the new API and use that instead."
735    )
736
737    contrib_estimator_head_comment = (
738        ast_edits.WARNING,
739        "(Manual edit required) `tf.contrib.estimator.*_head` has been "
740        "deprecated, and its implementation has been integrated with "
741        "`tf.estimator.*Head` in TensorFlow 2.0. "
742        "Please check the new API and use that instead."
743    )
744
745    initializers_no_dtype_comment = (
746        ast_edits.INFO, "Initializers no longer have the "
747        "dtype argument in the constructor or partition_info argument in the "
748        "__call__ method.\nThe calls have been converted to compat.v1 for "
749        "safety (even though they may already have been correct).")
750
751    metrics_comment = (
752        ast_edits.INFO,
753        "tf.metrics have been replaced with object oriented versions in"
754        " TF 2.0 and after. The metric function calls have been converted to "
755        "compat.v1 for backward compatibility. Please update these calls to "
756        "the TF 2.0 versions.")
757
758    losses_comment = (
759        ast_edits.INFO,
760        "tf.losses have been replaced with object oriented versions in"
761        " TF 2.0 and after. The loss function calls have been converted to "
762        "compat.v1 for backward compatibility. Please update these calls to "
763        "the TF 2.0 versions.")
764
765    # This could be done with a _rename_if_arg_not_found_transformer
766    deprecate_partition_strategy_comment = (
767        ast_edits.WARNING,
768        "`partition_strategy` has been removed from <function name>. "
769        " The 'div' strategy will be used by default.")
770
771    # make change instead
772    uniform_unit_scaling_initializer_comment = (
773        ast_edits.ERROR,
774        "uniform_unit_scaling_initializer has been removed. Please use"
775        " tf.initializers.variance_scaling instead with distribution=uniform "
776        "to get equivalent behaviour.")
777
778    # Make change instead (issue warning about strip_...)
779    export_saved_model_renamed = (
780        ast_edits.ERROR,
781        "(Manual edit required) Please rename the method export_savedmodel() "
782        "to export_saved_model(). Two things to note:\n\t(1) The argument "
783        "strip_default_attributes has been removed. The function will always "
784        "strip the default attributes from ops. If this breaks your code, "
785        "please switch to tf.compat.v1.estimator.Estimator.\n\t(2) This change "
786        "only effects core estimator. If you are using "
787        "tf.contrib.learn.Estimator, please switch to using core estimator.")
788
789    summary_api_comment = (
790        ast_edits.INFO,
791        "The TF 1.x summary API cannot be automatically migrated to TF 2.0, so "
792        "symbols have been converted to tf.compat.v1.summary.* and must be "
793        "migrated manually. Typical usage will only require changes to the "
794        "summary writing logic, not to individual calls like scalar(). "
795        "For examples of the new summary API, see the Effective TF 2.0 "
796        "migration document or check the TF 2.0 TensorBoard tutorials.")
797
798    contrib_summary_comment = (
799        ast_edits.WARNING,
800        "tf.contrib.summary.* functions have been migrated best-effort to "
801        "tf.compat.v2.summary.* equivalents where possible, but the resulting "
802        "code is not guaranteed to work, so please check carefully. For more "
803        "information about the new summary API, see the Effective TF 2.0 "
804        "migration document or check the updated TensorBoard tutorials.")
805
806    contrib_summary_family_arg_comment = (
807        ast_edits.WARNING,
808        "<function name> replacement does not accept a 'family' argument; "
809        "instead regular name scoping should be used. This call site specifies "
810        "a family argument that has been removed on conversion, so the emitted "
811        "tag names may be incorrect without manual editing.")
812
813    contrib_create_file_writer_comment = (
814        ast_edits.WARNING,
815        "tf.contrib.summary.create_file_writer() has been ported to the new "
816        "tf.compat.v2.summary.create_file_writer(), which no longer re-uses "
817        "existing event files for the same logdir; instead it always opens a "
818        "new writer/file. The python writer objects must be re-used explicitly "
819        "if the reusing behavior is desired.")
820
821    contrib_summary_record_every_n_comment = (
822        ast_edits.ERROR,
823        "(Manual edit required) "
824        "tf.contrib.summary.record_summaries_every_n_global_steps(n, step) "
825        "should be replaced by a call to tf.compat.v2.summary.record_if() with "
826        "the argument `lambda: tf.math.equal(0, global_step % n)` (or in graph "
827        "mode, the lambda body can be used directly). If no global step was "
828        "passed, instead use tf.compat.v1.train.get_or_create_global_step().")
829
830    contrib_summary_graph_comment = (
831        ast_edits.ERROR,
832        "(Manual edit required) tf.contrib.summary.graph() has no direct "
833        "equivalent in TF 2.0 because manual graph construction has been "
834        "superseded by use of tf.function. To log tf.function execution graphs "
835        "to the summary writer, use the new tf.compat.v2.summary.trace_* "
836        "functions instead.")
837
838    contrib_summary_import_event_comment = (
839        ast_edits.ERROR,
840        "(Manual edit required) tf.contrib.summary.import_event() has no "
841        "direct equivalent in TF 2.0. For a similar experimental feature, try "
842        "tf.compat.v2.summary.experimental.write_raw_pb() which also accepts "
843        "serialized summary protocol buffer input, but for tf.Summary "
844        "protobufs rather than tf.Events.")
845
846    keras_default_save_format_comment = (
847        ast_edits.WARNING,
848        "(This warning is only applicable if the code saves a tf.Keras model) "
849        "Keras model.save now saves to the Tensorflow SavedModel format by "
850        "default, instead of HDF5. To continue saving to HDF5, add the "
851        "argument save_format='h5' to the save() function.")
852
853    distribute_strategy_api_changes = (
854        "If you're using the strategy with a "
855        "custom training loop, note the following changes in methods: "
856        "make_dataset_iterator->experimental_distribute_dataset, "
857        "experimental_make_numpy_iterator->experimental_make_numpy_dataset, "
858        "extended.call_for_each_replica->run, "
859        "reduce requires an axis argument, "
860        "unwrap->experimental_local_results "
861        "experimental_initialize and experimental_finalize no longer needed ")
862
863    contrib_mirrored_strategy_warning = (
864        ast_edits.ERROR,
865        "(Manual edit required) tf.contrib.distribute.MirroredStrategy has "
866        "been migrated to tf.distribute.MirroredStrategy. Things to note: "
867        "Constructor arguments have changed. If you are using "
868        "MirroredStrategy with Keras training framework, the input provided to "
869        "`model.fit` will be assumed to have global batch size and split "
870        "across the replicas. " + distribute_strategy_api_changes)
871
872    core_mirrored_strategy_warning = (
873        ast_edits.WARNING,
874        "(Manual edit may be required) tf.distribute.MirroredStrategy API has "
875        "changed. " + distribute_strategy_api_changes)
876
877    contrib_one_device_strategy_warning = (
878        ast_edits.ERROR,
879        "(Manual edit required) tf.contrib.distribute.OneDeviceStrategy has "
880        "been migrated to tf.distribute.OneDeviceStrategy. " +
881        distribute_strategy_api_changes)
882
883    contrib_tpu_strategy_warning = (
884        ast_edits.ERROR,
885        "(Manual edit required) tf.contrib.distribute.TPUStrategy has "
886        "been migrated to tf.distribute.TPUStrategy. Note the "
887        "slight changes in constructor. " + distribute_strategy_api_changes)
888
889    contrib_collective_strategy_warning = (
890        ast_edits.ERROR,
891        "(Manual edit required) "
892        "tf.contrib.distribute.CollectiveAllReduceStrategy has "
893        "been migrated to "
894        "tf.distribute.experimental.MultiWorkerMirroredStrategy. Note the "
895        "changes in constructor. " + distribute_strategy_api_changes)
896
897    contrib_ps_strategy_warning = (
898        ast_edits.ERROR, "(Manual edit required) "
899        "tf.contrib.distribute.ParameterServerStrategy has "
900        "been migrated to "
901        "tf.compat.v1.distribute.experimental.ParameterServerStrategy (multi "
902        "machine) and tf.distribute.experimental.CentralStorageStrategy (one "
903        "machine). Note the changes in constructors. " +
904        distribute_strategy_api_changes)
905
906    keras_experimental_export_comment = (
907        ast_edits.WARNING,
908        "tf.keras.experimental.export_saved_model and "
909        "tf.keras.experimental.load_from_saved_model have been deprecated."
910        "Please use model.save(path, save_format='tf') "
911        "(or alternatively tf.keras.models.save_model), and "
912        "tf.keras.models.load_model(path) instead.")
913
914    # Function warnings. <function name> placeholder inside warnings will be
915    # replaced by function name.
916    # You can use *. to add items which do not check the FQN, and apply to e.g.,
917    # methods.
918    self.function_warnings = {
919        "*.export_savedmodel":
920            export_saved_model_renamed,
921        "*.save":
922            keras_default_save_format_comment,
923        "tf.assert_equal":
924            assert_return_type_comment,
925        "tf.assert_none_equal":
926            assert_return_type_comment,
927        "tf.assert_negative":
928            assert_return_type_comment,
929        "tf.assert_positive":
930            assert_return_type_comment,
931        "tf.assert_non_negative":
932            assert_return_type_comment,
933        "tf.assert_non_positive":
934            assert_return_type_comment,
935        "tf.assert_near":
936            assert_return_type_comment,
937        "tf.assert_less":
938            assert_return_type_comment,
939        "tf.assert_less_equal":
940            assert_return_type_comment,
941        "tf.assert_greater":
942            assert_return_type_comment,
943        "tf.assert_greater_equal":
944            assert_return_type_comment,
945        "tf.assert_integer":
946            assert_return_type_comment,
947        "tf.assert_type":
948            assert_return_type_comment,
949        "tf.assert_scalar":
950            assert_return_type_comment,
951        "tf.assert_rank":
952            assert_rank_comment,
953        "tf.assert_rank_at_least":
954            assert_rank_comment,
955        "tf.assert_rank_in":
956            assert_rank_comment,
957        "tf.contrib.layers.layer_norm":
958            contrib_layers_layer_norm_comment,
959        "tf.contrib.estimator.binary_classification_head":
960            contrib_estimator_head_comment,
961        "tf.contrib.estimator.logistic_regression_head":
962            contrib_estimator_head_comment,
963        "tf.contrib.estimator.multi_class_head":
964            contrib_estimator_head_comment,
965        "tf.contrib.estimator.multi_head":
966            contrib_estimator_head_comment,
967        "tf.contrib.estimator.multi_label_head":
968            contrib_estimator_head_comment,
969        "tf.contrib.estimator.poisson_regression_head":
970            contrib_estimator_head_comment,
971        "tf.contrib.estimator.regression_head":
972            contrib_estimator_head_comment,
973        "tf.contrib.saved_model.load_keras_model":
974            keras_experimental_export_comment,
975        "tf.contrib.saved_model.save_keras_model":
976            keras_experimental_export_comment,
977        "tf.contrib.summary.all_summary_ops":
978            contrib_summary_comment,
979        "tf.contrib.summary.audio":
980            contrib_summary_comment,
981        "tf.contrib.summary.create_file_writer":
982            contrib_create_file_writer_comment,
983        "tf.contrib.summary.generic":
984            contrib_summary_comment,
985        "tf.contrib.summary.graph":
986            contrib_summary_graph_comment,
987        "tf.contrib.summary.histogram":
988            contrib_summary_comment,
989        "tf.contrib.summary.import_event":
990            contrib_summary_import_event_comment,
991        "tf.contrib.summary.image":
992            contrib_summary_comment,
993        "tf.contrib.summary.record_summaries_every_n_global_steps":
994            contrib_summary_record_every_n_comment,
995        "tf.contrib.summary.scalar":
996            contrib_summary_comment,
997        "tf.debugging.assert_equal":
998            assert_return_type_comment,
999        "tf.debugging.assert_greater":
1000            assert_return_type_comment,
1001        "tf.debugging.assert_greater_equal":
1002            assert_return_type_comment,
1003        "tf.debugging.assert_integer":
1004            assert_return_type_comment,
1005        "tf.debugging.assert_less":
1006            assert_return_type_comment,
1007        "tf.debugging.assert_less_equal":
1008            assert_return_type_comment,
1009        "tf.debugging.assert_near":
1010            assert_return_type_comment,
1011        "tf.debugging.assert_negative":
1012            assert_return_type_comment,
1013        "tf.debugging.assert_non_negative":
1014            assert_return_type_comment,
1015        "tf.debugging.assert_non_positive":
1016            assert_return_type_comment,
1017        "tf.debugging.assert_none_equal":
1018            assert_return_type_comment,
1019        "tf.debugging.assert_positive":
1020            assert_return_type_comment,
1021        "tf.debugging.assert_type":
1022            assert_return_type_comment,
1023        "tf.debugging.assert_scalar":
1024            assert_return_type_comment,
1025        "tf.debugging.assert_rank":
1026            assert_rank_comment,
1027        "tf.debugging.assert_rank_at_least":
1028            assert_rank_comment,
1029        "tf.debugging.assert_rank_in":
1030            assert_rank_comment,
1031        "tf.train.exponential_decay":
1032            decay_function_comment,
1033        "tf.train.piecewise_constant_decay":
1034            decay_function_comment,
1035        "tf.train.polynomial_decay":
1036            decay_function_comment,
1037        "tf.train.natural_exp_decay":
1038            decay_function_comment,
1039        "tf.train.inverse_time_decay":
1040            decay_function_comment,
1041        "tf.train.cosine_decay":
1042            decay_function_comment,
1043        "tf.train.cosine_decay_restarts":
1044            decay_function_comment,
1045        "tf.train.linear_cosine_decay":
1046            decay_function_comment,
1047        "tf.train.noisy_linear_cosine_decay":
1048            decay_function_comment,
1049        "tf.nn.embedding_lookup":
1050            deprecate_partition_strategy_comment,
1051        "tf.nn.embedding_lookup_sparse":
1052            deprecate_partition_strategy_comment,
1053        "tf.nn.nce_loss":
1054            deprecate_partition_strategy_comment,
1055        "tf.nn.safe_embedding_lookup_sparse":
1056            deprecate_partition_strategy_comment,
1057        "tf.nn.sampled_softmax_loss":
1058            deprecate_partition_strategy_comment,
1059        "tf.keras.estimator.model_to_estimator":
1060            (ast_edits.WARNING,
1061             "Estimators from <function name> will save object-based "
1062             "checkpoints (format used by `keras_model.save_weights` and "
1063             "`keras_model.load_weights`) by default in 2.0. To continue "
1064             "saving name-based checkpoints, set `checkpoint_format='saver'`."),
1065        "tf.keras.experimental.export_saved_model":
1066            keras_experimental_export_comment,
1067        "tf.keras.experimental.load_from_saved_model":
1068            keras_experimental_export_comment,
1069        "tf.keras.initializers.Zeros":
1070            initializers_no_dtype_comment,
1071        "tf.keras.initializers.zeros":
1072            initializers_no_dtype_comment,
1073        "tf.keras.initializers.Ones":
1074            initializers_no_dtype_comment,
1075        "tf.keras.initializers.ones":
1076            initializers_no_dtype_comment,
1077        "tf.keras.initializers.Constant":
1078            initializers_no_dtype_comment,
1079        "tf.keras.initializers.constant":
1080            initializers_no_dtype_comment,
1081        "tf.keras.initializers.VarianceScaling":
1082            initializers_no_dtype_comment,
1083        "tf.keras.initializers.Orthogonal":
1084            initializers_no_dtype_comment,
1085        "tf.keras.initializers.orthogonal":
1086            initializers_no_dtype_comment,
1087        "tf.keras.initializers.Identity":
1088            initializers_no_dtype_comment,
1089        "tf.keras.initializers.identity":
1090            initializers_no_dtype_comment,
1091        "tf.keras.initializers.glorot_uniform":
1092            initializers_no_dtype_comment,
1093        "tf.keras.initializers.glorot_normal":
1094            initializers_no_dtype_comment,
1095        "tf.initializers.zeros":
1096            initializers_no_dtype_comment,
1097        "tf.zeros_initializer":
1098            initializers_no_dtype_comment,
1099        "tf.initializers.ones":
1100            initializers_no_dtype_comment,
1101        "tf.ones_initializer":
1102            initializers_no_dtype_comment,
1103        "tf.initializers.constant":
1104            initializers_no_dtype_comment,
1105        "tf.constant_initializer":
1106            initializers_no_dtype_comment,
1107        "tf.initializers.random_uniform":
1108            initializers_no_dtype_comment,
1109        "tf.random_uniform_initializer":
1110            initializers_no_dtype_comment,
1111        "tf.initializers.random_normal":
1112            initializers_no_dtype_comment,
1113        "tf.random_normal_initializer":
1114            initializers_no_dtype_comment,
1115        "tf.initializers.truncated_normal":
1116            initializers_no_dtype_comment,
1117        "tf.truncated_normal_initializer":
1118            initializers_no_dtype_comment,
1119        "tf.initializers.variance_scaling":
1120            initializers_no_dtype_comment,
1121        "tf.variance_scaling_initializer":
1122            initializers_no_dtype_comment,
1123        "tf.initializers.orthogonal":
1124            initializers_no_dtype_comment,
1125        "tf.orthogonal_initializer":
1126            initializers_no_dtype_comment,
1127        "tf.initializers.identity":
1128            initializers_no_dtype_comment,
1129        "tf.glorot_uniform_initializer":
1130            initializers_no_dtype_comment,
1131        "tf.initializers.glorot_uniform":
1132            initializers_no_dtype_comment,
1133        "tf.glorot_normal_initializer":
1134            initializers_no_dtype_comment,
1135        "tf.initializers.glorot_normal":
1136            initializers_no_dtype_comment,
1137        "tf.losses.absolute_difference":
1138            losses_comment,
1139        "tf.losses.add_loss":
1140            losses_comment,
1141        "tf.losses.compute_weighted_loss":
1142            losses_comment,
1143        "tf.losses.cosine_distance":
1144            losses_comment,
1145        "tf.losses.get_losses":
1146            losses_comment,
1147        "tf.losses.get_regularization_loss":
1148            losses_comment,
1149        "tf.losses.get_regularization_losses":
1150            losses_comment,
1151        "tf.losses.get_total_loss":
1152            losses_comment,
1153        "tf.losses.hinge_loss":
1154            losses_comment,
1155        "tf.losses.huber_loss":
1156            losses_comment,
1157        "tf.losses.log_loss":
1158            losses_comment,
1159        "tf.losses.mean_pairwise_squared_error":
1160            losses_comment,
1161        "tf.losses.mean_squared_error":
1162            losses_comment,
1163        "tf.losses.sigmoid_cross_entropy":
1164            losses_comment,
1165        "tf.losses.softmax_cross_entropy":
1166            losses_comment,
1167        "tf.losses.sparse_softmax_cross_entropy":
1168            losses_comment,
1169        "tf.metrics.accuracy":
1170            metrics_comment,
1171        "tf.metrics.auc":
1172            metrics_comment,
1173        "tf.metrics.average_precision_at_k":
1174            metrics_comment,
1175        "tf.metrics.false_negatives":
1176            metrics_comment,
1177        "tf.metrics.false_negatives_at_thresholds":
1178            metrics_comment,
1179        "tf.metrics.false_positives":
1180            metrics_comment,
1181        "tf.metrics.false_positives_at_thresholds":
1182            metrics_comment,
1183        "tf.metrics.mean":
1184            metrics_comment,
1185        "tf.metrics.mean_absolute_error":
1186            metrics_comment,
1187        "tf.metrics.mean_cosine_distance":
1188            metrics_comment,
1189        "tf.metrics.mean_iou":
1190            metrics_comment,
1191        "tf.metrics.mean_per_class_accuracy":
1192            metrics_comment,
1193        "tf.metrics.mean_relative_error":
1194            metrics_comment,
1195        "tf.metrics.mean_squared_error":
1196            metrics_comment,
1197        "tf.metrics.mean_tensor":
1198            metrics_comment,
1199        "tf.metrics.percentage_below":
1200            metrics_comment,
1201        "tf.metrics.precision":
1202            metrics_comment,
1203        "tf.metrics.precision_at_k":
1204            metrics_comment,
1205        "tf.metrics.precision_at_thresholds":
1206            metrics_comment,
1207        "tf.metrics.precision_at_top_k":
1208            metrics_comment,
1209        "tf.metrics.recall":
1210            metrics_comment,
1211        "tf.metrics.recall_at_k":
1212            metrics_comment,
1213        "tf.metrics.recall_at_thresholds":
1214            metrics_comment,
1215        "tf.metrics.recall_at_top_k":
1216            metrics_comment,
1217        "tf.metrics.root_mean_squared_error":
1218            metrics_comment,
1219        "tf.metrics.sensitivity_at_specificity":
1220            metrics_comment,
1221        "tf.metrics.sparse_average_precision_at_k":
1222            metrics_comment,
1223        "tf.metrics.sparse_precision_at_k":
1224            metrics_comment,
1225        "tf.metrics.specificity_at_sensitivity":
1226            metrics_comment,
1227        "tf.metrics.true_negatives":
1228            metrics_comment,
1229        "tf.metrics.true_negatives_at_thresholds":
1230            metrics_comment,
1231        "tf.metrics.true_positives":
1232            metrics_comment,
1233        "tf.metrics.true_positives_at_thresholds":
1234            metrics_comment,
1235        "tf.get_variable":
1236            (ast_edits.WARNING,
1237             "<function name> returns ResourceVariables by default in 2.0, "
1238             "which have well-defined semantics and are stricter about shapes. "
1239             "You can disable this behavior by passing use_resource=False, or "
1240             "by calling tf.compat.v1.disable_resource_variables()."),
1241        "tf.pywrap_tensorflow":
1242            (ast_edits.ERROR,
1243             "<function name> cannot be converted automatically. "
1244             "`tf.pywrap_tensorflow` will not be distributed with "
1245             "TensorFlow 2.0, please consider an alternative in public "
1246             "TensorFlow APIs."),
1247        "tf.contrib.distribute.MirroredStrategy":
1248            contrib_mirrored_strategy_warning,
1249        "tf.distribute.MirroredStrategy":
1250            core_mirrored_strategy_warning,
1251        "tf.contrib.distribute.OneDeviceStrategy":
1252            contrib_one_device_strategy_warning,
1253        "tf.contrib.distribute.TPUStrategy":
1254            contrib_tpu_strategy_warning,
1255        "tf.contrib.distribute.CollectiveAllReduceStrategy":
1256            contrib_collective_strategy_warning,
1257        "tf.contrib.distribute.ParameterServerStrategy":
1258            contrib_ps_strategy_warning,
1259        "tf.summary.FileWriter": summary_api_comment,
1260        "tf.summary.FileWriterCache": summary_api_comment,
1261        "tf.summary.Summary": summary_api_comment,
1262        "tf.summary.audio": summary_api_comment,
1263        "tf.summary.histogram": summary_api_comment,
1264        "tf.summary.image": summary_api_comment,
1265        "tf.summary.merge": summary_api_comment,
1266        "tf.summary.merge_all": summary_api_comment,
1267        "tf.summary.scalar": summary_api_comment,
1268        "tf.summary.tensor_summary": summary_api_comment,
1269        "tf.summary.text": summary_api_comment,
1270    }
1271    all_renames_v2.add_contrib_direct_import_support(self.function_warnings)
1272
1273    for symbol, replacement in all_renames_v2.addons_symbol_mappings.items():
1274      warning = (
1275          ast_edits.WARNING, (
1276              "(Manual edit required) `{}` has been migrated to `{}` in "
1277              "TensorFlow Addons. The API spec may have changed during the "
1278              "migration. Please see https://github.com/tensorflow/addons "
1279              "for more info.").format(symbol, replacement))
1280      self.function_warnings[symbol] = warning
1281
1282    # Warnings that are emitted only if a specific arg is found.
1283    self.function_arg_warnings = {
1284        "tf.nn.conv1d": {
1285            ("use_cudnn_on_gpu", 4):
1286                (ast_edits.WARNING,
1287                 "use_cudnn_on_gpu has been removed, behavior is now equivalent"
1288                 "to setting it to True."),
1289        },
1290        "tf.nn.conv2d": {
1291            ("use_cudnn_on_gpu", 4):
1292                (ast_edits.WARNING,
1293                 "use_cudnn_on_gpu has been removed, behavior is now equivalent"
1294                 "to setting it to True."),
1295        },
1296        "tf.nn.conv2d_backprop_filter": {
1297            ("use_cudnn_on_gpu", 5):
1298                (ast_edits.WARNING,
1299                 "use_cudnn_on_gpu has been removed, behavior is now equivalent"
1300                 "to setting it to True."),
1301        },
1302        "tf.nn.conv2d_backprop_input": {
1303            ("use_cudnn_on_gpu", 5):
1304                (ast_edits.WARNING,
1305                 "use_cudnn_on_gpu has been removed, behavior is now equivalent"
1306                 "to setting it to True."),
1307        },
1308        "tf.gradients": {
1309            ("colocate_gradients_with_ops", 4):
1310                (ast_edits.INFO, "tf.gradients no longer takes "
1311                 "'colocate_gradients_with_ops' argument, it behaves as if it "
1312                 "was set to True."),
1313        },
1314        "tf.hessians": {
1315            ("colocate_gradients_with_ops", 3):
1316                (ast_edits.INFO, "tf.hessians no longer takes "
1317                 "'colocate_gradients_with_ops' argument, it behaves as if it "
1318                 "was set to True."),
1319        },
1320        "*.minimize": {
1321            ("colocate_gradients_with_ops", 5):
1322                (ast_edits.INFO, "Optimizer.minimize no longer takes "
1323                 "'colocate_gradients_with_ops' argument, it behaves as if it "
1324                 "was set to True."),
1325        },
1326        "*.compute_gradients": {
1327            ("colocate_gradients_with_ops", 4):
1328                (ast_edits.INFO, "Optimizer.compute_gradients no "
1329                 "longer takes 'colocate_gradients_with_ops' argument, it "
1330                 "behaves as if it was set to True."),
1331        },
1332        "tf.cond": {
1333            ("strict", 3):
1334                (ast_edits.WARNING,
1335                 "tf.cond no longer takes 'strict' argument, it behaves as "
1336                 "if was set to True.")
1337        },
1338        "tf.contrib.summary.audio": {
1339            ("family", 4): contrib_summary_family_arg_comment,
1340        },
1341        "tf.contrib.summary.create_file_writer": {
1342            ("name", 4):
1343                (ast_edits.WARNING,
1344                 "tf.contrib.summary.create_file_writer() no longer supports "
1345                 "implicit writer re-use based on shared logdirs or resource "
1346                 "names; this call site passed a 'name' argument that has been "
1347                 "removed. The new tf.compat.v2.summary.create_file_writer() "
1348                 "replacement has a 'name' parameter but the semantics are "
1349                 "the usual ones to name the op itself and do not control "
1350                 "writer re-use; writers must be manually re-used if desired.")
1351        },
1352        "tf.contrib.summary.generic": {
1353            ("name", 0): (
1354                ast_edits.WARNING,
1355                "tf.contrib.summary.generic() takes a 'name' argument for the "
1356                "op name that also determines the emitted tag (prefixed by any "
1357                "active name scopes), but tf.compat.v2.summary.write(), which "
1358                "replaces it, separates these into 'tag' and 'name' arguments. "
1359                "The 'name' argument here has been converted to 'tag' to "
1360                "preserve a meaningful tag, but any name scopes will not be "
1361                "reflected in the tag without manual editing."),
1362            ("family", 3): contrib_summary_family_arg_comment,
1363        },
1364        "tf.contrib.summary.histogram": {
1365            ("family", 2): contrib_summary_family_arg_comment,
1366        },
1367        "tf.contrib.summary.image": {
1368            ("bad_color", 2): (
1369                ast_edits.WARNING,
1370                "tf.contrib.summary.image no longer takes the 'bad_color' "
1371                "argument; caller must now preprocess if needed. This call "
1372                "site specifies a bad_color argument so it cannot be converted "
1373                "safely."),
1374            ("family", 4): contrib_summary_family_arg_comment,
1375        },
1376        "tf.contrib.summary.scalar": {
1377            ("family", 2): contrib_summary_family_arg_comment,
1378        },
1379        "tf.image.resize": {
1380            ("align_corners", 3):
1381                (ast_edits.WARNING,
1382                 "align_corners is not supported by tf.image.resize, the new "
1383                 "default transformation is close to what v1 provided. If you "
1384                 "require exactly the same transformation as before, use "
1385                 "compat.v1.image.resize."),
1386        },
1387        "tf.image.resize_bilinear": {
1388            ("align_corners", 2):
1389                (ast_edits.WARNING,
1390                 "align_corners is not supported by tf.image.resize, the new "
1391                 "default transformation is close to what v1 provided. If you "
1392                 "require exactly the same transformation as before, use "
1393                 "compat.v1.image.resize_bilinear."),
1394        },
1395        "tf.image.resize_area": {
1396            ("align_corners", 2):
1397                (ast_edits.WARNING,
1398                 "align_corners is not supported by tf.image.resize, the new "
1399                 "default transformation is close to what v1 provided. If you "
1400                 "require exactly the same transformation as before, use "
1401                 "compat.v1.image.resize_area."),
1402        },
1403        "tf.image.resize_bicubic": {
1404            ("align_corners", 2):
1405                (ast_edits.WARNING,
1406                 "align_corners is not supported by tf.image.resize, the new "
1407                 "default transformation is close to what v1 provided. If you "
1408                 "require exactly the same transformation as before, use "
1409                 "compat.v1.image.resize_bicubic."),
1410        },
1411        "tf.image.resize_nearest_neighbor": {
1412            ("align_corners", 2):
1413                (ast_edits.WARNING,
1414                 "align_corners is not supported by tf.image.resize, the new "
1415                 "default transformation is close to what v1 provided. If you "
1416                 "require exactly the same transformation as before, use "
1417                 "compat.v1.image.resize_nearest_neighbor."),
1418        },
1419    }
1420    all_renames_v2.add_contrib_direct_import_support(self.function_arg_warnings)
1421
1422    # Specially handled functions
1423    # Each transformer is a callable which will be called with the arguments
1424    #   transformer(parent, node, full_name, name, logs)
1425    # Where logs is a list to which (level, line, col, msg) tuples can be
1426    # appended, full_name is the FQN of the function called (or None if that is
1427    # unknown), name is the name of the function called (or None is that is
1428    # unknown). node is an ast.Call node representing this function call, and
1429    # parent is its parent in the AST.
1430    # The function may modify node (but not parent), and must return
1431    # - none, if nothing was modified
1432    # - node, if node was modified in place (make sure to use
1433    #   pasta.ast_utils.replace_child to swap out children, otherwise formatting
1434    #   may get messy)
1435    # - a replacement for node, if the whole call node was replaced. The caller
1436    #   will take care of changing parent.
1437    canned_estimator_msg_optimizer = (
1438        "tf.keras.optimizers.* only, so the call was converted to compat.v1. "
1439        "Please note that tf.train.Optimizers have one-to-one correspondents "
1440        "in tf.keras.optimizers, so you may be able to convert to the new "
1441        "optimizers directly (See https://www.tensorflow.org/api_docs/python"
1442        "/tf/keras/optimizers). Checkpoint compatibility is not guaranteed, "
1443        "but there is a checkpoint converter tool that you can use.")
1444    canned_estimator_msg = (
1445        "no longer takes `input_layer_partitioner` arg, and it supports "
1446        + canned_estimator_msg_optimizer)
1447    self.function_transformers = {
1448        "*.make_initializable_iterator": _iterator_transformer,
1449        "*.make_one_shot_iterator": _iterator_transformer,
1450        "tf.nn.dropout": _dropout_transformer,
1451        "tf.to_bfloat16": _cast_transformer,
1452        "tf.to_complex128": _cast_transformer,
1453        "tf.to_complex64": _cast_transformer,
1454        "tf.to_double": _cast_transformer,
1455        "tf.to_float": _cast_transformer,
1456        "tf.to_int32": _cast_transformer,
1457        "tf.to_int64": _cast_transformer,
1458        "tf.nn.softmax_cross_entropy_with_logits":
1459            _softmax_cross_entropy_with_logits_transformer,
1460        "tf.image.extract_glimpse": _extract_glimpse_transformer,
1461        "tf.image.resize_area": _image_resize_transformer,
1462        "tf.image.resize_bicubic": _image_resize_transformer,
1463        "tf.image.resize_bilinear": _image_resize_transformer,
1464        "tf.image.resize_nearest_neighbor": _image_resize_transformer,
1465        "tf.nn.fractional_avg_pool": _pool_seed_transformer,
1466        "tf.nn.fractional_max_pool": _pool_seed_transformer,
1467        "tf.name_scope": _name_scope_transformer,
1468        # TODO(b/129398290)
1469        # "tf.string_split": _string_split_transformer,
1470        "tf.strings.split": _string_split_rtype_transformer,
1471        "tf.estimator.BaselineEstimator":
1472            functools.partial(
1473                _rename_if_arg_found_transformer,
1474                arg_name="optimizer",
1475                message=("tf.estimator.BaselineEstimator supports "
1476                         + canned_estimator_msg_optimizer),
1477            ),
1478        "tf.estimator.BaselineClassifier":
1479            functools.partial(
1480                _rename_if_arg_found_and_add_loss_reduction_transformer,
1481                arg_names=["optimizer"],
1482                message=("tf.estimator.BaselineClassifier supports "
1483                         + canned_estimator_msg_optimizer),
1484            ),
1485        "tf.estimator.BaselineRegressor":
1486            functools.partial(
1487                _rename_if_arg_found_and_add_loss_reduction_transformer,
1488                arg_names=["input_layer_partitioner", "optimizer"],
1489                message=("tf.estimator.BaselineRegressor supports "
1490                         + canned_estimator_msg_optimizer),
1491            ),
1492        "tf.estimator.DNNEstimator":
1493            functools.partial(
1494                _rename_if_any_arg_found_transformer,
1495                arg_names=["input_layer_partitioner", "optimizer"],
1496                message="tf.estimator.DNNEstimator no longer takes "
1497                "input_layer_partitioner, so the call was converted to "
1498                "compat.v1."
1499            ),
1500        "tf.estimator.DNNClassifier":
1501            functools.partial(
1502                _rename_if_arg_found_and_add_loss_reduction_transformer,
1503                arg_names=["input_layer_partitioner", "optimizer"],
1504                message="tf.estimator.DNNClassifier " + canned_estimator_msg,
1505            ),
1506        "tf.estimator.DNNRegressor":
1507            functools.partial(
1508                _rename_if_arg_found_and_add_loss_reduction_transformer,
1509                arg_names=["input_layer_partitioner", "optimizer"],
1510                message="tf.estimator.DNNRegressor " + canned_estimator_msg,
1511            ),
1512        "tf.estimator.LinearEstimator":
1513            functools.partial(
1514                _rename_if_any_arg_found_transformer,
1515                arg_names=["input_layer_partitioner", "optimizer"],
1516                message="tf.estimator.LinearEstimator " + canned_estimator_msg,
1517            ),
1518        "tf.estimator.LinearClassifier":
1519            functools.partial(
1520                _rename_if_arg_found_and_add_loss_reduction_transformer,
1521                arg_names=["input_layer_partitioner", "optimizer"],
1522                message="tf.estimator.LinearClassifier " + canned_estimator_msg,
1523            ),
1524        "tf.estimator.LinearRegressor":
1525            functools.partial(
1526                _rename_if_arg_found_and_add_loss_reduction_transformer,
1527                arg_names=["input_layer_partitioner", "optimizer"],
1528                message="tf.estimator.LinearRegressor " + canned_estimator_msg,
1529            ),
1530        "tf.estimator.DNNLinearCombinedEstimator":
1531            functools.partial(
1532                _rename_if_any_arg_found_transformer,
1533                arg_names=[
1534                    "input_layer_partitioner", "dnn_optimizer",
1535                    "linear_optimizer"
1536                ],
1537                message=("tf.estimator.DNNLinearCombinedEstimator "
1538                         + canned_estimator_msg),
1539            ),
1540        "tf.estimator.DNNLinearCombinedClassifier":
1541            functools.partial(
1542                _rename_if_arg_found_and_add_loss_reduction_transformer,
1543                arg_names=[
1544                    "input_layer_partitioner", "dnn_optimizer",
1545                    "linear_optimizer"
1546                ],
1547                message=("tf.estimator.DNNLinearCombinedClassifier "
1548                         + canned_estimator_msg),
1549            ),
1550        "tf.estimator.DNNLinearCombinedRegressor":
1551            functools.partial(
1552                _rename_if_arg_found_and_add_loss_reduction_transformer,
1553                arg_names=[
1554                    "input_layer_partitioner", "dnn_optimizer",
1555                    "linear_optimizer"
1556                ],
1557                message=("tf.estimator.DNNLinearCombinedRegressor "
1558                         + canned_estimator_msg),
1559            ),
1560        "tf.device": functools.partial(
1561            _rename_if_arg_found_transformer, arg_name="device_name",
1562            arg_ok_predicate=_is_ast_str, remove_if_ok=False,
1563            message="tf.device no longer takes functions as an argument. "
1564            "We could not determine that the argument value is a string, so "
1565            "the call was converted to compat.v1."),
1566        "tf.zeros_like": functools.partial(
1567            _rename_if_arg_found_transformer, arg_name="optimize",
1568            arg_ok_predicate=_is_ast_true, remove_if_ok=True,
1569            message="tf.zeros_like no longer takes an optimize argument, and "
1570            "behaves as if optimize=True. This call site specifies something "
1571            "other than optimize=True, so it was converted to compat.v1."),
1572        "tf.ones_like": functools.partial(
1573            _rename_if_arg_found_transformer, arg_name="optimize",
1574            arg_ok_predicate=_is_ast_true, remove_if_ok=True,
1575            message="tf.ones_like no longer takes an optimize argument, and "
1576            "behaves as if optimize=True. This call site specifies something "
1577            "other than optimize=True, so it was converted to compat.v1."),
1578        "tf.while_loop": functools.partial(
1579            _rename_if_arg_found_transformer,
1580            arg_name="return_same_structure",
1581            arg_ok_predicate=_is_ast_true, remove_if_ok=True,
1582            message="tf.while_loop no longer takes 'return_same_structure' "
1583            "argument and behaves as if return_same_structure=True. This call "
1584            "site specifies something other than return_same_structure=True, "
1585            "so it was converted to compat.v1."),
1586        "tf.nn.ctc_beam_search_decoder": functools.partial(
1587            _rename_if_arg_found_transformer,
1588            arg_name="merge_repeated",
1589            arg_ok_predicate=_is_ast_false, remove_if_ok=True,
1590            message="tf.nn.ctc_beam_search_decoder no longer takes the "
1591            "'merge_repeated' argument and behaves as if merge_repeated=False. "
1592            "This call site specifies something other than "
1593            "merge_repeated=False, so it was converted to compat.v1."),
1594        "tf.nn.dilation2d": functools.partial(
1595            _add_argument_transformer,
1596            arg_name="data_format",
1597            arg_value_ast=ast.Str("NHWC")),
1598        "tf.nn.erosion2d": functools.partial(
1599            _add_argument_transformer,
1600            arg_name="data_format",
1601            arg_value_ast=ast.Str("NHWC")),
1602        "tf.contrib.summary.always_record_summaries": functools.partial(
1603            _add_summary_recording_cond_transformer, cond="True"),
1604        "tf.contrib.summary.audio": _add_summary_step_transformer,
1605        "tf.contrib.summary.generic": _add_summary_step_transformer,
1606        "tf.contrib.summary.histogram": _add_summary_step_transformer,
1607        "tf.contrib.summary.image": _add_summary_step_transformer,
1608        "tf.contrib.summary.never_record_summaries": functools.partial(
1609            _add_summary_recording_cond_transformer, cond="False"),
1610        "tf.contrib.summary.scalar": _add_summary_step_transformer,
1611        "tf.contrib.layers.l1_regularizer":
1612            _contrib_layers_l1_regularizer_transformer,
1613        "tf.contrib.layers.l2_regularizer":
1614            _contrib_layers_l2_regularizer_transformer,
1615        "tf.contrib.layers.xavier_initializer":
1616            _contrib_layers_xavier_initializer_transformer,
1617        "tf.contrib.layers.xavier_initializer_conv2d":
1618            _contrib_layers_xavier_initializer_transformer,
1619        "tf.contrib.layers.variance_scaling_initializer":
1620            _contrib_layers_variance_scaling_initializer_transformer,
1621        "tf.initializers.uniform_unit_scaling":
1622            _add_uniform_scaling_initializer_transformer,
1623        "tf.uniform_unit_scaling_initializer":
1624            _add_uniform_scaling_initializer_transformer,
1625        "slim.l1_regularizer":
1626            _contrib_layers_l1_regularizer_transformer,
1627        "slim.l2_regularizer":
1628            _contrib_layers_l2_regularizer_transformer,
1629        "slim.xavier_initializer":
1630            _contrib_layers_xavier_initializer_transformer,
1631        "slim.xavier_initializer_conv2d":
1632            _contrib_layers_xavier_initializer_transformer,
1633        "slim.variance_scaling_initializer":
1634            _contrib_layers_variance_scaling_initializer_transformer,
1635        "tf.keras.models.save_model": functools.partial(
1636            _add_argument_transformer,
1637            arg_name="save_format",
1638            arg_value_ast=ast.Str("h5")),
1639    }
1640    all_renames_v2.add_contrib_direct_import_support(self.function_transformers)
1641
1642    self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
1643
1644  def preprocess(self, root_node, after_compat_v1_upgrade=False):
1645    visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec())
1646    visitor.visit(root_node)
1647    detections = set(visitor.results)
1648
1649    # Upgrade explicit compat v1 imports if `upgrade_compat_v1_import` is
1650    # enabled. Then preprocess the updated root node.
1651    # We only do this upgrading once, because some forms of the import may
1652    # still cause errors but aren't trivially upgradeable, and we don't want
1653    # to enter an infinite loop. E.g. `from tensorflow.compat import v1, v2`.
1654    if (compat_v1_import in detections and self.upgrade_compat_v1_import and
1655        not after_compat_v1_upgrade):
1656      CompatV1ImportReplacer().visit(root_node)
1657      return self.preprocess(root_node, after_compat_v1_upgrade=True)
1658
1659    # If we have detected the presence of imports of specific TF versions,
1660    # We want to modify the update spec to check only module deprecations
1661    # and skip all other conversions.
1662    if detections:
1663      self.function_handle = {}
1664      self.function_reorders = {}
1665      self.function_keyword_renames = {}
1666      self.symbol_renames = {}
1667      self.function_warnings = {}
1668      self.change_to_function = {}
1669      self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
1670      self.function_transformers = {}
1671      self.import_renames = {}
1672    return root_node, visitor.log, visitor.warnings_and_errors
1673
1674  def clear_preprocessing(self):
1675    self.__init__()
1676
1677
1678def _is_ast_str(node):
1679  """Determine whether this node represents a string."""
1680  allowed_types = [ast.Str]
1681  if hasattr(ast, "Bytes"):
1682    allowed_types += [ast.Bytes]
1683  if hasattr(ast, "JoinedStr"):
1684    allowed_types += [ast.JoinedStr]
1685  if hasattr(ast, "FormattedValue"):
1686    allowed_types += [ast.FormattedValue]
1687  return isinstance(node, allowed_types)
1688
1689
1690def _is_ast_true(node):
1691  if hasattr(ast, "NameConstant"):
1692    return isinstance(node, ast.NameConstant) and node.value is True
1693  else:
1694    return isinstance(node, ast.Name) and node.id == "True"
1695
1696
1697def _is_ast_false(node):
1698  if hasattr(ast, "NameConstant"):
1699    return isinstance(node, ast.NameConstant) and node.value is False
1700  else:
1701    return isinstance(node, ast.Name) and node.id == "False"
1702
1703
1704# Lots of unused arguments below, since these are called in a standard manner.
1705# pylint: disable=unused-argument
1706
1707
1708def _rename_if_arg_found_transformer(parent, node, full_name, name, logs,
1709                                     arg_name=None,
1710                                     arg_ok_predicate=None,
1711                                     remove_if_ok=False,
1712                                     message=None):
1713  """Replaces the given call with tf.compat.v1 if the given arg is found.
1714
1715  This requires the function to be called with all named args, so for using
1716  this transformer, the function should also be added to renames.
1717
1718  If the arg is not found, the call site is left alone.
1719
1720  If the arg is found, and if arg_ok_predicate is given, it is called with
1721  the ast Expression representing the argument value found. If it returns
1722  True, the function is left alone.
1723
1724  If the arg is found, arg_ok_predicate is not None and returns ok, and
1725  remove_if_ok is True, the argument is removed from the call.
1726
1727  Otherwise, `compat.v1` is inserted between tf and the function name.
1728
1729  Args:
1730    parent: Parent of node.
1731    node: ast.Call node to maybe modify.
1732    full_name: full name of function to modify
1733    name: name of function to modify
1734    logs: list of logs to append to
1735    arg_name: name of the argument to look for
1736    arg_ok_predicate: predicate callable with the ast of the argument value,
1737      returns whether the argument value is allowed.
1738    remove_if_ok: remove the argument if present and ok as determined by
1739      arg_ok_predicate.
1740    message: message to print if a non-ok arg is found (and hence, the function
1741      is renamed to its compat.v1 version).
1742
1743  Returns:
1744    node, if it was modified, else None.
1745  """
1746  # Check whether arg is there.
1747  arg_present, arg_value = ast_edits.get_arg_value(node, arg_name)
1748  if not arg_present:
1749    return
1750
1751  # Check whether arg is problematic (and if not, maybe remove it).
1752  if arg_ok_predicate and arg_ok_predicate(arg_value):
1753    if remove_if_ok:
1754      for i, kw in enumerate(node.keywords):
1755        if kw.arg == arg_name:
1756          node.keywords.pop(i)
1757          logs.append((ast_edits.INFO, node.lineno, node.col_offset,
1758                       "Removed argument %s for function %s" % (
1759                           arg_name, full_name or name)))
1760          break
1761      return node
1762    else:
1763      return
1764
1765  # All conditions met, insert v1 and log what we did.
1766  # We must have a full name, so the func is an attribute.
1767  new_name = six.ensure_str(full_name).replace("tf.", "tf.compat.v1.", 1)
1768  node.func = ast_edits.full_name_node(new_name)
1769  logs.append((
1770      ast_edits.INFO, node.lineno, node.col_offset,
1771      "Renaming %s to %s because argument %s is present. %s" %
1772      (full_name, new_name, arg_name, message if message is not None else "")
1773  ))
1774  return node
1775
1776
1777def _add_argument_transformer(parent, node, full_name, name, logs,
1778                              arg_name, arg_value_ast):
1779  """Adds an argument (as a final kwarg arg_name=arg_value_ast)."""
1780  node.keywords.append(ast.keyword(arg=arg_name, value=arg_value_ast))
1781  logs.append((
1782      ast_edits.INFO, node.lineno, node.col_offset,
1783      "Adding argument '%s' to call to %s." % (pasta.dump(node.keywords[-1]),
1784                                               full_name or name)
1785  ))
1786  return node
1787
1788
1789def _iterator_transformer(parent, node, full_name, name, logs):
1790  """Transform iterator methods to compat function calls."""
1791  # First, check that node.func.value is not already something we like
1792  # (tf.compat.v1.data), or something which is handled in the rename
1793  # (tf.data). This transformer only handles the method call to function call
1794  # conversion.
1795  if full_name and (six.ensure_str(full_name).startswith("tf.compat.v1.data") or
1796                    six.ensure_str(full_name).startswith("tf.data")):
1797    return
1798
1799  # This should never happen, since we're only called for Attribute nodes.
1800  if not isinstance(node.func, ast.Attribute):
1801    return
1802
1803  # Transform from x.f(y) to tf.compat.v1.data.f(x, y)
1804  # Fortunately, node.func.value should already have valid position info
1805  node.args = [node.func.value] + node.args
1806  node.func.value = ast_edits.full_name_node("tf.compat.v1.data")
1807
1808  logs.append((ast_edits.WARNING, node.lineno, node.col_offset,
1809               "Changing dataset.%s() to tf.compat.v1.data.%s(dataset). "
1810               "Please check this transformation.\n" % (name, name)))
1811
1812  return node
1813
1814
1815def _dropout_transformer(parent, node, full_name, name, logs):
1816  """Replace keep_prob with 1-rate."""
1817  def _replace_keep_prob_node(parent, old_value):
1818    """Replaces old_value with 1-(old_value)."""
1819    one = ast.Num(n=1)
1820    one.lineno = 0
1821    one.col_offset = 0
1822    new_value = ast.BinOp(left=one, op=ast.Sub(),
1823                          right=old_value)
1824    # This copies the prefix and suffix on old_value to new_value.
1825    pasta.ast_utils.replace_child(parent, old_value, new_value)
1826    ast.copy_location(new_value, old_value)
1827    # Put parentheses around keep_prob.value (and remove the old prefix/
1828    # suffix, they should only be around new_value).
1829    pasta.base.formatting.set(old_value, "prefix", "(")
1830    pasta.base.formatting.set(old_value, "suffix", ")")
1831
1832  # Check if we have a keep_prob keyword arg
1833  for keep_prob in node.keywords:
1834    if keep_prob.arg == "keep_prob":
1835      logs.append((ast_edits.INFO, node.lineno, node.col_offset,
1836                   "Changing keep_prob arg of tf.nn.dropout to rate\n"))
1837      keep_prob.arg = "rate"
1838      _replace_keep_prob_node(keep_prob, keep_prob.value)
1839      return node
1840
1841  # Maybe it was a positional arg
1842  if len(node.args) < 2:
1843    logs.append((ast_edits.ERROR, node.lineno, node.col_offset,
1844                 "tf.nn.dropout called without arguments, so "
1845                 "automatic fix was disabled. tf.nn.dropout has changed "
1846                 "the semantics of the second argument."))
1847  else:
1848    rate_arg = ast.keyword(arg="rate", value=node.args[1])
1849    _replace_keep_prob_node(rate_arg, rate_arg.value)
1850    node.keywords.append(rate_arg)
1851    del node.args[1]
1852    logs.append((ast_edits.INFO, node.lineno, node.col_offset,
1853                 "Changing keep_prob arg of tf.nn.dropout to rate, and "
1854                 "recomputing value.\n"))
1855
1856    return node
1857
1858
1859def _cast_transformer(parent, node, full_name, name, logs):
1860  """Transforms to_int and to_float to cast(..., dtype=...)."""
1861
1862  # Find out the dtype to cast to from the function name
1863  dtype_str = name[3:]
1864  # Special cases where the full dtype is not given
1865  if dtype_str == "float":
1866    dtype_str = "float32"
1867  elif dtype_str == "double":
1868    dtype_str = "float64"
1869  new_arg = ast.keyword(arg="dtype",
1870                        value=ast.Attribute(value=ast.Name(id="tf",
1871                                                           ctx=ast.Load()),
1872                                            attr=dtype_str, ctx=ast.Load()))
1873  # Ensures a valid transformation when a positional name arg is given
1874  if len(node.args) == 2:
1875    name_arg = ast.keyword(arg="name",
1876                           value=node.args[-1])
1877    node.args = node.args[:-1]
1878    node.keywords.append(name_arg)
1879
1880  # Python3 ast requires the args for the Attribute, but codegen will mess up
1881  # the arg order if we just set them to 0.
1882  new_arg.value.lineno = node.lineno
1883  new_arg.value.col_offset = node.col_offset+100
1884
1885  node.keywords.append(new_arg)
1886  if isinstance(node.func, ast.Attribute):
1887    node.func.attr = "cast"
1888  else:
1889    assert isinstance(node.func, ast.Name)
1890    node.func.id = "cast"
1891
1892  logs.append((ast_edits.INFO, node.lineno, node.col_offset,
1893               "Changed %s call to tf.cast(..., dtype=tf.%s)." % (full_name,
1894                                                                  dtype_str)))
1895  return node
1896
1897
1898def _softmax_cross_entropy_with_logits_transformer(
1899    parent, node, full_name, name, logs):
1900  """Wrap labels argument with stop_gradients."""
1901  def _wrap_label(parent, old_value):
1902    """Wrap labels with tf.stop_gradient."""
1903    already_stop_grad = (isinstance(old_value, ast.Call) and
1904                         isinstance(old_value.func, ast.Attribute) and
1905                         old_value.func.attr == "stop_gradient" and
1906                         isinstance(old_value.func.value, ast.Name) and
1907                         old_value.func.value.id == "tf")
1908    if already_stop_grad:
1909      return False
1910    try:
1911      new_value = ast.Call(
1912          ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
1913          [old_value], [])
1914    except TypeError:
1915      new_value = ast.Call(
1916          ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
1917          [old_value], [], None, None)
1918
1919    # This copies the prefix and suffix on old_value to new_value.
1920    pasta.ast_utils.replace_child(parent, old_value, new_value)
1921    ast.copy_location(new_value, old_value)
1922    return True
1923
1924  # Check if we have a labels keyword arg
1925  for karg in node.keywords:
1926    if karg.arg == "labels":
1927      if _wrap_label(karg, karg.value):
1928        logs.append((ast_edits.INFO, node.lineno, node.col_offset,
1929                     "Changing labels arg of "
1930                     "tf.nn.softmax_cross_entropy_with_logits to "
1931                     "tf.stop_gradient(labels). Please check this "
1932                     "transformation.\n"))
1933      return node
1934  return node
1935
1936
1937def _image_resize_transformer(parent, node, full_name, name, logs):
1938  """Transforms image.resize_* to image.resize(..., method=*, ...)."""
1939  resize_method = name[7:].upper()
1940  new_arg = ast.keyword(arg="method",
1941                        value=ast.Attribute(
1942                            value=ast.Attribute(
1943                                value=ast.Attribute(
1944                                    value=ast.Name(id="tf", ctx=ast.Load()),
1945                                    attr="image", ctx=ast.Load()),
1946                                attr="ResizeMethod", ctx=ast.Load()),
1947                            attr=resize_method, ctx=ast.Load()))
1948
1949  # Ensures a valid transformation when a positional name arg is given
1950  if len(node.args) == 4:
1951    pos_arg = ast.keyword(arg="preserve_aspect_ratio",
1952                          value=node.args[-1])
1953    node.args = node.args[:-1]
1954    node.keywords.append(pos_arg)
1955  if len(node.args) == 3:
1956    pos_arg = ast.keyword(arg="align_corners",
1957                          value=node.args[-1])
1958    node.args = node.args[:-1]
1959
1960  new_keywords = []
1961  for kw in node.keywords:
1962    if kw.arg != "align_corners":
1963      new_keywords.append(kw)
1964  node.keywords = new_keywords
1965
1966  # Python3 ast requires the args for the Attribute, but codegen will mess up
1967  # the arg order if we just set them to 0.
1968  new_arg.value.lineno = node.lineno
1969  new_arg.value.col_offset = node.col_offset+100
1970
1971  node.keywords.append(new_arg)
1972  if isinstance(node.func, ast.Attribute):
1973    node.func.attr = "resize"
1974  else:
1975    assert isinstance(node.func, ast.Name)
1976    node.func.id = "resize"
1977
1978  logs.append((ast_edits.INFO, node.lineno, node.col_offset,
1979               "Changed %s call to tf.image.resize(..., "
1980               "method=tf.image.ResizeMethod.%s)." % (full_name,
1981                                                      resize_method)))
1982  return node
1983
1984
1985def _pool_seed_transformer(parent, node, full_name, name, logs):
1986  """Removes seed2 and deterministic, and adds non-zero seed if needed."""
1987  # This requires that this function uses all kwargs (add to renames!).
1988  seed_arg = None
1989  deterministic = False
1990  modified = False
1991  new_keywords = []
1992
1993  for kw in node.keywords:
1994    if sys.version_info[:2] >= (3, 5) and isinstance(kw, ast.Starred):
1995      pass
1996    elif kw.arg == "seed":
1997      seed_arg = kw
1998    elif kw.arg == "seed2" or kw.arg == "deterministic":
1999      lineno = getattr(kw, "lineno", node.lineno)
2000      col_offset = getattr(kw, "col_offset", node.col_offset)
2001      logs.append((ast_edits.INFO, lineno, col_offset,
2002                   "Removed argument %s for function %s" % (
2003                       kw.arg, full_name or name)))
2004      if kw.arg == "deterministic":
2005        if not _is_ast_false(kw.value):
2006          deterministic = True
2007      modified = True
2008      continue
2009    new_keywords.append(kw)
2010
2011  if deterministic:
2012    if seed_arg is None:
2013      new_keywords.append(ast.keyword(arg="seed", value=ast.Num(42)))
2014      logs.add((
2015          ast_edits.INFO, node.lineno, node.col_offset,
2016          "Adding seed=42 to call to %s since determinism was requested" % (
2017              full_name or name)
2018      ))
2019    else:
2020      logs.add((
2021          ast_edits.WARNING, node.lineno, node.col_offset,
2022          "The deterministic argument is deprecated for %s, pass a "
2023          "non-zero seed for determinism. The deterministic argument is "
2024          "present, possibly not False, and the seed is already set. The "
2025          "converter cannot determine whether it is nonzero, please check."
2026      ))
2027
2028  if modified:
2029    node.keywords = new_keywords
2030    return node
2031  else:
2032    return
2033
2034
2035def _extract_glimpse_transformer(parent, node, full_name, name, logs):
2036
2037  def _replace_uniform_noise_node(parent, old_value):
2038    """Replaces old_value with 'uniform' or 'gaussian'."""
2039    uniform = ast.Str(s="uniform")
2040    gaussian = ast.Str(s="gaussian")
2041    new_value = ast.IfExp(body=uniform, test=old_value, orelse=gaussian)
2042    # This copies the prefix and suffix on old_value to new_value.
2043    pasta.ast_utils.replace_child(parent, old_value, new_value)
2044    ast.copy_location(new_value, old_value)
2045    # Put parentheses around noise.value.test (and remove the old prefix/
2046    # suffix, they should only be around new_value.test), so that:
2047    # "uniform" if (a if b else c) else "gaussian" is valid.
2048    pasta.base.formatting.set(new_value.test, "prefix", "(")
2049    pasta.base.formatting.set(new_value.test, "suffix", ")")
2050
2051  # Check if we have a uniform_noise keyword arg
2052  for uniform_noise in node.keywords:
2053    if uniform_noise.arg == "uniform_noise":
2054      logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2055                   "Changing uniform_noise arg of tf.image.extract_glimpse "
2056                   "to noise, and recomputing value. Please check this "
2057                   "transformation.\n"))
2058      uniform_noise.arg = "noise"
2059      value = "uniform" if uniform_noise.value else "gaussian"
2060      _replace_uniform_noise_node(uniform_noise, uniform_noise.value)
2061      return node
2062
2063  # Since `noise`/`uniform_noise` is optional arg, nothing needs to be
2064  # done if len(node.args) < 5.
2065  if len(node.args) >= 5:
2066    _replace_uniform_noise_node(node, node.args[5])
2067    logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2068                 "Changing uniform_noise arg of tf.image.extract_glimpse to "
2069                 "noise, and recomputing value.\n"))
2070    return node
2071
2072def _add_summary_step_transformer(parent, node, full_name, name, logs):
2073  """Adds a step argument to the summary API call if not specified.
2074
2075  The inserted argument value is tf.compat.v1.train.get_or_create_global_step().
2076  """
2077  for keyword_arg in node.keywords:
2078    if keyword_arg.arg == "step":
2079      return node
2080  default_value = "tf.compat.v1.train.get_or_create_global_step()"
2081  # Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
2082  ast_value = pasta.parse(default_value)
2083  node.keywords.append(ast.keyword(arg="step", value=ast_value))
2084  logs.append((
2085      ast_edits.WARNING, node.lineno, node.col_offset,
2086      "Summary API writing function %s now requires a 'step' argument; "
2087      "inserting default of %s." % (full_name or name, default_value)))
2088  return node
2089
2090
2091def _add_summary_recording_cond_transformer(parent, node, full_name, name, logs,
2092                                            cond):
2093  """Adds cond argument to tf.contrib.summary.xxx_record_summaries().
2094
2095  This is in anticipation of them being renamed to tf.summary.record_if(), which
2096  requires the cond argument.
2097  """
2098  node.args.append(pasta.parse(cond))
2099  logs.append((
2100      ast_edits.INFO, node.lineno, node.col_offset,
2101      "Adding `%s` argument to %s in anticipation of it being renamed to "
2102      "tf.compat.v2.summary.record_if()" % (cond, full_name or name)))
2103  return node
2104
2105
2106def _add_loss_reduction_transformer(parent, node, full_name, name, logs):
2107  """Adds a loss_reduction argument if not specified.
2108
2109  Default value for tf.estimator.*Classifier and tf.estimator.*Regressor
2110  loss_reduction argument changed to SUM_OVER_BATCH_SIZE. So, we update
2111  existing calls to use the old default value `tf.keras.losses.Reduction.SUM`.
2112
2113  Note: to apply this transformation, symbol must be added
2114  to reordered_function_names above.
2115  """
2116  for keyword_arg in node.keywords:
2117    if keyword_arg.arg == "loss_reduction":
2118      return node
2119  default_value = "tf.keras.losses.Reduction.SUM"
2120  # Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
2121  ast_value = pasta.parse(default_value)
2122  node.keywords.append(ast.keyword(arg="loss_reduction", value=ast_value))
2123  logs.append((
2124      ast_edits.INFO, node.lineno, node.col_offset,
2125      "%s: Default value of loss_reduction has been changed to "
2126      "SUM_OVER_BATCH_SIZE; inserting old default value %s.\n"
2127      % (full_name or name, default_value)))
2128  return node
2129
2130
2131def _rename_if_any_arg_found_transformer(
2132    parent,
2133    node,
2134    full_name,
2135    name,
2136    logs,
2137    arg_names=None,
2138    arg_ok_predicate=None,
2139    remove_if_ok=False,
2140    message=None):
2141  """Replaces the given call with tf.compat.v1 if any of the arg_names is found.
2142
2143  Args:
2144    parent: Parent of node.
2145    node: ast.Call node to modify.
2146    full_name: full name of function to modify.
2147    name: name of function to modify.
2148    logs: list of logs to append to.
2149    arg_names: list of names of the argument to look for.
2150    arg_ok_predicate: predicate callable with the ast of the argument value,
2151      returns whether the argument value is allowed.
2152    remove_if_ok: remove the argument if present and ok as determined by
2153      arg_ok_predicate.
2154    message: message to print if a non-ok arg is found (and hence, the function
2155      is renamed to its compat.v1 version).
2156
2157  Returns:
2158    node, if it was modified, else None.
2159  """
2160  for arg_name in arg_names:
2161    rename_node = _rename_if_arg_found_transformer(parent, node,
2162                                                   full_name, name, logs,
2163                                                   arg_name, arg_ok_predicate,
2164                                                   remove_if_ok, message)
2165    node = rename_node if rename_node else node
2166
2167  return node
2168
2169
2170def _rename_if_arg_found_and_add_loss_reduction_transformer(
2171    parent,
2172    node,
2173    full_name,
2174    name,
2175    logs,
2176    arg_names=None,
2177    arg_ok_predicate=None,
2178    remove_if_ok=False,
2179    message=None):
2180  """Combination of _rename_if_arg_found and _add_loss_reduction transformers.
2181
2182  Args:
2183    parent: Parent of node.
2184    node: ast.Call node to maybe modify.
2185    full_name: full name of function to modify
2186    name: name of function to modify
2187    logs: list of logs to append to
2188    arg_names: list of names of the argument to look for
2189    arg_ok_predicate: predicate callable with the ast of the argument value,
2190      returns whether the argument value is allowed.
2191    remove_if_ok: remove the argument if present and ok as determined by
2192      arg_ok_predicate.
2193    message: message to print if a non-ok arg is found (and hence, the function
2194      is renamed to its compat.v1 version).
2195
2196  Returns:
2197    node, if it was modified, else None.
2198  """
2199
2200  node = _add_loss_reduction_transformer(parent, node, full_name, name, logs)
2201  for arg_name in arg_names:
2202    rename_node = _rename_if_arg_found_transformer(parent, node, full_name,
2203                                                   name, logs, arg_name,
2204                                                   arg_ok_predicate,
2205                                                   remove_if_ok, message)
2206    node = rename_node if rename_node else node
2207
2208  return node
2209
2210
2211def _add_uniform_scaling_initializer_transformer(
2212    parent, node, full_name, name, logs):
2213  """Updates references to uniform_unit_scaling_initializer.
2214
2215  Transforms:
2216  tf.uniform_unit_scaling_initializer(factor, seed, dtype) to
2217  tf.compat.v1.keras.initializers.VarianceScaling(
2218      scale=factor, distribution="uniform", seed=seed)
2219
2220  Note: to apply this transformation, symbol must be added
2221  to reordered_function_names above.
2222  """
2223  for keyword_arg in node.keywords:
2224    if keyword_arg.arg == "factor":
2225      keyword_arg.arg = "scale"
2226
2227  distribution_value = "\"uniform\""
2228  # Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
2229  ast_value = pasta.parse(distribution_value)
2230  node.keywords.append(ast.keyword(arg="distribution", value=ast_value))
2231
2232  lineno = node.func.value.lineno
2233  col_offset = node.func.value.col_offset
2234  node.func.value = ast_edits.full_name_node("tf.compat.v1.keras.initializers")
2235  node.func.value.lineno = lineno
2236  node.func.value.col_offset = col_offset
2237  node.func.attr = "VarianceScaling"
2238  return node
2239
2240
2241def _contrib_layers_xavier_initializer_transformer(
2242    parent, node, full_name, name, logs):
2243  """Updates references to contrib.layers.xavier_initializer.
2244
2245  Transforms:
2246  tf.contrib.layers.xavier_initializer(uniform, seed, dtype) to
2247  tf.compat.v1.keras.initializers.VarianceScaling(
2248      scale=1.0, mode="fan_avg",
2249      distribution=("uniform" if uniform else "truncated_normal"),
2250      seed=seed, dtype=dtype)
2251
2252  Returns: The new node
2253  """
2254  def _get_distribution(old_value):
2255    """Returns an AST matching the following:
2256    ("uniform" if (old_value) else "truncated_normal")
2257    """
2258    dist = pasta.parse("\"uniform\" if old_value else \"truncated_normal\"")
2259    ifexpr = dist.body[0].value
2260    pasta.ast_utils.replace_child(ifexpr, ifexpr.test, old_value)
2261
2262    pasta.base.formatting.set(dist, "prefix", "(")
2263    pasta.base.formatting.set(dist, "suffix", ")")
2264
2265    return dist
2266
2267  found_distribution = False
2268  for keyword_arg in node.keywords:
2269    if keyword_arg.arg == "uniform":
2270      found_distribution = True
2271      keyword_arg.arg = "distribution"
2272
2273      old_value = keyword_arg.value
2274      new_value = _get_distribution(keyword_arg.value)
2275
2276      pasta.ast_utils.replace_child(keyword_arg, old_value, new_value)
2277
2278      pasta.base.formatting.set(keyword_arg.value, "prefix", "(")
2279      pasta.base.formatting.set(keyword_arg.value, "suffix", ")")
2280
2281  new_keywords = []
2282  scale = pasta.parse("1.0")
2283  new_keywords.append(ast.keyword(arg="scale", value=scale))
2284
2285  mode = pasta.parse("\"fan_avg\"")
2286  new_keywords.append(ast.keyword(arg="mode", value=mode))
2287
2288  if len(node.args) >= 1:
2289    found_distribution = True
2290    dist = _get_distribution(node.args[0])
2291    new_keywords.append(ast.keyword(arg="distribution", value=dist))
2292  if not found_distribution:
2293    # Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
2294    uniform_dist = pasta.parse("\"uniform\"")
2295    new_keywords.append(ast.keyword(arg="distribution", value=uniform_dist))
2296  if len(node.args) >= 2:
2297    new_keywords.append(ast.keyword(arg="seed", value=node.args[1]))
2298  if len(node.args) >= 3:
2299    new_keywords.append(ast.keyword(arg="dtype", value=node.args[2]))
2300  node.args = []
2301
2302  node.keywords = new_keywords + node.keywords
2303
2304  lineno = node.func.value.lineno
2305  col_offset = node.func.value.col_offset
2306  node.func.value = ast_edits.full_name_node("tf.compat.v1.keras.initializers")
2307  node.func.value.lineno = lineno
2308  node.func.value.col_offset = col_offset
2309  node.func.attr = "VarianceScaling"
2310
2311  logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2312               "Changing tf.contrib.layers xavier initializer"
2313               " to a tf.compat.v1.keras.initializers.VarianceScaling and"
2314               " converting arguments.\n"))
2315
2316  return node
2317
2318
2319def _contrib_layers_variance_scaling_initializer_transformer(
2320    parent, node, full_name, name, logs):
2321  """Updates references to contrib.layers.variance_scaling_initializer.
2322
2323  Transforms:
2324  tf.contrib.layers.variance_scaling_initializer(
2325    factor, mode, uniform, seed, dtype
2326  ) to
2327  tf.compat.v1.keras.initializers.VarianceScaling(
2328      scale=factor, mode=mode.lower(),
2329      distribution=("uniform" if uniform else "truncated_normal"),
2330      seed=seed, dtype=dtype)
2331
2332  And handles the case where no factor is provided and scale needs to be
2333  set to 2.0 to match contrib's default instead of tf.keras.initializer's
2334  default of 1.0
2335  """
2336  def _replace_distribution(parent, old_value):
2337    """Replaces old_value: ("uniform" if (old_value) else "truncated_normal")"""
2338    new_value = pasta.parse(
2339        "\"uniform\" if old_value else \"truncated_normal\"")
2340    ifexpr = new_value.body[0].value
2341    pasta.ast_utils.replace_child(ifexpr, ifexpr.test, old_value)
2342
2343    pasta.ast_utils.replace_child(parent, old_value, new_value)
2344
2345    pasta.base.formatting.set(new_value, "prefix", "(")
2346    pasta.base.formatting.set(new_value, "suffix", ")")
2347
2348  def _replace_mode(parent, old_value):
2349    """Replaces old_value with (old_value).lower()."""
2350    new_value = pasta.parse("mode.lower()")
2351    mode = new_value.body[0].value.func
2352    pasta.ast_utils.replace_child(mode, mode.value, old_value)
2353
2354    # This copies the prefix and suffix on old_value to new_value.
2355    pasta.ast_utils.replace_child(parent, old_value, new_value)
2356
2357    # Put parentheses around keep_prob.value (and remove the old prefix/
2358    # suffix, they should only be around new_value).
2359    pasta.base.formatting.set(old_value, "prefix", "(")
2360    pasta.base.formatting.set(old_value, "suffix", ")")
2361
2362  # Need to keep track of scale because slim & keras
2363  # have different defaults
2364  found_scale = False
2365  for keyword_arg in node.keywords:
2366    if keyword_arg.arg == "factor":
2367      keyword_arg.arg = "scale"
2368      found_scale = True
2369    if keyword_arg.arg == "mode":
2370      _replace_mode(keyword_arg, keyword_arg.value)
2371    if keyword_arg.arg == "uniform":
2372      keyword_arg.arg = "distribution"
2373      _replace_distribution(keyword_arg, keyword_arg.value)
2374
2375  # Handle any detected positional arguments
2376  if len(node.args) >= 1:
2377    found_scale = True
2378  if len(node.args) >= 2:
2379    _replace_mode(node, node.args[1])
2380  if len(node.args) >= 3:
2381    _replace_distribution(node, node.args[2])
2382
2383  # If no scale was provided, make tf 2.0 use slim's default factor
2384  if not found_scale:
2385    # Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
2386    scale_value = pasta.parse("2.0")
2387    node.keywords = ([ast.keyword(arg="scale", value=scale_value)]
2388                     + node.keywords)
2389
2390  lineno = node.func.value.lineno
2391  col_offset = node.func.value.col_offset
2392  node.func.value = ast_edits.full_name_node("tf.compat.v1.keras.initializers")
2393  node.func.value.lineno = lineno
2394  node.func.value.col_offset = col_offset
2395  node.func.attr = "VarianceScaling"
2396
2397  logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2398               "Changing tf.contrib.layers.variance_scaling_initializer"
2399               " to a tf.compat.v1.keras.initializers.VarianceScaling and"
2400               " converting arguments.\n"))
2401
2402  return node
2403
2404
2405def _contrib_layers_l1_regularizer_transformer(
2406    parent, node, full_name, name, logs):
2407  """Replace slim l1 regularizer with Keras one.
2408
2409  This entails renaming the 'scale' arg to 'l' and dropping any
2410  provided scope arg.
2411  """
2412  # Check if we have a scale or scope keyword arg
2413  scope_keyword = None
2414  for keyword in node.keywords:
2415    if keyword.arg == "scale":
2416      logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2417                   "Renaming scale arg of regularizer\n"))
2418      keyword.arg = "l"
2419    if keyword.arg == "scope":
2420      scope_keyword = keyword
2421
2422  # Remove the scope keyword or arg if it is present
2423  if scope_keyword:
2424    logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2425                 "Dropping scope arg from tf.contrib.layers.l1_regularizer,"
2426                 " because it is unsupported in tf.keras.regularizers.l1\n"))
2427    node.keywords.remove(scope_keyword)
2428  if len(node.args) > 1:
2429    node.args = node.args[:1]
2430    logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2431                 "Dropping scope arg from tf.contrib.layers.l1_regularizer,"
2432                 " because it is unsupported in tf.keras.regularizers.l1\n"))
2433
2434  lineno = node.func.value.lineno
2435  col_offset = node.func.value.col_offset
2436  node.func.value = ast_edits.full_name_node("tf.keras.regularizers")
2437  node.func.value.lineno = lineno
2438  node.func.value.col_offset = col_offset
2439  node.func.attr = "l1"
2440
2441  return node
2442
2443
2444def _contrib_layers_l2_regularizer_transformer(
2445    parent, node, full_name, name, logs):
2446  """Replace slim l2 regularizer with Keras one, with l=0.5*scale.
2447
2448  Also drops the scope argument.
2449  """
2450  def _replace_scale_node(parent, old_value):
2451    """Replaces old_value with 0.5*(old_value)."""
2452    half = ast.Num(n=0.5)
2453    half.lineno = 0
2454    half.col_offset = 0
2455    new_value = ast.BinOp(left=half, op=ast.Mult(),
2456                          right=old_value)
2457    # This copies the prefix and suffix on old_value to new_value.
2458    pasta.ast_utils.replace_child(parent, old_value, new_value)
2459
2460    # Put parentheses around scale.value (and remove the old prefix/
2461    # suffix, they should only be around new_value).
2462    pasta.base.formatting.set(old_value, "prefix", "(")
2463    pasta.base.formatting.set(old_value, "suffix", ")")
2464
2465  # Check if we have a scale or scope keyword arg
2466  scope_keyword = None
2467  for keyword in node.keywords:
2468    if keyword.arg == "scale":
2469      keyword.arg = "l"
2470      _replace_scale_node(keyword, keyword.value)
2471    if keyword.arg == "scope":
2472      scope_keyword = keyword
2473
2474  # Maybe it was a positional arg
2475  if len(node.args) >= 1:
2476    _replace_scale_node(node, node.args[0])
2477
2478  # Remove the scope keyword or arg if it is present
2479  if scope_keyword:
2480    logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2481                 "Dropping scope arg from tf.contrib.layers.l2_regularizer,"
2482                 " because it is unsupported in tf.keras.regularizers.l2\n"))
2483    node.keywords.remove(scope_keyword)
2484  if len(node.args) > 1:
2485    node.args = node.args[:1]
2486    logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2487                 "Dropping scope arg from tf.contrib.layers.l2_regularizer,"
2488                 " because it is unsupported in tf.keras.regularizers.l2\n"))
2489
2490  logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2491               "Multiplying scale arg of tf.contrib.layers.l2_regularizer"
2492               " by half to what tf.keras.regularizers.l2 expects.\n"))
2493
2494  lineno = node.func.value.lineno
2495  col_offset = node.func.value.col_offset
2496  node.func.value = ast_edits.full_name_node("tf.keras.regularizers")
2497  node.func.value.lineno = lineno
2498  node.func.value.col_offset = col_offset
2499  node.func.attr = "l2"
2500
2501  return node
2502
2503
2504def _name_scope_transformer(parent, node, full_name, name, logs):
2505  """Fix name scope invocation to use 'default_name' and omit 'values' args."""
2506
2507  name_found, name = ast_edits.get_arg_value(node, "name", 0)
2508  default_found, default_name = ast_edits.get_arg_value(node, "default_name", 1)
2509
2510  # If an actual name was given...
2511  if name_found and pasta.dump(name) != "None":
2512    logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2513                 "`name` passed to `name_scope`. Because you may be re-entering"
2514                 " an existing scope, it is not safe to convert automatically, "
2515                 " the v2 name_scope does not support re-entering scopes by"
2516                 " name.\n"))
2517    # Rename to compat.v1
2518    new_name = "tf.compat.v1.name_scope"
2519    logs.append((ast_edits.INFO, node.func.lineno, node.func.col_offset,
2520                 "Renamed %r to %r" % (full_name, new_name)))
2521    new_name_node = ast_edits.full_name_node(new_name, node.func.ctx)
2522    ast.copy_location(new_name_node, node.func)
2523    pasta.ast_utils.replace_child(node, node.func, new_name_node)
2524    return node
2525
2526  if default_found:
2527    # New name scope doesn't have name, but it has a default name. We use
2528    # name=default_name, and values can be dropped (it's only for
2529    # error reporting and useless outside of graph mode).
2530    logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2531                 "Using default_name as name in call to name_scope.\n"))
2532    # Remove all args other than name
2533    node.args = []
2534    node.keywords = [ast.keyword(arg="name", value=default_name)]
2535    return node
2536
2537  logs.append((ast_edits.ERROR, node.lineno, node.col_offset,
2538               "name_scope call with neither name nor default_name cannot be "
2539               "converted properly."))
2540
2541
2542def _rename_to_compat_v1(node, full_name, logs, reason):
2543  new_name = six.ensure_str(full_name).replace("tf.", "tf.compat.v1.", 1)
2544  return _rename_func(node, full_name, new_name, logs, reason)
2545
2546
2547def _rename_func(node, full_name, new_name, logs, reason):
2548  logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2549               "Renamed %r to %r: %s" % (full_name, new_name, reason)))
2550  new_name_node = ast_edits.full_name_node(new_name, node.func.ctx)
2551  ast.copy_location(new_name_node, node.func)
2552  pasta.ast_utils.replace_child(node, node.func, new_name_node)
2553  return node
2554
2555
2556def _string_split_transformer(parent, node, full_name, name, logs):
2557  """Update tf.string_split arguments: skip_empty, sep, result_type, source."""
2558  # Check the skip_empty parameter: if not false, then use compat.v1.
2559  for i, kw in enumerate(node.keywords):
2560    if kw.arg == "skip_empty":
2561      if _is_ast_false(kw.value):
2562        logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2563                     "removed argument skip_empty for tf.string_split."))
2564        node.keywords.pop(i)
2565        break
2566      else:
2567        return _rename_to_compat_v1(
2568            node, full_name, logs, "tf.string_split's replacement no longer "
2569            "takes the skip_empty argument.")
2570
2571  # Check the sep parameter: if it's definitely an empty string, use
2572  # tf.strings.bytes_split().  If we can't tell, then use compat.v1.
2573  found_sep = False
2574  for i, kw in enumerate(node.keywords):
2575    if kw.arg == "sep":
2576      found_sep = True
2577      if isinstance(kw.value, ast.Str):
2578        if kw.value.s == "":
2579          node = _rename_func(
2580              node, full_name, "tf.strings.bytes_split", logs,
2581              "Splitting bytes is not handled by tf.strings.bytes_split().")
2582          node.keywords.pop(i)
2583      else:
2584        return _rename_to_compat_v1(
2585            node, full_name, logs,
2586            "The semantics for tf.string_split's sep parameter have changed "
2587            "when sep is the empty string; but sep is not a string literal, "
2588            "so we can't tell if it's an empty string.")
2589  if not found_sep:
2590    return _rename_to_compat_v1(
2591        node, full_name, logs,
2592        "The semantics for tf.string_split's sep parameter have changed "
2593        "when sep unspecified: it now splits on all whitespace, not just "
2594        "the space character.")
2595  # Check the result_type parameter
2596  return _string_split_rtype_transformer(parent, node, full_name, name, logs)
2597
2598
2599def _string_split_rtype_transformer(parent, node, full_name, name, logs):
2600  """Update tf.strings.split arguments: result_type, source."""
2601  # Remove the "result_type" argument.
2602  need_to_sparse = True
2603  for i, kw in enumerate(node.keywords):
2604    if kw.arg == "result_type":
2605      if (isinstance(kw.value, ast.Str) and
2606          kw.value.s in ("RaggedTensor", "SparseTensor")):
2607        logs.append((ast_edits.INFO, node.lineno, node.col_offset,
2608                     "Removed argument result_type=%r for function %s" %
2609                     (kw.value.s, full_name or name)))
2610        node.keywords.pop(i)
2611        if kw.value.s == "RaggedTensor":
2612          need_to_sparse = False
2613      else:
2614        return _rename_to_compat_v1(
2615            node, full_name, logs,
2616            "%s no longer takes the result_type parameter." % full_name)
2617      break
2618
2619  for i, kw in enumerate(node.keywords):
2620    if kw.arg == "source":
2621      kw.arg = "input"
2622
2623  # If necessary, add a call to .to_sparse() to convert the output of
2624  # strings.split from a RaggedTensor to a SparseTensor.
2625  if need_to_sparse:
2626    if (isinstance(parent, ast.Attribute) and parent.attr == "to_sparse"):
2627      return  # Prevent infinite recursion (since child nodes are transformed)
2628    logs.append(
2629        (ast_edits.INFO, node.lineno, node.col_offset,
2630         "Adding call to RaggedTensor.to_sparse() to result of strings.split, "
2631         "since it now returns a RaggedTensor."))
2632    node = ast.Attribute(value=copy.deepcopy(node), attr="to_sparse")
2633    try:
2634      node = ast.Call(node, [], [])
2635    except TypeError:
2636      node = ast.Call(node, [], [], None, None)
2637
2638  return node
2639