• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for break_statements module."""
16
17from tensorflow.python.autograph.converters import break_statements
18from tensorflow.python.autograph.core import converter_testing
19from tensorflow.python.autograph.pyct import anno
20from tensorflow.python.platform import test
21
22
23class BreakCanonicalizationTest(converter_testing.TestCase):
24
25  def assertTransformedEquivalent(self, f, *inputs):
26    tr = self.transform(f, break_statements)
27    self.assertEqual(f(*inputs), tr(*inputs))
28
29  def test_while_loop(self):
30
31    def f(x):
32      v = []
33      while x > 0:
34        x -= 1
35        if x % 2 == 0:
36          break
37        v.append(x)
38      return v
39
40    self.assertTransformedEquivalent(f, 0)
41    self.assertTransformedEquivalent(f, 1)
42    self.assertTransformedEquivalent(f, 4)
43
44  def test_while_loop_preserves_directives(self):
45
46    def f(x):
47      while x > 0:
48        x -= 1
49        if x % 2 == 0:
50          break
51
52    _, node, ctx = self.transform(f, (), include_ast=True)
53    fake_annotation = object()
54    anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
55    node = break_statements.transform(node, ctx)
56
57    self.assertIs(
58        anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
59
60  def test_for_loop(self):
61
62    def f(a):
63      v = []
64      for x in a:
65        x -= 1
66        if x % 2 == 0:
67          break
68        v.append(x)
69      return v
70
71    tr = self.transform(f, break_statements)
72
73    self.assertEqual([3], tr([5, 4]))
74
75  def test_for_loop_preserves_directives(self):
76
77    def f(a):
78      for x in a:
79        if x % 2 == 0:
80          break
81
82    _, node, ctx = self.transform(f, (), include_ast=True)
83    fake_annotation = object()
84    anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
85    node = break_statements.transform(node, ctx)
86    self.assertIs(
87        anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
88
89  def test_nested(self):
90
91    def f(x):
92      v = []
93      u = []
94      w = []
95      while x > 0:
96        x -= 1
97        if x % 2 == 0:
98          if x % 3 != 0:
99            u.append(x)
100          else:
101            w.append(x)
102            break
103        v.append(x)
104      return v, u, w
105
106    self.assertTransformedEquivalent(f, 0)
107    self.assertTransformedEquivalent(f, 3)
108    self.assertTransformedEquivalent(f, 11)
109
110  def test_nested_loops(self):
111
112    def f(x):
113      v = []
114      u = []
115      while x > 0:
116        x -= 1
117        y = x
118        while y > 0:
119          y -= 1
120          if y % 2 == 0:
121            break
122          u.append(y)
123        if x == 0:
124          break
125        v.append(x)
126      return v, u
127
128    self.assertTransformedEquivalent(f, 0)
129    self.assertTransformedEquivalent(f, 2)
130    self.assertTransformedEquivalent(f, 3)
131    self.assertTransformedEquivalent(f, 5)
132
133  def test_loop_orelse(self):
134
135    def f(x):
136      v = []
137      u = []
138      while x > 0:
139        x -= 1
140        y = x
141        while y > 1:
142          break
143        else:
144          u.append(y)
145          break
146        v.append(x)
147      return v, u
148
149    self.assertTransformedEquivalent(f, 0)
150    self.assertTransformedEquivalent(f, 2)
151    self.assertTransformedEquivalent(f, 3)
152
153  def test_multiple_correlated_breaks_with_side_effects(self):
154    def f(cond1):
155      lst = []
156      while True:
157        if cond1:
158          lst.append(1)
159        else:
160          break
161        if lst[-1] > 0:  # lst always has an element here
162          break
163      return lst
164
165    self.assertTransformedEquivalent(f, True)
166    self.assertTransformedEquivalent(f, False)
167
168
169if __name__ == '__main__':
170  test.main()
171