1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for directives module.""" 16 17from tensorflow.python.autograph.converters import directives as directives_converter 18from tensorflow.python.autograph.core import converter_testing 19from tensorflow.python.autograph.lang import directives 20from tensorflow.python.autograph.pyct import anno 21from tensorflow.python.platform import test 22 23 24class DirectivesTest(converter_testing.TestCase): 25 26 def test_local_target(self): 27 28 def f(): 29 l = [] 30 string_var = 0 31 directives.set_element_type(l, 'a', string_var) 32 33 _, node, _ = self.transform(f, directives_converter, include_ast=True) 34 35 def_, = anno.getanno(node.body[0].targets[0], 36 anno.Static.DEFINITIONS) 37 d = def_.directives[directives.set_element_type] 38 self.assertEqual(d['dtype'].value, 'a') 39 self.assertEqual(d['shape'].id, 'string_var') 40 41 def test_argument_target(self): 42 43 def f(a): 44 directives.set_element_type(a, 1, shape=2) 45 pass 46 47 _, node, _ = self.transform(f, directives_converter, include_ast=True) 48 49 def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) 50 d = def_.directives[directives.set_element_type] 51 self.assertEqual(d['dtype'].value, 1) 52 self.assertEqual(d['shape'].value, 2) 53 54 def test_loop_target(self): 55 56 def f(): 57 a = True 58 while True: 59 directives.set_loop_options(parallel_iterations=10, back_prop=a) # pylint: disable=unexpected-keyword-arg 60 pass 61 62 _, node, _ = self.transform(f, directives_converter, include_ast=True) 63 64 d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES) 65 d = d[directives.set_loop_options] 66 self.assertEqual(d['parallel_iterations'].value, 10) 67 self.assertEqual(d['back_prop'].id, 'a') 68 self.assertNotIn('swap_memory', d) 69 70 def test_loop_target_no_loop(self): 71 72 def f(): 73 directives.set_loop_options() 74 pass 75 76 with self.assertRaisesRegex(ValueError, 'must be used inside a statement'): 77 self.transform(f, directives_converter, include_ast=True) 78 79 def test_loop_target_not_first(self): 80 81 def f(): 82 a = 1 83 while True: 84 a = 2 85 directives.set_loop_options(parallel_iterations=10, back_prop=a) # pylint: disable=unexpected-keyword-arg 86 87 with self.assertRaisesRegex(ValueError, 'must be the first statement'): 88 self.transform(f, directives_converter, include_ast=True) 89 90 def test_value_verification_does_not_trigger_properties(self): 91 92 self_test = self 93 94 class TestClass(object): 95 96 @property 97 def b(self): 98 self_test.fail('This should never be evaluated') 99 100 tc = TestClass() 101 102 def f(): 103 return tc.b + 1 104 105 _, node, _ = self.transform(f, directives_converter, include_ast=True) 106 107 self.assertIsNotNone(node) 108 109 def test_value_verification_does_not_trigger_getattr(self): 110 111 class TestClass(object): 112 113 def __init__(self): 114 self.getattr_called = False 115 116 def __getattr__(self, _): 117 # Note: seems that any exception raised here is absorbed by hasattr. 118 # So we can't call test.fail or raise. 119 self.getattr_called = True 120 121 tc = TestClass() 122 123 def f(): 124 return tc.b + 1 125 126 _, node, _ = self.transform(f, directives_converter, include_ast=True) 127 128 self.assertIsNotNone(node) 129 self.assertFalse(tc.getattr_called) 130 131 132if __name__ == '__main__': 133 test.main() 134