• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for directives module."""
16
17from tensorflow.python.autograph.converters import directives as directives_converter
18from tensorflow.python.autograph.core import converter_testing
19from tensorflow.python.autograph.lang import directives
20from tensorflow.python.autograph.pyct import anno
21from tensorflow.python.platform import test
22
23
24class DirectivesTest(converter_testing.TestCase):
25
26  def test_local_target(self):
27
28    def f():
29      l = []
30      string_var = 0
31      directives.set_element_type(l, 'a', string_var)
32
33    _, node, _ = self.transform(f, directives_converter, include_ast=True)
34
35    def_, = anno.getanno(node.body[0].targets[0],
36                         anno.Static.DEFINITIONS)
37    d = def_.directives[directives.set_element_type]
38    self.assertEqual(d['dtype'].value, 'a')
39    self.assertEqual(d['shape'].id, 'string_var')
40
41  def test_argument_target(self):
42
43    def f(a):
44      directives.set_element_type(a, 1, shape=2)
45      pass
46
47    _, node, _ = self.transform(f, directives_converter, include_ast=True)
48
49    def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
50    d = def_.directives[directives.set_element_type]
51    self.assertEqual(d['dtype'].value, 1)
52    self.assertEqual(d['shape'].value, 2)
53
54  def test_loop_target(self):
55
56    def f():
57      a = True
58      while True:
59        directives.set_loop_options(parallel_iterations=10, back_prop=a)  # pylint: disable=unexpected-keyword-arg
60        pass
61
62    _, node, _ = self.transform(f, directives_converter, include_ast=True)
63
64    d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES)
65    d = d[directives.set_loop_options]
66    self.assertEqual(d['parallel_iterations'].value, 10)
67    self.assertEqual(d['back_prop'].id, 'a')
68    self.assertNotIn('swap_memory', d)
69
70  def test_loop_target_no_loop(self):
71
72    def f():
73      directives.set_loop_options()
74      pass
75
76    with self.assertRaisesRegex(ValueError, 'must be used inside a statement'):
77      self.transform(f, directives_converter, include_ast=True)
78
79  def test_loop_target_not_first(self):
80
81    def f():
82      a = 1
83      while True:
84        a = 2
85        directives.set_loop_options(parallel_iterations=10, back_prop=a)  # pylint: disable=unexpected-keyword-arg
86
87    with self.assertRaisesRegex(ValueError, 'must be the first statement'):
88      self.transform(f, directives_converter, include_ast=True)
89
90  def test_value_verification_does_not_trigger_properties(self):
91
92    self_test = self
93
94    class TestClass(object):
95
96      @property
97      def b(self):
98        self_test.fail('This should never be evaluated')
99
100    tc = TestClass()
101
102    def f():
103      return tc.b + 1
104
105    _, node, _ = self.transform(f, directives_converter, include_ast=True)
106
107    self.assertIsNotNone(node)
108
109  def test_value_verification_does_not_trigger_getattr(self):
110
111    class TestClass(object):
112
113      def __init__(self):
114        self.getattr_called = False
115
116      def __getattr__(self, _):
117        # Note: seems that any exception raised here is absorbed by hasattr.
118        # So we can't call test.fail or raise.
119        self.getattr_called = True
120
121    tc = TestClass()
122
123    def f():
124      return tc.b + 1
125
126    _, node, _ = self.transform(f, directives_converter, include_ast=True)
127
128    self.assertIsNotNone(node)
129    self.assertFalse(tc.getattr_called)
130
131
132if __name__ == '__main__':
133  test.main()
134