• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ast_edits which is used in tf upgraders.
16
17All of the tests assume that we want to change from an API containing
18
19    import foo as f
20
21    def f(a, b, kw1, kw2): ...
22    def g(a, b, kw1, c, kw1_alias): ...
23    def g2(a, b, kw1, c, d, kw1_alias): ...
24    def h(a, kw1, kw2, kw1_alias, kw2_alias): ...
25
26and the changes to the API consist of renaming, reordering, and/or removing
27arguments. Thus, we want to be able to generate changes to produce each of the
28following new APIs:
29
30    import bar as f
31
32    def f(a, b, kw1, kw3): ...
33    def f(a, b, kw2, kw1): ...
34    def f(a, b, kw3, kw1): ...
35    def g(a, b, kw1, c): ...
36    def g(a, b, c, kw1): ...
37    def g2(a, b, kw1, c, d): ...
38    def g2(a, b, c, d, kw1): ...
39    def h(a, kw1, kw2): ...
40
41"""
42
43import ast
44import io
45import os
46
47from tensorflow.python.framework import test_util
48from tensorflow.python.platform import test as test_lib
49from tensorflow.tools.compatibility import ast_edits
50
51
52class ModuleDeprecationSpec(ast_edits.NoUpdateSpec):
53  """A specification which deprecates 'a.b'."""
54
55  def __init__(self):
56    ast_edits.NoUpdateSpec.__init__(self)
57    self.module_deprecations.update({"a.b": (ast_edits.ERROR, "a.b is evil.")})
58
59
60class RenameKeywordSpec(ast_edits.NoUpdateSpec):
61  """A specification where kw2 gets renamed to kw3.
62
63  The new API is
64
65    def f(a, b, kw1, kw3): ...
66
67  """
68
69  def __init__(self):
70    ast_edits.NoUpdateSpec.__init__(self)
71    self.update_renames()
72
73  def update_renames(self):
74    self.function_keyword_renames["f"] = {"kw2": "kw3"}
75
76
77class ReorderKeywordSpec(ast_edits.NoUpdateSpec):
78  """A specification where kw2 gets moved in front of kw1.
79
80  The new API is
81
82    def f(a, b, kw2, kw1): ...
83
84  """
85
86  def __init__(self):
87    ast_edits.NoUpdateSpec.__init__(self)
88    self.update_reorders()
89
90  def update_reorders(self):
91    # Note that these should be in the old order.
92    self.function_reorders["f"] = ["a", "b", "kw1", "kw2"]
93
94
95class ReorderAndRenameKeywordSpec(ReorderKeywordSpec, RenameKeywordSpec):
96  """A specification where kw2 gets moved in front of kw1 and is changed to kw3.
97
98  The new API is
99
100    def f(a, b, kw3, kw1): ...
101
102  """
103
104  def __init__(self):
105    ReorderKeywordSpec.__init__(self)
106    RenameKeywordSpec.__init__(self)
107    self.update_renames()
108    self.update_reorders()
109
110
111class RemoveDeprecatedAliasKeyword(ast_edits.NoUpdateSpec):
112  """A specification where kw1_alias is removed in g.
113
114  The new API is
115
116    def g(a, b, kw1, c): ...
117    def g2(a, b, kw1, c, d): ...
118
119  """
120
121  def __init__(self):
122    ast_edits.NoUpdateSpec.__init__(self)
123    self.function_keyword_renames["g"] = {"kw1_alias": "kw1"}
124    self.function_keyword_renames["g2"] = {"kw1_alias": "kw1"}
125
126
127class RemoveDeprecatedAliasAndReorderRest(RemoveDeprecatedAliasKeyword):
128  """A specification where kw1_alias is removed in g.
129
130  The new API is
131
132    def g(a, b, c, kw1): ...
133    def g2(a, b, c, d, kw1): ...
134
135  """
136
137  def __init__(self):
138    RemoveDeprecatedAliasKeyword.__init__(self)
139    # Note that these should be in the old order.
140    self.function_reorders["g"] = ["a", "b", "kw1", "c"]
141    self.function_reorders["g2"] = ["a", "b", "kw1", "c", "d"]
142
143
144class RemoveMultipleKeywordArguments(ast_edits.NoUpdateSpec):
145  """A specification where both keyword aliases are removed from h.
146
147  The new API is
148
149    def h(a, kw1, kw2): ...
150
151  """
152
153  def __init__(self):
154    ast_edits.NoUpdateSpec.__init__(self)
155    self.function_keyword_renames["h"] = {
156        "kw1_alias": "kw1",
157        "kw2_alias": "kw2",
158    }
159
160
161class RenameImports(ast_edits.NoUpdateSpec):
162  """Specification for renaming imports."""
163
164  def __init__(self):
165    ast_edits.NoUpdateSpec.__init__(self)
166    self.import_renames = {
167        "foo": ast_edits.ImportRename(
168            "bar",
169            excluded_prefixes=["foo.baz"])
170    }
171
172
173class TestAstEdits(test_util.TensorFlowTestCase):
174
175  def _upgrade(self, spec, old_file_text):
176    in_file = io.StringIO(old_file_text)
177    out_file = io.StringIO()
178    upgrader = ast_edits.ASTCodeUpgrader(spec)
179    count, report, errors = (
180        upgrader.process_opened_file("test.py", in_file,
181                                     "test_out.py", out_file))
182    return (count, report, errors), out_file.getvalue()
183
184  def testModuleDeprecation(self):
185    text = "a.b.c(a.b.x)"
186    (_, _, errors), new_text = self._upgrade(ModuleDeprecationSpec(), text)
187    self.assertEqual(text, new_text)
188    self.assertIn("Using member a.b.c", errors[0])
189    self.assertIn("1:0", errors[0])
190    self.assertIn("Using member a.b.c", errors[0])
191    self.assertIn("1:6", errors[1])
192
193  def testNoTransformIfNothingIsSupplied(self):
194    text = "f(a, b, kw1=c, kw2=d)\n"
195    _, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
196    self.assertEqual(new_text, text)
197
198    text = "f(a, b, c, d)\n"
199    _, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
200    self.assertEqual(new_text, text)
201
202  def testKeywordRename(self):
203    """Test that we get the expected result if renaming kw2 to kw3."""
204    text = "f(a, b, kw1=c, kw2=d)\n"
205    expected = "f(a, b, kw1=c, kw3=d)\n"
206    (_, report, _), new_text = self._upgrade(RenameKeywordSpec(), text)
207    self.assertEqual(new_text, expected)
208    self.assertNotIn("Manual check required", report)
209
210    # No keywords specified, no reordering, so we should get input as output
211    text = "f(a, b, c, d)\n"
212    (_, report, _), new_text = self._upgrade(RenameKeywordSpec(), text)
213    self.assertEqual(new_text, text)
214    self.assertNotIn("Manual check required", report)
215
216    # Positional *args passed in that we cannot inspect, should warn
217    text = "f(a, *args)\n"
218    (_, report, _), _ = self._upgrade(RenameKeywordSpec(), text)
219    self.assertNotIn("Manual check required", report)
220
221    # **kwargs passed in that we cannot inspect, should warn
222    text = "f(a, b, kw1=c, **kwargs)\n"
223    (_, report, _), _ = self._upgrade(RenameKeywordSpec(), text)
224    self.assertIn("Manual check required", report)
225
226  def testKeywordReorderWithParens(self):
227    """Test that we get the expected result if there are parens around args."""
228    text = "f((a), ( ( b ) ))\n"
229    acceptable_outputs = [
230        # No change is a valid output
231        text,
232        # Also cases where all arguments are fully specified are allowed
233        "f(a=(a), b=( ( b ) ))\n",
234        # Making the parens canonical is ok
235        "f(a=(a), b=((b)))\n",
236    ]
237    _, new_text = self._upgrade(ReorderKeywordSpec(), text)
238    self.assertIn(new_text, acceptable_outputs)
239
240  def testKeywordReorder(self):
241    """Test that we get the expected result if kw2 is now before kw1."""
242    text = "f(a, b, kw1=c, kw2=d)\n"
243    acceptable_outputs = [
244        # No change is a valid output
245        text,
246        # Just reordering the kw.. args is also ok
247        "f(a, b, kw2=d, kw1=c)\n",
248        # Also cases where all arguments are fully specified are allowed
249        "f(a=a, b=b, kw1=c, kw2=d)\n",
250        "f(a=a, b=b, kw2=d, kw1=c)\n",
251    ]
252    (_, report, _), new_text = self._upgrade(ReorderKeywordSpec(), text)
253    self.assertIn(new_text, acceptable_outputs)
254    self.assertNotIn("Manual check required", report)
255
256    # Keywords are reordered, so we should reorder arguments too
257    text = "f(a, b, c, d)\n"
258    acceptable_outputs = [
259        "f(a, b, d, c)\n",
260        "f(a=a, b=b, kw1=c, kw2=d)\n",
261        "f(a=a, b=b, kw2=d, kw1=c)\n",
262    ]
263    (_, report, _), new_text = self._upgrade(ReorderKeywordSpec(), text)
264    self.assertIn(new_text, acceptable_outputs)
265    self.assertNotIn("Manual check required", report)
266
267    # Positional *args passed in that we cannot inspect, should warn
268    text = "f(a, b, *args)\n"
269    (_, report, _), _ = self._upgrade(ReorderKeywordSpec(), text)
270    self.assertIn("Manual check required", report)
271
272    # **kwargs passed in that we cannot inspect, should warn
273    text = "f(a, b, kw1=c, **kwargs)\n"
274    (_, report, _), _ = self._upgrade(ReorderKeywordSpec(), text)
275    self.assertNotIn("Manual check required", report)
276
277  def testKeywordReorderAndRename(self):
278    """Test that we get the expected result if kw2 is renamed and moved."""
279    text = "f(a, b, kw1=c, kw2=d)\n"
280    acceptable_outputs = [
281        "f(a, b, kw3=d, kw1=c)\n",
282        "f(a=a, b=b, kw1=c, kw3=d)\n",
283        "f(a=a, b=b, kw3=d, kw1=c)\n",
284    ]
285    (_, report, _), new_text = self._upgrade(
286        ReorderAndRenameKeywordSpec(), text)
287    self.assertIn(new_text, acceptable_outputs)
288    self.assertNotIn("Manual check required", report)
289
290    # Keywords are reordered, so we should reorder arguments too
291    text = "f(a, b, c, d)\n"
292    acceptable_outputs = [
293        "f(a, b, d, c)\n",
294        "f(a=a, b=b, kw1=c, kw3=d)\n",
295        "f(a=a, b=b, kw3=d, kw1=c)\n",
296    ]
297    (_, report, _), new_text = self._upgrade(
298        ReorderAndRenameKeywordSpec(), text)
299    self.assertIn(new_text, acceptable_outputs)
300    self.assertNotIn("Manual check required", report)
301
302    # Positional *args passed in that we cannot inspect, should warn
303    text = "f(a, *args, kw1=c)\n"
304    (_, report, _), _ = self._upgrade(ReorderAndRenameKeywordSpec(), text)
305    self.assertIn("Manual check required", report)
306
307    # **kwargs passed in that we cannot inspect, should warn
308    text = "f(a, b, kw1=c, **kwargs)\n"
309    (_, report, _), _ = self._upgrade(ReorderAndRenameKeywordSpec(), text)
310    self.assertIn("Manual check required", report)
311
312  def testRemoveDeprecatedKeywordAlias(self):
313    """Test that we get the expected result if a keyword alias is removed."""
314    text = "g(a, b, kw1=x, c=c)\n"
315    acceptable_outputs = [
316        # Not using deprecated alias, so original is ok
317        text,
318        "g(a=a, b=b, kw1=x, c=c)\n",
319    ]
320    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
321    self.assertIn(new_text, acceptable_outputs)
322
323    # No keyword used, should be no change
324    text = "g(a, b, x, c)\n"
325    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
326    self.assertEqual(new_text, text)
327
328    # If we used the alias, it should get renamed
329    text = "g(a, b, kw1_alias=x, c=c)\n"
330    acceptable_outputs = [
331        "g(a, b, kw1=x, c=c)\n",
332        "g(a, b, c=c, kw1=x)\n",
333        "g(a=a, b=b, kw1=x, c=c)\n",
334        "g(a=a, b=b, c=c, kw1=x)\n",
335    ]
336    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
337    self.assertIn(new_text, acceptable_outputs)
338
339    # It should get renamed even if it's last
340    text = "g(a, b, c=c, kw1_alias=x)\n"
341    acceptable_outputs = [
342        "g(a, b, kw1=x, c=c)\n",
343        "g(a, b, c=c, kw1=x)\n",
344        "g(a=a, b=b, kw1=x, c=c)\n",
345        "g(a=a, b=b, c=c, kw1=x)\n",
346    ]
347    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
348    self.assertIn(new_text, acceptable_outputs)
349
350  def testRemoveDeprecatedKeywordAndReorder(self):
351    """Test for when a keyword alias is removed and args are reordered."""
352    text = "g(a, b, kw1=x, c=c)\n"
353    acceptable_outputs = [
354        "g(a, b, c=c, kw1=x)\n",
355        "g(a=a, b=b, kw1=x, c=c)\n",
356    ]
357    _, new_text = self._upgrade(RemoveDeprecatedAliasAndReorderRest(), text)
358    self.assertIn(new_text, acceptable_outputs)
359
360    # Keywords are reordered, so we should reorder arguments too
361    text = "g(a, b, x, c)\n"
362    # Don't accept an output which doesn't reorder c and d
363    acceptable_outputs = [
364        "g(a, b, c, x)\n",
365        "g(a=a, b=b, kw1=x, c=c)\n",
366    ]
367    _, new_text = self._upgrade(RemoveDeprecatedAliasAndReorderRest(), text)
368    self.assertIn(new_text, acceptable_outputs)
369
370    # If we used the alias, it should get renamed
371    text = "g(a, b, kw1_alias=x, c=c)\n"
372    acceptable_outputs = [
373        "g(a, b, kw1=x, c=c)\n",
374        "g(a, b, c=c, kw1=x)\n",
375        "g(a=a, b=b, kw1=x, c=c)\n",
376        "g(a=a, b=b, c=c, kw1=x)\n",
377    ]
378    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
379    self.assertIn(new_text, acceptable_outputs)
380
381    # It should get renamed and reordered even if it's last
382    text = "g(a, b, c=c, kw1_alias=x)\n"
383    acceptable_outputs = [
384        "g(a, b, kw1=x, c=c)\n",
385        "g(a, b, c=c, kw1=x)\n",
386        "g(a=a, b=b, kw1=x, c=c)\n",
387        "g(a=a, b=b, c=c, kw1=x)\n",
388    ]
389    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
390    self.assertIn(new_text, acceptable_outputs)
391
392  def testRemoveDeprecatedKeywordAndReorder2(self):
393    """Same as testRemoveDeprecatedKeywordAndReorder but on g2 (more args)."""
394    text = "g2(a, b, kw1=x, c=c, d=d)\n"
395    acceptable_outputs = [
396        "g2(a, b, c=c, d=d, kw1=x)\n",
397        "g2(a=a, b=b, kw1=x, c=c, d=d)\n",
398    ]
399    _, new_text = self._upgrade(RemoveDeprecatedAliasAndReorderRest(), text)
400    self.assertIn(new_text, acceptable_outputs)
401
402    # Keywords are reordered, so we should reorder arguments too
403    text = "g2(a, b, x, c, d)\n"
404    # Don't accept an output which doesn't reorder c and d
405    acceptable_outputs = [
406        "g2(a, b, c, d, x)\n",
407        "g2(a=a, b=b, kw1=x, c=c, d=d)\n",
408    ]
409    _, new_text = self._upgrade(RemoveDeprecatedAliasAndReorderRest(), text)
410    self.assertIn(new_text, acceptable_outputs)
411
412    # If we used the alias, it should get renamed
413    text = "g2(a, b, kw1_alias=x, c=c, d=d)\n"
414    acceptable_outputs = [
415        "g2(a, b, kw1=x, c=c, d=d)\n",
416        "g2(a, b, c=c, d=d, kw1=x)\n",
417        "g2(a=a, b=b, kw1=x, c=c, d=d)\n",
418        "g2(a=a, b=b, c=c, d=d, kw1=x)\n",
419    ]
420    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
421    self.assertIn(new_text, acceptable_outputs)
422
423    # It should get renamed and reordered even if it's not in order
424    text = "g2(a, b, d=d, c=c, kw1_alias=x)\n"
425    acceptable_outputs = [
426        "g2(a, b, kw1=x, c=c, d=d)\n",
427        "g2(a, b, c=c, d=d, kw1=x)\n",
428        "g2(a, b, d=d, c=c, kw1=x)\n",
429        "g2(a=a, b=b, kw1=x, c=c, d=d)\n",
430        "g2(a=a, b=b, c=c, d=d, kw1=x)\n",
431        "g2(a=a, b=b, d=d, c=c, kw1=x)\n",
432    ]
433    _, new_text = self._upgrade(RemoveDeprecatedAliasKeyword(), text)
434    self.assertIn(new_text, acceptable_outputs)
435
436  def testRemoveMultipleKeywords(self):
437    """Remove multiple keywords at once."""
438    # Not using deprecated keywords -> no rename
439    text = "h(a, kw1=x, kw2=y)\n"
440    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
441    self.assertEqual(new_text, text)
442
443    # Using positional arguments (in proper order) -> no change
444    text = "h(a, x, y)\n"
445    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
446    self.assertEqual(new_text, text)
447
448    # Use only the old names, in order
449    text = "h(a, kw1_alias=x, kw2_alias=y)\n"
450    acceptable_outputs = [
451        "h(a, x, y)\n",
452        "h(a, kw1=x, kw2=y)\n",
453        "h(a=a, kw1=x, kw2=y)\n",
454        "h(a, kw2=y, kw1=x)\n",
455        "h(a=a, kw2=y, kw1=x)\n",
456    ]
457    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
458    self.assertIn(new_text, acceptable_outputs)
459
460    # Use only the old names, in reverse order, should give one of same outputs
461    text = "h(a, kw2_alias=y, kw1_alias=x)\n"
462    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
463    self.assertIn(new_text, acceptable_outputs)
464
465    # Mix old and new names
466    text = "h(a, kw1=x, kw2_alias=y)\n"
467    _, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
468    self.assertIn(new_text, acceptable_outputs)
469
470  def testUnrestrictedFunctionWarnings(self):
471    class FooWarningSpec(ast_edits.NoUpdateSpec):
472      """Usages of function attribute foo() prints out a warning."""
473
474      def __init__(self):
475        ast_edits.NoUpdateSpec.__init__(self)
476        self.function_warnings = {"*.foo": (ast_edits.WARNING, "not good")}
477
478    texts = ["object.foo()", "get_object().foo()",
479             "get_object().foo()", "object.foo().bar()"]
480    for text in texts:
481      (_, report, _), _ = self._upgrade(FooWarningSpec(), text)
482      self.assertIn("not good", report)
483
484    # Note that foo() won't result in a warning, because in this case foo is
485    # not an attribute, but a name.
486    false_alarms = ["foo", "foo()", "foo.bar()", "obj.run_foo()", "obj.foo"]
487    for text in false_alarms:
488      (_, report, _), _ = self._upgrade(FooWarningSpec(), text)
489      self.assertNotIn("not good", report)
490
491  def testFullNameNode(self):
492    t = ast_edits.full_name_node("a.b.c")
493    self.assertEqual(
494        ast.dump(t),
495        "Attribute(value=Attribute(value=Name(id='a', ctx=Load()), attr='b', "
496        "ctx=Load()), attr='c', ctx=Load())")
497
498  def testImport(self):
499    # foo should be renamed to bar.
500    text = "import foo as f"
501    expected_text = "import bar as f"
502    _, new_text = self._upgrade(RenameImports(), text)
503    self.assertEqual(expected_text, new_text)
504
505    text = "import foo"
506    expected_text = "import bar as foo"
507    _, new_text = self._upgrade(RenameImports(), text)
508    self.assertEqual(expected_text, new_text)
509
510    text = "import foo.test"
511    expected_text = "import bar.test"
512    _, new_text = self._upgrade(RenameImports(), text)
513    self.assertEqual(expected_text, new_text)
514
515    text = "import foo.test as t"
516    expected_text = "import bar.test as t"
517    _, new_text = self._upgrade(RenameImports(), text)
518    self.assertEqual(expected_text, new_text)
519
520    text = "import foo as f, a as b"
521    expected_text = "import bar as f, a as b"
522    _, new_text = self._upgrade(RenameImports(), text)
523    self.assertEqual(expected_text, new_text)
524
525  def testFromImport(self):
526    # foo should be renamed to bar.
527    text = "from foo import a"
528    expected_text = "from bar import a"
529    _, new_text = self._upgrade(RenameImports(), text)
530    self.assertEqual(expected_text, new_text)
531
532    text = "from foo.a import b"
533    expected_text = "from bar.a import b"
534    _, new_text = self._upgrade(RenameImports(), text)
535    self.assertEqual(expected_text, new_text)
536
537    text = "from foo import *"
538    expected_text = "from bar import *"
539    _, new_text = self._upgrade(RenameImports(), text)
540    self.assertEqual(expected_text, new_text)
541
542    text = "from foo import a, b"
543    expected_text = "from bar import a, b"
544    _, new_text = self._upgrade(RenameImports(), text)
545    self.assertEqual(expected_text, new_text)
546
547  def testImport_NoChangeNeeded(self):
548    text = "import bar as b"
549    _, new_text = self._upgrade(RenameImports(), text)
550    self.assertEqual(text, new_text)
551
552  def testFromImport_NoChangeNeeded(self):
553    text = "from bar import a as b"
554    _, new_text = self._upgrade(RenameImports(), text)
555    self.assertEqual(text, new_text)
556
557  def testExcludedImport(self):
558    # foo.baz module is excluded from changes.
559    text = "import foo.baz"
560    _, new_text = self._upgrade(RenameImports(), text)
561    self.assertEqual(text, new_text)
562
563    text = "import foo.baz as a"
564    _, new_text = self._upgrade(RenameImports(), text)
565    self.assertEqual(text, new_text)
566
567    text = "from foo import baz as a"
568    _, new_text = self._upgrade(RenameImports(), text)
569    self.assertEqual(text, new_text)
570
571    text = "from foo.baz import a"
572    _, new_text = self._upgrade(RenameImports(), text)
573    self.assertEqual(text, new_text)
574
575  def testMultipleImports(self):
576    text = "import foo.bar as a, foo.baz as b, foo.baz.c, foo.d"
577    expected_text = "import bar.bar as a, foo.baz as b, foo.baz.c, bar.d"
578    _, new_text = self._upgrade(RenameImports(), text)
579    self.assertEqual(expected_text, new_text)
580
581    text = "from foo import baz, a, c"
582    expected_text = """from foo import baz
583from bar import a, c"""
584    _, new_text = self._upgrade(RenameImports(), text)
585    self.assertEqual(expected_text, new_text)
586
587  def testImportInsideFunction(self):
588    text = """
589def t():
590  from c import d
591  from foo import baz, a
592  from e import y
593"""
594    expected_text = """
595def t():
596  from c import d
597  from foo import baz
598  from bar import a
599  from e import y
600"""
601    _, new_text = self._upgrade(RenameImports(), text)
602    self.assertEqual(expected_text, new_text)
603
604  def testUpgradeInplaceWithSymlink(self):
605    if os.name == "nt":
606      self.skipTest("os.symlink doesn't work uniformly on Windows.")
607
608    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
609    os.mkdir(upgrade_dir)
610    file_a = os.path.join(upgrade_dir, "a.py")
611    file_b = os.path.join(upgrade_dir, "b.py")
612
613    with open(file_a, "a") as f:
614      f.write("import foo as f")
615    os.symlink(file_a, file_b)
616
617    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
618    upgrader.process_tree_inplace(upgrade_dir)
619
620    self.assertTrue(os.path.islink(file_b))
621    self.assertEqual(file_a, os.readlink(file_b))
622    with open(file_a, "r") as f:
623      self.assertEqual("import bar as f", f.read())
624
625  def testUpgradeInPlaceWithSymlinkInDifferentDir(self):
626    if os.name == "nt":
627      self.skipTest("os.symlink doesn't work uniformly on Windows.")
628
629    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
630    other_dir = os.path.join(self.get_temp_dir(), "bar")
631    os.mkdir(upgrade_dir)
632    os.mkdir(other_dir)
633    file_c = os.path.join(other_dir, "c.py")
634    file_d = os.path.join(upgrade_dir, "d.py")
635
636    with open(file_c, "a") as f:
637      f.write("import foo as f")
638    os.symlink(file_c, file_d)
639
640    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
641    upgrader.process_tree_inplace(upgrade_dir)
642
643    self.assertTrue(os.path.islink(file_d))
644    self.assertEqual(file_c, os.readlink(file_d))
645    # File pointed to by symlink is in a different directory.
646    # Therefore, it should not be upgraded.
647    with open(file_c, "r") as f:
648      self.assertEqual("import foo as f", f.read())
649
650  def testUpgradeCopyWithSymlink(self):
651    if os.name == "nt":
652      self.skipTest("os.symlink doesn't work uniformly on Windows.")
653
654    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
655    output_dir = os.path.join(self.get_temp_dir(), "bar")
656    os.mkdir(upgrade_dir)
657    file_a = os.path.join(upgrade_dir, "a.py")
658    file_b = os.path.join(upgrade_dir, "b.py")
659
660    with open(file_a, "a") as f:
661      f.write("import foo as f")
662    os.symlink(file_a, file_b)
663
664    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
665    upgrader.process_tree(upgrade_dir, output_dir, copy_other_files=True)
666
667    new_file_a = os.path.join(output_dir, "a.py")
668    new_file_b = os.path.join(output_dir, "b.py")
669    self.assertTrue(os.path.islink(new_file_b))
670    self.assertEqual(new_file_a, os.readlink(new_file_b))
671    with open(new_file_a, "r") as f:
672      self.assertEqual("import bar as f", f.read())
673
674  def testUpgradeCopyWithSymlinkInDifferentDir(self):
675    if os.name == "nt":
676      self.skipTest("os.symlink doesn't work uniformly on Windows.")
677
678    upgrade_dir = os.path.join(self.get_temp_dir(), "foo")
679    other_dir = os.path.join(self.get_temp_dir(), "bar")
680    output_dir = os.path.join(self.get_temp_dir(), "baz")
681    os.mkdir(upgrade_dir)
682    os.mkdir(other_dir)
683    file_a = os.path.join(other_dir, "a.py")
684    file_b = os.path.join(upgrade_dir, "b.py")
685
686    with open(file_a, "a") as f:
687      f.write("import foo as f")
688    os.symlink(file_a, file_b)
689
690    upgrader = ast_edits.ASTCodeUpgrader(RenameImports())
691    upgrader.process_tree(upgrade_dir, output_dir, copy_other_files=True)
692
693    new_file_b = os.path.join(output_dir, "b.py")
694    self.assertTrue(os.path.islink(new_file_b))
695    self.assertEqual(file_a, os.readlink(new_file_b))
696    with open(file_a, "r") as f:
697      self.assertEqual("import foo as f", f.read())
698
699
700if __name__ == "__main__":
701  test_lib.main()
702