1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Nested loops and conditional statements (e.g. while, for, if). 16 17Meant to verify that arbitrarily nested statements are processed correctly. 18""" 19 20import itertools 21 22from absl.testing import parameterized 23import tensorflow as tf 24 25from tensorflow.python.autograph.tests import reference_test_base 26 27 28def independent_ifs(x, y): 29 z = 0 30 if x > 0: 31 if y > 0: 32 z = x + y 33 return z 34 35 36def dependent_inner_if(x): 37 y = 0 38 if x > 0: 39 y = -2 * x 40 if y > 0: 41 x = -3 * x 42 else: 43 y = 4 * x 44 return x, y 45 46 47def dependent_imbalanced_inner_if(x): 48 y = 0 49 if x > 0: 50 if x < 3: 51 y = -2 * x 52 x = -3 * x 53 return x, y 54 55 56def _hidden_raise(): 57 raise ValueError('exception used for control flow') 58 59 60def if_with_local_modification_masked_by_exception(x): 61 y = 0 62 if x > 0: 63 try: 64 if x > 1: 65 _hidden_raise() 66 y = 1 67 except ValueError: 68 pass 69 if y == 0: 70 y = 2 71 return y 72 73 74_test_global = None 75 76 77def if_nested_with_modification_of_global(x): 78 y = 0 79 if x > 0: 80 if x > 0: 81 global _test_global 82 if _test_global is None: 83 _test_global = 1 84 else: 85 _test_global += 1 86 y += _test_global 87 return y 88 89 90def independent_inner_for(a, b): 91 p = 0 92 for _ in a: 93 tmp = b 94 for j in tmp: 95 p += j 96 return p 97 98 99def independent_inner_while(a, b): 100 p = 0 101 while a > 0: 102 tmp = b 103 while tmp > 0: 104 p += 1 105 tmp -= 1 106 a -= 1 107 return p 108 109 110def dependent_inner_for(a, b): 111 r = 1 112 s = 0 113 for _ in a: 114 r += s 115 tmp = b 116 for j in tmp: 117 s += j 118 return r 119 120 121def dependent_inner_while(a, b): 122 r = 1 123 while a > 0: 124 r += 1 125 tmp = b 126 while tmp > 0: 127 a -= 1 128 tmp -= 1 129 a -= 1 130 return r 131 132 133def if_in_for(a): 134 k = 0 135 for i in a: 136 if i % 2 > 0: 137 j = i // 2 138 k += j 139 return k 140 141 142def while_with_continue_in_context_manager(x): 143 z = 0 144 while x > 0: 145 with tf.name_scope(''): 146 x = x - 1 147 if x < 5: 148 continue 149 z = z + 1 150 return z 151 152 153def while_continue_in_try(x): 154 z = 0 155 while x > 0: 156 x = x - 1 157 try: 158 if x < 5: 159 continue 160 z = z + 1 161 finally: 162 z = z + 10 163 return z 164 165 166def while_break_in_context_manager(x): 167 z = 0 168 while x > 0: 169 with tf.name_scope(''): 170 x = x - 1 171 if x < 5: 172 break 173 z = z + 1 174 return z 175 176 177def while_break_in_try(x): 178 z = 0 179 while x > 0: 180 x = x - 1 181 try: 182 if x < 5: 183 break 184 z = z + 1 185 finally: 186 z = z + 10 187 return z 188 189 190def loop_initializing_invariant_variable(n): 191 for i in range(n): 192 if i == 0: 193 a = 1 194 else: 195 a = 2 196 return a 197 198 199def loop_initializing_variant_variable(n): 200 for i in range(n): 201 if i == 0: 202 a = 1 203 else: 204 a = a + 1 205 return a 206 207 208def _int_tensor(x): 209 return tf.constant(x, dtype=tf.int32) 210 211 212class NestedControlFlowTest( 213 reference_test_base.TestCase, parameterized.TestCase): 214 215 @parameterized.parameters(*itertools.product( 216 ( 217 -1, 1, 218 ), 219 ( 220 -1, 1, 221 ), 222 ( 223 int, 224 _int_tensor, 225 ), 226 ( 227 int, 228 _int_tensor, 229 ), 230 )) 231 def test_independent_ifs(self, x, y, type_x, type_y): 232 x = type_x(x) 233 y = type_x(y) 234 self.assertFunctionMatchesEager(independent_ifs, x, y) 235 236 @parameterized.parameters(*itertools.product( 237 ( 238 -1, 1, 239 ), 240 ( 241 int, 242 _int_tensor, 243 ), 244 )) 245 def test_dependent_inner_if(self, x, type_): 246 x = type_(x) 247 self.assertFunctionMatchesEager(dependent_inner_if, x) 248 249 @parameterized.parameters(*itertools.product( 250 ( 251 -1, 1, 252 ), 253 ( 254 int, 255 _int_tensor, 256 ), 257 )) 258 def test_dependent_imbalanced_inner_if(self, x, type_): 259 x = type_(x) 260 self.assertFunctionMatchesEager(dependent_imbalanced_inner_if, x) 261 262 @parameterized.parameters( 263 (-1,), 264 (0,), 265 (1,), 266 (2,), 267 ) 268 def test_if_with_local_modification_masked_by_exception(self, x): 269 # Note: If the input is a Tensor, the behavior is undefined. 270 self.assertFunctionMatchesEager( 271 if_with_local_modification_masked_by_exception, x) 272 273 def test_if_nested_with_modification_of_global(self): 274 global _test_global 275 _test_global = None 276 self.assertEqual(tf.function(if_nested_with_modification_of_global)(1), 1) 277 self.assertEqual(_test_global, 1) 278 279 def test_if_nested_with_modification_of_global_not_executed(self): 280 global _test_global 281 _test_global = None 282 self.assertEqual(tf.function(if_nested_with_modification_of_global)(0), 0) 283 self.assertIsNone(_test_global) 284 285 @parameterized.parameters(*itertools.product( 286 ( 287 0, 1, 2, 288 ), 289 ( 290 0, 1, 2, 291 ), 292 ( 293 range, 294 tf.range, 295 ), 296 ( 297 range, 298 tf.range, 299 ), 300 )) 301 def test_independent_inner_for(self, a, b, type_a, type_b): 302 a = type_a(a) 303 b = type_b(b) 304 self.assertFunctionMatchesEager(independent_inner_for, a, b) 305 306 @parameterized.parameters(*itertools.product( 307 ( 308 0, 1, 2, 309 ), 310 ( 311 0, 1, 2, 312 ), 313 ( 314 int, 315 _int_tensor, 316 ), 317 ( 318 int, 319 _int_tensor, 320 ), 321 )) 322 def test_independent_inner_while(self, a, b, type_a, type_b): 323 a = type_a(a) 324 b = type_b(b) 325 self.assertFunctionMatchesEager(independent_inner_while, a, b) 326 327 @parameterized.parameters(*itertools.product( 328 ( 329 0, 1, 2, 330 ), 331 ( 332 0, 1, 2, 333 ), 334 ( 335 range, 336 tf.range, 337 ), 338 ( 339 range, 340 tf.range, 341 ), 342 )) 343 def test_dependent_inner_for(self, a, b, type_a, type_b): 344 a = type_a(a) 345 b = type_b(b) 346 self.assertFunctionMatchesEager(dependent_inner_for, a, b) 347 348 @parameterized.parameters(*itertools.product( 349 ( 350 0, 1, 2, 3, 4, 351 ), 352 ( 353 0, 1, 2, 3, 4, 354 ), 355 ( 356 int, 357 _int_tensor, 358 ), 359 ( 360 int, 361 _int_tensor, 362 ), 363 )) 364 def test_dependent_inner_while(self, a, b, type_a, type_b): 365 if (type_a is int) and (type_b is _int_tensor): 366 self.skipTest('b/124378596') 367 a = type_a(a) 368 b = type_b(b) 369 self.assertFunctionMatchesEager(dependent_inner_while, a, b) 370 371 @parameterized.parameters(*itertools.product( 372 ( 373 0, 1, 2, 374 ), 375 ( 376 range, 377 tf.range, 378 ), 379 )) 380 def test_if_in_for(self, a, type_): 381 a = type_(a) 382 self.assertFunctionMatchesEager(if_in_for, a) 383 384 @parameterized.parameters(*itertools.product( 385 ( 386 0, 4, 10, 387 ), 388 ( 389 int, 390 _int_tensor, 391 ), 392 )) 393 def test_while_continue_in_context_manager(self, x, type_): 394 x = type_(x) 395 self.assertFunctionMatchesEager(while_with_continue_in_context_manager, x) 396 397 @parameterized.parameters(*itertools.product( 398 ( 399 0, 4, 10, 400 ), 401 ( 402 int, 403 _int_tensor, 404 ), 405 )) 406 def test_while_continue_in_try(self, x, type_): 407 x = type_(x) 408 self.assertFunctionMatchesEager(while_continue_in_try, x) 409 410 @parameterized.parameters(*itertools.product( 411 ( 412 0, 4, 10, 413 ), 414 ( 415 int, 416 _int_tensor, 417 ), 418 )) 419 def test_while_break_in_context_manager(self, x, type_): 420 x = type_(x) 421 self.assertFunctionMatchesEager(while_break_in_context_manager, x) 422 423 @parameterized.parameters(*itertools.product( 424 ( 425 0, 4, 10, 426 ), 427 ( 428 int, 429 _int_tensor, 430 ), 431 )) 432 def test_while_break_in_try(self, x, type_): 433 x = type_(x) 434 self.assertFunctionMatchesEager(while_break_in_try, x) 435 436 @parameterized.parameters(*itertools.product( 437 ( 438 1, 2, 439 ), 440 ( 441 int, 442 _int_tensor, 443 ), 444 )) 445 def test_loop_initializing_invariant_variable_legal(self, n, type_): 446 n = type_(n) 447 self.assertFunctionMatchesEager(loop_initializing_invariant_variable, n) 448 449 def test_loop_initializing_invariant_variable_illegal(self): 450 with self.assertRaises(UnboundLocalError): 451 tf.function(loop_initializing_invariant_variable)(0) 452 with self.assertRaisesRegex( 453 tf.errors.InvalidArgumentError, 'loop must iterate at least once'): 454 tf.function(loop_initializing_invariant_variable)(tf.constant(0)) 455 456 @parameterized.parameters( 457 (1,), 458 (2,), 459 ) 460 def test_loop_initializing_variant_variable_legal(self, n): 461 tf.function(loop_initializing_variant_variable)(n) 462 463 @parameterized.parameters( 464 (0,), 465 (1,), 466 (2,), 467 ) 468 def test_loop_initializing_variant_variable_illegal(self, n): 469 with self.assertRaisesRegex(ValueError, 'must be defined before the loop'): 470 tf.function(loop_initializing_variant_variable)(tf.constant(n)) 471 472 473if __name__ == '__main__': 474 tf.test.main() 475