1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for origin_info module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.autograph.pyct import anno 22from tensorflow.python.autograph.pyct import compiler 23from tensorflow.python.autograph.pyct import origin_info 24from tensorflow.python.autograph.pyct import parser 25from tensorflow.python.platform import test 26 27 28class OriginInfoTest(test.TestCase): 29 30 def test_create_source_map(self): 31 32 def test_fn(x): 33 return x + 1 34 35 node, _, _ = parser.parse_entity(test_fn) 36 fake_origin = origin_info.OriginInfo( 37 loc=origin_info.Location('fake_filename', 3, 7), 38 function_name='fake_function_name', 39 source_code_line='fake source line', 40 comment=None) 41 anno.setanno(node.body[0], anno.Basic.ORIGIN, fake_origin) 42 converted_code = compiler.ast_to_source(node) 43 44 source_map = origin_info.create_source_map( 45 node, converted_code, 'test_filename', [0]) 46 47 loc = origin_info.LineLocation('test_filename', 2) 48 self.assertIn(loc, source_map) 49 self.assertIs(source_map[loc], fake_origin) 50 51 def test_source_map_no_origin(self): 52 53 def test_fn(x): 54 return x + 1 55 56 node, _, _ = parser.parse_entity(test_fn) 57 converted_code = compiler.ast_to_source(node) 58 59 source_map = origin_info.create_source_map( 60 node, converted_code, 'test_filename', [0]) 61 62 self.assertEqual(len(source_map), 0) 63 64 def test_resolve(self): 65 66 def test_fn(x): 67 """Docstring.""" 68 return x # comment 69 70 node, source, _ = parser.parse_entity(test_fn) 71 72 origin_info.resolve(node, source) 73 74 origin = anno.getanno(node, anno.Basic.ORIGIN) 75 self.assertEqual(origin.loc.lineno, 1) 76 self.assertEqual(origin.loc.col_offset, 0) 77 self.assertEqual(origin.source_code_line, 'def test_fn(x):') 78 self.assertIsNone(origin.comment) 79 80 origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) 81 self.assertEqual(origin.loc.lineno, 2) 82 self.assertEqual(origin.loc.col_offset, 2) 83 self.assertEqual(origin.source_code_line, ' """Docstring."""') 84 self.assertIsNone(origin.comment) 85 86 origin = anno.getanno(node.body[1], anno.Basic.ORIGIN) 87 self.assertEqual(origin.loc.lineno, 3) 88 self.assertEqual(origin.loc.col_offset, 2) 89 self.assertEqual(origin.source_code_line, ' return x # comment') 90 self.assertEqual(origin.comment, 'comment') 91 92 def disabled_test_resolve_with_future_imports(self): 93 94 def test_fn(x): 95 """Docstring.""" 96 print(x) 97 return x # comment 98 99 node, source, _ = parser.parse_entity(test_fn) 100 101 origin_info.resolve(node, source) 102 103 origin = anno.getanno(node, anno.Basic.ORIGIN) 104 self.assertEqual(origin.loc.lineno, 2) 105 self.assertEqual(origin.loc.col_offset, 0) 106 self.assertEqual(origin.source_code_line, 'def test_fn(x):') 107 self.assertIsNone(origin.comment) 108 109 origin = anno.getanno(node.body[0], anno.Basic.ORIGIN) 110 self.assertEqual(origin.loc.lineno, 3) 111 self.assertEqual(origin.loc.col_offset, 2) 112 self.assertEqual(origin.source_code_line, ' """Docstring."""') 113 self.assertIsNone(origin.comment) 114 115 origin = anno.getanno(node.body[2], anno.Basic.ORIGIN) 116 self.assertEqual(origin.loc.lineno, 5) 117 self.assertEqual(origin.loc.col_offset, 2) 118 self.assertEqual(origin.source_code_line, ' return x # comment') 119 self.assertEqual(origin.comment, 'comment') 120 121 122if __name__ == '__main__': 123 test.main() 124