1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for break_statements module.""" 16 17from tensorflow.python.autograph.converters import break_statements 18from tensorflow.python.autograph.core import converter_testing 19from tensorflow.python.autograph.pyct import anno 20from tensorflow.python.platform import test 21 22 23class BreakCanonicalizationTest(converter_testing.TestCase): 24 25 def assertTransformedEquivalent(self, f, *inputs): 26 tr = self.transform(f, break_statements) 27 self.assertEqual(f(*inputs), tr(*inputs)) 28 29 def test_while_loop(self): 30 31 def f(x): 32 v = [] 33 while x > 0: 34 x -= 1 35 if x % 2 == 0: 36 break 37 v.append(x) 38 return v 39 40 self.assertTransformedEquivalent(f, 0) 41 self.assertTransformedEquivalent(f, 1) 42 self.assertTransformedEquivalent(f, 4) 43 44 def test_while_loop_preserves_directives(self): 45 46 def f(x): 47 while x > 0: 48 x -= 1 49 if x % 2 == 0: 50 break 51 52 _, node, ctx = self.transform(f, (), include_ast=True) 53 fake_annotation = object() 54 anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation) 55 node = break_statements.transform(node, ctx) 56 57 self.assertIs( 58 anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation) 59 60 def test_for_loop(self): 61 62 def f(a): 63 v = [] 64 for x in a: 65 x -= 1 66 if x % 2 == 0: 67 break 68 v.append(x) 69 return v 70 71 tr = self.transform(f, break_statements) 72 73 self.assertEqual([3], tr([5, 4])) 74 75 def test_for_loop_preserves_directives(self): 76 77 def f(a): 78 for x in a: 79 if x % 2 == 0: 80 break 81 82 _, node, ctx = self.transform(f, (), include_ast=True) 83 fake_annotation = object() 84 anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation) 85 node = break_statements.transform(node, ctx) 86 self.assertIs( 87 anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation) 88 89 def test_nested(self): 90 91 def f(x): 92 v = [] 93 u = [] 94 w = [] 95 while x > 0: 96 x -= 1 97 if x % 2 == 0: 98 if x % 3 != 0: 99 u.append(x) 100 else: 101 w.append(x) 102 break 103 v.append(x) 104 return v, u, w 105 106 self.assertTransformedEquivalent(f, 0) 107 self.assertTransformedEquivalent(f, 3) 108 self.assertTransformedEquivalent(f, 11) 109 110 def test_nested_loops(self): 111 112 def f(x): 113 v = [] 114 u = [] 115 while x > 0: 116 x -= 1 117 y = x 118 while y > 0: 119 y -= 1 120 if y % 2 == 0: 121 break 122 u.append(y) 123 if x == 0: 124 break 125 v.append(x) 126 return v, u 127 128 self.assertTransformedEquivalent(f, 0) 129 self.assertTransformedEquivalent(f, 2) 130 self.assertTransformedEquivalent(f, 3) 131 self.assertTransformedEquivalent(f, 5) 132 133 def test_loop_orelse(self): 134 135 def f(x): 136 v = [] 137 u = [] 138 while x > 0: 139 x -= 1 140 y = x 141 while y > 1: 142 break 143 else: 144 u.append(y) 145 break 146 v.append(x) 147 return v, u 148 149 self.assertTransformedEquivalent(f, 0) 150 self.assertTransformedEquivalent(f, 2) 151 self.assertTransformedEquivalent(f, 3) 152 153 def test_multiple_correlated_breaks_with_side_effects(self): 154 def f(cond1): 155 lst = [] 156 while True: 157 if cond1: 158 lst.append(1) 159 else: 160 break 161 if lst[-1] > 0: # lst always has an element here 162 break 163 return lst 164 165 self.assertTransformedEquivalent(f, True) 166 self.assertTransformedEquivalent(f, False) 167 168 169if __name__ == '__main__': 170 test.main() 171