1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for while_v2.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.core.protobuf import rewriter_config_pb2 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import function 31from tensorflow.python.framework import meta_graph 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import test_util 34from tensorflow.python.grappler import tf_optimizer 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import control_flow_util 38from tensorflow.python.ops import control_flow_util_v2 39from tensorflow.python.ops import control_flow_v2_toggles 40from tensorflow.python.ops import custom_gradient 41from tensorflow.python.ops import gen_array_ops 42from tensorflow.python.ops import gradients_impl 43from tensorflow.python.ops import list_ops 44from tensorflow.python.ops import map_fn 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import random_ops 47from tensorflow.python.ops import variables 48from tensorflow.python.ops import while_v2 49from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 50from tensorflow.python.platform import test 51 52def random_gamma(shape): # pylint: disable=invalid-name 53 return random_ops.random_gamma(shape, 1.0) 54 55 56def random_gamma_with_alpha_beta(shape): # pylint: disable=invalid-name 57 return random_ops.random_gamma( 58 shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]]) 59 60 61def random_poisson_v2(shape): # pylint: disable=invalid-name 62 return random_ops.random_poisson_v2(shape, 1.0) 63 64 65def random_poisson_v2_with_lam(shape): # pylint: disable=invalid-name 66 return random_ops.random_poisson_v2(shape, [12.2, 3.3]) 67 68 69def fill(shape): # pylint: disable=invalid-name 70 return array_ops.fill(shape, 1.0) 71 72 73class WhileV2Test(test.TestCase, parameterized.TestCase): 74 75 @test_util.run_deprecated_v1 76 def testSingleLoopVar(self): 77 x = constant_op.constant(2.) 78 ret = while_loop_v2( 79 lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False) 80 grad = gradients_impl.gradients(ret, [x]) 81 with self.cached_session(): 82 self.assertEqual(self.evaluate(ret), 16.) 83 self.assertSequenceEqual(self.evaluate(grad), [32.]) 84 85 @test_util.run_deprecated_v1 86 def testSingleLoopVarBackPropFalse(self): 87 x = constant_op.constant(2.) 88 ret = while_loop_v2( 89 lambda v: v < 8., 90 lambda v: v * v, [x], 91 return_same_structure=False, 92 back_prop=False) 93 grad = gradients_impl.gradients(ret, [x]) 94 self.assertEqual(grad, [None]) 95 with self.cached_session(): 96 self.assertEqual(self.evaluate(ret), 16.) 97 98 @test_util.run_deprecated_v1 99 def testCustomGradient(self): 100 x = constant_op.constant(2.) 101 n = constant_op.constant(1., name="const-n") 102 m = variables.Variable(1.0) 103 self.evaluate(variables.global_variables_initializer()) 104 105 def body_fn(v): # pylint: disable=invalid-name 106 107 @custom_gradient.custom_gradient 108 def inner_fn(v): # pylint: disable=invalid-name 109 110 def grad_fn(dy, variables=None): # pylint: disable=invalid-name, unused-argument, redefined-outer-name 111 return dy * 2 * v * n * m, [v * v] 112 113 return v * v * m, grad_fn 114 115 return inner_fn(v) 116 117 ret = while_loop_v2( 118 lambda v: v < 8., body_fn, [x], return_same_structure=False) 119 grad = gradients_impl.gradients(ret, [x]) 120 with self.cached_session(): 121 self.assertEqual(self.evaluate(ret), 16.) 122 self.assertSequenceEqual(self.evaluate(grad), [32.]) 123 124 @test_util.run_v1_only("b/120545219") 125 def testReturnSameStructureTrue(self): 126 x = constant_op.constant(2.) 127 ret = while_loop_v2( 128 lambda v: v < 8., lambda v: v * v, [x], return_same_structure=True) 129 grad = gradients_impl.gradients(ret, [x]) 130 with self.cached_session() as sess: 131 eval_result = sess.run(ret) 132 self.assertIsInstance(eval_result, list) 133 self.assertLen(eval_result, 1) 134 self.assertEqual(16., eval_result[0]) 135 self.assertSequenceEqual(sess.run(grad), [32.]) 136 137 def testVerifyInputOutputTypesMatch(self): 138 139 @def_function.function 140 def BuildWhile(): 141 x = constant_op.constant(1., dtypes.float32) 142 143 def Body(x): 144 return math_ops.cast(x, dtypes.float16) + 1 145 146 while_loop_v2(lambda x: x < 10, Body, [x]) 147 148 with self.assertRaisesRegexp( 149 TypeError, 150 r"Loop var Const:0 enters the loop with type <dtype: 'float32'> " 151 r"but has type <dtype: 'float16'> after 1 iteration."): 152 BuildWhile() 153 154 @parameterized.parameters(dtypes.float32, dtypes.float64) 155 def testGradientTapeResourceVariable(self, dtype): 156 with context.eager_mode(): 157 v = variables.Variable(1., dtype=dtype) 158 159 @def_function.function 160 def fnWithLoop(): # pylint: disable=invalid-name 161 with backprop.GradientTape() as tape: 162 _, x = while_loop_v2( 163 lambda i, _: i < 2, 164 lambda i, x: (i + 1, x * v), 165 [0, constant_op.constant(2., dtype=dtype)]) 166 return tape.gradient(x, v) 167 168 self.assertAllEqual(fnWithLoop(), 4.0) 169 170 def testExternalControlDependencies(self): 171 with ops.Graph().as_default(), self.test_session(): 172 v = variables.Variable(1.) 173 v.initializer.run() 174 op = v.assign_add(1.) 175 176 def body_fn(i): # pylint: disable=invalid-name 177 with ops.control_dependencies([op]): 178 return i + 1 179 180 loop = while_loop_v2(lambda i: i < 1, body_fn, [0]) 181 loop[0].op.run() 182 self.assertAllEqual(self.evaluate(v), 2.0) 183 184 @test_util.run_deprecated_v1 185 def testMultipleLoopVarsBasic(self): 186 x = constant_op.constant(5.) 187 y = constant_op.constant(3.) 188 189 # x = 5. 190 # y = 3. 191 # while x < 45.: 192 # x = x * y 193 ret = while_loop_v2( 194 lambda v, _: v < 45., 195 lambda v, w: (v * w, w), [x, y], 196 return_same_structure=False) 197 # ret = [x*y^2, y] 198 199 # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. 200 grad = gradients_impl.gradients(ret, [x]) # [2*x*y] 201 with self.cached_session(): 202 self.assertSequenceEqual(self.evaluate(ret), [45., 3.]) 203 self.assertSequenceEqual(self.evaluate(grad), [9.]) 204 205 @test_util.run_deprecated_v1 206 def testMultipleLoopNonscalarCond(self): 207 x = constant_op.constant([[5.]]) 208 y = constant_op.constant(3.) 209 210 # x = 5. 211 # y = 3. 212 # while x < 45.: 213 # x = x * y 214 ret = while_loop_v2( 215 lambda v, _: v < 45., 216 lambda v, w: (v * w, w), [x, y], 217 return_same_structure=False) 218 # ret == [x*y^2, y] 219 220 # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. 221 grad = gradients_impl.gradients(ret, [x]) # [2*x*y] 222 with self.cached_session(): 223 self.assertSequenceEqual(self.evaluate(ret), [45., 3.]) 224 self.assertSequenceEqual(self.evaluate(grad), [9.]) 225 226 @test_util.run_deprecated_v1 227 def testMultipleLoopVars(self): 228 x = constant_op.constant(5.) 229 y = constant_op.constant(3.) 230 231 # x = 5. 232 # y = 3. 233 # while x < 45.: 234 # x = x * y 235 # y = x + y 236 ret = while_loop_v2( 237 lambda v, _: v < 45., 238 lambda v, w: (v * w, v + w), [x, y], 239 return_same_structure=False) 240 # ret = [y*x**2 + x*y**2, x*y + x + y] 241 242 gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2] 243 gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1] 244 gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1] 245 grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2] 246 grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1] 247 grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1] 248 with self.cached_session(): 249 self.assertSequenceEqual(self.evaluate(ret), [120., 23.]) 250 self.assertSequenceEqual(self.evaluate(gradx_0), [39.]) 251 self.assertSequenceEqual(self.evaluate(gradx_1), [4.]) 252 self.assertSequenceEqual(self.evaluate(gradx_2), [43.]) 253 self.assertSequenceEqual(self.evaluate(grady_0), [55.]) 254 self.assertSequenceEqual(self.evaluate(grady_1), [6.]) 255 self.assertSequenceEqual(self.evaluate(grady_2), [61.]) 256 257 @test_util.run_deprecated_v1 258 def testGradientTape(self): 259 with backprop.GradientTape() as t: 260 x = constant_op.constant(2.) 261 t.watch(x) 262 ret = while_loop_v2( 263 lambda v: v < 4., lambda v: v * v, [x], 264 return_same_structure=False) # x**2 265 grad = t.gradient(ret, x) 266 with self.cached_session() as sess: 267 self.assertAllEqual(sess.run(grad), 4.0) 268 269 @test_util.run_deprecated_v1 270 def testMultipleWhileLoops(self): 271 x = constant_op.constant(2.) 272 ret1 = while_loop_v2( 273 lambda v: v < 4., lambda v: v * v, [x], 274 return_same_structure=False) # x**2 275 ret2 = while_loop_v2( 276 lambda v: v < 16., lambda v: v * v, [ret1], 277 return_same_structure=False) # x**4 278 grad = gradients_impl.gradients(ret2, [x]) # 4x**3 279 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 280 with self.cached_session(): 281 self.assertSequenceEqual(self.evaluate(grad), [32.]) 282 self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) 283 284 def testMultipleWhileLoopsWithFunc(self): 285 x = constant_op.constant(2.) 286 287 @def_function.function 288 def Fn(): 289 ret1 = while_loop_v2( 290 lambda v: v < 4., 291 lambda v: v * v, [x], 292 return_same_structure=False, 293 name="while_1") # x**2 294 ret2 = while_loop_v2( 295 lambda v: v < 16., 296 lambda v: v * v, [x], 297 return_same_structure=False, 298 name="while_2") # x**4 299 return ret1, ret2 300 301 concrete_fn = Fn.get_concrete_function() 302 while_1 = concrete_fn.graph.get_operation_by_name("while_1") 303 while_2 = concrete_fn.graph.get_operation_by_name("while_2") 304 self.assertEqual(while_1.type, "StatelessWhile") 305 self.assertEqual(while_2.type, "StatelessWhile") 306 self.assertEmpty(while_1.control_inputs) 307 self.assertEmpty(while_2.control_inputs) 308 309 def testMultipleWhileLoopsGradStateless(self): 310 311 @def_function.function 312 def Fn(): 313 x = constant_op.constant(2.) 314 with backprop.GradientTape() as tape: 315 tape.watch(x) 316 ret1 = while_loop_v2( 317 lambda v: v < 4., 318 lambda v: v * v, [x], 319 return_same_structure=False, 320 name="while_1") # x**2 321 ret2 = while_loop_v2( 322 lambda v: v < 16., 323 lambda v: v * v, [x], 324 return_same_structure=False, 325 name="while_2") # x**4 326 loss = ret1 + ret2 327 return tape.gradient(loss, x) 328 329 graph = Fn.get_concrete_function().graph 330 while_ops = [op for op in graph.get_operations() if "While" in op.type] 331 self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4, 332 "Must have exactly 4 StatelessWhile ops.") 333 for op in while_ops: 334 self.assertEmpty(op.control_inputs, 335 "{} should not have any control inputs".format(op.name)) 336 337 def testMultipleWhileLoopsWithDeps(self): 338 x = variables.Variable(2.) 339 c = constant_op.constant(2.) 340 341 @def_function.function 342 def Fn(): 343 ret1 = while_loop_v2( 344 lambda v: v < 4., 345 lambda v: v * x, [c], 346 return_same_structure=False, 347 name="while_1") # 2x 348 ret2 = while_loop_v2( 349 lambda v: v < 16., 350 lambda v: v * x * x, [c], 351 return_same_structure=False, 352 name="while_2") # 4x 353 return ret1, ret2 354 355 concrete_fn = Fn.get_concrete_function() 356 while_1 = concrete_fn.graph.get_operation_by_name("while_1") 357 while_2 = concrete_fn.graph.get_operation_by_name("while_2") 358 self.assertEqual(while_1.type, "While") 359 self.assertEqual(while_2.type, "While") 360 self.assertEmpty(while_1.control_inputs) 361 self.assertLen(while_2.control_inputs, 1) 362 self.assertIs(while_2.control_inputs[0], while_1) 363 364 def testMultipleWhileLoopsWithVarsDeps(self): 365 x1 = variables.Variable(2.) 366 x2 = variables.Variable(3.) 367 c = constant_op.constant(2.) 368 369 @def_function.function 370 def Fn(): 371 ret1 = while_loop_v2( 372 lambda v: v < 4., 373 lambda v: v * x1, [c], 374 return_same_structure=False, 375 name="while_1") # 2x 376 ret2 = while_loop_v2( 377 lambda v: v < 16., 378 lambda v: v * x1 * x1, [c], 379 return_same_structure=False, 380 name="while_2") # 4x 381 ret3 = while_loop_v2( 382 lambda v: v < 4., 383 lambda v: v * x2, [c], 384 return_same_structure=False, 385 name="while_3") # 3x 386 ret4 = while_loop_v2( 387 lambda v: v < 16., 388 lambda v: v * x2 * x2, [c], 389 return_same_structure=False, 390 name="while_4") # 9x 391 ret5 = while_loop_v2( 392 lambda v: v < 16., 393 lambda v: v * v, [c], 394 return_same_structure=False, 395 name="while_stateless") # x**2 396 return ret1, ret2, ret3, ret4, ret5 397 398 concrete_fn = Fn.get_concrete_function() 399 while_1 = concrete_fn.graph.get_operation_by_name("while_1") 400 while_2 = concrete_fn.graph.get_operation_by_name("while_2") 401 while_3 = concrete_fn.graph.get_operation_by_name("while_3") 402 while_4 = concrete_fn.graph.get_operation_by_name("while_4") 403 while_stateless = concrete_fn.graph.get_operation_by_name( 404 "while_stateless") 405 self.assertEqual(while_1.type, "While") 406 self.assertEqual(while_2.type, "While") 407 self.assertEqual(while_3.type, "While") 408 self.assertEqual(while_4.type, "While") 409 self.assertEqual(while_stateless.type, "StatelessWhile") 410 self.assertEmpty(while_1.control_inputs) 411 self.assertLen(while_2.control_inputs, 1) 412 self.assertIs(while_2.control_inputs[0], while_1) 413 self.assertEmpty(while_3.control_inputs) 414 self.assertLen(while_4.control_inputs, 1) 415 self.assertIs(while_4.control_inputs[0], while_3) 416 self.assertEmpty(while_stateless.control_inputs) 417 418 @test_util.run_deprecated_v1 419 def testDoubleDerivative(self): 420 x = constant_op.constant(2.) 421 ret = while_loop_v2( 422 lambda v: v < 8., lambda v: v**2, [x], 423 return_same_structure=False) # x**4 424 grad = gradients_impl.gradients(ret, [x]) # 4x**3 425 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 426 with self.cached_session(): 427 self.assertEqual(self.evaluate(ret), 16.) 428 self.assertSequenceEqual(self.evaluate(grad), [32.]) 429 self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) 430 431 @test_util.run_v2_only 432 def testMultipleWhileLoopsEager(self): 433 434 @def_function.function 435 def Func(): 436 x = constant_op.constant(2.) 437 ret1 = while_loop_v2( 438 lambda v: v < 4., lambda v: v * v, [x], 439 return_same_structure=False) # x**2 440 ret2 = while_loop_v2( 441 lambda v: v < 16., 442 lambda v: v * v, [ret1], 443 return_same_structure=False) # x**4 444 grad = gradients_impl.gradients(ret2, [x])[0] # 4x**3 445 grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2 446 return grad, grad_grad 447 448 grad, grad_grad = Func() 449 self.assertEqual(grad.numpy(), 32.) 450 self.assertEqual(grad_grad.numpy(), 48.) 451 452 @test_util.run_v2_only 453 def testDoubleDerivativeEager(self): 454 455 @def_function.function 456 def Func(): 457 x = constant_op.constant(2.) 458 ret = while_loop_v2( 459 lambda v: v < 8., lambda v: v**2, [x], 460 return_same_structure=False) # x**4 461 grad = gradients_impl.gradients(ret, [x])[0] # 4x**3 462 grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2 463 return ret, grad, grad_grad 464 465 ret, grad, grad_grad = Func() 466 self.assertEqual(ret.numpy(), 16.) 467 self.assertEqual(grad.numpy(), 32.) 468 self.assertEqual(grad_grad.numpy(), 48.) 469 470 def _testPruning(self): 471 x = constant_op.constant(1) 472 473 tensor_list = list_ops.empty_tensor_list( 474 element_dtype=x.dtype, element_shape=x.shape) 475 476 def Cond(x, tl): 477 del tl # Unused for Cond. 478 return x < 5 479 480 def Body(x, tl): 481 return x + 1, list_ops.tensor_list_push_back(tl, x) 482 483 outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) 484 485 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 486 train_op.append(outputs[0]) 487 488 g = GetOptimizedGraph() 489 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 490 # away, causing an extra Enter node. 491 enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1 492 self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 493 # Test that the TensorList is pruned out. 494 self.assertEmpty([ 495 n for n in g.node if n.op == "Enter" and 496 n.attr["T"].type == dtypes.variant.as_datatype_enum 497 ]) 498 self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 499 500 stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) 501 train_op.append(stack) 502 g = GetOptimizedGraph() 503 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 504 # away, causing an extra Enter node. 505 enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 506 self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 507 # Test that the TensorList is not pruned out. 508 self.assertNotEmpty([ 509 n for n in g.node if n.op == "Enter" and 510 n.attr["T"].type == dtypes.variant.as_datatype_enum 511 ]) 512 self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 513 514 @test_util.run_deprecated_v1 515 def testPruningV1(self): 516 self._testPruning() 517 518 @test_util.enable_control_flow_v2 519 @test_util.run_deprecated_v1 520 def testPruningV2(self): 521 self._testPruning() 522 523 def _testDoNotAccumulateInvariants(self): 524 push_op = ("TensorListPushBack" 525 if control_flow_v2_toggles.control_flow_v2_enabled() else 526 "StackPushV2") 527 528 # Tests that loop invariants, i.e., tensors that are "captured" by the 529 # while loop and not passed as loop variables are not accumulated in 530 # gradient computation. 531 v = constant_op.constant(5.0, name="v") 532 533 r = control_flow_ops.while_loop( 534 lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) 535 536 output = gradients_impl.gradients(r, v)[0] 537 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 538 train_op.append(output) 539 540 g = GetOptimizedGraph() 541 # The gradient for v * x requires the value of both v and x. Since v is a 542 # loop invariant it is not accumulated so we have just one accumulator for 543 # x. 544 self.assertLen([n for n in g.node if n.op == push_op], 1) 545 546 @test_util.run_deprecated_v1 547 def testDoNotAccumulateInvariantsV1(self): 548 self._testDoNotAccumulateInvariants() 549 550 @test_util.run_deprecated_v1 551 @test_util.enable_control_flow_v2 552 def testDoNotAccumulateInvariantsV2(self): 553 self._testDoNotAccumulateInvariants() 554 555 @test_util.enable_control_flow_v2 556 @test_util.run_deprecated_v1 557 @test_util.enable_output_all_intermediates 558 def testPruningNested(self): 559 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 560 x = constant_op.constant(0) 561 562 tensor_list = list_ops.empty_tensor_list( 563 element_dtype=x.dtype, element_shape=x.shape) 564 565 def Cond(x, tl): 566 del tl # Unused for Cond. 567 return x < 25 568 569 def Body(x, tl): 570 571 def InnerCond(inner_x, unused_outer_x, unused_tl): 572 return inner_x < 5 573 574 def InnerBody(inner_x, outer_x, tl): 575 return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x) 576 577 inner_x = constant_op.constant(0) 578 return control_flow_ops.while_loop(InnerCond, InnerBody, 579 [inner_x, x, tl])[1:] 580 581 outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) 582 583 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 584 train_op.append(outputs[0]) 585 586 g = GetOptimizedGraph() 587 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 588 # away, causing an extra Enter node. 589 # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 590 # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 591 # Test that the TensorList is pruned out. 592 self.assertEmpty([ 593 n for n in g.node if n.op == "Enter" and 594 n.attr["T"].type == dtypes.variant.as_datatype_enum 595 ]) 596 self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 597 self.assertEmpty([n for n in g.node if n.op == "_While"]) 598 599 stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) 600 train_op.append(stack) 601 g = GetOptimizedGraph() 602 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 603 # away, causing an extra Enter node. 604 # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 605 # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 606 # Test that the TensorList is not pruned out. 607 self.assertNotEmpty([ 608 n for n in g.node if n.op == "Enter" and 609 n.attr["T"].type == dtypes.variant.as_datatype_enum 610 ]) 611 self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 612 613 @test_util.enable_control_flow_v2 614 @test_util.run_deprecated_v1 615 @test_util.enable_output_all_intermediates 616 def testPruningNested2(self): 617 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 618 v = constant_op.constant(5.0, name="v") 619 620 p = array_ops.placeholder(dtype=dtypes.int32) 621 622 def MidBodyBuilder(iterations): 623 624 def MidBody(i, x): 625 r = control_flow_ops.while_loop( 626 lambda *_: True, 627 lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")), 628 (0, x), 629 maximum_iterations=iterations, 630 name="inner") 631 return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) 632 633 return MidBody 634 635 def OuterBody(i, x): 636 iterations = array_ops.size(p, name="iterations") 637 return (i + 1, x + control_flow_ops.while_loop( 638 lambda *_: True, 639 MidBodyBuilder(iterations), (0, x), 640 maximum_iterations=iterations, 641 name="mid")[1]) 642 643 def CreateWhileLoop(): 644 with ops.device("/cpu:0"): 645 r = control_flow_ops.while_loop( 646 lambda *_: True, 647 OuterBody, (0, 1.0), 648 maximum_iterations=5, 649 name="outer") 650 return array_ops.identity(r[1]) 651 652 output = CreateWhileLoop() 653 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 654 train_op.append(output) 655 656 g = GetOptimizedGraph() 657 self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1) 658 659 @test_util.enable_control_flow_v2 660 @test_util.run_deprecated_v1 661 @test_util.enable_output_all_intermediates 662 def testPruningNested3(self): 663 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 664 v = constant_op.constant(5.0, name="v") 665 666 def CreateWhileLoop(): 667 r = control_flow_ops.while_loop( 668 lambda _: True, 669 lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0], 670 maximum_iterations=5, 671 name="outer") 672 return array_ops.identity(r) 673 674 r = CreateWhileLoop() 675 output = gradients_impl.gradients(r, v)[0] 676 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 677 train_op.append(output) 678 679 g = GetOptimizedGraph() 680 self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1) 681 682 def _assertNotAccumulated(self, while_op, index): 683 """Asserts that `while_op` input at `index` is not accumulated.""" 684 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 685 placeholder = body_graph.inputs[index] 686 self.assertNotIn("TensorListPushBack", 687 [op.type for op in placeholder.consumers()]) 688 689 @test_util.enable_control_flow_v2 690 @test_util.run_deprecated_v1 691 @test_util.enable_output_all_intermediates 692 def testDoNotOutputLoopCounterAsIntermediate(self): 693 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 694 v = constant_op.constant(5.0, name="v") 695 r = control_flow_ops.while_loop( 696 lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) 697 # Skip over Identity. 698 while_op = r.op.inputs[0].op 699 self._assertNotAccumulated(while_op, 0) 700 701 @test_util.enable_control_flow_v2 702 @test_util.run_deprecated_v1 703 @test_util.enable_output_all_intermediates 704 def testDoNotOutputLoopInvariantAsIntermediate(self): 705 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 706 707 def GetInputIndex(op, tensor): 708 for index, inp in enumerate(op.inputs): 709 if inp is tensor: 710 return index 711 712 v = constant_op.constant(5.0, name="v") 713 r = control_flow_ops.while_loop( 714 lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) 715 # Skip over Identity. 716 while_op = r.op.inputs[0].op 717 # We can't directly use while_op.inputs.index() because Tensors are not 718 # hashshable. 719 index = GetInputIndex(while_op, v) 720 self._assertNotAccumulated(while_op, index) 721 722 @test_util.run_deprecated_v1 723 def testCaptureExternalTensorInCond(self): 724 x = constant_op.constant(2.) 725 y = constant_op.constant(1.) 726 ret = while_loop_v2( 727 lambda v: v + y < 9., 728 lambda v: v * 3., [x], 729 return_same_structure=False) 730 grad = gradients_impl.gradients(ret, [x]) 731 with self.cached_session(): 732 self.assertEqual(self.evaluate(ret), 18.) 733 self.assertSequenceEqual(self.evaluate(grad), [9.]) 734 735 @test_util.run_deprecated_v1 736 def testCaptureExternalTensorInBody(self): 737 x = constant_op.constant(2.) 738 y = constant_op.constant(3.) 739 ret = while_loop_v2( 740 lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False) 741 grad = gradients_impl.gradients(ret, [x]) 742 with self.cached_session(): 743 self.assertEqual(self.evaluate(ret), 18.) 744 self.assertSequenceEqual(self.evaluate(grad), [9.]) 745 746 @test_util.run_deprecated_v1 747 def testLoopWithTensorListPushBack(self): 748 x = constant_op.constant(2.) 749 750 tensor_list = list_ops.empty_tensor_list( 751 element_dtype=dtypes.float32, element_shape=ScalarShape()) 752 753 def Cond(x, tl): 754 del tl # Unused for Cond. 755 return x < 5. 756 757 def Body(x, tl): 758 tl = list_ops.tensor_list_push_back(tl, x) 759 tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.)) 760 return x**2., tl 761 762 ret = while_loop_v2( 763 Cond, Body, [x, tensor_list], return_same_structure=False) 764 grad = gradients_impl.gradients(ret[0], x) 765 with self.cached_session() as sess: 766 self.assertEqual(sess.run(ret[0]), 16.) 767 self.assertSequenceEqual(self.evaluate(grad), [32.]) 768 769 @test_util.run_deprecated_v1 770 def testDuplicateAccumulator(self): 771 x = constant_op.constant(2.) 772 773 tensor_list = list_ops.empty_tensor_list( 774 element_dtype=dtypes.float32, element_shape=ScalarShape()) 775 776 def Cond(x, tl): 777 del tl # Unused for Cond. 778 return x < 5. 779 780 def Body(x, tl): 781 # There is an accumulator in the loop already so we should not add 782 # another. 783 tl = list_ops.tensor_list_push_back(tl, x) 784 return x**2., tl 785 786 ret = while_loop_v2( 787 Cond, Body, [x, tensor_list], return_same_structure=False) 788 789 for op in ops.get_default_graph().get_operations(): 790 if op.type == "While" or op.type == "StatelessWhile": 791 while_op = op 792 793 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 794 x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0] 795 x_input_t = body_graph.inputs[x_input_index] 796 accumulator_count = len( 797 [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"]) 798 self.assertEqual(accumulator_count, 1) 799 800 grad = gradients_impl.gradients(ret[0], x) 801 with self.cached_session() as sess: 802 self.assertEqual(sess.run(ret[0]), 16.) 803 self.assertSequenceEqual(self.evaluate(grad), [32.]) 804 805 @parameterized.named_parameters( 806 ("UnknownShape", None), 807 ("PartiallyDefinedShape", [None, 2]), 808 ("FullyDefinedShape", [1, 2]), 809 ) 810 @test_util.run_deprecated_v1 811 def testAccumulatorElementShape(self, shape): 812 813 def MatchShape(actual_tensor_shape): 814 # Compare the shapes, treating None dimensions as equal. We do not 815 # directly check actual_tensor_shape and tf.TensorShape(shape) for 816 # equality because tf.Dimension.__eq__ returns None if either dimension is 817 # None. 818 if shape is None: 819 self.assertIsNone(actual_tensor_shape.dims) 820 else: 821 self.assertListEqual(actual_tensor_shape.as_list(), shape) 822 823 def GetAccumulatorForInputAtIndex(while_op, idx): 824 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 825 y_input_t = body_graph.inputs[idx] 826 push_back_node = [c for c in y_input_t.consumers() 827 if c.type == "TensorListPushBack"][0] 828 output_idx = body_graph.outputs.index(push_back_node.outputs[0]) 829 return while_op.outputs[output_idx] 830 831 x = array_ops.placeholder(dtype=dtypes.float32, shape=shape) 832 y = array_ops.placeholder(dtype=dtypes.float32, shape=shape) 833 834 # Forward pass. 835 ret = while_loop_v2(lambda v, u: v < 8., 836 lambda v, u: (math_ops.pow(v, u), u), 837 [x, y], 838 return_same_structure=True) 839 while_op = ret[0].op.inputs[0].op 840 # Gradient pass. 841 grad = gradients_impl.gradients(ret[0], x) 842 # Note: There is an Identity b/w grad[0] and the While op. 843 grad_while_op = grad[0].op.inputs[0].op 844 845 # Get the TensorList output of While op containing the accumulated values 846 # of y. 847 x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0] 848 output = GetAccumulatorForInputAtIndex(while_op, x_input_index) 849 _, val = list_ops.tensor_list_pop_back(output, 850 element_dtype=dtypes.float32) 851 MatchShape(val.shape) 852 853 # Take second derivative to generate intermediate grad_while_op outputs 854 gradients_impl.gradients(grad, x) 855 856 # Get the TensorList output of gradient While op containing the accumulated 857 # values of grad_x (note that grad_x is needed by the second derivative). 858 # grad_while_op.inputs: 859 grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0]) 860 grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 861 grad_output_index) 862 _, val = list_ops.tensor_list_pop_back(grad_output, 863 element_dtype=dtypes.float32) 864 MatchShape(val.shape) 865 866 def _createWhile(self, name): 867 """Helper function testDefaultName.""" 868 output = while_v2.while_loop( 869 lambda i: i < 3, 870 lambda i: i + 1, [constant_op.constant(0)], 871 return_same_structure=False) 872 while_op = output.op.inputs[0].op 873 self.assertEqual(while_op.type, "StatelessWhile") 874 return while_op 875 876 def testDefaultName(self): 877 with ops.Graph().as_default(): 878 while_op = self._createWhile(None) 879 self.assertEqual(while_op.name, "while") 880 self.assertRegexpMatches( 881 while_op.get_attr("cond").name, r"while_cond_\d*") 882 self.assertRegexpMatches( 883 while_op.get_attr("body").name, r"while_body_\d*") 884 885 with ops.Graph().as_default(): 886 with ops.name_scope("foo"): 887 while1_op = self._createWhile("") 888 self.assertEqual(while1_op.name, "foo/while") 889 self.assertRegexpMatches( 890 while1_op.get_attr("cond").name, r"foo_while_cond_\d*") 891 self.assertRegexpMatches( 892 while1_op.get_attr("body").name, r"foo_while_body_\d*") 893 894 while2_op = self._createWhile(None) 895 self.assertEqual(while2_op.name, "foo/while_1") 896 self.assertRegexpMatches( 897 while2_op.get_attr("cond").name, r"foo_while_1_cond_\d*") 898 self.assertRegexpMatches( 899 while2_op.get_attr("body").name, r"foo_while_1_body_\d*") 900 901 @test_util.enable_control_flow_v2 902 @test_util.run_deprecated_v1 903 def testWhileAndTensorArray(self): 904 param = constant_op.constant(2.0) 905 y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") 906 # map_fn uses TensorArray internally. 907 r = map_fn.map_fn(lambda x: math_ops.multiply(x, param), y0) 908 grad = gradients_impl.gradients(r, param)[0] 909 self.assertAllClose([2.0, 4.0, 6.0, 8.0, 10.0, 12.0], self.evaluate(r)) 910 self.assertAllClose(21.0, self.evaluate(grad)) 911 912 @test_util.run_deprecated_v1 913 def testNestedWhile(self): 914 # Compute sum of geometric progression: n^0 + n^1 + ... + n^m 915 # We compute the pow using a while loop. 916 n = constant_op.constant(3.) 917 m = constant_op.constant(5.) 918 sum_of_powers = constant_op.constant(0.) 919 920 def Body(i, previous_sum): 921 prod = constant_op.constant(1.) 922 return i - 1., previous_sum + while_loop_v2( 923 lambda c, _: c > 0, 924 lambda c, v: (c - 1., v * n), [i, prod], 925 return_same_structure=False)[1] 926 927 result = while_loop_v2( 928 lambda i, _: i >= 0, 929 Body, [m, sum_of_powers], 930 return_same_structure=False)[1] 931 grad = gradients_impl.gradients(result, [n]) 932 self.assertEqual(self.evaluate(result), 364.) 933 self.assertSequenceEqual(self.evaluate(grad), [547.]) 934 935 @test_util.run_deprecated_v1 936 def testNestedWhileWithLegacyDefun(self): 937 n = constant_op.constant(3.) 938 m = constant_op.constant(5.) 939 sum_of_powers = constant_op.constant(0.) 940 941 def Body(i, previous_sum): 942 prod = constant_op.constant(1.) 943 944 def InnerBodyWrapper(c, v): 945 946 @function.Defun(dtypes.float32, dtypes.float32) 947 def InnerBody(c, v): 948 return c - 1., v * n 949 950 results = InnerBody(c, v) 951 results[0].set_shape([]) 952 results[1].set_shape([]) 953 return results 954 955 return i - 1., previous_sum + while_loop_v2( 956 lambda c, _: c > 0, 957 InnerBodyWrapper, [i, prod], 958 return_same_structure=False)[1] 959 960 result = while_loop_v2( 961 lambda i, _: i >= 0, 962 Body, [m, sum_of_powers], 963 return_same_structure=False)[1] 964 grad = gradients_impl.gradients(result, [n]) 965 self.assertEqual(self.evaluate(result), 364.) 966 self.assertSequenceEqual(self.evaluate(grad), [547.]) 967 968 @test_util.run_deprecated_v1 969 def testIdentityNodeInBody(self): 970 971 def Body(v): 972 v = array_ops.identity(v) 973 v = array_ops.identity(v) 974 return v * v 975 976 x = constant_op.constant(2.) 977 ret = while_loop_v2( 978 lambda v: v < 8., Body, [x], return_same_structure=False) 979 grad = gradients_impl.gradients(ret, [x]) 980 self.assertEqual(self.evaluate(ret), 16.) 981 self.assertSequenceEqual(self.evaluate(grad), [32.]) 982 983 @test_util.run_deprecated_v1 984 def testForwardPassRewrite(self): 985 x = constant_op.constant(1.0, name="x") 986 output = while_v2.while_loop(lambda x: x < 10.0, 987 lambda x: x * 2.0, 988 [x])[0] 989 while_op = output.op.inputs[0].op 990 self.assertEqual(while_op.type, "StatelessWhile") 991 # outputs = [loop_counter, max_iters, x] 992 self.assertLen(while_op.outputs, 3) 993 994 gradients_impl.gradients(output, x) 995 # while_op should have been rewritten to output intermediates. 996 # outputs = [loop_counter, max_iters, x, x_accumulator] 997 self.assertLen(while_op.outputs, 4) 998 999 gradients_impl.gradients(output, x) 1000 # Computing the gradient again shouldn't rewrite while_op again. 1001 self.assertLen(while_op.outputs, 4) 1002 1003 @parameterized.named_parameters( 1004 ("RandomUniform", random_ops.random_uniform, [5, 3]), 1005 ("RandomNormal", random_ops.random_normal, [5, 3]), 1006 ("ParameterizedTruncatedNormal", 1007 random_ops.parameterized_truncated_normal, [5, 3]), 1008 ("TruncatedNormal", random_ops.truncated_normal, [5, 3]), 1009 ("RandomGamma", random_gamma, [5, 3]), 1010 ("RandomPoissonV2", random_poisson_v2, [5, 3]), 1011 ("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]), 1012 ("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]), 1013 ) 1014 @test_util.run_deprecated_v1 1015 def testRandomOpsShape(self, random_fn, expected_shape): 1016 shape = constant_op.constant([3]) 1017 1018 def Body(i, u): 1019 shape_extended = array_ops.concat([[5], shape], axis=0) 1020 u = random_fn(shape_extended) 1021 assert u.shape.as_list() == expected_shape, str(u.shape.as_list()) 1022 return i + 1, u 1023 1024 _, _ = while_loop_v2( 1025 cond=lambda i, _: i < 3, 1026 body=Body, 1027 loop_vars=[ 1028 0, 1029 array_ops.zeros(expected_shape, dtype=dtypes.float32), 1030 ]) 1031 1032 @test_util.run_deprecated_v1 1033 def testReshapeShape(self): 1034 shape = constant_op.constant([3, 4]) 1035 1036 def Body(i, u): 1037 shape_extended = array_ops.concat([[5], shape], axis=0) 1038 u = array_ops.reshape(u, [-1]) 1039 assert u.shape.as_list() == [60], str(u.shape.as_list()) 1040 u = array_ops.reshape(u, shape_extended) 1041 assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list()) 1042 return i + 1, u 1043 1044 _, _ = while_loop_v2( 1045 cond=lambda i, _: i < 3, 1046 body=Body, 1047 loop_vars=[ 1048 0, 1049 array_ops.zeros([5, 3, 4], dtype=dtypes.float32), 1050 ]) 1051 1052 @parameterized.named_parameters( 1053 ("Zeros", array_ops.zeros), 1054 ("Ones", array_ops.ones), 1055 ("Fill", fill), 1056 ) 1057 @test_util.run_deprecated_v1 1058 def testFillOpsShape(self, fill_fn): 1059 shape = constant_op.constant([3, 4]) 1060 1061 def Body(i, u): 1062 shape_extended = array_ops.concat([[5], shape], axis=0) 1063 u = fill_fn(shape_extended) 1064 assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list()) 1065 return i + 1, u 1066 1067 _, _ = while_loop_v2( 1068 cond=lambda i, _: i < 3, 1069 body=Body, 1070 loop_vars=[ 1071 0, 1072 array_ops.zeros([5, 3, 4], dtype=dtypes.float32), 1073 ]) 1074 1075 @test_util.run_deprecated_v1 1076 def testExternalColocationGrad(self): 1077 external_t = constant_op.constant(2.) 1078 v0 = constant_op.constant(2.) 1079 1080 def Body(v): 1081 with ops.colocate_with(external_t): 1082 return v * v 1083 1084 ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] 1085 grad = gradients_impl.gradients(ret, [v0])[0] 1086 self.assertAllEqual(ret, 16.) 1087 self.assertAllEqual(grad, 32.) 1088 1089 @test_util.run_deprecated_v1 1090 def testDoNotAccumulateConstNodes(self): 1091 1092 def Body(v): 1093 return v * 2.0 1094 1095 v0 = constant_op.constant(2.) 1096 ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] 1097 # Gradients computation has the side-effect of updating the forward op 1098 # which is what we want to test. 1099 unused_grad = gradients_impl.gradients(ret, [v0])[0] 1100 # ret is separated from the `While` op by an `Identity` so we skip over 1101 # that. 1102 forward_while_op = ret.op.inputs[0].op 1103 body_graph = while_v2._get_graph(forward_while_op, "body", "_body_graph") 1104 push_back_nodes = [ 1105 o for o in body_graph.get_operations() if o.type == "TensorListPushBack" 1106 ] 1107 # Gradient of `Mul` requires accumulating both its inputs. But since one 1108 # of those is a Const (2.0), we should have just one accumulator. 1109 self.assertLen(push_back_nodes, 1) 1110 1111 def testDoNotAccumulateForwardTensorsForReductionOps(self): 1112 1113 @def_function.function 1114 def Fn(): 1115 with backprop.GradientTape() as tape: 1116 x = constant_op.constant(2.) 1117 tape.watch(x) 1118 1119 def Body(i, x): 1120 forward_graph = ops.get_default_graph() 1121 1122 @custom_gradient.custom_gradient 1123 def SquaredWithZeroGrad(x): 1124 1125 def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name 1126 del variables 1127 gradient_graph = ops.get_default_graph() 1128 shape = gen_array_ops.shape(x) 1129 assert shape.graph is forward_graph 1130 rank = gen_array_ops.rank(x) 1131 assert rank.graph is forward_graph 1132 size = gen_array_ops.size(x) 1133 assert size.graph is forward_graph 1134 zeros = array_ops.zeros(shape) 1135 assert zeros.graph is gradient_graph 1136 return zeros 1137 1138 return x * 2, Grad 1139 1140 return i + 1, SquaredWithZeroGrad(x) 1141 1142 _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x]) 1143 grad = tape.gradient(result, x) 1144 return grad 1145 1146 Fn() 1147 1148 1149def ScalarShape(): 1150 return ops.convert_to_tensor([], dtype=dtypes.int32) 1151 1152 1153def GetOptimizedGraph(): 1154 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 1155 config = config_pb2.ConfigProto() 1156 config.graph_options.rewrite_options.CopyFrom( 1157 rewriter_config_pb2.RewriterConfig( 1158 constant_folding=rewriter_config_pb2.RewriterConfig.OFF, 1159 memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)) 1160 return tf_optimizer.OptimizeGraph(config, mg) 1161 1162 1163if __name__ == "__main__": 1164 test.main() 1165