• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for tf 2.0 upgrader."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import inspect
23import os
24import tempfile
25
26from absl.testing import parameterized
27import six
28import tensorflow.compat.v1 as tf
29# OSS TF V2 import placeholder.
30
31from tensorflow.python.framework import test_util
32from tensorflow.python.platform import test as test_lib
33from tensorflow.python.util import tf_decorator
34from tensorflow.python.util import tf_export
35from tensorflow.python.util import tf_inspect
36from tensorflow.tools.common import public_api
37from tensorflow.tools.common import traverse
38from tensorflow.tools.compatibility import ast_edits
39from tensorflow.tools.compatibility import tf_upgrade_v2
40
41
42def get_symbol_for_name(root, name):
43  name_parts = six.ensure_str(name).split(".")
44  symbol = root
45  # Iterate starting with second item since 1st item is "tf.".
46  for part in name_parts[1:]:
47    symbol = getattr(symbol, part)
48  return symbol
49
50
51def get_args(symbol):
52  if hasattr(inspect, "signature"):
53    signature = inspect.signature(symbol)
54    # Ignore *args and **kwargs for now.
55    return [param.name for param in signature.parameters.values()
56            if param.kind == param.POSITIONAL_OR_KEYWORD]
57  return tf_inspect.getargspec(symbol)[0]
58
59
60def get_func_and_args_from_str(call_str):
61  """Parse call string to get function and argument names.
62
63  Args:
64    call_str: Call string must be in the form:
65              `tf.foo(arg1=val1, arg2=val2, ...)`.
66
67  Returns:
68    (function_name, list of arg names) tuple.
69  """
70  open_paren_index = six.ensure_str(call_str).find("(")
71  close_paren_index = call_str.rfind(")")
72
73  function_name = call_str[:six.ensure_str(call_str).find("(")]
74  args = six.ensure_str(call_str[open_paren_index +
75                                 1:close_paren_index]).split(",")
76  args = [six.ensure_str(arg).split("=")[0].strip() for arg in args]
77  args = [arg for arg in args if arg]  # filter out empty strings
78  return function_name, args
79
80
81class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
82  """Test various APIs that have been changed in 2.0.
83
84  We also test whether a converted file is executable. test_file_v1_10.py
85  aims to exhaustively test that API changes are convertible and actually
86  work when run with current TensorFlow.
87  """
88
89  @classmethod
90  def setUpClass(cls):
91    super(TestUpgrade, cls).setUpClass()
92    cls.v2_symbols = {}
93    cls.v1_symbols = {}
94    if hasattr(tf.compat, "v2"):
95
96      def symbol_collector(unused_path, unused_parent, children):
97        for child in children:
98          _, attr = tf_decorator.unwrap(child[1])
99          api_names_v2 = tf_export.get_v2_names(attr)
100          for name in api_names_v2:
101            cls.v2_symbols["tf." + six.ensure_str(name)] = attr
102
103      visitor = public_api.PublicAPIVisitor(symbol_collector)
104      visitor.private_map["tf.compat"] = ["v1", "v2"]
105      traverse.traverse(tf.compat.v2, visitor)
106
107    if hasattr(tf.compat, "v1"):
108
109      def symbol_collector_v1(unused_path, unused_parent, children):
110        for child in children:
111          _, attr = tf_decorator.unwrap(child[1])
112          api_names_v1 = tf_export.get_v1_names(attr)
113          for name in api_names_v1:
114            cls.v1_symbols["tf." + six.ensure_str(name)] = attr
115
116      visitor = public_api.PublicAPIVisitor(symbol_collector_v1)
117      visitor.private_map["tf.compat"] = ["v1", "v2"]
118      traverse.traverse(tf.compat.v1, visitor)
119
120  def _upgrade(self,
121               old_file_text,
122               import_rename=False,
123               upgrade_compat_v1_import=False):
124    in_file = six.StringIO(old_file_text)
125    out_file = six.StringIO()
126    upgrader = ast_edits.ASTCodeUpgrader(
127        tf_upgrade_v2.TFAPIChangeSpec(
128            import_rename, upgrade_compat_v1_import=upgrade_compat_v1_import))
129    count, report, errors = (
130        upgrader.process_opened_file("test.py", in_file,
131                                     "test_out.py", out_file))
132    return count, report, errors, out_file.getvalue()
133
134  def _upgrade_multiple(self, old_file_texts):
135    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
136    results = []
137    for old_file_text in old_file_texts:
138      in_file = six.StringIO(old_file_text)
139      out_file = six.StringIO()
140      count, report, errors = (
141          upgrader.process_opened_file("test.py", in_file,
142                                       "test_out.py", out_file))
143      results.append([count, report, errors, out_file.getvalue()])
144    return results
145
146  def testParseError(self):
147    _, report, unused_errors, unused_new_text = self._upgrade(
148        "import tensorflow as tf\na + \n")
149    self.assertNotEqual(six.ensure_str(report).find("Failed to parse"), -1)
150
151  def testReport(self):
152    text = "tf.angle(a)\n"
153    _, report, unused_errors, unused_new_text = self._upgrade(text)
154    # This is not a complete test, but it is a sanity test that a report
155    # is generating information.
156    self.assertTrue(
157        six.ensure_str(report).find("Renamed function `tf.angle` to "
158                                    "`tf.math.angle`"))
159
160  def testRename(self):
161    text = "tf.conj(a)\n"
162    _, unused_report, unused_errors, new_text = self._upgrade(text)
163    self.assertEqual(new_text, "tf.math.conj(a)\n")
164    text = "tf.rsqrt(tf.log_sigmoid(3.8))\n"
165    _, unused_report, unused_errors, new_text = self._upgrade(text)
166    self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log_sigmoid(3.8))\n")
167
168  def testAllAPI(self):
169    if not hasattr(tf.compat, "v2"):
170      return
171
172    # Converts all symbols in the v1 namespace to the v2 namespace, raising
173    # an error if the target of the conversion is not in the v2 namespace.
174    # Please regenerate the renames file or edit any manual renames if this
175    # test fails.
176    def conversion_visitor(unused_path, unused_parent, children):
177      for child in children:
178        _, attr = tf_decorator.unwrap(child[1])
179        api_names = tf_export.get_v1_names(attr)
180        for name in api_names:
181          _, _, _, text = self._upgrade("tf." + six.ensure_str(name))
182          if (text and
183              not text.startswith("tf.compat.v1") and
184              not text.startswith("tf.compat.v2") and
185              text not in self.v2_symbols and
186              # Ignore any symbol that contains __internal__
187              "__internal__" not in text and
188              # Builds currently install old version of estimator that doesn't
189              # have some 2.0 symbols.
190              not text.startswith("tf.estimator")):
191            self.assertFalse(
192                True, "Symbol %s generated from %s not in v2 API" % (
193                    text, name))
194
195    visitor = public_api.PublicAPIVisitor(conversion_visitor)
196    visitor.do_not_descend_map["tf"].append("contrib")
197    visitor.private_map["tf.compat"] = ["v1", "v2"]
198    traverse.traverse(tf.compat.v1, visitor)
199
200  def testAllAPIV1(self):
201    collect = True
202    v1_symbols = set([])
203
204    # Converts all symbols in the v1 namespace to the v2 namespace, raising
205    # an error if the target of the conversion is not in the v1 namespace.
206    def conversion_visitor(unused_path, unused_parent, children):
207      for child in children:
208        _, attr = tf_decorator.unwrap(child[1])
209        api_names = tf_export.get_v1_names(attr)
210        for name in api_names:
211          if collect:
212            v1_symbols.add("tf." + six.ensure_str(name))
213          else:
214            _, _, _, text = self._upgrade("tf." + six.ensure_str(name))
215            if (text and
216                not text.startswith("tf.compat.v1") and
217                not text.startswith("tf.compat.v2") and
218                not text.startswith("tf.estimator") and
219                text not in v1_symbols):
220              self.assertFalse(
221                  True, "Symbol %s generated from %s not in v1 API" % (
222                      text, name))
223
224    visitor = public_api.PublicAPIVisitor(conversion_visitor)
225    visitor.do_not_descend_map["tf"].append("contrib")
226    visitor.private_map["tf.compat"] = ["v1", "v2"]
227    traverse.traverse(tf.compat.v1, visitor)
228    collect = False
229    traverse.traverse(tf.compat.v1, visitor)
230
231  def testV1KeywordArgNames(self):
232    all_keyword_renames = (
233        tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
234
235    # Visitor that verifies V1 argument names.
236    def arg_test_visitor(unused_path, unused_parent, children):
237      for child in children:
238        _, attr = tf_decorator.unwrap(child[1])
239        names_v1 = tf_export.get_v1_names(attr)
240
241        for name in names_v1:
242          name = "tf.%s" % name
243          if name not in all_keyword_renames:
244            continue
245          arg_names_v1 = tf_inspect.getargspec(attr)[0]
246          keyword_renames = all_keyword_renames[name]
247          self.assertEqual(type(keyword_renames), dict)
248
249          # Assert that v1 function has valid v1 argument names.
250          for from_name, _ in keyword_renames.items():
251            self.assertIn(
252                from_name, arg_names_v1,
253                "%s not found in %s arguments: %s" %
254                (from_name, name, str(arg_names_v1)))
255
256    visitor = public_api.PublicAPIVisitor(arg_test_visitor)
257    visitor.do_not_descend_map["tf"].append("contrib")
258    visitor.private_map["tf.compat"] = ["v1", "v2"]
259    traverse.traverse(tf.compat.v1, visitor)
260
261  def testV2KeywordArgNames(self):
262    # This test converts a call of the form:
263    # tf.foo(arg1=0, arg2=1, ...)
264    # to 2.0. Then, checks that converted function has valid argument names.
265    if not hasattr(tf.compat, "v2"):
266      return
267    v2_arg_exceptions = {
268        "verify_shape_is_now_always_true",
269        # These arguments should not be used, they just specify
270        # that a function takes named arguments.
271        "keyword_required",
272        "_sentinel",
273    }
274    v1_name_exceptions = {
275        "tf.print",  # requires print_function import
276    }
277    function_warnings = (
278        tf_upgrade_v2.TFAPIChangeSpec().function_warnings)
279    function_transformers = (
280        tf_upgrade_v2.TFAPIChangeSpec().function_transformers)
281    keyword_renames = (
282        tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
283
284    # Visitor that converts to V2 and checks V2 argument names.
285    def conversion_visitor(unused_path, unused_parent, children):
286      for child in children:
287        _, attr = tf_decorator.unwrap(child[1])
288        if not tf_inspect.isfunction(attr):
289          continue
290        names_v1 = tf_export.get_v1_names(attr)
291        arg_names_v1 = get_args(attr)
292
293        for name in names_v1:
294          tf_name = "tf.%s" % name
295          if tf_name in function_warnings or tf_name in function_transformers:
296            continue  # These require manual change
297          if tf_name in v1_name_exceptions:
298            continue
299          # Assert that arg names after converting to v2 are present in
300          # v2 function.
301          # 1. First, create an input of the form:
302          #    tf.foo(arg1=val1, arg2=val2, ...)
303          args = ",".join(
304              ["%s=%d" % (from_name, from_index)
305               for from_index, from_name in enumerate(arg_names_v1)])
306          text_input = "%s(%s)" % (tf_name, args)
307          # 2. Convert the input to V2.
308          _, _, _, text = self._upgrade(text_input)
309          new_function_name, new_args = get_func_and_args_from_str(text)
310          if "__internal__" in new_function_name:
311            # Skip the tf.__internal__ and tf.keras.__internal__ API.
312            continue
313          if new_function_name == "tf.compat.v1.%s" % name:
314            if tf_name in keyword_renames:
315              # If we rename arguments, new function must be available in 2.0.
316              # We should not be using compat.v1 in this case.
317              self.fail(
318                  "Function '%s' is not in 2.0 when converting\n%s\nto\n%s" %
319                  (new_function_name, text_input, text))
320            continue
321          if new_function_name.startswith("tf.compat.v2"):
322            self.assertIn(new_function_name.replace("tf.compat.v2.", "tf."),
323                          self.v2_symbols)
324            continue
325          # 3. Verify V2 function and arguments.
326          args_v2 = get_args(self.v2_symbols[new_function_name])
327          args_v2.extend(v2_arg_exceptions)
328          for new_arg in new_args:
329            self.assertIn(
330                new_arg, args_v2,
331                "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n"
332                "Supported arguments: %s" % (
333                    new_arg, text_input, text, str(args_v2)))
334          # 4. Verify that the argument exists in v1 as well.
335          if new_function_name in set(["tf.nn.ctc_loss",
336                                       "tf.saved_model.save"]):
337            continue
338          args_v1 = get_args(self.v1_symbols[new_function_name])
339          args_v1.extend(v2_arg_exceptions)
340          for new_arg in new_args:
341            self.assertIn(
342                new_arg, args_v1,
343                "Invalid argument '%s' in 1.0 when converting\n%s\nto\n%s.\n"
344                "Supported arguments: %s" % (
345                    new_arg, text_input, text, str(args_v1)))
346
347    visitor = public_api.PublicAPIVisitor(conversion_visitor)
348    visitor.do_not_descend_map["tf"].append("contrib")
349    visitor.private_map["tf.compat"] = ["v1", "v2"]
350    traverse.traverse(tf.compat.v1, visitor)
351
352  def testPositionsMatchArgGiven(self):
353    full_dict = tf_upgrade_v2.TFAPIChangeSpec().function_arg_warnings
354    method_names = list(full_dict.keys())
355    for method_name in method_names:
356      args = list(full_dict[method_name].keys())
357      if "contrib" in method_name:
358        # Skip descending and fetching contrib methods during test. These are
359        # not available in the repo anymore.
360        continue
361      elif six.ensure_str(method_name).startswith("*."):
362        # special case for optimizer methods
363        method = six.ensure_str(method_name).replace("*", "tf.train.Optimizer")
364      else:
365        method = method_name
366
367      method = get_symbol_for_name(tf, method)
368      arg_spec = tf_inspect.getfullargspec(method)
369      for (arg, pos) in args:
370        # to deal with the self argument on methods on objects
371        if six.ensure_str(method_name).startswith("*."):
372          pos += 1
373        self.assertEqual(arg_spec[0][pos], arg)
374
375  def testReorderFileNeedsUpdate(self):
376    reordered_function_names = (
377        tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names)
378    function_reorders = (
379        tf_upgrade_v2.TFAPIChangeSpec().function_reorders)
380    manual_function_reorders = (
381        tf_upgrade_v2.TFAPIChangeSpec().manual_function_reorders)
382
383    added_names_message = """Some function names in
384self.reordered_function_names are not in reorders_v2.py.
385Please run the following commands to update reorders_v2.py:
386bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
387bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
388"""
389    removed_names_message = """%s in self.reorders_v2 does not match
390any name in self.reordered_function_names.
391Please run the following commands to update reorders_v2.py:
392bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
393bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
394"""
395    self.assertTrue(
396        reordered_function_names.issubset(function_reorders),
397        added_names_message)
398    # function_reorders should contain reordered_function_names
399    # and their TensorFlow V1 aliases.
400    for name in function_reorders:
401      if name in manual_function_reorders:
402        continue
403      # get other names for this function
404      attr = get_symbol_for_name(tf.compat.v1, name)
405      _, attr = tf_decorator.unwrap(attr)
406      v1_names = tf_export.get_v1_names(attr)
407      self.assertTrue(v1_names)
408      v1_names = ["tf.%s" % n for n in v1_names]
409      # check if any other name is in
410      self.assertTrue(
411          any(n in reordered_function_names for n in v1_names),
412          removed_names_message % name)
413
414  def testRenameConstant(self):
415    text = "tf.MONOLITHIC_BUILD\n"
416    _, unused_report, unused_errors, new_text = self._upgrade(text)
417    self.assertEqual(new_text, "tf.sysconfig.MONOLITHIC_BUILD\n")
418    text = "some_call(tf.MONOLITHIC_BUILD)\n"
419    _, unused_report, unused_errors, new_text = self._upgrade(text)
420    self.assertEqual(new_text, "some_call(tf.sysconfig.MONOLITHIC_BUILD)\n")
421
422  def testRenameArgs(self):
423    text = ("tf.nn.pool(input_a, window_shape_a, pooling_type_a, padding_a, "
424            "dilation_rate_a, strides_a, name_a, data_format_a)\n")
425    _, unused_report, unused_errors, new_text = self._upgrade(text)
426    self.assertEqual(new_text,
427                     ("tf.nn.pool(input=input_a, window_shape=window_shape_a,"
428                      " pooling_type=pooling_type_a, padding=padding_a, "
429                      "dilations=dilation_rate_a, strides=strides_a, "
430                      "name=name_a, data_format=data_format_a)\n"))
431
432  def testReorder(self):
433    text = "tf.boolean_mask(a, b, c, d)\n"
434    _, unused_report, unused_errors, new_text = self._upgrade(text)
435    self.assertEqual(new_text,
436                     "tf.boolean_mask(tensor=a, mask=b, name=c, axis=d)\n")
437
438  def testLearningRateDecay(self):
439    for decay in ["tf.train.exponential_decay",
440                  "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
441                  "tf.train.inverse_time_decay", "tf.train.cosine_decay",
442                  "tf.train.cosine_decay_restarts",
443                  "tf.train.linear_cosine_decay",
444                  "tf.train.noisy_linear_cosine_decay",
445                  "tf.train.piecewise_constant_decay",
446                 ]:
447
448      text = "%s(a, b)\n" % decay
449      _, report, unused_errors, _ = self._upgrade(text)
450      self.assertIn("switch to the schedules in "
451                    "`tf.keras.optimizers.schedules`", report)
452
453  def verify_compat_v1_rename_correctness(self, values, ns_prefix=""):
454    if ns_prefix:
455      ns_prefix += "."
456    for v in values:
457      text = "tf." + ns_prefix + v + "(a, b)"
458      _, _, _, new_text = self._upgrade(text)
459      self.assertEqual("tf.compat.v1." + ns_prefix + v + "(a, b)", new_text)
460
461  def testInitializers(self):
462    initializers = [
463        "zeros",
464        "ones",
465        "constant",
466        "random_uniform",
467        "random_normal",
468        "truncated_normal",
469        "variance_scaling",
470        "orthogonal",
471        "glorot_uniform",
472        "glorot_normal",
473        "identity",
474        "lecun_normal",
475        "lecun_uniform",
476        "he_normal",
477        "he_uniform",
478    ]
479    self.verify_compat_v1_rename_correctness(
480        initializers, ns_prefix="initializers")
481
482    initializers = [
483        "zeros_initializer",
484        "ones_initializer",
485        "constant_initializer",
486        "random_uniform_initializer",
487        "random_normal_initializer",
488        "truncated_normal_initializer",
489        "variance_scaling_initializer",
490        "orthogonal_initializer",
491        "glorot_uniform_initializer",
492        "glorot_normal_initializer",
493    ]
494    self.verify_compat_v1_rename_correctness(initializers)
495
496    initializers = [
497        "zeros",
498        "ones",
499        "Ones",
500        "Zeros",
501        "constant",
502        "Constant",
503        "VarianceScaling",
504        "Orthogonal",
505        "orthogonal",
506        "Identity",
507        "identity",
508        "glorot_uniform",
509        "glorot_normal",
510        "lecun_normal",
511        "lecun_uniform",
512        "he_normal",
513        "he_uniform",
514        "TruncatedNormal",
515        "truncated_normal",
516        "RandomUniform",
517        "uniform",
518        "random_uniform",
519        "RandomNormal",
520        "normal",
521        "random_normal",
522    ]
523    self.verify_compat_v1_rename_correctness(
524        initializers, ns_prefix="keras.initializers")
525
526  def testContribXavierInitializer(self):
527    for contrib_alias in ["tf.contrib.", "contrib_"]:
528      text = contrib_alias + "layers.xavier_initializer()\n"
529      _, unused_report, unused_errors, new_text = self._upgrade(text)
530      self.assertEqual(
531          new_text,
532          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
533          "mode=\"fan_avg\", "
534          "distribution=\"uniform\")\n",
535      )
536
537      text = "slim.xavier_initializer(True or False)\n"
538      _, unused_report, unused_errors, new_text = self._upgrade(text)
539      self.assertEqual(
540          new_text,
541          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
542          "mode=\"fan_avg\", "
543          "distribution=(\"uniform\" if True or False else "
544          "\"truncated_normal\"))\n",
545      )
546
547      text = "slim.xavier_initializer(uniform=(True or False))\n"
548      _, unused_report, unused_errors, new_text = self._upgrade(text)
549      self.assertEqual(
550          new_text,
551          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
552          "mode=\"fan_avg\", "
553          "distribution=(\"uniform\" if True or False else "
554          "\"truncated_normal\"))\n",
555      )
556
557      text = contrib_alias + "layers.xavier_initializer_conv2d(False, 12)\n"
558      _, unused_report, unused_errors, new_text = self._upgrade(text)
559      self.assertEqual(
560          new_text,
561          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
562          "mode=\"fan_avg\", "
563          "distribution=(\"uniform\" if False else \"truncated_normal\"), "
564          "seed=12)\n",
565      )
566
567      text = (contrib_alias + "layers.xavier_initializer_conv2d("
568              "False, 12, tf.float32)\n")
569      _, unused_report, unused_errors, new_text = self._upgrade(text)
570      self.assertEqual(
571          new_text,
572          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
573          "mode=\"fan_avg\", "
574          "distribution=(\"uniform\" if False else \"truncated_normal\"), "
575          "seed=12, "
576          "dtype=tf.float32)\n",
577      )
578
579      text = (contrib_alias + "layers.xavier_initializer("
580              "False, 12, dtypes=tf.float32)\n")
581      _, unused_report, unused_errors, new_text = self._upgrade(text)
582      self.assertEqual(
583          new_text,
584          "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, "
585          "mode=\"fan_avg\", "
586          "distribution=(\"uniform\" if False else \"truncated_normal\"), "
587          "seed=12, "
588          "dtypes=tf.float32)\n",
589      )
590
591  def testVarianceScalingInitializer(self):
592    text = ("tf.contrib.layers.variance_scaling_initializer("
593            "mode=(\"FAN\" + \"_AVG\"))\n")
594    _, unused_report, unused_errors, new_text = self._upgrade(text)
595    self.assertEqual(
596        new_text,
597        "tf.compat.v1.keras.initializers.VarianceScaling(scale=2.0, "
598        "mode=(\"FAN\" + \"_AVG\").lower())\n",
599    )
600
601    text = ("slim.variance_scaling_initializer("
602            "uniform=(True or False), mode=(\"FAN\" + \"_AVG\"))\n")
603    _, unused_report, unused_errors, new_text = self._upgrade(text)
604    self.assertEqual(
605        new_text,
606        "tf.compat.v1.keras.initializers.VarianceScaling(scale=2.0, "
607        "distribution=(\"uniform\" if True or False else \"truncated_normal\"),"
608        " mode=(\"FAN\" + \"_AVG\").lower())\n",
609    )
610
611    text = "tf.contrib.layers.variance_scaling_initializer(factor=1.0)\n"
612    _, unused_report, unused_errors, new_text = self._upgrade(text)
613    self.assertEqual(
614        new_text,
615        "tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0)\n",
616    )
617
618    text = ("tf.contrib.layers.variance_scaling_initializer("
619            "12.0, \"FAN_AVG\", True, dtypes=tf.float32)\n")
620    _, unused_report, unused_errors, new_text = self._upgrade(text)
621    self.assertEqual(
622        new_text,
623        "tf.compat.v1.keras.initializers.VarianceScaling(12.0, "
624        "(\"FAN_AVG\").lower(), "
625        "(\"uniform\" if True else \"truncated_normal\"), "
626        "dtypes=tf.float32)\n",
627    )
628
629  def testMetrics(self):
630    metrics = [
631        "accuracy",
632        "auc",
633        "average_precision_at_k",
634        "false_negatives",
635        "false_negatives_at_thresholds",
636        "false_positives",
637        "false_positives_at_thresholds",
638        "mean",
639        "mean_absolute_error",
640        "mean_cosine_distance",
641        "mean_iou",
642        "mean_per_class_accuracy",
643        "mean_relative_error",
644        "mean_squared_error",
645        "mean_tensor",
646        "percentage_below",
647        "precision",
648        "precision_at_k",
649        "precision_at_thresholds",
650        "precision_at_top_k",
651        "recall",
652        "recall_at_k",
653        "recall_at_thresholds",
654        "recall_at_top_k",
655        "root_mean_squared_error",
656        "sensitivity_at_specificity",
657        "sparse_average_precision_at_k",
658        "sparse_precision_at_k",
659        "specificity_at_sensitivity",
660        "true_negatives",
661        "true_negatives_at_thresholds",
662        "true_positives",
663        "true_positives_at_thresholds",
664    ]
665    for m in metrics:
666      text = "tf.metrics." + m + "(a, b)"
667      _, report, unused_errors, new_text = self._upgrade(text)
668      self.assertEqual("tf.compat.v1.metrics." + m + "(a, b)", new_text)
669      self.assertIn(
670          "tf.metrics have been replaced with object oriented versions", report)
671
672  def testLosses(self):
673    losses = [
674        "absolute_difference",
675        "add_loss",
676        "compute_weighted_loss",
677        "cosine_distance",
678        "get_losses",
679        "get_regularization_loss",
680        "get_regularization_losses",
681        "get_total_loss",
682        "hinge_loss",
683        "huber_loss",
684        "log_loss",
685        "mean_pairwise_squared_error",
686        "mean_squared_error",
687        "sigmoid_cross_entropy",
688        "softmax_cross_entropy",
689        "sparse_softmax_cross_entropy",
690    ]
691    for l in losses:
692      text = "tf.losses." + l + "(a, b)"
693      _, report, unused_errors, new_text = self._upgrade(text)
694      self.assertEqual("tf.compat.v1.losses." + l + "(a, b)", new_text)
695      self.assertIn(
696          "tf.losses have been replaced with object oriented versions", report)
697
698  def testEstimatorLossReductionChange(self):
699    classes = [
700        "LinearClassifier", "LinearRegressor", "DNNLinearCombinedClassifier",
701        "DNNLinearCombinedRegressor", "DNNRegressor", "DNNClassifier",
702        "BaselineClassifier", "BaselineRegressor"
703    ]
704    for c in classes:
705      ns = "tf.estimator." + c
706      text = ns + "()"
707      expected_text = ns + "(loss_reduction=tf.keras.losses.Reduction.SUM)"
708      _, report, errors, new_text = self._upgrade(text)
709      self.assertEqual(expected_text, new_text)
710
711      text = ns + "(loss_reduction=TEST)"
712      expected_text = ns + "(loss_reduction=TEST)"
713      _, report, errors, new_text = self._upgrade(text)
714      self.assertEqual(text, new_text)
715    text = "tf.estimator.BaselineClassifier(m, c, w, v, o, c, lr)"
716    expected_text = (
717        "tf.compat.v1.estimator.BaselineClassifier("
718        "model_dir=m, n_classes=c, weight_column=w, label_vocabulary=v, "
719        "optimizer=o, config=c, loss_reduction=lr)")
720    _, report, errors, new_text = self._upgrade(text)
721    self.assertEqual(expected_text, new_text)
722
723    text = "tf.estimator.BaselineClassifier(model_dir=model_dir)"
724    expected_text = ("tf.estimator.BaselineClassifier(" +
725                     "model_dir=model_dir, "
726                     "loss_reduction=tf.keras.losses.Reduction.SUM)")
727    _, report, errors, new_text = self._upgrade(text)
728    self.assertEqual(expected_text, new_text)
729
730  def testBaseEstimatorPartitioner(self):
731    classes = ["LinearEstimator", "DNNLinearCombinedEstimator", "DNNEstimator"]
732    for c in classes:
733      ns = "tf.estimator." + c
734      suffix = "(input_layer_partitioner=TEST)"
735      text = ns + suffix
736      expected_text = "tf.compat.v1.estimator." + c + suffix
737      _, unused_report, unused_errors, new_text = self._upgrade(text)
738      self.assertEqual(new_text, expected_text)
739
740  def testCannedEstimatorPartitioner(self):
741    classes = [
742        "LinearClassifier", "LinearRegressor", "DNNLinearCombinedClassifier",
743        "DNNLinearCombinedRegressor", "DNNRegressor", "DNNClassifier"
744    ]
745
746    for c in classes:
747      ns = "tf.estimator." + c
748      suffix = "(input_layer_partitioner=TEST)"
749      text = ns + suffix
750      suffix = ("(input_layer_partitioner=TEST, "
751                "loss_reduction=tf.keras.losses.Reduction.SUM)")
752      expected_text = "tf.compat.v1.estimator." + c + suffix
753      _, unused_report, unused_errors, new_text = self._upgrade(text)
754      self.assertEqual(new_text, expected_text)
755
756  def testBaseEstimatorOptimizer(self):
757    classes = ["BaselineEstimator", "LinearEstimator", "DNNEstimator"]
758    for c in classes:
759      ns = "tf.estimator." + c
760      suffix = "(optimizer=TEST)"
761      text = ns + suffix
762      expected_text = "tf.compat.v1.estimator." + c + suffix
763      _, unused_report, unused_errors, new_text = self._upgrade(text)
764      self.assertEqual(new_text, expected_text)
765
766  def testDNNLinearCombinedEstimatorOptimizer(self):
767    classes = ["DNNLinearCombinedEstimator"]
768    for c in classes:
769      ns = "tf.estimator." + c
770      suffix = "(dnn_optimizer=TEST, linear_optimizer=Test)"
771      text = ns + suffix
772      expected_text = "tf.compat.v1.estimator." + c + suffix
773      _, unused_report, unused_errors, new_text = self._upgrade(text)
774      self.assertEqual(new_text, expected_text)
775
776  def testCannedEstimatorOptimizer(self):
777    classes = [
778        "BaselineClassifier", "BaselineRegressor", "LinearClassifier",
779        "LinearRegressor", "DNNRegressor", "DNNClassifier"
780    ]
781
782    for c in classes:
783      ns = "tf.estimator." + c
784      suffix = "(optimizer=TEST)"
785      text = ns + suffix
786      suffix = ("(optimizer=TEST, "
787                "loss_reduction=tf.keras.losses.Reduction.SUM)")
788      expected_text = "tf.compat.v1.estimator." + c + suffix
789      _, unused_report, unused_errors, new_text = self._upgrade(text)
790      self.assertEqual(new_text, expected_text)
791
792  def testDNNLinearCombinedOptimizer(self):
793    classes = [
794        "DNNLinearCombinedClassifier",
795        "DNNLinearCombinedRegressor",
796    ]
797    for c in classes:
798      ns = "tf.estimator." + c
799      suffix = "(dnn_optimizer=TEST, linear_optimizer=Test)"
800      text = ns + suffix
801      suffix = ("(dnn_optimizer=TEST, linear_optimizer=Test, "
802                "loss_reduction=tf.keras.losses.Reduction.SUM)")
803      expected_text = "tf.compat.v1.estimator." + c + suffix
804      _, unused_report, unused_errors, new_text = self._upgrade(text)
805      self.assertEqual(new_text, expected_text)
806
807  def testBaseEstimatorPartitionerAndOptimizer(self):
808    classes = ["LinearEstimator", "DNNEstimator"]
809    for c in classes:
810      ns = "tf.estimator." + c
811      suffix = "(input_layer_partitioner=TEST, optimizer=TEST)"
812      text = ns + suffix
813      expected_text = "tf.compat.v1.estimator." + c + suffix
814      _, unused_report, unused_errors, new_text = self._upgrade(text)
815      self.assertEqual(new_text, expected_text)
816
817  def testDNNLinearCombinedEstimatorPartitionerAndOptimizer(self):
818    classes = ["DNNLinearCombinedEstimator"]
819    for c in classes:
820      ns = "tf.estimator." + c
821      suffix = ("(input_layer_partitioner=TEST, dnn_optimizer=TEST, "
822                "linear_optimizer=TEST)")
823      text = ns + suffix
824      expected_text = "tf.compat.v1.estimator." + c + suffix
825      _, unused_report, unused_errors, new_text = self._upgrade(text)
826      self.assertEqual(new_text, expected_text)
827
828  def testCannedEstimatorPartitionerAndOptimizer(self):
829    classes = [
830        "LinearClassifier", "LinearRegressor", "DNNRegressor", "DNNClassifier"
831    ]
832
833    for c in classes:
834      ns = "tf.estimator." + c
835      suffix = "(input_layer_partitioner=TEST, optimizer=TEST)"
836      text = ns + suffix
837      suffix = ("(input_layer_partitioner=TEST, optimizer=TEST, "
838                "loss_reduction=tf.keras.losses.Reduction.SUM)")
839      expected_text = "tf.compat.v1.estimator." + c + suffix
840      _, unused_report, unused_errors, new_text = self._upgrade(text)
841      self.assertEqual(new_text, expected_text)
842
843  def testDNNLinearCombinedPartitionerAndOptimizer(self):
844    classes = [
845        "DNNLinearCombinedClassifier",
846        "DNNLinearCombinedRegressor",
847    ]
848
849    for c in classes:
850      ns = "tf.estimator." + c
851      suffix = ("(input_layer_partitioner=TEST, dnn_optimizer=TEST, "
852                "linear_optimizer=TEST)")
853      text = ns + suffix
854      suffix = ("(input_layer_partitioner=TEST, dnn_optimizer=TEST, "
855                "linear_optimizer=TEST, "
856                "loss_reduction=tf.keras.losses.Reduction.SUM)")
857      expected_text = "tf.compat.v1.estimator." + c + suffix
858      _, unused_report, unused_errors, new_text = self._upgrade(text)
859      self.assertEqual(new_text, expected_text)
860
861  def testExtractGlimpse(self):
862    text = ("tf.image.extract_glimpse(x, size, off, False, "
863            "False, False, name=\"foo\")\n")
864    _, unused_report, unused_errors, new_text = self._upgrade(text)
865    self.assertEqual(
866        new_text,
867        "tf.image.extract_glimpse(x, size, off, False, "
868        "False, 'uniform' if (False) else 'gaussian', name=\"foo\")\n",
869    )
870
871    text = ("tf.image.extract_glimpse(x, size, off, centered=False, "
872            "normalized=False, uniform_noise=True if uniform_noise else "
873            "False, name=\"foo\")\n")
874    _, unused_report, unused_errors, new_text = self._upgrade(text)
875    self.assertEqual(
876        new_text,
877        "tf.image.extract_glimpse(x, size, off, centered=False, "
878        "normalized=False, noise='uniform' if (True if uniform_noise else "
879        "False) else 'gaussian', name=\"foo\")\n",
880    )
881
882    text = ("tf.image.extract_glimpse(x,\n"
883            "                         size,\n"
884            "                         off,\n"
885            "                         centered=True,\n"
886            "                         normalized=True, # Stuff before\n"
887            "                         uniform_noise=False,\n"
888            "                         name=\"foo\")# Stuff after\n")
889    _, unused_report, unused_errors, new_text = self._upgrade(text)
890    self.assertEqual(
891        new_text, "tf.image.extract_glimpse(x,\n"
892        "                         size,\n"
893        "                         off,\n"
894        "                         centered=True,\n"
895        "                         normalized=True, # Stuff before\n"
896        "                         noise='uniform' if (False) else 'gaussian',\n"
897        "                         name=\"foo\")# Stuff after\n")
898
899    text = "tf.image.extract_glimpse(x)\n"
900    _, unused_report, errors, new_text = self._upgrade(text)
901    self.assertEqual(new_text, text)
902    self.assertEqual(errors, [])
903
904  def testDropout(self):
905    text = "tf.nn.dropout(x, keep_prob, name=\"foo\")\n"
906    _, unused_report, unused_errors, new_text = self._upgrade(text)
907    self.assertEqual(
908        new_text,
909        "tf.nn.dropout(x, rate=1 - (keep_prob), name=\"foo\")\n",
910    )
911
912    text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n"
913    _, unused_report, unused_errors, new_text = self._upgrade(text)
914    self.assertEqual(
915        new_text,
916        "tf.nn.dropout(x, rate=1 - (.4), name=\"foo\")\n",
917    )
918
919    text = (
920        "tf.nn.dropout(x,  # Stuff before\n"
921        "              keep_prob=.4,  # Stuff after\n"
922        "              name=\"foo\")\n"
923    )
924    _, unused_report, unused_errors, new_text = self._upgrade(text)
925    self.assertEqual(
926        new_text,
927        "tf.nn.dropout(x,  # Stuff before\n"
928        "              rate=1 - (.4),  # Stuff after\n"
929        "              name=\"foo\")\n",
930    )
931
932    text = "tf.nn.dropout(x)\n"
933    _, unused_report, errors, new_text = self._upgrade(text)
934    self.assertEqual(new_text, text)
935    self.assertIn("tf.nn.dropout called without arguments", errors[0])
936
937  def testDropoutExpr(self):
938    text = "tf.nn.dropout(x, 1 - func(3 + 4.), name=\"foo\")\n"
939    _, unused_report, unused_errors, new_text = self._upgrade(text)
940    self.assertEqual(
941        new_text,
942        "tf.nn.dropout(x, rate=1 - (1 - func(3 + 4.)), name=\"foo\")\n",
943    )
944
945  def testContribL1(self):
946    text = "tf.contrib.layers.l1_regularizer(scale)\n"
947    _, unused_report, unused_errors, new_text = self._upgrade(text)
948    self.assertEqual(
949        new_text,
950        "tf.keras.regularizers.l1(scale)\n",
951    )
952    self.assertNotIn("Dropping scope", unused_report)
953
954    text = "tf.contrib.layers.l1_regularizer(scale, scope)\n"
955    _, unused_report, unused_errors, new_text = self._upgrade(text)
956    self.assertEqual(
957        new_text,
958        "tf.keras.regularizers.l1(scale)\n",
959    )
960    self.assertIn("Dropping scope", unused_report)
961
962    text = (
963        "slim.l1_regularizer(  # Stuff before\n"
964        "                    scale=.4,"
965        "                    scope=\"foo\")\n"
966    )
967    _, unused_report, unused_errors, new_text = self._upgrade(text)
968    self.assertEqual(
969        new_text,
970        "tf.keras.regularizers.l1(  # Stuff before\n"
971        "                    l=.4)\n",
972    )
973    self.assertIn("Dropping scope", unused_report)
974
975  def testContribL2(self):
976    text = "tf.contrib.layers.l2_regularizer(scale)\n"
977    _, unused_report, unused_errors, new_text = self._upgrade(text)
978    self.assertEqual(
979        new_text,
980        "tf.keras.regularizers.l2(0.5 * (scale))\n",
981    )
982    self.assertNotIn("Dropping scope", unused_report)
983
984    text = "tf.contrib.layers.l2_regularizer(scale, scope)\n"
985    _, unused_report, unused_errors, new_text = self._upgrade(text)
986    self.assertEqual(
987        new_text,
988        "tf.keras.regularizers.l2(0.5 * (scale))\n",
989    )
990    self.assertIn("Dropping scope", unused_report)
991
992    text = (
993        "slim.l2_regularizer(  # Stuff before\n"
994        "                    scale=.4,"
995        "                    scope=\"foo\")\n"
996    )
997    _, unused_report, unused_errors, new_text = self._upgrade(text)
998    self.assertEqual(
999        new_text,
1000        "tf.keras.regularizers.l2(  # Stuff before\n"
1001        "                    l=0.5 * (.4))\n",
1002    )
1003    self.assertIn("Dropping scope", unused_report)
1004
1005  def testContribL2Expr(self):
1006    text = "tf.contrib.layers.l2_regularizer(1 - func(3 + 4.), scope=\"foo\")\n"
1007    _, unused_report, unused_errors, new_text = self._upgrade(text)
1008    self.assertEqual(
1009        new_text,
1010        "tf.keras.regularizers.l2(0.5 * (1 - func(3 + 4.)))\n",
1011    )
1012
1013  def testMathCountNonZeroChanges(self):
1014    text = (
1015        "tf.math.count_nonzero(input_tensor=input, dtype=dtype, name=name, "
1016        "reduction_indices=axis, keep_dims=keepdims)\n"
1017        )
1018    _, unused_report, unused_errors, new_text = self._upgrade(text)
1019    expected_text = (
1020        "tf.math.count_nonzero(input=input, dtype=dtype, name=name, "
1021        "axis=axis, keepdims=keepdims)\n"
1022        )
1023    self.assertEqual(new_text, expected_text)
1024
1025  def testCountNonZeroChanges(self):
1026    text = (
1027        "tf.count_nonzero(input_tensor=input, dtype=dtype, name=name, "
1028        "reduction_indices=axis, keep_dims=keepdims)\n"
1029        )
1030    _, unused_report, unused_errors, new_text = self._upgrade(text)
1031    expected_text = (
1032        "tf.math.count_nonzero(input=input, dtype=dtype, name=name, "
1033        "axis=axis, keepdims=keepdims)\n"
1034        )
1035    self.assertEqual(new_text, expected_text)
1036
1037  def testRandomMultinomialToRandomCategorical(self):
1038    text = (
1039        "tf.random.multinomial(logits, samples, seed, name, output_dtype)\n"
1040        )
1041    _, unused_report, unused_errors, new_text = self._upgrade(text)
1042    expected_text = (
1043        "tf.random.categorical(logits=logits, num_samples=samples, seed=seed, "
1044        "name=name, dtype=output_dtype)\n"
1045        )
1046    self.assertEqual(new_text, expected_text)
1047
1048    text = (
1049        "tf.multinomial(logits, samples, seed, name, output_dtype)\n"
1050        )
1051    _, unused_report, unused_errors, new_text = self._upgrade(text)
1052    expected_text = (
1053        "tf.random.categorical(logits=logits, num_samples=samples, seed=seed, "
1054        "name=name, dtype=output_dtype)\n"
1055        )
1056    self.assertEqual(new_text, expected_text)
1057
1058  def testRandomPoissonConversion(self):
1059    text1 = "tf.random_poisson(lam, shape, dtype)"
1060    text2 = "tf.random.poisson(lam, shape, dtype)"
1061    expected_text = "tf.random.poisson(lam=lam, shape=shape, dtype=dtype)"
1062    _, unused_report, unused_errors, new_text1 = self._upgrade(text1)
1063    self.assertEqual(new_text1, expected_text)
1064    _, unused_report, unused_errors, new_text2 = self._upgrade(text2)
1065    self.assertEqual(new_text2, expected_text)
1066
1067  def testConvolutionOpUpdate(self):
1068    text = (
1069        "tf.nn.convolution(input, filter, padding, strides, dilation_rate, "
1070        "name, data_format)"
1071    )
1072    _, unused_report, unused_errors, new_text = self._upgrade(text)
1073    expected_text = (
1074        "tf.nn.convolution(input=input, filters=filter, padding=padding, "
1075        "strides=strides, dilations=dilation_rate, name=name, "
1076        "data_format=data_format)"
1077    )
1078    self.assertEqual(new_text, expected_text)
1079
1080  def test_substr(self):
1081    text = "tf.substr(input, pos, len, name, unit)\n"
1082    _, unused_report, errors, new_text = self._upgrade(text)
1083    self.assertEqual("tf.strings.substr(input=input, pos=pos, len=len, "
1084                     "name=name, unit=unit)\n", new_text)
1085    self.assertEqual(errors, [])
1086
1087  def testColocateGradientsWithOps(self):
1088    text = "tf.gradients(yx=a, foo=False)\n"
1089    _, unused_report, errors, new_text = self._upgrade(text)
1090    self.assertEqual(text, new_text)
1091    self.assertEqual(errors, [])
1092
1093    text = "tf.gradients(yx=a, colocate_gradients_with_ops=False)\n"
1094    _, report, unused_errors, new_text = self._upgrade(text)
1095    self.assertEqual("tf.gradients(yx=a)\n", new_text)
1096    self.assertIn("tf.gradients no longer takes", report)
1097
1098    text = "tf.gradients(y, x, grad_ys, name, colocate, gate)\n"
1099    expected = ("tf.gradients(ys=y, xs=x, grad_ys=grad_ys, name=name, "
1100                "gate_gradients=gate)\n")
1101    _, unused_report, errors, new_text = self._upgrade(text)
1102    self.assertEqual(expected, new_text)
1103
1104  def testColocateGradientsWithOpsMinimize(self):
1105    text = "optimizer.minimize(a, foo=False)\n"
1106    _, unused_report, errors, new_text = self._upgrade(text)
1107    self.assertEqual(text, new_text)
1108    self.assertEqual(errors, [])
1109
1110    text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n"
1111    _, report, unused_errors, new_text = self._upgrade(text)
1112    self.assertEqual("optimizer.minimize(a)\n", new_text)
1113    self.assertIn("Optimizer.minimize no longer takes", report)
1114
1115  def testColocateGradientsWithOpsComputeGradients(self):
1116    text = "optimizer.compute_gradients(a, foo=False)\n"
1117    _, unused_report, errors, new_text = self._upgrade(text)
1118    self.assertEqual(text, new_text)
1119    self.assertEqual(errors, [])
1120
1121    text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n"
1122    _, report, unused_errors, new_text = self._upgrade(text)
1123    self.assertEqual("optimizer.compute_gradients(a)\n", new_text)
1124    self.assertIn("Optimizer.compute_gradients no longer takes", report)
1125
1126  def testColocateGradientsWithHessians(self):
1127    text = "tf.hessians(ys=a, xs=b, colocate_gradients_with_ops=False)\n"
1128    _, report, unused_errors, new_text = self._upgrade(text)
1129    self.assertEqual("tf.hessians(ys=a, xs=b)\n", new_text)
1130    self.assertIn("tf.hessians no longer takes", report)
1131
1132  def testExportSavedModelRename(self):
1133    text = "self.est.export_savedmodel(path)"
1134    _, report, unused_errors, unused_new_text = self._upgrade(text)
1135    self.assertIn(
1136        "rename the method export_savedmodel() to export_saved_model()",
1137        report)
1138
1139  def testArgmin(self):
1140    text = "tf.argmin(input, name=n, dimension=1, output_type=type)"
1141    expected_text = "tf.argmin(input=input, name=n, axis=1, output_type=type)"
1142    _, unused_report, unused_errors, new_text = self._upgrade(text)
1143    self.assertEqual(new_text, expected_text)
1144
1145    text = "tf.argmin(input, 0)"
1146    expected_text = "tf.argmin(input=input, axis=0)"
1147    _, unused_report, unused_errors, new_text = self._upgrade(text)
1148    self.assertEqual(new_text, expected_text)
1149
1150    text = "tf.arg_min(input, 0)"
1151    expected_text = "tf.argmin(input, 0)"
1152    _, unused_report, unused_errors, new_text = self._upgrade(text)
1153    self.assertEqual(new_text, expected_text)
1154
1155  def testArgmax(self):
1156    text = "tf.argmax(input, name=n, dimension=1, output_type=type)"
1157    expected_text = "tf.argmax(input=input, name=n, axis=1, output_type=type)"
1158    _, unused_report, unused_errors, new_text = self._upgrade(text)
1159    self.assertEqual(new_text, expected_text)
1160
1161    text = "tf.argmax(input, 0)"
1162    expected_text = "tf.argmax(input=input, axis=0)"
1163    _, unused_report, unused_errors, new_text = self._upgrade(text)
1164    self.assertEqual(new_text, expected_text)
1165
1166    text = "tf.arg_max(input, 0)"
1167    expected_text = "tf.argmax(input, 0)"
1168    _, unused_report, unused_errors, new_text = self._upgrade(text)
1169    self.assertEqual(new_text, expected_text)
1170
1171  def testAutograph(self):
1172    text = "tf.autograph.to_graph(f, True, arg_values=None, arg_types=None)"
1173    expected_text = "tf.autograph.to_graph(f, True)"
1174    _, unused_report, unused_errors, new_text = self._upgrade(text)
1175    self.assertEqual(new_text, expected_text)
1176
1177    text = ("tf.autograph.to_code"
1178            "(f, False, arg_values=None, arg_types=None, indentation=' ')")
1179    expected_text = "tf.autograph.to_code(f, False)"
1180    _, unused_report, unused_errors, new_text = self._upgrade(text)
1181    self.assertEqual(new_text, expected_text)
1182
1183  def testEstimatorInputs(self):
1184    text = "tf.estimator.inputs.numpy_input_fn(0)"
1185    expected_text = "tf.compat.v1.estimator.inputs.numpy_input_fn(0)"
1186    _, unused_report, unused_errors, new_text = self._upgrade(text)
1187    self.assertEqual(new_text, expected_text)
1188
1189    text = "tf.estimator.inputs.pandas_input_fn(0)"
1190    expected_text = "tf.compat.v1.estimator.inputs.pandas_input_fn(0)"
1191    _, unused_report, unused_errors, new_text = self._upgrade(text)
1192    self.assertEqual(new_text, expected_text)
1193
1194  def testBatchToSpace(self):
1195    text = "tf.batch_to_space_nd(input, block_shape, crops, name)"
1196    expected_text = "tf.batch_to_space(input, block_shape, crops, name)"
1197    _, unused_report, unused_errors, new_text = self._upgrade(text)
1198    self.assertEqual(new_text, expected_text)
1199
1200    text = "tf.batch_to_space(input, crops, block_size, name)"
1201    expected_text = (
1202        "tf.batch_to_space(input=input, crops=crops, block_shape=block_size, "
1203        "name=name)")
1204    _, unused_report, unused_errors, new_text = self._upgrade(text)
1205    self.assertEqual(new_text, expected_text)
1206
1207    text = "tf.manip.batch_to_space_nd(input, block_shape, crops, name)"
1208    expected_text = "tf.batch_to_space(input, block_shape, crops, name)"
1209    _, unused_report, unused_errors, new_text = self._upgrade(text)
1210    self.assertEqual(new_text, expected_text)
1211
1212  def testExtractImagePatches(self):
1213    text = (
1214        "tf.extract_image_patches(images, ksizes=ksizes, strides=strides,"
1215        "rates=rates, padding=padding, name=name)")
1216    expected_text = (
1217        "tf.image.extract_patches(images, sizes=ksizes, strides=strides,"
1218        "rates=rates, padding=padding, name=name)")
1219    _, unused_report, unused_errors, new_text = self._upgrade(text)
1220    self.assertEqual(new_text, expected_text)
1221
1222  def testKerasSavedModel(self):
1223    text = (
1224        "tf.contrib.saved_model.save_keras_model(model, './saved_models')\n"
1225        "tf.contrib.saved_model.load_keras_model(saved_model_path)\n")
1226    expected_text = (
1227        "tf.compat.v1.keras.experimental.export_saved_model(model, "
1228        "'./saved_models')\ntf.compat.v1.keras.experimental."
1229        "load_from_saved_model(saved_model_path)\n"
1230    )
1231    _, report, unused_errors, new_text = self._upgrade(text)
1232    self.assertEqual(new_text, expected_text)
1233    expected_info = "Please use model.save"
1234    self.assertIn(expected_info, report)
1235
1236  def testStatelessMultinomial(self):
1237    text = (
1238        "tf.random.stateless_multinomial(logits, num_samples, seed, "
1239        "output_dtype=dtype, name=name)")
1240    expected_text = (
1241        "tf.random.stateless_categorical(logits, num_samples, seed, "
1242        "dtype=dtype, name=name)")
1243    _, unused_report, unused_errors, new_text = self._upgrade(text)
1244    self.assertEqual(new_text, expected_text)
1245
1246  def testSoftMaxCrossEntropyWithLogitsV2(self):
1247    text = (
1248        "tf.nn.softmax_cross_entropy_with_logits_v2("
1249        "labels=labels, logits=logits, dim=2)")
1250    expected_text = (
1251        "tf.nn.softmax_cross_entropy_with_logits("
1252        "labels=labels, logits=logits, axis=2)")
1253    _, unused_report, errors, new_text = self._upgrade(text)
1254    self.assertEqual(new_text, expected_text)
1255
1256    self.assertFalse(errors)
1257
1258  def testSoftMaxCrossEntropyWithLogits(self):
1259    text = ("tf.nn.softmax_cross_entropy_with_logits("
1260            "labels=labels, logits=logits, dim=2)")
1261    expected_text = (
1262        "tf.nn.softmax_cross_entropy_with_logits("
1263        "labels=tf.stop_gradient(labels), logits=logits, axis=2)")
1264    _, unused_report, unused_errors, new_text = self._upgrade(text)
1265    self.assertEqual(new_text, expected_text)
1266
1267    text = ("tf.nn.softmax_cross_entropy_with_logits("
1268            "labels=foo(bar))")
1269    expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
1270                     "labels=tf.stop_gradient(foo(bar)))")
1271    _, unused_report, unused_errors, new_text = self._upgrade(text)
1272    self.assertEqual(expected_text, new_text)
1273
1274  def testSoftMaxCrossEntropyWithLogitsDoesntNest(self):
1275    text = ("tf.nn.softmax_cross_entropy_with_logits("
1276            "labels=tf.stop_gradient(labels), logits=logits, dim=2)")
1277    expected_text = (
1278        "tf.nn.softmax_cross_entropy_with_logits("
1279        "labels=tf.stop_gradient(labels), logits=logits, axis=2)")
1280    _, unused_report, unused_errors, new_text = self._upgrade(text)
1281    self.assertEqual(new_text, expected_text)
1282
1283    text = ("tf.nn.softmax_cross_entropy_with_logits("
1284            "labels=tf.stop_gradient(foo(bar)))")
1285    expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
1286                     "labels=tf.stop_gradient(foo(bar)))")
1287    _, unused_report, unused_errors, new_text = self._upgrade(text)
1288    self.assertEqual(expected_text, new_text)
1289
1290    text = ("tf.nn.softmax_cross_entropy_with_logits("
1291            "labels=foo())")
1292    expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
1293                     "labels=tf.stop_gradient(foo()))")
1294    _, unused_report, unused_errors, new_text = self._upgrade(text)
1295    self.assertEqual(expected_text, new_text)
1296
1297    text = ("tf.nn.softmax_cross_entropy_with_logits("
1298            "labels=foo().zz())")
1299    expected_text = ("tf.nn.softmax_cross_entropy_with_logits("
1300                     "labels=tf.stop_gradient(foo().zz()))")
1301    _, unused_report, unused_errors, new_text = self._upgrade(text)
1302    self.assertEqual(expected_text, new_text)
1303
1304  def testSparseMatmul(self):
1305    text = ("tf.sparse_matmul(a, b, c, d, e, f, g)\n")
1306    expected_text = ("tf.linalg.matmul(a=a, b=b, transpose_a=c, transpose_b=d, "
1307                     "a_is_sparse=e, b_is_sparse=f, name=g)\n")
1308    _, unused_report, unused_errors, new_text = self._upgrade(text)
1309    self.assertEqual(new_text, expected_text)
1310
1311  def testWeightedMoments(self):
1312    text = "tf.nn.weighted_moments(x, axes, freq, name, kd)"
1313    expected_text = (
1314        "tf.nn.weighted_moments(x=x, axes=axes, frequency_weights=freq, "
1315        "name=name, keepdims=kd)")
1316    _, unused_report, unused_errors, new_text = self._upgrade(text)
1317    self.assertEqual(new_text, expected_text)
1318
1319  def testSparseAdd(self):
1320    text = "tf.sparse.add(a, b, t)"
1321    expected_text = "tf.sparse.add(a=a, b=b, threshold=t)"
1322    _, unused_report, unused_errors, new_text = self._upgrade(text)
1323    self.assertEqual(new_text, expected_text)
1324
1325  def testSparseConcat(self):
1326    text = "tf.sparse.concat(ax, inp, name, exp, concat)"
1327    expected_text = (
1328        "tf.sparse.concat(axis=ax, sp_inputs=inp, name=name, "
1329        "expand_nonconcat_dims=exp, axis=concat)")
1330    _, unused_report, unused_errors, new_text = self._upgrade(text)
1331    self.assertEqual(new_text, expected_text)
1332
1333  def testSeparableConv2D(self):
1334    text = "tf.nn.separable_conv2d(inp, d, pt, strides, pad, rate, name, fmt)"
1335    expected_text = (
1336        "tf.nn.separable_conv2d(input=inp, depthwise_filter=d, "
1337        "pointwise_filter=pt, strides=strides, padding=pad, "
1338        "dilations=rate, name=name, data_format=fmt)")
1339    _, unused_report, unused_errors, new_text = self._upgrade(text)
1340    self.assertEqual(new_text, expected_text)
1341
1342  def testConv2D(self):
1343    text = (
1344        "tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu, "
1345        "data_format)")
1346    expected_text = (
1347        "tf.nn.conv2d(input=input, filters=filter, strides=strides, "
1348        "padding=padding, data_format=data_format)")
1349    _, unused_report, unused_errors, new_text = self._upgrade(text)
1350    self.assertEqual(new_text, expected_text)
1351
1352    text = (
1353        "tf.nn.conv2d(input, filter=filter, strides=strides, padding=padding, "
1354        "use_cudnn_on_gpu=use_cudnn_on_gpu)")
1355    expected_text = ("tf.nn.conv2d(input=input, filters=filter, "
1356                     "strides=strides, padding=padding)")
1357    _, unused_report, unused_errors, new_text = self._upgrade(text)
1358    self.assertEqual(new_text, expected_text)
1359
1360  def testConv2DBackpropFilter(self):
1361    text = (
1362        "tf.nn.conv2d_backprop_filter(input, filter_sizes, out_backprop, "
1363        "strides, padding, use_cudnn_on_gpu, data_format)")
1364    expected_text = (
1365        "tf.compat.v1.nn.conv2d_backprop_filter(input, filter_sizes, "
1366        "out_backprop, strides, padding, use_cudnn_on_gpu, data_format)")
1367    _, unused_report, unused_errors, new_text = self._upgrade(text)
1368    self.assertEqual(new_text, expected_text)
1369
1370  def testConv2DBackpropInput(self):
1371    text = (
1372        "tf.nn.conv2d_backprop_input(input_sizes, filter, out_backprop, "
1373        "strides, padding, use_cudnn_on_gpu, data_format)")
1374    expected_text = (
1375        "tf.nn.conv2d_transpose(output_shape=input_sizes, filters=filter, "
1376        "input=out_backprop, strides=strides, padding=padding, "
1377        "data_format=data_format)")
1378    _, unused_report, unused_errors, new_text = self._upgrade(text)
1379    self.assertEqual(new_text, expected_text)
1380
1381  def testSpacetoBatch(self):
1382    text = "tf.space_to_batch_nd(input, shape, paddings, name)"
1383    expected_text = "tf.space_to_batch(input, shape, paddings, name)"
1384    _, unused_report, unused_errors, new_text = self._upgrade(text)
1385    self.assertEqual(new_text, expected_text)
1386
1387    text = "tf.nn.space_to_batch(input, paddings, block_size, name)"
1388    expected_text = (
1389        "tf.space_to_batch(input=input, paddings=paddings, "
1390        "block_shape=block_size, name=name)")
1391    _, unused_report, unused_errors, new_text = self._upgrade(text)
1392    self.assertEqual(new_text, expected_text)
1393
1394  def testInTopK(self):
1395    text = "tf.math.in_top_k(a, b, c, n)"
1396    expected_text = (
1397        "tf.math.in_top_k(predictions=a, targets=b, k=c, name=n)")
1398    _, unused_report, unused_errors, new_text = self._upgrade(text)
1399    self.assertEqual(new_text, expected_text)
1400
1401  def testDepthToSpace(self):
1402    text = "tf.nn.depth_to_space(input, block_size, name, data_format)"
1403    expected_text = (
1404        "tf.nn.depth_to_space(input=input, block_size=block_size, "
1405        "name=name, data_format=data_format)")
1406    _, unused_report, unused_errors, new_text = self._upgrade(text)
1407    self.assertEqual(new_text, expected_text)
1408
1409  def testEmbeddingLookup(self):
1410    text = ("tf.nn.embedding_lookup(params, ids, partition_strategy, name, "
1411            "validate_indices, max_norm)")
1412    expected_text = ("tf.nn.embedding_lookup(params=params, ids=ids, "
1413                     "partition_strategy=partition_strategy, name=name, "
1414                     "max_norm=max_norm)")
1415    _, unused_report, unused_errors, new_text = self._upgrade(text)
1416    self.assertEqual(new_text, expected_text)
1417
1418  def testEmbeddingLookupSparse(self):
1419    text = ("tf.nn.embedding_lookup_sparse(params, sp_ids, sp_weights, "
1420            "partition_strategy, name, combiner, max_norm)")
1421    expected_text = ("tf.nn.embedding_lookup_sparse(params=params, "
1422                     "sp_ids=sp_ids, sp_weights=sp_weights, "
1423                     "partition_strategy=partition_strategy, name=name, "
1424                     "combiner=combiner, max_norm=max_norm)")
1425    _, unused_report, unused_errors, new_text = self._upgrade(text)
1426    self.assertEqual(new_text, expected_text)
1427
1428  def testNnInTopK(self):
1429    text = "tf.nn.in_top_k(predictions, targets, k, name)"
1430    expected_text = ("tf.nn.in_top_k(predictions=predictions, "
1431                     "targets=targets, k=k, name=name)")
1432    _, unused_report, unused_errors, new_text = self._upgrade(text)
1433    self.assertEqual(new_text, expected_text)
1434
1435  def testSpaceToDepth(self):
1436    text = "tf.nn.space_to_depth(input, block_size, name, data_format)"
1437    expected_text = ("tf.nn.space_to_depth(input=input, block_size=block_size, "
1438                     "name=name, data_format=data_format)")
1439    _, unused_report, unused_errors, new_text = self._upgrade(text)
1440    self.assertEqual(new_text, expected_text)
1441
1442  def testPrint(self):
1443    # tf.print() cannot be parsed unless we import print_function
1444    text = """from __future__ import print_function
1445tf.print()
1446tf.print('abc')
1447"""
1448    _, unused_report, unused_errors, new_text = self._upgrade(text)
1449    self.assertEqual(new_text, text)  # Text should stay the same
1450
1451  def testSparseSplit(self):
1452    text = (
1453        "tf.sparse_split(sp_input=sp_input, num_split=num_split, axis=axis, "
1454        "name=name)")
1455    expected_text = (
1456        "tf.sparse.split(sp_input=sp_input, num_split=num_split, axis=axis, "
1457        "name=name)")
1458    _, unused_report, unused_errors, new_text = self._upgrade(text)
1459    self.assertEqual(new_text, expected_text)
1460
1461    text = (
1462        "tf.sparse_split(sp_input=sp_input, num_split=num_split, "
1463        "name=name, split_dim=axis)")
1464    expected_text = (
1465        "tf.sparse.split(sp_input=sp_input, num_split=num_split, "
1466        "name=name, axis=axis)")
1467    _, unused_report, unused_errors, new_text = self._upgrade(text)
1468    self.assertEqual(new_text, expected_text)
1469
1470    text = (
1471        "tf.sparse.split(sp_input=sp_input, num_split=num_split, "
1472        "name=name, split_dim=axis)")
1473    expected_text = (
1474        "tf.sparse.split(sp_input=sp_input, num_split=num_split, "
1475        "name=name, axis=axis)")
1476    _, unused_report, unused_errors, new_text = self._upgrade(text)
1477    self.assertEqual(new_text, expected_text)
1478
1479  def testIterators(self):
1480    for (text, expected) in [
1481        ("(expr + yielding(data)).make_one_shot_iterator()",
1482         "tf.compat.v1.data.make_one_shot_iterator((expr + yielding(data)))"),
1483        ("dataset.make_one_shot_iterator()",
1484         "tf.compat.v1.data.make_one_shot_iterator(dataset)"),
1485        ("dataset.make_one_shot_iterator(shared_name=foo)",
1486         "tf.compat.v1.data.make_one_shot_iterator(dataset, shared_name=foo)"),
1487        ("dataset.make_one_shot_iterator(x, y, z)",
1488         "tf.compat.v1.data.make_one_shot_iterator(dataset, x, y, z)"),
1489        ("dataset.make_initializable_iterator()",
1490         "tf.compat.v1.data.make_initializable_iterator(dataset)"),
1491        ("ds.make_initializable_iterator(shared_name=foo)",
1492         "tf.compat.v1.data.make_initializable_iterator(ds, shared_name=foo)"),
1493        ("dataset.make_initializable_iterator(x, y, z)",
1494         "tf.compat.v1.data.make_initializable_iterator(dataset, x, y, z)"),
1495        ("tf.data.make_one_shot_iterator(dataset)",
1496         "tf.compat.v1.data.make_one_shot_iterator(dataset)"),
1497        ("tf.data.make_one_shot_iterator(dataset, shared_name=foo)",
1498         "tf.compat.v1.data.make_one_shot_iterator(dataset, shared_name=foo)"),
1499        ("tf.data.make_one_shot_iterator(dataset, x, y, z)",
1500         "tf.compat.v1.data.make_one_shot_iterator(dataset, x, y, z)"),
1501        ("tf.data.make_initializable_iterator(dataset)",
1502         "tf.compat.v1.data.make_initializable_iterator(dataset)"),
1503        ("tf.data.make_initializable_iterator(ds, shared_name=foo)",
1504         "tf.compat.v1.data.make_initializable_iterator(ds, shared_name=foo)"),
1505        ("tf.data.make_initializable_iterator(dataset, x, y, z)",
1506         "tf.compat.v1.data.make_initializable_iterator(dataset, x, y, z)"),
1507        ("tf.compat.v1.data.make_one_shot_iterator(dataset)",
1508         "tf.compat.v1.data.make_one_shot_iterator(dataset)"),
1509        ("tf.compat.v1.data.make_one_shot_iterator(dataset, shared_name=foo)",
1510         "tf.compat.v1.data.make_one_shot_iterator(dataset, shared_name=foo)"),
1511        ("tf.compat.v1.data.make_one_shot_iterator(dataset, x, y, z)",
1512         "tf.compat.v1.data.make_one_shot_iterator(dataset, x, y, z)"),
1513        ("tf.compat.v1.data.make_initializable_iterator(dataset)",
1514         "tf.compat.v1.data.make_initializable_iterator(dataset)"),
1515        ("tf.compat.v1.data.make_initializable_iterator(ds, shared_name=foo)",
1516         "tf.compat.v1.data.make_initializable_iterator(ds, shared_name=foo)"),
1517        ("tf.compat.v1.data.make_initializable_iterator(dataset, x, y, z)",
1518         "tf.compat.v1.data.make_initializable_iterator(dataset, x, y, z)")]:
1519      _, unused_report, unused_errors, actual = self._upgrade(text)
1520      self.assertEqual(actual, expected)
1521
1522  def testStructure(self):
1523    for (text, expected) in [
1524        ("tf.data.experimental.DatasetStructure", "tf.data.DatasetSpec"),
1525        ("tf.data.experimental.OptionalStructure", "tf.OptionalSpec"),
1526        ("tf.data.experimental.RaggedTensorStructure", "tf.RaggedTensorSpec"),
1527        ("tf.data.experimental.SparseTensorStructure", "tf.SparseTensorSpec"),
1528        ("tf.data.experimental.Structure", "tf.TypeSpec"),
1529        ("tf.data.experimental.TensorArrayStructure", "tf.TensorArraySpec"),
1530        ("tf.data.experimental.TensorStructure", "tf.TensorSpec"),
1531    ]:
1532      _, unused_report, unused_errors, actual = self._upgrade(text)
1533      self.assertEqual(actual, expected)
1534
1535  def testMapAndBatch(self):
1536    suffix = ".data.experimental.map_and_batch_with_legacy_function(args)"
1537    text = "tf" + suffix
1538    expected = "tf.compat.v1" + suffix
1539    _, unused_report, unused_errors, actual = self._upgrade(text)
1540    self.assertEqual(actual, expected)
1541
1542  def testCast(self):
1543    for (name, dtype) in [("int32", "int32"),
1544                          ("int64", "int64"),
1545                          ("float", "float32"),
1546                          ("double", "float64"),
1547                          ("complex64", "complex64"),
1548                          ("complex128", "complex128"),
1549                          ("bfloat16", "bfloat16")]:
1550      text = "tf.to_%s(x, name='test')" % name
1551      expected_text = "tf.cast(x, name='test', dtype=tf.%s)" % dtype
1552      _, unused_report, unused_errors, new_text = self._upgrade(text)
1553      self.assertEqual(expected_text, new_text)
1554
1555  def testCastPositionalSecondArgument(self):
1556    for (name, dtype) in [("int32", "int32"),
1557                          ("int64", "int64"),
1558                          ("float", "float32"),
1559                          ("double", "float64"),
1560                          ("complex64", "complex64"),
1561                          ("complex128", "complex128"),
1562                          ("bfloat16", "bfloat16")]:
1563      text = "tf.to_%s(x, 'test')" % name
1564      expected_text = "tf.cast(x, name='test', dtype=tf.%s)" % dtype
1565      _, unused_report, unused_errors, new_text = self._upgrade(text)
1566      self.assertEqual(expected_text, new_text)
1567
1568  def testImageResize(self):
1569    for method in ["bilinear", "area", "bicubic", "nearest_neighbor"]:
1570      text = "tf.image.resize_%s(i, s)" % method
1571      expected_text = ("tf.image.resize(i, s, "
1572                       "method=tf.image.ResizeMethod.%s)" % method.upper())
1573      _, unused_report, unused_errors, new_text = self._upgrade(text)
1574      self.assertEqual(expected_text, new_text)
1575
1576  def testImageResizeExtraPositionalArgs(self):
1577    for method in ["bilinear", "area", "bicubic", "nearest_neighbor"]:
1578      text = "tf.image.resize_%s(i, s, a, p)" % method
1579      expected_text = [
1580          "tf.image.resize(i, s, ", "preserve_aspect_ratio=p, ",
1581          "method=tf.image.ResizeMethod.%s)" % method.upper()
1582      ]
1583      _, unused_report, unused_errors, new_text = self._upgrade(text)
1584      for s in expected_text:
1585        self.assertIn(s, new_text)
1586
1587  def testCond(self):
1588    text = "tf.cond(a, b, c, True)"
1589    expected_text = "tf.cond(pred=a, true_fn=b, false_fn=c)"
1590    _, unused_report, errors, new_text = self._upgrade(text)
1591    self.assertEqual(expected_text, new_text)
1592    self.assertIn("tf.cond", errors[0])
1593    self.assertIn("requires manual check", errors[0])
1594
1595  def testParens(self):
1596    text = """
1597def _log_prob(self, x):
1598  return tf.reduce_logsumexp(
1599      (self.mixture_distribution.logits + self.distribution.log_prob(
1600          x[..., tf.newaxis])),
1601          axis=-1)"""
1602    expected_text = """
1603def _log_prob(self, x):
1604  return tf.reduce_logsumexp(
1605      input_tensor=(self.mixture_distribution.logits + self.distribution.log_prob(
1606          x[..., tf.newaxis])),
1607          axis=-1)"""
1608    _, unused_report, unused_errors, new_text = self._upgrade(text)
1609    self.assertEqual(expected_text, new_text)
1610
1611  def testAssertStatements(self):
1612    for name in [
1613        "assert_greater", "assert_equal", "assert_none_equal", "assert_less",
1614        "assert_negative", "assert_positive", "assert_non_negative",
1615        "assert_non_positive", "assert_near", "assert_less",
1616        "assert_less_equal", "assert_greater", "assert_greater_equal",
1617        "assert_scalar"
1618    ]:
1619      text = "tf.%s(a)" % name
1620      expected_text = "tf.compat.v1.%s(a)" % name
1621      _, report, unused_errors, new_text = self._upgrade(text)
1622      self.assertEqual(expected_text, new_text)
1623      self.assertIn("%s has been" % name, report)
1624
1625      text = "tf.debugging.%s(a)" % name
1626      expected_text = "tf.compat.v1.debugging.%s(a)" % name
1627      _, report, unused_errors, new_text = self._upgrade(text)
1628      self.assertEqual(expected_text, new_text)
1629      self.assertIn("%s has been" % name, report)
1630
1631  def testAssertRankStatements(self):
1632    for name in ["assert_rank", "assert_rank_at_least", "assert_rank_in"]:
1633      text = "tf.%s(a)" % name
1634      expected_text = "tf.compat.v1.%s(a)" % name
1635      _, report, unused_errors, new_text = self._upgrade(text)
1636      self.assertEqual(expected_text, new_text)
1637      self.assertIn("%s has been" % name, report)
1638
1639      text = "tf.debugging.%s(a)" % name
1640      expected_text = "tf.compat.v1.debugging.%s(a)" % name
1641      _, report, unused_errors, new_text = self._upgrade(text)
1642      self.assertEqual(expected_text, new_text)
1643      self.assertIn("%s has been" % name, report)
1644
1645  def test_assert_equal_graph_def(self):
1646    text = ("tf.test.assert_equal_graph_def(a, b, checkpoint_v2=x, "
1647            "hash_table_shared_name=y)")
1648    expected = "tf.test.assert_equal_graph_def(actual=a, expected=b)"
1649    _, _, _, new_text = self._upgrade(text)
1650    self.assertEqual(expected, new_text)
1651
1652  def test_is_tensor_upgrade(self):
1653    text = "tf.contrib.framework.is_tensor(x)"
1654    expected = "tf.is_tensor(x)"
1655    _, _, _, new_text = self._upgrade(text)
1656    self.assertEqual(expected, new_text)
1657
1658  def test_is_tensor_direct_import_upgrade(self):
1659    text = "contrib_framework.is_tensor(x)"
1660    expected = "tf.is_tensor(x)"
1661    _, _, _, new_text = self._upgrade(text)
1662    self.assertEqual(expected, new_text)
1663
1664  def test_CriticalSection_upgrade(self):
1665    text = "tf.contrib.framework.CriticalSection(shared_name='blah')"
1666    expected = "tf.CriticalSection(shared_name='blah')"
1667    _, _, _, new_text = self._upgrade(text)
1668    self.assertEqual(expected, new_text)
1669
1670  def test_sample_distorted_bounding_box(self):
1671    # pylint: disable=line-too-long
1672    text = "tf.image.sample_distorted_bounding_box(a, b, c, d, e, f, g, h, i, j)"
1673    expected = "tf.image.sample_distorted_bounding_box(image_size=a, bounding_boxes=b, seed=c, min_object_covered=e, aspect_ratio_range=f, area_range=g, max_attempts=h, use_image_if_no_bounding_boxes=i, name=j)"
1674    # pylint: enable=line-too-long
1675    _, _, _, new_text = self._upgrade(text)
1676    self.assertEqual(expected, new_text)
1677
1678  def test_contrib_initialize(self):
1679    text = "tf.contrib.summary.initialize"
1680    expected = "tf.compat.v1.summary.initialize"
1681    _, _, _, new_text = self._upgrade(text)
1682    self.assertEqual(expected, new_text)
1683
1684  def test_contrib_framework_argsort(self):
1685    text = "tf.contrib.framework.argsort"
1686    expected = "tf.argsort"
1687    # pylint: enable=line-too-long
1688    _, _, _, new_text = self._upgrade(text)
1689    self.assertEqual(expected, new_text)
1690
1691  def test_flags_bare(self):
1692    _, _, errors, _ = self._upgrade("tf.flags")
1693    self.assertIn("tf.flags and tf.app.flags have been removed", errors[0])
1694
1695  def test_flags_flags(self):
1696    _, _, errors, _ = self._upgrade("tf.flags.FLAGS")
1697    self.assertIn("tf.flags and tf.app.flags have been removed", errors[0])
1698
1699  def test_contrib_estimator_head_deprecation(self):
1700    for contrib_alias in ["tf.contrib.", "contrib_"]:
1701      api_symbols = ["binary_classification_head", "logistic_regression_head",
1702                     "multi_class_head", "multi_head", "multi_label_head",
1703                     "poisson_regression_head", "regression_head"]
1704      for symbol in api_symbols:
1705        text = contrib_alias + "estimator." + symbol
1706        _, report, _, _ = self._upgrade(text)
1707        self.assertIn("`tf.contrib.estimator.*_head` has been deprecated",
1708                      report)
1709
1710  def test_contrib_layers_layer_norm_deprecation(self):
1711    for contrib_alias in ["tf.contrib.", "contrib_"]:
1712      _, report, _, _ = self._upgrade(contrib_alias + "layers.layer_norm")
1713      self.assertIn(
1714          "`tf.contrib.layers.layer_norm` has been deprecated", report)
1715
1716  def test_contrib_rnn_deprecation(self):
1717    _, report, _, _ = self._upgrade("tf.contrib.rnn")
1718    self.assertIn("tf.contrib.rnn.* has been deprecated", report)
1719
1720  def test_contrib_cudnn_rnn_deprecation(self):
1721    _, report, _, _ = self._upgrade("tf.contrib.cudnn_rnn")
1722    self.assertIn("tf.contrib.cudnn_rnn.* has been deprecated", report)
1723
1724  def test_max_pool_2d(self):
1725    text = "tf.nn.max_pool(value=4)"
1726    expected_text = "tf.nn.max_pool2d(input=4)"
1727    _, _, _, new_text = self._upgrade(text)
1728    self.assertEqual(expected_text, new_text)
1729
1730  def test_contrib_estimator_early_stopping(self):
1731    for contrib_alias in ["tf.contrib.", "contrib_"]:
1732      api_symbols = [
1733          "make_early_stopping_hook", "stop_if_higher_hook",
1734          "stop_if_lower_hook",
1735          "stop_if_no_decrease_hook", "stop_if_no_increase_hook"
1736      ]
1737      for symbol in api_symbols:
1738        text = contrib_alias + "estimator." + symbol
1739        expected_text = "tf.estimator.experimental." + symbol
1740        _, _, _, new_text = self._upgrade(text)
1741        self.assertEqual(expected_text, new_text)
1742
1743  def test_contrib_rnn_cell(self):
1744    api_symbols = ["RNNCell", "BasicLSTMCell", "BasicRNNCell", "GRUCell",
1745                   "LSTMCell", "MultiRNNCell"]
1746    for symbol in api_symbols:
1747      text = "tf.contrib.rnn." + symbol
1748      expected_text = "tf.compat.v1.nn.rnn_cell." + symbol
1749      _, _, _, new_text = self._upgrade(text)
1750      self.assertEqual(expected_text, new_text)
1751
1752  def test_contrib_rnn_function(self):
1753    api_symbols = ["static_rnn", "static_state_saving_rnn",
1754                   "static_bidirectional_rnn"]
1755    for symbol in api_symbols:
1756      text = "tf.contrib.rnn." + symbol
1757      expected_text = "tf.compat.v1.nn." + symbol
1758      _, _, _, new_text = self._upgrade(text)
1759      self.assertEqual(expected_text, new_text)
1760
1761  def test_contrib_summary_generic(self):
1762    text = "tf.contrib.summary.generic('foo', myval, meta, 'fam', 42)"
1763    expected = ("tf.compat.v2.summary.write(tag='foo', data=myval, "
1764                "metadata=meta, step=42)")
1765    _, _, errors, new_text = self._upgrade(text)
1766    self.assertEqual(expected, new_text)
1767    # Arg errors come in alphabetical order of arguments, not appearance order.
1768    self.assertIn("'family' argument", errors[0])
1769    self.assertIn("'name' argument", errors[1])
1770    self.assertIn("tf.compat.v2.summary.*", errors[2])
1771
1772  def test_contrib_summary_audio(self):
1773    text = "tf.contrib.summary.audio('foo', myval, 44100, 3, 'fam', 42)"
1774    expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, "
1775                "sample_rate=44100, max_outputs=3, step=42)")
1776    _, _, errors, new_text = self._upgrade(text)
1777    self.assertEqual(expected, new_text)
1778    self.assertIn("'family' argument", errors[0])
1779    self.assertIn("tf.compat.v2.summary.*", errors[1])
1780
1781  def test_contrib_summary_histogram(self):
1782    text = "tf.contrib.summary.histogram('foo', myval, 'fam', 42)"
1783    expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, "
1784                "step=42)")
1785    _, _, errors, new_text = self._upgrade(text)
1786    self.assertEqual(expected, new_text)
1787    self.assertIn("'family' argument", errors[0])
1788    self.assertIn("tf.compat.v2.summary.*", errors[1])
1789
1790  def test_contrib_summary_image(self):
1791    text = "tf.contrib.summary.image('foo', myval, red, 3, 'fam', 42)"
1792    expected = ("tf.compat.v2.summary.image(name='foo', data=myval, "
1793                "max_outputs=3, step=42)")
1794    _, _, errors, new_text = self._upgrade(text)
1795    self.assertEqual(expected, new_text)
1796    self.assertIn("'bad_color' argument", errors[0])
1797    self.assertIn("'family' argument", errors[1])
1798    self.assertIn("tf.compat.v2.summary.*", errors[2])
1799
1800  def test_contrib_summary_scalar(self):
1801    text = "tf.contrib.summary.scalar('foo', myval, 'fam', 42)"
1802    expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, "
1803                "step=42)")
1804    _, _, errors, new_text = self._upgrade(text)
1805    self.assertEqual(expected, new_text)
1806    self.assertIn("'family' argument", errors[0])
1807    self.assertIn("tf.compat.v2.summary.*", errors[1])
1808
1809  def test_contrib_summary_generic_nostep(self):
1810    text = "tf.contrib.summary.generic('foo', myval)"
1811    expected = ("tf.compat.v2.summary.write(tag='foo', data=myval, "
1812                "step=tf.compat.v1.train.get_or_create_global_step())")
1813    _, _, errors, new_text = self._upgrade(text)
1814    self.assertEqual(expected, new_text)
1815    self.assertIn("'name' argument", errors[0])
1816    self.assertIn("'step' argument", errors[1])
1817    self.assertIn("tf.compat.v2.summary.*", errors[2])
1818
1819  def test_contrib_summary_audio_nostep(self):
1820    text = "tf.contrib.summary.audio('foo', myval, 44100)"
1821    expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, "
1822                "sample_rate=44100, "
1823                "step=tf.compat.v1.train.get_or_create_global_step())")
1824    _, _, errors, new_text = self._upgrade(text)
1825    self.assertEqual(expected, new_text)
1826    self.assertIn("'step' argument", errors[0])
1827    self.assertIn("tf.compat.v2.summary.*", errors[1])
1828
1829  def test_contrib_summary_histogram_nostep(self):
1830    text = "tf.contrib.summary.histogram('foo', myval)"
1831    expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, "
1832                "step=tf.compat.v1.train.get_or_create_global_step())")
1833    _, _, errors, new_text = self._upgrade(text)
1834    self.assertEqual(expected, new_text)
1835    self.assertIn("'step' argument", errors[0])
1836    self.assertIn("tf.compat.v2.summary.*", errors[1])
1837
1838  def test_contrib_summary_image_nostep(self):
1839    text = "tf.contrib.summary.image('foo', myval)"
1840    expected = ("tf.compat.v2.summary.image(name='foo', data=myval, "
1841                "step=tf.compat.v1.train.get_or_create_global_step())")
1842    _, _, errors, new_text = self._upgrade(text)
1843    self.assertEqual(expected, new_text)
1844    self.assertIn("'step' argument", errors[0])
1845    self.assertIn("tf.compat.v2.summary.*", errors[1])
1846
1847  def test_contrib_summary_scalar_nostep(self):
1848    text = "tf.contrib.summary.scalar('foo', myval)"
1849    expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, "
1850                "step=tf.compat.v1.train.get_or_create_global_step())")
1851    _, _, errors, new_text = self._upgrade(text)
1852    self.assertEqual(expected, new_text)
1853    self.assertIn("'step' argument", errors[0])
1854    self.assertIn("tf.compat.v2.summary.*", errors[1])
1855
1856  def test_contrib_summary_graph(self):
1857    text = "tf.contrib.summary.graph(my_graph)"
1858    _, _, errors, _ = self._upgrade(text)
1859    expected_error = "tf.compat.v2.summary.trace"
1860    self.assertIn(expected_error, errors[0])
1861
1862  def test_contrib_summary_import_event(self):
1863    text = "tf.contrib.summary.import_event(my_event)"
1864    _, _, errors, _ = self._upgrade(text)
1865    expected_error = "tf.compat.v2.summary.experimental.write_raw_pb"
1866    self.assertIn(expected_error, errors[0])
1867
1868  def test_contrib_summary_flush(self):
1869    text = "tf.contrib.summary.flush(writer=foo)"
1870    expected = "tf.compat.v2.summary.flush(writer=foo)"
1871    _, _, _, new_text = self._upgrade(text)
1872    self.assertEqual(expected, new_text)
1873
1874  def test_contrib_summary_create_file_writer(self):
1875    text = ("tf.contrib.summary.create_file_writer('my_logdir', 0, 1000, "
1876            "'.foo', 'shared-name')")
1877    expected = ("tf.compat.v2.summary.create_file_writer(logdir='my_logdir', "
1878                "max_queue=0, flush_millis=1000, filename_suffix='.foo')")
1879    _, _, errors, new_text = self._upgrade(text)
1880    self.assertEqual(expected, new_text)
1881    self.assertIn("'name' argument", errors[0])
1882    self.assertIn("no longer re-uses existing event files", errors[1])
1883
1884  def test_contrib_summary_always_record_summaries(self):
1885    text = "tf.contrib.summary.always_record_summaries()"
1886    expected = "tf.compat.v2.summary.record_if(True)"
1887    _, _, _, new_text = self._upgrade(text)
1888    self.assertEqual(expected, new_text)
1889
1890  def test_contrib_summary_never_record_summaries(self):
1891    text = "tf.contrib.summary.never_record_summaries()"
1892    expected = "tf.compat.v2.summary.record_if(False)"
1893    _, _, _, new_text = self._upgrade(text)
1894    self.assertEqual(expected, new_text)
1895
1896  def test_contrib_summary_record_summaries_every_n_global_steps(self):
1897    text = "tf.contrib.summary.record_summaries_every_n_global_steps(10)"
1898    _, _, errors, _ = self._upgrade(text)
1899    expected_error = "replaced by a call to tf.compat.v2.summary.record_if()"
1900    self.assertIn(expected_error, errors[0])
1901
1902  def test_contrib_summary_all_summary_ops(self):
1903    text = "tf.contrib.summary.all_summary_ops()"
1904    expected = "tf.compat.v1.summary.all_v2_summary_ops()"
1905    _, _, _, new_text = self._upgrade(text)
1906    self.assertEqual(expected, new_text)
1907
1908  def test_contrib_summary_full_example(self):
1909    deindent = lambda n, s: "\n".join(line[n:] for line in s.split("\n"))
1910    text = deindent(4, """
1911    import tensorflow as tf
1912    tf.enable_eager_execution()
1913    writer = tf.contrib.summary.create_file_writer(
1914        "/tmp/migration_test", flush_millis=1000)
1915    with writer.as_default(), tf.contrib.summary.always_record_summaries():
1916      tf.contrib.summary.scalar("loss", 0.42)
1917      tf.contrib.summary.histogram("weights", [1.0, 2.0], step=7)
1918      tf.contrib.summary.flush()
1919    """)
1920    expected = deindent(4, """
1921    import tensorflow as tf
1922    tf.compat.v1.enable_eager_execution()
1923    writer = tf.compat.v2.summary.create_file_writer(
1924        logdir="/tmp/migration_test", flush_millis=1000)
1925    with writer.as_default(), tf.compat.v2.summary.record_if(True):
1926      tf.compat.v2.summary.scalar(name="loss", data=0.42, step=tf.compat.v1.train.get_or_create_global_step())
1927      tf.compat.v2.summary.histogram(name="weights", data=[1.0, 2.0], step=7)
1928      tf.compat.v2.summary.flush()
1929    """)
1930    _, _, _, new_text = self._upgrade(text)
1931    self.assertEqual(expected, new_text)
1932
1933  def test_summary_api_warning(self):
1934    text = "tf.summary.scalar('foo', 42)"
1935    _, report, _, _ = self._upgrade(text)
1936    expected_info = "TF 1.x summary API cannot be automatically migrated"
1937    self.assertIn(expected_info, report)
1938
1939  def test_avg_pool_2d(self):
1940    text = "tf.nn.avg_pool(value=4)"
1941    expected_text = "tf.nn.avg_pool2d(input=4)"
1942    _, _, _, new_text = self._upgrade(text)
1943    self.assertEqual(expected_text, new_text)
1944
1945  def test_saved_model_load(self):
1946    text = "tf.saved_model.load(sess, ['foo_graph'])"
1947    expected = "tf.compat.v1.saved_model.load(sess, ['foo_graph'])"
1948    _, _, _, new_text = self._upgrade(text)
1949    self.assertEqual(expected, new_text)
1950
1951  def test_saved_model_load_v2(self):
1952    text = "tf.saved_model.load_v2('/tmp/blah')"
1953    expected = "tf.compat.v2.saved_model.load('/tmp/blah')"
1954    _, _, _, new_text = self._upgrade(text)
1955    self.assertEqual(expected, new_text)
1956
1957  def test_app_flags(self):
1958    text = "flags = tf.app.flags"
1959    expected = "flags = tf.compat.v1.app.flags"
1960    _, _, _, new_text = self._upgrade(text)
1961    self.assertEqual(expected, new_text)
1962
1963  def test_uniform_unit_scaling_initializer(self):
1964    text = "tf.uniform_unit_scaling_initializer(0.5)"
1965    expected_text = ("tf.compat.v1.keras.initializers.VarianceScaling("
1966                     "scale=0.5, distribution=\"uniform\")")
1967    _, _, _, new_text = self._upgrade(text)
1968    self.assertEqual(expected_text, new_text)
1969
1970    text = "tf.initializers.uniform_unit_scaling(0.5)"
1971    expected_text = ("tf.compat.v1.keras.initializers.VarianceScaling("
1972                     "scale=0.5, distribution=\"uniform\")")
1973    _, _, _, new_text = self._upgrade(text)
1974    self.assertEqual(expected_text, new_text)
1975
1976  def test_name_scope(self):
1977    text = "tf.name_scope(None, default_name, [some, values])"
1978    expected_text = "tf.name_scope(name=default_name)"
1979    _, _, _, new_text = self._upgrade(text)
1980    self.assertEqual(expected_text, new_text)
1981
1982    text = "tf.name_scope(default_name=default_name, values=stuff)"
1983    expected_text = "tf.name_scope(name=default_name)"
1984    _, _, _, new_text = self._upgrade(text)
1985    self.assertEqual(expected_text, new_text)
1986
1987    text = "tf.name_scope(name=n, default_name=d, values=s)"
1988    expected_text = "tf.compat.v1.name_scope(name=n, default_name=d, values=s)"
1989    _, report, _, new_text = self._upgrade(text)
1990    self.assertEqual(expected_text, new_text)
1991    self.assertIn("`name` passed to `name_scope`", report)
1992
1993    text = "tf.name_scope(name=None, values=stuff)"
1994    _, _, errors, _ = self._upgrade(text)
1995    self.assertIn("name_scope call with neither name nor default_name",
1996                  errors[0])
1997
1998  @parameterized.parameters(
1999      # Rename parameter: delimiter -> sep and add .to_sparse()
2000      ["tf.string_split('test', delimiter=' ')",
2001       "tf.strings.split(input='test', sep=' ').to_sparse()"],
2002      # Rename parameter: source -> input
2003      ["tf.strings.split(source='test1')",
2004       "tf.strings.split(input='test1').to_sparse()"],
2005      # Use compat.v1 for skip_empty parameter.
2006      ["tf.string_split('test', ' ', True)",
2007       "tf.compat.v1.string_split(source='test', sep=' ', skip_empty=True)"],
2008      ["tf.string_split('test', ' ', skip_empty=False)",
2009       "tf.strings.split(input='test', sep=' ').to_sparse()"],
2010      # Split behavior for sep=None changed.  (In particular, it now splits on
2011      # all whitespace, not just the space character)
2012      ["tf.string_split(x)",
2013       "tf.compat.v1.string_split(source=x)"],
2014      # Split behavior for sep='' changed:
2015      ["tf.string_split(x, '')",
2016       "tf.strings.bytes_split(input=x).to_sparse()"],
2017      ["tf.string_split(x, sep='')",
2018       "tf.strings.bytes_split(input=x).to_sparse()"],
2019      ["tf.string_split(x, delimiter='')",
2020       "tf.strings.bytes_split(input=x).to_sparse()"],
2021      ["tf.string_split(x, '', result_type='RaggedTensor')",
2022       "tf.strings.bytes_split(input=x)"],
2023      # If sep is a variable, we can't tell if it's empty:
2024      ["tf.string_split(x, sep)",
2025       "tf.compat.v1.string_split(source=x, sep=sep)"],
2026      # If sep is a non-empty string literal, then we don't need compat.v1.
2027      ["tf.string_split(x, 'non-empty-sep')",
2028       "tf.strings.split(input=x, sep='non-empty-sep').to_sparse()"],
2029      # Add to_sparse unless result_type is RaggedTensor:
2030      ["tf.string_split(x, ' ')",
2031       "tf.strings.split(input=x, sep=' ').to_sparse()"],
2032      ["tf.string_split(x, ' ', result_type='SparseTensor')",
2033       "tf.strings.split(input=x, sep=' ').to_sparse()"],
2034      ["tf.string_split(x, ' ', result_type='RaggedTensor')",
2035       "tf.strings.split(input=x, sep=' ')"],
2036      ["tf.string_split(x, ' ', result_type=x)",
2037       "tf.compat.v1.string_split(source=x, sep=' ', result_type=x)"],
2038  )  # pyformat: disable
2039  # TODO(b/129398290)
2040  def DISABLED_test_string_split(self, text, expected_text):
2041    """Tests for transforming from tf.string_split."""
2042    _, _, _, new_text = self._upgrade(text)
2043    self.assertEqual(expected_text, new_text)
2044
2045  @parameterized.parameters(
2046      # Add to_sparse unless result_type is RaggedTensor:
2047      ["tf.strings.split(x, sep)",
2048       "tf.strings.split(x, sep).to_sparse()"],
2049      ["tf.strings.split(x, sep, result_type='SparseTensor')",
2050       "tf.strings.split(x, sep).to_sparse()"],
2051      ["tf.strings.split(x, sep, result_type='RaggedTensor')",
2052       "tf.strings.split(x, sep)"],
2053      ["tf.strings.split(x, sep, result_type=x)",
2054       "tf.compat.v1.strings.split(x, sep, result_type=x)"],
2055  )  # pyformat: disable
2056  def test_strings_split(self, text, expected_text):
2057    """Tests for transforming from tf.strings.split."""
2058    _, _, _, new_text = self._upgrade(text)
2059    self.assertEqual(expected_text, new_text)
2060
2061  def test_sdca_to_raw_ops(self):
2062    text = "tf.train.sdca_fprint(input_tensor)"
2063    expected_text = "tf.raw_ops.SdcaFprint(input=input_tensor)"
2064    _, _, _, new_text = self._upgrade(text)
2065    self.assertEqual(expected_text, new_text)
2066
2067    text = "tf.train.sdca_fprint(input, name=n)"
2068    expected_text = "tf.raw_ops.SdcaFprint(input=input, name=n)"
2069    _, _, _, new_text = self._upgrade(text)
2070    self.assertEqual(expected_text, new_text)
2071
2072    text = "tf.train.sdca_shrink_l1(w, l, ll)"
2073    expected_text = "tf.raw_ops.SdcaShrinkL1(weights=w, l1=l, l2=ll)"
2074    _, _, _, new_text = self._upgrade(text)
2075    self.assertEqual(expected_text, new_text)
2076
2077    text = (
2078        "tf.train.sdca_optimizer(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o)")
2079    expected_text = (
2080        "tf.raw_ops.SdcaOptimizer(sparse_example_indices=a, "
2081        "sparse_feature_indices=b, sparse_feature_values=c, dense_features=d, "
2082        "example_weights=e, example_labels=f, sparse_indices=g, "
2083        "sparse_weights=h, dense_weights=i, example_state_data=j, loss_type=k, "
2084        "l1=l, l2=m, num_loss_partitions=n, num_inner_iterations=o)")
2085    _, _, _, new_text = self._upgrade(text)
2086    self.assertEqual(expected_text, new_text)
2087
2088  def test_contrib_to_addons_move(self):
2089    small_mapping = {
2090        "tf.contrib.layers.poincare_normalize":
2091            "tfa.layers.PoincareNormalize",
2092        "tf.contrib.layers.maxout":
2093            "tfa.layers.Maxout",
2094        "tf.contrib.layers.group_norm":
2095            "tfa.layers.GroupNormalization",
2096        "tf.contrib.layers.instance_norm":
2097            "tfa.layers.InstanceNormalization",
2098    }
2099    for symbol, replacement in small_mapping.items():
2100      text = "{}('stuff', *args, **kwargs)".format(symbol)
2101      _, report, _, _ = self._upgrade(text)
2102      self.assertIn(replacement, report)
2103
2104  def testXlaExperimental(self):
2105    text = "tf.xla.experimental.jit_scope(0)"
2106    expected_text = "tf.xla.experimental.jit_scope(0)"
2107    _, _, _, new_text = self._upgrade(text)
2108    self.assertEqual(new_text, expected_text)
2109
2110    text = "tf.xla.experimental.compile(0)"
2111    expected_text = "tf.xla.experimental.compile(0)"
2112    _, _, _, new_text = self._upgrade(text)
2113    self.assertEqual(new_text, expected_text)
2114
2115  def testNnErosion2d(self):
2116    text = "tf.nn.erosion2d(v, k, s, r, p)"
2117    expected_text = "tf.nn.erosion2d(v, k, s, r, p, data_format='NHWC')"
2118    _, _, _, new_text = self._upgrade(text)
2119    self.assertEqual(new_text, expected_text)
2120
2121  def testNnDilation2d(self):
2122    text = "tf.nn.dilation2d(v, k, s, r, p)"
2123    expected_text = "tf.nn.dilation2d(v, k, s, r, p, data_format='NHWC')"
2124    _, _, _, new_text = self._upgrade(text)
2125    self.assertEqual(new_text, expected_text)
2126
2127  def testPywrapTensorflowWarning(self):
2128    text = "tf.pywrap_tensorflow.foo()"
2129    expected = "tf.pywrap_tensorflow.foo()"
2130    _, _, errors, new_text = self._upgrade(text)
2131    self.assertEqual(expected, new_text)
2132    self.assertIn("`tf.pywrap_tensorflow` will not be distributed", errors[0])
2133
2134  def testKerasSaveModelFormat(self):
2135    text = "tf.keras.models.save_model(model, path)"
2136    expected_text = "tf.keras.models.save_model(model, path, save_format='h5')"
2137    _, report, _, new_text = self._upgrade(text)
2138    self.assertEqual(new_text, expected_text)
2139    self.assertNotIn(
2140        "saves to the Tensorflow SavedModel format by default", report)
2141
2142    _, report, _, _ = self._upgrade("model.save(path)")
2143    self.assertIn(
2144        "saves to the Tensorflow SavedModel format by default", report)
2145
2146  def test_distribute_strategy(self):
2147    text = "tf.contrib.distribute.CrossDeviceOps()"
2148    expected = "tf.distribute.CrossDeviceOps()"
2149    _, _, _, new_text = self._upgrade(text)
2150    self.assertEqual(expected, new_text)
2151
2152    text = "tf.contrib.distribute.MirroredStrategy"
2153    expected = "tf.contrib.distribute.MirroredStrategy"
2154    _, _, errors, new_text = self._upgrade(text)
2155    self.assertEqual(expected, new_text)
2156    self.assertIn("migrated to tf.distribute.MirroredStrategy", errors[0])
2157
2158    text = "tf.distribute.MirroredStrategy"
2159    expected = "tf.distribute.MirroredStrategy"
2160    _, report, _, new_text = self._upgrade(text)
2161    self.assertEqual(expected, new_text)
2162    self.assertIn("tf.distribute.MirroredStrategy API has changed", report)
2163    self.assertIn("make_dataset_iterator->experimental_distribute_dataset",
2164                  report)
2165
2166    text = "tf.contrib.distribute.TPUStrategy"
2167    expected = "tf.contrib.distribute.TPUStrategy"
2168    _, _, errors, new_text = self._upgrade(text)
2169    self.assertEqual(expected, new_text)
2170    self.assertIn("migrated to tf.distribute.TPUStrategy",
2171                  errors[0])
2172
2173    text = "tf.contrib.distribute.foo"
2174    expected = "tf.contrib.distribute.foo"
2175    _, report, _, new_text = self._upgrade(text)
2176    self.assertEqual(expected, new_text)
2177    self.assertIn("tf.contrib.distribute.* have been migrated", report)
2178
2179  def test_decode_raw(self):
2180    text = "tf.io.decode_raw(bytes=[1,2,3], output_dtype=tf.int32)"
2181    expected_text = (
2182        "tf.io.decode_raw(input_bytes=[1,2,3], output_dtype=tf.int32)")
2183    _, _, _, new_text = self._upgrade(text)
2184    self.assertEqual(expected_text, new_text)
2185
2186  def testRecomputeGrad(self):
2187    text = "tf.contrib.layers.recompute_grad()"
2188    expected = "tf.recompute_grad()"
2189    _, _, _, new_text = self._upgrade(text)
2190    self.assertEqual(expected, new_text)
2191
2192  def test_load_variable(self):
2193    text = "tf.contrib.framework.load_variable('a')"
2194    expected_text = (
2195        "tf.train.load_variable('a')")
2196    _, _, _, new_text = self._upgrade(text)
2197    self.assertEqual(expected_text, new_text)
2198    text = "tf.contrib.framework.load_variable(checkpoint_dir='a')"
2199    expected_text = (
2200        "tf.train.load_variable(ckpt_dir_or_file='a')")
2201    _, _, _, new_text = self._upgrade(text)
2202    self.assertEqual(expected_text, new_text)
2203
2204  def test_import_rename_analysis(self):
2205    old_symbol = "tf.conj(a)"
2206    new_symbol = "tf.math.conj(a)"
2207
2208    import_header = "import tensorflow as tf\n"
2209    text = import_header + old_symbol
2210    expected_text = "import tensorflow.compat.v2 as tf\n" + new_symbol
2211    _, unused_report, unused_errors, new_text = self._upgrade(
2212        text, import_rename=True)
2213    self.assertEqual(new_text, expected_text)
2214
2215    import_header = "import tensorflow as tf, other_import as y\n"
2216    text = import_header + old_symbol
2217    new_import_header = "import tensorflow.compat.v2 as tf, other_import as y\n"
2218    expected_text = new_import_header + new_symbol
2219    _, unused_report, unused_errors, new_text = self._upgrade(
2220        text, import_rename=True)
2221    self.assertEqual(new_text, expected_text)
2222
2223    import_header = ("import tensorflow as tf\n"
2224                     "import tensorflow.compat.v1 as tf_v1\n"
2225                     "import tensorflow.compat.v2 as tf_v2\n")
2226    text = import_header + old_symbol
2227    expected_header = ("import tensorflow.compat.v2 as tf\n"
2228                       "import tensorflow.compat.v1 as tf_v1\n"
2229                       "import tensorflow.compat.v2 as tf_v2\n")
2230    expected_text = expected_header + new_symbol
2231    _, _, _, new_text = self._upgrade(text, import_rename=True)
2232    self.assertEqual(new_text, expected_text)
2233
2234    import_header = ("import tensorflow.compat.v1 as tf\n"
2235                     "import tensorflow.compat.v1 as tf_v1\n"
2236                     "import tensorflow.compat.v2 as tf_v2\n")
2237    text = import_header + old_symbol
2238    expected_header = ("import tensorflow.compat.v2 as tf\n"
2239                       "import tensorflow.compat.v1 as tf_v1\n"
2240                       "import tensorflow.compat.v2 as tf_v2\n")
2241    expected_text = expected_header + new_symbol
2242    _, _, _, new_text = self._upgrade(
2243        text, import_rename=True, upgrade_compat_v1_import=True)
2244    self.assertEqual(new_text, expected_text)
2245
2246    import_header = ("import tensorflow.compat.v1 as tf\n"
2247                     "import tensorflow.compat.v1 as tf_v1\n"
2248                     "import tensorflow.compat.v2 as tf_v2\n")
2249    text = import_header + old_symbol
2250    expected_header = ("import tensorflow as tf\n"
2251                       "import tensorflow.compat.v1 as tf_v1\n"
2252                       "import tensorflow.compat.v2 as tf_v2\n")
2253    expected_text = expected_header + new_symbol
2254    _, _, _, new_text = self._upgrade(
2255        text, import_rename=False, upgrade_compat_v1_import=True)
2256    self.assertEqual(new_text, expected_text)
2257
2258    import_header = "from tensorflow import foo\n"
2259    text = import_header + old_symbol
2260    expected_text = "from tensorflow.compat.v2 import foo\n" + new_symbol
2261    _, unused_report, unused_errors, new_text = self._upgrade(
2262        text, import_rename=True)
2263    self.assertEqual(new_text, expected_text)
2264
2265    import_header = "from tensorflow import *\n"
2266    text = import_header + old_symbol
2267    expected_text = "from tensorflow.compat.v2 import *\n" + new_symbol
2268    _, unused_report, unused_errors, new_text = self._upgrade(
2269        text, import_rename=True)
2270    self.assertEqual(new_text, expected_text)
2271
2272    import_header = "from tensorflow.foo import bar\n"
2273    text = import_header + old_symbol
2274    expected_text = "from tensorflow.compat.v2.foo import bar\n" + new_symbol
2275    _, unused_report, unused_errors, new_text = self._upgrade(
2276        text, import_rename=True)
2277    self.assertEqual(new_text, expected_text)
2278
2279    import_header = ("from tensorflow import foo as tf\n"
2280                     "from tensorflow.compat import v1 as tf_v1\n"
2281                     "from tensorflow.compat import v2 as tf_v2\n")
2282    text = import_header + old_symbol
2283    expected_header = ("from tensorflow.compat.v2 import foo as tf\n"
2284                       "from tensorflow.compat import v1 as tf_v1\n"
2285                       "from tensorflow.compat import v2 as tf_v2\n")
2286    expected_text = expected_header + new_symbol
2287    _, _, _, new_text = self._upgrade(text, import_rename=True)
2288    self.assertEqual(new_text, expected_text)
2289
2290  def test_import_analysis(self):
2291    old_symbol = "tf.conj(a)"
2292    new_symbol = "tf.math.conj(a)"
2293
2294    # We upgrade the base un-versioned tensorflow aliased as tf
2295    import_header = "import tensorflow as tf\n"
2296    text = import_header + old_symbol
2297    expected_text = import_header + new_symbol
2298    _, unused_report, unused_errors, new_text = self._upgrade(text)
2299    self.assertEqual(new_text, expected_text)
2300
2301    import_header = ("import tensorflow as tf\n"
2302                     "import tensorflow.compat.v1 as tf_v1\n"
2303                     "import tensorflow.compat.v2 as tf_v2\n")
2304    text = import_header + old_symbol
2305    expected_text = import_header + new_symbol
2306    _, _, _, new_text = self._upgrade(text)
2307    self.assertEqual(new_text, expected_text)
2308
2309    # We don't handle unaliased tensorflow imports currently,
2310    # So the upgrade script show log errors
2311    import_header = "import tensorflow\n"
2312    text = import_header + old_symbol
2313    expected_text = import_header + old_symbol
2314    _, _, errors, new_text = self._upgrade(text)
2315    self.assertEqual(new_text, expected_text)
2316    self.assertIn("unaliased `import tensorflow`", "\n".join(errors))
2317
2318    # Upgrading explicitly-versioned tf code is unsafe, but we don't
2319    # need to throw errors when we detect explicitly-versioned tf.
2320    import_header = "import tensorflow.compat.v1 as tf\n"
2321    text = import_header + old_symbol
2322    expected_text = import_header + old_symbol
2323    _, report, errors, new_text = self._upgrade(text)
2324    self.assertEqual(new_text, expected_text)
2325    self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
2326                  report)
2327    self.assertEmpty(errors)
2328
2329    import_header = "from tensorflow.compat import v1 as tf\n"
2330    text = import_header + old_symbol
2331    expected_text = import_header + old_symbol
2332    _, report, errors, new_text = self._upgrade(text)
2333    self.assertEqual(new_text, expected_text)
2334    self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
2335                  report)
2336    self.assertEmpty(errors)
2337
2338    import_header = "from tensorflow.compat import v1 as tf, v2 as tf2\n"
2339    text = import_header + old_symbol
2340    expected_text = import_header + old_symbol
2341    _, report, errors, new_text = self._upgrade(text)
2342    self.assertEqual(new_text, expected_text)
2343    self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
2344                  report)
2345    self.assertEmpty(errors)
2346
2347    import_header = "import tensorflow.compat.v2 as tf\n"
2348    text = import_header + old_symbol
2349    expected_text = import_header + old_symbol
2350    _, report, errors, new_text = self._upgrade(text)
2351    self.assertEqual(new_text, expected_text)
2352    self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
2353                  report)
2354    self.assertEmpty(errors)
2355
2356    import_header = "from tensorflow.compat import v1 as tf1, v2 as tf\n"
2357    text = import_header + old_symbol
2358    expected_text = import_header + old_symbol
2359    _, report, errors, new_text = self._upgrade(text)
2360    self.assertEqual(new_text, expected_text)
2361    self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
2362                  report)
2363    self.assertEmpty(errors)
2364
2365  def test_api_spec_reset_between_files(self):
2366    for old_symbol, new_symbol in [
2367        ("tf.conj(a)", "tf.math.conj(a)"),
2368        ("tf.to_int32(x)", "tf.cast(x, dtype=tf.int32)")]:
2369
2370      ## Test that the api spec is reset in between files:
2371      import_header = "import tensorflow.compat.v2 as tf\n"
2372      text_a = import_header + old_symbol
2373      expected_text_a = import_header + old_symbol
2374      text_b = old_symbol
2375      expected_text_b = new_symbol
2376      results = self._upgrade_multiple([text_a, text_b])
2377      result_a, result_b = results[0], results[1]
2378      self.assertEqual(result_a[3], expected_text_a)
2379      self.assertEqual(result_b[3], expected_text_b)
2380
2381  def test_model_to_estimator_checkpoint_warning(self):
2382    text = "tf.keras.estimator.model_to_estimator(model)"
2383    _, report, _, _ = self._upgrade(text)
2384    expected_info = "will save object-based checkpoints"
2385    self.assertIn(expected_info, report)
2386
2387  def test_keras_experimental_export_warning(self):
2388    text = "tf.keras.experimental.export_saved_model"
2389    _, report, _, _ = self._upgrade(text)
2390    expected_info = "Please use model.save"
2391    self.assertIn(expected_info, report)
2392
2393
2394class TestUpgradeFiles(test_util.TensorFlowTestCase):
2395
2396  def testInplace(self):
2397    """Check to make sure we don't have a file system race."""
2398    temp_file = tempfile.NamedTemporaryFile("w", delete=False)
2399    original = "tf.conj(a)\n"
2400    upgraded = "tf.math.conj(a)\n"
2401    temp_file.write(original)
2402    temp_file.close()
2403    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
2404    upgrader.process_file(temp_file.name, temp_file.name)
2405    self.assertAllEqual(open(temp_file.name).read(), upgraded)
2406    os.unlink(temp_file.name)
2407
2408  def testInplaceNoOutputChangeOnErrorHandling(self):
2409    """In place file should not be modified when parsing error is handled."""
2410    temp_file = tempfile.NamedTemporaryFile("w", delete=False)
2411    original = "print 'a' \n"
2412    upgraded = "print 'a' \n"
2413    temp_file.write(original)
2414    temp_file.close()
2415    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
2416    upgrader.process_file(
2417        temp_file.name, temp_file.name, no_change_to_outfile_on_error=True)
2418    self.assertAllEqual(open(temp_file.name).read(), upgraded)
2419    os.unlink(temp_file.name)
2420
2421  def testInplaceEmptyOutputOnError(self):
2422    """In place file becomes empty when parsing error is not handled."""
2423    temp_file = tempfile.NamedTemporaryFile("w", delete=False)
2424    original = "print 'a' \n"
2425    upgraded = ""
2426    temp_file.write(original)
2427    temp_file.close()
2428    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
2429    upgrader.process_file(temp_file.name, temp_file.name)
2430    self.assertAllEqual(open(temp_file.name).read(), upgraded)
2431    os.unlink(temp_file.name)
2432
2433
2434if __name__ == "__main__":
2435  test_lib.main()
2436