• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for origin_info module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.autograph.pyct import anno
22from tensorflow.python.autograph.pyct import compiler
23from tensorflow.python.autograph.pyct import origin_info
24from tensorflow.python.autograph.pyct import parser
25from tensorflow.python.platform import test
26
27
28class OriginInfoTest(test.TestCase):
29
30  def test_create_source_map(self):
31
32    def test_fn(x):
33      return x + 1
34
35    node, _, _ = parser.parse_entity(test_fn)
36    fake_origin = origin_info.OriginInfo(
37        loc=origin_info.Location('fake_filename', 3, 7),
38        function_name='fake_function_name',
39        source_code_line='fake source line',
40        comment=None)
41    anno.setanno(node.body[0], anno.Basic.ORIGIN, fake_origin)
42    converted_code = compiler.ast_to_source(node)
43
44    source_map = origin_info.create_source_map(
45        node, converted_code, 'test_filename', [0])
46
47    loc = origin_info.LineLocation('test_filename', 2)
48    self.assertIn(loc, source_map)
49    self.assertIs(source_map[loc], fake_origin)
50
51  def test_source_map_no_origin(self):
52
53    def test_fn(x):
54      return x + 1
55
56    node, _, _ = parser.parse_entity(test_fn)
57    converted_code = compiler.ast_to_source(node)
58
59    source_map = origin_info.create_source_map(
60        node, converted_code, 'test_filename', [0])
61
62    self.assertEqual(len(source_map), 0)
63
64  def test_resolve(self):
65
66    def test_fn(x):
67      """Docstring."""
68      return x  # comment
69
70    node, source, _ = parser.parse_entity(test_fn)
71
72    origin_info.resolve(node, source)
73
74    origin = anno.getanno(node, anno.Basic.ORIGIN)
75    self.assertEqual(origin.loc.lineno, 1)
76    self.assertEqual(origin.loc.col_offset, 0)
77    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
78    self.assertIsNone(origin.comment)
79
80    origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
81    self.assertEqual(origin.loc.lineno, 2)
82    self.assertEqual(origin.loc.col_offset, 2)
83    self.assertEqual(origin.source_code_line, '  """Docstring."""')
84    self.assertIsNone(origin.comment)
85
86    origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
87    self.assertEqual(origin.loc.lineno, 3)
88    self.assertEqual(origin.loc.col_offset, 2)
89    self.assertEqual(origin.source_code_line, '  return x  # comment')
90    self.assertEqual(origin.comment, 'comment')
91
92  def disabled_test_resolve_with_future_imports(self):
93
94    def test_fn(x):
95      """Docstring."""
96      print(x)
97      return x  # comment
98
99    node, source, _ = parser.parse_entity(test_fn)
100
101    origin_info.resolve(node, source)
102
103    origin = anno.getanno(node, anno.Basic.ORIGIN)
104    self.assertEqual(origin.loc.lineno, 2)
105    self.assertEqual(origin.loc.col_offset, 0)
106    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
107    self.assertIsNone(origin.comment)
108
109    origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
110    self.assertEqual(origin.loc.lineno, 3)
111    self.assertEqual(origin.loc.col_offset, 2)
112    self.assertEqual(origin.source_code_line, '  """Docstring."""')
113    self.assertIsNone(origin.comment)
114
115    origin = anno.getanno(node.body[2], anno.Basic.ORIGIN)
116    self.assertEqual(origin.loc.lineno, 5)
117    self.assertEqual(origin.loc.col_offset, 2)
118    self.assertEqual(origin.source_code_line, '  return x  # comment')
119    self.assertEqual(origin.comment, 'comment')
120
121
122if __name__ == '__main__':
123  test.main()
124