1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for control_flow module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.autograph.converters import control_flow 24from tensorflow.python.autograph.core import converter_testing 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import test_util 28from tensorflow.python.platform import test 29 30 31class ControlFlowTest(converter_testing.TestCase): 32 33 def assertTransformedResult(self, test_fn, inputs, expected, symbols=None): 34 if not isinstance(inputs, tuple): 35 inputs = (inputs,) 36 if not symbols: 37 symbols = {} 38 with self.converted(test_fn, control_flow, symbols, 39 constant_op.constant) as result: 40 self.assertEqual(self.evaluate(result.test_fn(*inputs)), expected) 41 42 @test_util.run_deprecated_v1 43 def test_while_basic(self): 44 45 def test_fn(n): 46 i = 0 47 s = 0 48 while i < n: 49 s += i 50 i += 1 51 return s, i, n 52 53 self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5)) 54 55 @test_util.run_deprecated_v1 56 def test_while_nested(self): 57 58 def test_fn(n): 59 i = 0 60 j = 0 61 s = 0 62 while i < n: 63 while j < i: 64 j += 3 65 u = i + j # 'u' is not defined within the inner loop 66 s += u 67 i += 1 68 j = 0 69 return s, i, j, n 70 71 self.assertTransformedResult(test_fn, constant_op.constant(5), 72 (25, 5, 0, 5)) 73 74 @test_util.run_deprecated_v1 75 def test_while_single_output(self): 76 77 def test_fn(n): 78 while n > 0: 79 n -= 1 80 return n 81 82 self.assertTransformedResult(test_fn, constant_op.constant(5), 0) 83 84 def test_while_local_composite(self): 85 86 class TestClass(object): 87 88 def __init__(self): 89 self.x = constant_op.constant(3) 90 91 def test_fn(n): 92 while n > 0: 93 tc = TestClass() 94 tc.x = tc.x 95 n -= 1 96 return n 97 98 self.assertTransformedResult( 99 test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass}) 100 101 # TODO(b/127642077): Add tests for x.y.z = 2*x.y.z and x.y[z] = 2*x.y[z]. 102 def test_while_local_composite_complex_nestable(self): 103 104 # This class is ok to be in a tf.while_loop's state. 105 class TestClass(collections.namedtuple('TestClass', ('x'))): 106 pass 107 108 def test_fn(n): 109 tc = TestClass([constant_op.constant(0)]) 110 while n > 0: 111 tc = TestClass([constant_op.constant(3)]) 112 tc.x[0] = tc.x[0] + 1 113 n -= 1 114 return tc.x[0] 115 116 ns = {'TestClass': TestClass, 'constant_op': constant_op} 117 self.assertTransformedResult( 118 test_fn, constant_op.constant(5), 4, symbols=ns) 119 120 def test_while_local_composite_complex_illegal(self): 121 122 class TestClass(object): 123 124 def __init__(self): 125 self.x = [constant_op.constant(3)] 126 127 def test_fn(n): 128 while n > 0: 129 tc = TestClass() 130 tc.x[0] = tc.x[0] + 1 131 n -= 1 132 return tc.x[0] 133 134 with self.converted( 135 test_fn, control_flow, {'TestClass': TestClass}) as result: 136 # The tested function would require `tc` to become part of the while loop 137 # state, but TensorFlow doesn't support classes at the moment. 138 with self.assertRaisesRegexp(ValueError, 'must.*initialize.*Tensor.*tc'): 139 result.test_fn(constant_op.constant(5)) 140 141 @test_util.run_deprecated_v1 142 def test_while_dispatches_by_cond_only(self): 143 144 class TensorIncompatibleNumeric(object): 145 """Works in arithmetic expression, but errors out with TF ops.""" 146 147 def __init__(self, val): 148 self.val = val 149 150 def __add__(self, other): 151 return TensorIncompatibleNumeric(self.val + other) 152 153 def test_fn(n, s): 154 while n > 0: 155 n -= 1 156 s += n 157 return s 158 159 self.assertTransformedResult(test_fn, (constant_op.constant(5), 0), 10) 160 with self.converted(test_fn, control_flow, {}) as result: 161 # n alone controls the staging. When the loop is not staged, Python 162 # knows how to add the two objects. But when staged, tf.while_loop will 163 # not know how to deal with the TensorIncompatibleNumeric object. 164 self.assertEqual(result.test_fn(5, TensorIncompatibleNumeric(0)).val, 10) 165 with self.assertRaises(TypeError): 166 result.test_fn(constant_op.constant(5), TensorIncompatibleNumeric(0)) 167 168 @test_util.run_deprecated_v1 169 def test_if_basic(self): 170 171 def test_fn(n): 172 a = 0 173 b = 0 174 if n > 0: 175 a = -n 176 else: 177 b = 2 * n 178 return a, b 179 180 self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0)) 181 self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2)) 182 183 @test_util.run_deprecated_v1 184 def test_if_complex_outputs(self): 185 186 class TestClass(object): 187 188 def __init__(self, a, b): 189 self.a = a 190 self.b = b 191 192 def test_fn(n, obj): 193 obj.a = 0 194 obj.b = 0 195 if n > 0: 196 obj.a = -n 197 else: 198 obj.b = 2 * n 199 return obj 200 201 with self.converted(test_fn, control_flow, {}) as result: 202 res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0)) 203 self.assertEqual(self.evaluate((res_obj.a, res_obj.b)), (-1, 0)) 204 res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0)) 205 self.assertEqual(self.evaluate((res_obj.a, res_obj.b)), (0, -2)) 206 207 @test_util.run_deprecated_v1 208 def test_if_single_output(self): 209 210 def test_fn(n): 211 if n > 0: 212 n = -n 213 return n 214 215 self.assertTransformedResult(test_fn, constant_op.constant(1), -1) 216 217 @test_util.run_deprecated_v1 218 def test_if_semi(self): 219 220 def test_fn(n): 221 if n > 0: 222 n = 3 223 return n 224 225 self.assertTransformedResult(test_fn, constant_op.constant(2), 3) 226 self.assertTransformedResult(test_fn, constant_op.constant(-3), -3) 227 228 @test_util.run_deprecated_v1 229 def test_if_local_var(self): 230 231 def test_fn(n): 232 if n > 0: 233 b = 4 234 n = b + 1 235 return n 236 237 self.assertTransformedResult(test_fn, constant_op.constant(1), 5) 238 self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) 239 240 @test_util.run_deprecated_v1 241 def test_if_no_outputs(self): 242 243 def test_fn(n): 244 if n > 0: 245 b = 4 # pylint:disable=unused-variable 246 return n 247 248 # Without side effect guards, the if statement will stage a cond, 249 # but that will be pruned at execution. 250 self.assertTransformedResult(test_fn, constant_op.constant(1), 1) 251 self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) 252 253 @test_util.run_deprecated_v1 254 def test_if_unbalanced_multiple_composites(self): 255 256 class Foo(object): 257 258 def __init__(self): 259 self.b = 2 260 self.c = 3 261 262 def test_fn(x, condition): 263 264 z = 5 265 if condition: 266 x.b = 7 267 x.c = 11 268 z = 13 269 270 return x.b, x.c, z 271 272 self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)), 273 (7, 11, 13)) 274 self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)), 275 (2, 3, 5)) 276 277 @test_util.run_deprecated_v1 278 def test_if_unbalanced_composite(self): 279 280 class Foo(object): 281 282 def __init__(self): 283 self.b = 2 284 285 def test_fn(x, condition): 286 287 z = 5 288 if condition: 289 x.b = 7 290 z = 13 291 292 return x.b, z 293 294 self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)), 295 (7, 13)) 296 self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)), 297 (2, 5)) 298 299 @test_util.run_deprecated_v1 300 def test_simple_for(self): 301 302 def test_fn(l): 303 s1 = 0 304 s2 = 0 305 for e in l: 306 s1 += e 307 s2 += e * e 308 return s1, s2 309 310 self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10)) 311 empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) 312 self.assertTransformedResult(test_fn, empty_vector, (0, 0)) 313 314 @test_util.run_deprecated_v1 315 def test_for_single_output(self): 316 317 def test_fn(l): 318 s = 0 319 for e in l: 320 s += e 321 return s 322 323 self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4) 324 empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) 325 self.assertTransformedResult(test_fn, empty_vector, 0) 326 327 def test_for_iterated_expression(self): 328 329 eval_count = [0] 330 331 def count_evals(x): 332 eval_count[0] += 1 333 return x 334 335 def test_fn(n): 336 s = 0 337 for e in count_evals(range(n)): 338 s += e 339 return s 340 341 ns = {'count_evals': count_evals} 342 node, ctx = self.prepare(test_fn, ns) 343 node = control_flow.transform(node, ctx) 344 345 with self.compiled(node, ns) as result: 346 self.assertEqual(result.test_fn(5), 10) 347 self.assertEqual(eval_count[0], 1) 348 349 @test_util.run_deprecated_v1 350 def test_for_tuple_unpacking(self): 351 def test_fn(x_list): 352 z = tf.constant(0) # pylint:disable=undefined-variable 353 for i, x in enumerate(x_list): 354 z = z + x + i 355 return z 356 357 self.assertTransformedResult(test_fn, [3, 3], 7) 358 359 360if __name__ == '__main__': 361 test.main() 362