1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for while_v2.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22 23from google.protobuf import text_format 24from tensorflow.core.framework import graph_pb2 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.core.protobuf import rewriter_config_pb2 27from tensorflow.python.eager import backprop 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import function 33from tensorflow.python.framework import importer 34from tensorflow.python.framework import meta_graph 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_shape 37from tensorflow.python.framework import test_util 38from tensorflow.python.grappler import tf_optimizer 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import control_flow_ops 41from tensorflow.python.ops import control_flow_util 42from tensorflow.python.ops import control_flow_util_v2 43from tensorflow.python.ops import control_flow_v2_toggles 44from tensorflow.python.ops import custom_gradient 45from tensorflow.python.ops import gen_array_ops 46from tensorflow.python.ops import gen_list_ops 47from tensorflow.python.ops import gradient_checker_v2 48from tensorflow.python.ops import gradients_impl 49from tensorflow.python.ops import list_ops 50from tensorflow.python.ops import map_fn 51from tensorflow.python.ops import math_ops 52from tensorflow.python.ops import random_ops 53from tensorflow.python.ops import variables 54from tensorflow.python.ops import while_v2 55from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 56from tensorflow.python.platform import test 57 58 59def random_gamma(shape): # pylint: disable=invalid-name 60 return random_ops.random_gamma(shape, 1.0) 61 62 63def random_gamma_with_alpha_beta(shape): # pylint: disable=invalid-name 64 return random_ops.random_gamma( 65 shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]]) 66 67 68def random_poisson_v2(shape): # pylint: disable=invalid-name 69 return random_ops.random_poisson_v2(shape, 1.0) 70 71 72def random_poisson_v2_with_lam(shape): # pylint: disable=invalid-name 73 return random_ops.random_poisson_v2(shape, [12.2, 3.3]) 74 75 76def fill(shape): # pylint: disable=invalid-name 77 return array_ops.fill(shape, 1.0) 78 79 80class WhileV2Test(test.TestCase, parameterized.TestCase): 81 82 @test_util.run_deprecated_v1 83 def testSingleLoopVar(self): 84 x = constant_op.constant(2.) 85 ret = while_loop_v2( 86 lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False) 87 grad = gradients_impl.gradients(ret, [x]) 88 with self.cached_session(): 89 self.assertEqual(self.evaluate(ret), 16.) 90 self.assertSequenceEqual(self.evaluate(grad), [32.]) 91 92 @test_util.run_deprecated_v1 93 def testSingleLoopVarBackPropFalse(self): 94 x = constant_op.constant(2.) 95 ret = while_loop_v2( 96 lambda v: v < 8., 97 lambda v: v * v, [x], 98 return_same_structure=False, 99 back_prop=False) 100 grad = gradients_impl.gradients(ret, [x]) 101 self.assertEqual(grad, [None]) 102 with self.cached_session(): 103 self.assertEqual(self.evaluate(ret), 16.) 104 105 @test_util.run_deprecated_v1 106 def testCustomGradient(self): 107 x = constant_op.constant(2.) 108 n = constant_op.constant(1., name="const-n") 109 m = variables.Variable(1.0) 110 self.evaluate(variables.global_variables_initializer()) 111 112 def body_fn(v): # pylint: disable=invalid-name 113 114 @custom_gradient.custom_gradient 115 def inner_fn(v): # pylint: disable=invalid-name 116 117 def grad_fn(dy, variables=None): # pylint: disable=invalid-name, unused-argument, redefined-outer-name 118 return dy * 2 * v * n * m, [v * v] 119 120 return v * v * m, grad_fn 121 122 return inner_fn(v) 123 124 ret = while_loop_v2( 125 lambda v: v < 8., body_fn, [x], return_same_structure=False) 126 grad = gradients_impl.gradients(ret, [x]) 127 with self.cached_session(): 128 self.assertEqual(self.evaluate(ret), 16.) 129 self.assertSequenceEqual(self.evaluate(grad), [32.]) 130 131 @test_util.run_v1_only("b/120545219") 132 def testReturnSameStructureTrue(self): 133 x = constant_op.constant(2.) 134 ret = while_loop_v2( 135 lambda v: v < 8., lambda v: v * v, [x], return_same_structure=True) 136 grad = gradients_impl.gradients(ret, [x]) 137 with self.cached_session() as sess: 138 eval_result = sess.run(ret) 139 self.assertIsInstance(eval_result, list) 140 self.assertLen(eval_result, 1) 141 self.assertEqual(16., eval_result[0]) 142 self.assertSequenceEqual(sess.run(grad), [32.]) 143 144 def testVerifyInputOutputTypesMatch(self): 145 146 @def_function.function 147 def BuildWhile(): 148 x = constant_op.constant(1., dtypes.float32) 149 150 def Body(x): 151 return math_ops.cast(x, dtypes.float16) + 1 152 153 while_loop_v2(lambda x: x < 10, Body, [x]) 154 155 with self.assertRaisesRegex( 156 TypeError, 157 r"Loop var Const:0 enters the loop with type <dtype: 'float32'> " 158 r"but has type <dtype: 'float16'> after 1 iteration."): 159 BuildWhile() 160 161 @parameterized.parameters(dtypes.float32, dtypes.float64) 162 def testGradientTapeResourceVariable(self, dtype): 163 with context.eager_mode(): 164 v = variables.Variable(1., dtype=dtype) 165 166 @def_function.function 167 def fnWithLoop(): # pylint: disable=invalid-name 168 with backprop.GradientTape() as tape: 169 _, x = while_loop_v2( 170 lambda i, _: i < 2, 171 lambda i, x: (i + 1, x * v), 172 [0, constant_op.constant(2., dtype=dtype)]) 173 return tape.gradient(x, v) 174 175 self.assertAllEqual(fnWithLoop(), 4.0) 176 177 def checkIteratedGradients(self, func): 178 with context.eager_mode(): 179 180 def _Grad(f): 181 def _GradFunction(primal): 182 with backprop.GradientTape() as tape: 183 tape.watch(primal) 184 primal_out = f(primal) 185 return tape.gradient(primal_out, primal) 186 return _GradFunction 187 188 f = func 189 one = constant_op.constant(1.) 190 191 for _ in range(3): 192 theoretical, numerical = gradient_checker_v2.compute_gradient( 193 def_function.function(f), [one]) 194 self.assertAllClose(theoretical, numerical, rtol=1e-3) 195 f = _Grad(f) 196 self.assertAllClose(array_ops.reshape(numerical, []), 197 def_function.function(f)(one), 198 rtol=1e-3) 199 200 def testIteratedGradients(self): 201 202 def _Func(x): 203 _, z = while_loop_v2( 204 lambda i, _: i < 2, 205 lambda i, y: (i + 1, math_ops.cos(y)), 206 [0, x]) 207 return z 208 209 self.checkIteratedGradients(_Func) 210 211 def testIteratedGradientsWithList(self): 212 213 def _Func(x): 214 results = list_ops.empty_tensor_list( 215 element_shape=[], element_dtype=dtypes.float32) 216 217 def _LoopBody(i, y, handle): 218 return (i + 1, math_ops.cos(y), 219 list_ops.tensor_list_push_back(handle, y)) 220 221 _, z, results = while_loop_v2( 222 lambda i, _, h: i < 2, _LoopBody, [0, x, results]) 223 return z + math_ops.reduce_sum(list_ops.tensor_list_stack( 224 results, dtypes.float32)) 225 226 self.checkIteratedGradients(_Func) 227 228 def testGradWhileGradWhileWithVariable(self): 229 with context.eager_mode(): 230 v = variables.Variable(1.) 231 232 @def_function.function 233 def _Func(x): 234 235 def _Inner(a): 236 with backprop.GradientTape() as tape: 237 tape.watch(a) 238 _, b = while_loop_v2( 239 lambda i, _: i < 2, 240 lambda i, y: (i + 1, math_ops.cos(v + y)), 241 [0, a]) 242 return tape.gradient(b, a) 243 244 _, z = while_loop_v2( 245 lambda i, _: i < 2, 246 lambda i, y: (i + 1, _Inner(y)), 247 [0, x]) 248 return z 249 250 with backprop.GradientTape(persistent=True) as tape: 251 x = constant_op.constant(1.) 252 tape.watch(x) 253 y = _Func(x) 254 dx, _ = tape.gradient(y, [x, v]) 255 theoretical, numerical = gradient_checker_v2.compute_gradient( 256 _Func, [x]) 257 self.assertAllClose(numerical, theoretical, rtol=1e-3) 258 self.assertAllClose(array_ops.reshape(numerical, []), 259 dx, rtol=1e-3) 260 261 def testThreeNestWithLists(self): 262 with context.eager_mode(): 263 def _WrapInWhile(f): 264 def _Wrapped(x): 265 results = list_ops.empty_tensor_list( 266 element_shape=[], element_dtype=dtypes.float32) 267 268 def _LoopBody(i, y, handle): 269 return (i + 1, f(math_ops.cos(y)), 270 list_ops.tensor_list_push_back(handle, y)) 271 272 _, z, results = control_flow_ops.while_loop( 273 lambda i, _, h: i < 2, _LoopBody, [0, x, results]) 274 return z + math_ops.reduce_sum(list_ops.tensor_list_stack( 275 results, dtypes.float32)) 276 return _Wrapped 277 278 f = math_ops.sin 279 280 target_function = _WrapInWhile(_WrapInWhile(_WrapInWhile(f))) 281 282 @def_function.function 283 def _TapeFromGraphMode(x): 284 with backprop.GradientTape(persistent=True) as tape: 285 tape.watch(x) 286 y = target_function(x) 287 return tape.gradient(y, x) 288 289 x = constant_op.constant(1.) 290 dx = _TapeFromGraphMode(x) 291 theoretical, numerical = gradient_checker_v2.compute_gradient( 292 target_function, [x]) 293 self.assertAllClose(numerical, theoretical, rtol=3e-3) 294 self.assertAllClose(array_ops.reshape(numerical, []), dx, rtol=3e-3) 295 296 def testDeviceLabelsInherited(self): 297 def _LoopBody(i, y): 298 result = math_ops.cos(y) 299 self.assertIn("CPU:10", result.device) 300 with ops.device("CPU:11"): 301 result = array_ops.identity(result) 302 self.assertIn("CPU:11", result.device) 303 return i + 1, result 304 305 @def_function.function 306 def _FunctionWithWhileLoop(): 307 x = constant_op.constant(1.) 308 with ops.device("CPU:10"): 309 _, z = while_loop_v2( 310 lambda i, _: i < 2, 311 _LoopBody, 312 [0, x]) 313 return z 314 # The test assertion runs at trace time. 315 _FunctionWithWhileLoop.get_concrete_function() 316 317 def testExternalControlDependencies(self): 318 with ops.Graph().as_default(), self.test_session(): 319 v = variables.Variable(1.) 320 self.evaluate(v.initializer) 321 op = v.assign_add(1.) 322 323 def body_fn(i): # pylint: disable=invalid-name 324 with ops.control_dependencies([op]): 325 return i + 1 326 327 loop = while_loop_v2(lambda i: i < 1, body_fn, [0]) 328 loop[0].op.run() 329 self.assertAllEqual(self.evaluate(v), 2.0) 330 331 @test_util.run_deprecated_v1 332 def testMultipleLoopVarsBasic(self): 333 x = constant_op.constant(5.) 334 y = constant_op.constant(3.) 335 336 # x = 5. 337 # y = 3. 338 # while x < 45.: 339 # x = x * y 340 ret = while_loop_v2( 341 lambda v, _: v < 45., 342 lambda v, w: (v * w, w), [x, y], 343 return_same_structure=False) 344 # ret = [x*y^2, y] 345 346 # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. 347 grad = gradients_impl.gradients(ret, [x]) # [2*x*y] 348 with self.cached_session(): 349 self.assertSequenceEqual(self.evaluate(ret), [45., 3.]) 350 self.assertSequenceEqual(self.evaluate(grad), [9.]) 351 352 @test_util.run_deprecated_v1 353 def testMultipleLoopNonscalarCond(self): 354 x = constant_op.constant([[5.]]) 355 y = constant_op.constant(3.) 356 357 # x = 5. 358 # y = 3. 359 # while x < 45.: 360 # x = x * y 361 ret = while_loop_v2( 362 lambda v, _: v < 45., 363 lambda v, w: (v * w, w), [x, y], 364 return_same_structure=False) 365 # ret == [x*y^2, y] 366 367 # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. 368 grad = gradients_impl.gradients(ret, [x]) # [2*x*y] 369 with self.cached_session(): 370 self.assertSequenceEqual(self.evaluate(ret), [45., 3.]) 371 self.assertSequenceEqual(self.evaluate(grad), [9.]) 372 373 @test_util.run_deprecated_v1 374 def testMultipleLoopVars(self): 375 x = constant_op.constant(5.) 376 y = constant_op.constant(3.) 377 378 # x = 5. 379 # y = 3. 380 # while x < 45.: 381 # x = x * y 382 # y = x + y 383 ret = while_loop_v2( 384 lambda v, _: v < 45., 385 lambda v, w: (v * w, v + w), [x, y], 386 return_same_structure=False) 387 # ret = [y*x**2 + x*y**2, x*y + x + y] 388 389 gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2] 390 gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1] 391 gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1] 392 grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2] 393 grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1] 394 grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1] 395 with self.cached_session(): 396 self.assertSequenceEqual(self.evaluate(ret), [120., 23.]) 397 self.assertSequenceEqual(self.evaluate(gradx_0), [39.]) 398 self.assertSequenceEqual(self.evaluate(gradx_1), [4.]) 399 self.assertSequenceEqual(self.evaluate(gradx_2), [43.]) 400 self.assertSequenceEqual(self.evaluate(grady_0), [55.]) 401 self.assertSequenceEqual(self.evaluate(grady_1), [6.]) 402 self.assertSequenceEqual(self.evaluate(grady_2), [61.]) 403 404 @test_util.run_deprecated_v1 405 def testGradientTape(self): 406 with backprop.GradientTape() as t: 407 x = constant_op.constant(2.) 408 t.watch(x) 409 ret = while_loop_v2( 410 lambda v: v < 4., lambda v: v * v, [x], 411 return_same_structure=False) # x**2 412 grad = t.gradient(ret, x) 413 with self.cached_session() as sess: 414 self.assertAllEqual(sess.run(grad), 4.0) 415 416 @test_util.run_deprecated_v1 417 def testMultipleWhileLoops(self): 418 x = constant_op.constant(2.) 419 ret1 = while_loop_v2( 420 lambda v: v < 4., lambda v: v * v, [x], 421 return_same_structure=False) # x**2 422 ret2 = while_loop_v2( 423 lambda v: v < 16., lambda v: v * v, [ret1], 424 return_same_structure=False) # x**4 425 grad = gradients_impl.gradients(ret2, [x]) # 4x**3 426 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 427 with self.cached_session(): 428 self.assertSequenceEqual(self.evaluate(grad), [32.]) 429 self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) 430 431 def testMultipleWhileLoopsWithFunc(self): 432 x = constant_op.constant(2.) 433 434 @def_function.function 435 def Fn(): 436 ret1 = while_loop_v2( 437 lambda v: v < 4., 438 lambda v: v * v, [x], 439 return_same_structure=False, 440 name="while_1") # x**2 441 ret2 = while_loop_v2( 442 lambda v: v < 16., 443 lambda v: v * v, [x], 444 return_same_structure=False, 445 name="while_2") # x**4 446 return ret1, ret2 447 448 concrete_fn = Fn.get_concrete_function() 449 while_1 = concrete_fn.graph.get_operation_by_name("while_1") 450 while_2 = concrete_fn.graph.get_operation_by_name("while_2") 451 self.assertEqual(while_1.type, "StatelessWhile") 452 self.assertEqual(while_2.type, "StatelessWhile") 453 self.assertEmpty(while_1.control_inputs) 454 self.assertEmpty(while_2.control_inputs) 455 456 def testMultipleWhileLoopsGradStateless(self): 457 458 @def_function.function 459 def Fn(): 460 x = constant_op.constant(2.) 461 with backprop.GradientTape() as tape: 462 tape.watch(x) 463 ret1 = while_loop_v2( 464 lambda v: v < 4., 465 lambda v: v * v, [x], 466 return_same_structure=False, 467 name="while_1") # x**2 468 ret2 = while_loop_v2( 469 lambda v: v < 16., 470 lambda v: v * v, [x], 471 return_same_structure=False, 472 name="while_2") # x**4 473 loss = ret1 + ret2 474 return tape.gradient(loss, x) 475 476 graph = Fn.get_concrete_function().graph 477 while_ops = [op for op in graph.get_operations() if "While" in op.type] 478 self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4, 479 "Must have exactly 4 StatelessWhile ops.") 480 for op in while_ops: 481 self.assertEmpty(op.control_inputs, 482 "{} should not have any control inputs".format(op.name)) 483 484 def testMultipleWhileLoopsWithDeps(self): 485 x = variables.Variable(2.) 486 c = constant_op.constant(2.) 487 488 @def_function.function 489 def Fn(): 490 491 def Body1(v): 492 x.assign(x) 493 return v * x 494 495 ret1 = while_loop_v2( 496 lambda v: v < 4., 497 Body1, [c], 498 return_same_structure=False, 499 name="while_1") # 2x 500 501 def Body2(v): 502 x.assign(x) 503 return v * x * x 504 505 ret2 = while_loop_v2( 506 lambda v: v < 16., 507 Body2, [c], 508 return_same_structure=False, 509 name="while_2") # 4x 510 return ret1, ret2 511 512 concrete_fn = Fn.get_concrete_function() 513 while_1 = concrete_fn.graph.get_operation_by_name("while_1") 514 while_2 = concrete_fn.graph.get_operation_by_name("while_2") 515 self.assertEqual(while_1.type, "While") 516 self.assertEqual(while_2.type, "While") 517 self.assertEmpty(while_1.control_inputs) 518 self.assertLen(while_2.control_inputs, 1) 519 self.assertIs(while_2.control_inputs[0], while_1) 520 521 def testMultipleWhileLoopsWithVarsDeps(self): 522 x1 = variables.Variable(2.) 523 x2 = variables.Variable(3.) 524 c = constant_op.constant(2.) 525 526 @def_function.function 527 def Fn(): 528 529 def Body1(v): 530 x1.assign(x1) 531 return v * x1 532 533 ret1 = while_loop_v2( 534 lambda v: v < 4., 535 Body1, [c], 536 return_same_structure=False, 537 name="while_1") # 2x 538 539 def Body2(v): 540 x1.assign(x1) 541 return v * x1 * x1 542 543 ret2 = while_loop_v2( 544 lambda v: v < 16., 545 Body2, [c], 546 return_same_structure=False, 547 name="while_2") # 4x 548 549 def Body3(v): 550 x2.assign(x2) 551 return v * x2 552 553 ret3 = while_loop_v2( 554 lambda v: v < 4., 555 Body3, [c], 556 return_same_structure=False, 557 name="while_3") # 3x 558 559 def Body4(v): 560 x2.assign(x2) 561 return v * x2 * x2 562 563 ret4 = while_loop_v2( 564 lambda v: v < 16., 565 Body4, [c], 566 return_same_structure=False, 567 name="while_4") # 9x 568 ret5 = while_loop_v2( 569 lambda v: v < 16., 570 lambda v: v * v, [c], 571 return_same_structure=False, 572 name="while_stateless") # x**2 573 return ret1, ret2, ret3, ret4, ret5 574 575 concrete_fn = Fn.get_concrete_function() 576 while_1 = concrete_fn.graph.get_operation_by_name("while_1") 577 while_2 = concrete_fn.graph.get_operation_by_name("while_2") 578 while_3 = concrete_fn.graph.get_operation_by_name("while_3") 579 while_4 = concrete_fn.graph.get_operation_by_name("while_4") 580 while_stateless = concrete_fn.graph.get_operation_by_name( 581 "while_stateless") 582 self.assertEqual(while_1.type, "While") 583 self.assertEqual(while_2.type, "While") 584 self.assertEqual(while_3.type, "While") 585 self.assertEqual(while_4.type, "While") 586 self.assertEqual(while_stateless.type, "StatelessWhile") 587 self.assertEmpty(while_1.control_inputs) 588 self.assertLen(while_2.control_inputs, 1) 589 self.assertIs(while_2.control_inputs[0], while_1) 590 self.assertEmpty(while_3.control_inputs) 591 self.assertLen(while_4.control_inputs, 1) 592 self.assertIs(while_4.control_inputs[0], while_3) 593 self.assertEmpty(while_stateless.control_inputs) 594 595 @test_util.run_deprecated_v1 596 def testDoubleDerivative(self): 597 x = constant_op.constant(2.) 598 ret = while_loop_v2( 599 lambda v: v < 8., lambda v: v**2, [x], 600 return_same_structure=False) # x**4 601 grad = gradients_impl.gradients(ret, [x]) # 4x**3 602 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 603 with self.cached_session(): 604 self.assertEqual(self.evaluate(ret), 16.) 605 self.assertSequenceEqual(self.evaluate(grad), [32.]) 606 self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) 607 608 @test_util.run_v2_only 609 def testMultipleWhileLoopsEager(self): 610 611 @def_function.function 612 def Func(): 613 x = constant_op.constant(2.) 614 ret1 = while_loop_v2( 615 lambda v: v < 4., lambda v: v * v, [x], 616 return_same_structure=False) # x**2 617 ret2 = while_loop_v2( 618 lambda v: v < 16., 619 lambda v: v * v, [ret1], 620 return_same_structure=False) # x**4 621 grad = gradients_impl.gradients(ret2, [x])[0] # 4x**3 622 grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2 623 return grad, grad_grad 624 625 grad, grad_grad = Func() 626 self.assertEqual(grad.numpy(), 32.) 627 self.assertEqual(grad_grad.numpy(), 48.) 628 629 @test_util.run_v2_only 630 def testDoubleDerivativeEager(self): 631 632 @def_function.function 633 def Func(): 634 x = constant_op.constant(2.) 635 ret = while_loop_v2( 636 lambda v: v < 8., lambda v: v**2, [x], 637 return_same_structure=False) # x**4 638 grad = gradients_impl.gradients(ret, [x])[0] # 4x**3 639 grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2 640 return ret, grad, grad_grad 641 642 ret, grad, grad_grad = Func() 643 self.assertEqual(ret.numpy(), 16.) 644 self.assertEqual(grad.numpy(), 32.) 645 self.assertEqual(grad_grad.numpy(), 48.) 646 647 def _testPruning(self): 648 x = constant_op.constant(1) 649 650 tensor_list = list_ops.empty_tensor_list( 651 element_dtype=x.dtype, element_shape=x.shape) 652 653 def Cond(x, tl): 654 del tl # Unused for Cond. 655 return x < 5 656 657 def Body(x, tl): 658 return x + 1, list_ops.tensor_list_push_back(tl, x) 659 660 outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) 661 662 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 663 train_op.append(outputs[0]) 664 665 g = GetOptimizedGraph() 666 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 667 # away, causing an extra Enter node. 668 enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1 669 self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 670 # Test that the TensorList is pruned out. 671 self.assertEmpty([ 672 n for n in g.node if n.op == "Enter" and 673 n.attr["T"].type == dtypes.variant.as_datatype_enum 674 ]) 675 self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 676 677 stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) 678 train_op.append(stack) 679 g = GetOptimizedGraph() 680 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 681 # away, causing an extra Enter node. 682 enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 683 self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 684 # Test that the TensorList is not pruned out. 685 self.assertNotEmpty([ 686 n for n in g.node if n.op == "Enter" and 687 n.attr["T"].type == dtypes.variant.as_datatype_enum 688 ]) 689 self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 690 691 @test_util.run_deprecated_v1 692 def testPruningV1(self): 693 self._testPruning() 694 695 @test_util.enable_control_flow_v2 696 @test_util.run_deprecated_v1 697 def testPruningV2(self): 698 self._testPruning() 699 700 def _testDoNotAccumulateInvariants(self): 701 push_op = ("TensorListPushBack" 702 if control_flow_v2_toggles.control_flow_v2_enabled() else 703 "StackPushV2") 704 705 # Tests that loop invariants, i.e., tensors that are "captured" by the 706 # while loop and not passed as loop variables are not accumulated in 707 # gradient computation. 708 v = constant_op.constant(5.0, name="v") 709 710 r = control_flow_ops.while_loop( 711 lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) 712 713 output = gradients_impl.gradients(r, v)[0] 714 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 715 train_op.append(output) 716 717 g = GetOptimizedGraph() 718 # The gradient for v * x requires the value of both v and x. Since v is a 719 # loop invariant it is not accumulated so we have just one accumulator for 720 # x. 721 self.assertLen([n for n in g.node if n.op == push_op], 1) 722 723 @test_util.run_deprecated_v1 724 def testDoNotAccumulateInvariantsV1(self): 725 self._testDoNotAccumulateInvariants() 726 727 @test_util.run_deprecated_v1 728 @test_util.enable_control_flow_v2 729 def testDoNotAccumulateInvariantsV2(self): 730 self._testDoNotAccumulateInvariants() 731 732 @test_util.enable_control_flow_v2 733 @test_util.run_deprecated_v1 734 @test_util.enable_output_all_intermediates 735 def testPruningNested(self): 736 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 737 x = constant_op.constant(0) 738 739 tensor_list = list_ops.empty_tensor_list( 740 element_dtype=x.dtype, element_shape=x.shape) 741 742 def Cond(x, tl): 743 del tl # Unused for Cond. 744 return x < 25 745 746 def Body(x, tl): 747 748 def InnerCond(inner_x, unused_outer_x, unused_tl): 749 return inner_x < 5 750 751 def InnerBody(inner_x, outer_x, tl): 752 return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x) 753 754 inner_x = constant_op.constant(0) 755 return control_flow_ops.while_loop(InnerCond, InnerBody, 756 [inner_x, x, tl])[1:] 757 758 outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) 759 760 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 761 train_op.append(outputs[0]) 762 763 g = GetOptimizedGraph() 764 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 765 # away, causing an extra Enter node. 766 # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 767 # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 768 # Test that the TensorList is pruned out. 769 self.assertEmpty([ 770 n for n in g.node if n.op == "Enter" and 771 n.attr["T"].type == dtypes.variant.as_datatype_enum 772 ]) 773 self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 774 self.assertEmpty([n for n in g.node if n.op == "_While"]) 775 776 stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) 777 train_op.append(stack) 778 g = GetOptimizedGraph() 779 # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned 780 # away, causing an extra Enter node. 781 # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 782 # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) 783 # Test that the TensorList is not pruned out. 784 self.assertNotEmpty([ 785 n for n in g.node if n.op == "Enter" and 786 n.attr["T"].type == dtypes.variant.as_datatype_enum 787 ]) 788 self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"]) 789 790 @test_util.enable_control_flow_v2 791 @test_util.run_deprecated_v1 792 @test_util.enable_output_all_intermediates 793 def testPruningNested2(self): 794 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 795 v = constant_op.constant(5.0, name="v") 796 797 p = array_ops.placeholder(dtype=dtypes.int32) 798 799 def MidBodyBuilder(iterations): 800 801 def MidBody(i, x): 802 r = control_flow_ops.while_loop( 803 lambda *_: True, 804 lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")), 805 (0, x), 806 maximum_iterations=iterations, 807 name="inner") 808 return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) 809 810 return MidBody 811 812 def OuterBody(i, x): 813 iterations = array_ops.size(p, name="iterations") 814 return (i + 1, x + control_flow_ops.while_loop( 815 lambda *_: True, 816 MidBodyBuilder(iterations), (0, x), 817 maximum_iterations=iterations, 818 name="mid")[1]) 819 820 def CreateWhileLoop(): 821 with ops.device("/cpu:0"): 822 r = control_flow_ops.while_loop( 823 lambda *_: True, 824 OuterBody, (0, 1.0), 825 maximum_iterations=5, 826 name="outer") 827 return array_ops.identity(r[1]) 828 829 output = CreateWhileLoop() 830 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 831 train_op.append(output) 832 833 g = GetOptimizedGraph() 834 self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1) 835 836 @test_util.enable_control_flow_v2 837 @test_util.run_deprecated_v1 838 @test_util.enable_output_all_intermediates 839 def testPruningNested3(self): 840 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 841 v = constant_op.constant(5.0, name="v") 842 843 def CreateWhileLoop(): 844 r = control_flow_ops.while_loop( 845 lambda _: True, 846 lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0], 847 maximum_iterations=5, 848 name="outer") 849 return array_ops.identity(r) 850 851 r = CreateWhileLoop() 852 output = gradients_impl.gradients(r, v)[0] 853 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 854 train_op.append(output) 855 856 g = GetOptimizedGraph() 857 self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1) 858 859 def _assertNotAccumulated(self, while_op, index): 860 """Asserts that `while_op` input at `index` is not accumulated.""" 861 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 862 placeholder = body_graph.inputs[index] 863 self.assertNotIn("TensorListPushBack", 864 [op.type for op in placeholder.consumers()]) 865 866 @test_util.enable_control_flow_v2 867 @test_util.run_deprecated_v1 868 @test_util.enable_output_all_intermediates 869 def testDoNotOutputLoopCounterAsIntermediate(self): 870 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 871 v = constant_op.constant(5.0, name="v") 872 r = control_flow_ops.while_loop( 873 lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) 874 # Skip over Identity. 875 while_op = r.op.inputs[0].op 876 self._assertNotAccumulated(while_op, 0) 877 878 @test_util.enable_control_flow_v2 879 @test_util.run_deprecated_v1 880 @test_util.enable_output_all_intermediates 881 def testDoNotOutputLoopInvariantAsIntermediate(self): 882 assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 883 884 def GetInputIndex(op, tensor): 885 for index, inp in enumerate(op.inputs): 886 if inp is tensor: 887 return index 888 889 v = constant_op.constant(5.0, name="v") 890 r = control_flow_ops.while_loop( 891 lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) 892 # Skip over Identity. 893 while_op = r.op.inputs[0].op 894 # We can't directly use while_op.inputs.index() because Tensors are not 895 # hashable. 896 index = GetInputIndex(while_op, v) 897 self._assertNotAccumulated(while_op, index) 898 899 @test_util.run_deprecated_v1 900 def testCaptureExternalTensorInCond(self): 901 x = constant_op.constant(2.) 902 y = constant_op.constant(1.) 903 ret = while_loop_v2( 904 lambda v: v + y < 9., 905 lambda v: v * 3., [x], 906 return_same_structure=False) 907 grad = gradients_impl.gradients(ret, [x]) 908 with self.cached_session(): 909 self.assertEqual(self.evaluate(ret), 18.) 910 self.assertSequenceEqual(self.evaluate(grad), [9.]) 911 912 @test_util.run_deprecated_v1 913 def testCaptureExternalTensorInBody(self): 914 x = constant_op.constant(2.) 915 y = constant_op.constant(3.) 916 ret = while_loop_v2( 917 lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False) 918 grad = gradients_impl.gradients(ret, [x]) 919 with self.cached_session(): 920 self.assertEqual(self.evaluate(ret), 18.) 921 self.assertSequenceEqual(self.evaluate(grad), [9.]) 922 923 @test_util.run_deprecated_v1 924 def testLoopWithTensorListPushBack(self): 925 x = constant_op.constant(2.) 926 927 tensor_list = list_ops.empty_tensor_list( 928 element_dtype=dtypes.float32, element_shape=ScalarShape()) 929 930 def Cond(x, tl): 931 del tl # Unused for Cond. 932 return x < 5. 933 934 def Body(x, tl): 935 tl = list_ops.tensor_list_push_back(tl, x) 936 tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.)) 937 return x**2., tl 938 939 ret = while_loop_v2( 940 Cond, Body, [x, tensor_list], return_same_structure=False) 941 grad = gradients_impl.gradients(ret[0], x) 942 with self.cached_session() as sess: 943 self.assertEqual(sess.run(ret[0]), 16.) 944 self.assertSequenceEqual(self.evaluate(grad), [32.]) 945 946 @test_util.run_deprecated_v1 947 def testDuplicateAccumulator(self): 948 x = constant_op.constant(2.) 949 950 tensor_list = list_ops.empty_tensor_list( 951 element_dtype=dtypes.float32, element_shape=ScalarShape()) 952 953 def Cond(x, tl): 954 del tl # Unused for Cond. 955 return x < 5. 956 957 def Body(x, tl): 958 # There is an accumulator in the loop already so we should not add 959 # another. 960 tl = list_ops.tensor_list_push_back(tl, x) 961 return x**2., tl 962 963 ret = while_loop_v2( 964 Cond, Body, [x, tensor_list], return_same_structure=False) 965 966 for op in ops.get_default_graph().get_operations(): 967 if op.type == "While" or op.type == "StatelessWhile": 968 while_op = op 969 970 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 971 x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0] 972 x_input_t = body_graph.inputs[x_input_index] 973 accumulator_count = len( 974 [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"]) 975 self.assertEqual(accumulator_count, 1) 976 977 grad = gradients_impl.gradients(ret[0], x) 978 with self.cached_session() as sess: 979 self.assertEqual(sess.run(ret[0]), 16.) 980 self.assertSequenceEqual(self.evaluate(grad), [32.]) 981 982 @parameterized.named_parameters( 983 ("UnknownShape", None), 984 ("PartiallyDefinedShape", [None, 2]), 985 ("FullyDefinedShape", [1, 2]), 986 ) 987 @test_util.run_deprecated_v1 988 def testAccumulatorElementShape(self, shape): 989 990 def MatchShape(actual_tensor_shape): 991 # Compare the shapes, treating None dimensions as equal. We do not 992 # directly check actual_tensor_shape and tf.TensorShape(shape) for 993 # equality because tf.Dimension.__eq__ returns None if either dimension is 994 # None. 995 if shape is None: 996 self.assertIsNone(actual_tensor_shape.dims) 997 else: 998 self.assertListEqual(actual_tensor_shape.as_list(), shape) 999 1000 def GetAccumulatorForInputAtIndex(while_op, idx): 1001 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 1002 y_input_t = body_graph.inputs[idx] 1003 push_back_node = [c for c in y_input_t.consumers() 1004 if c.type == "TensorListPushBack"][0] 1005 output_idx = body_graph.outputs.index(push_back_node.outputs[0]) 1006 return while_op.outputs[output_idx] 1007 1008 x = array_ops.placeholder(dtype=dtypes.float32, shape=shape) 1009 y = array_ops.placeholder(dtype=dtypes.float32, shape=shape) 1010 1011 # Forward pass. 1012 ret = while_loop_v2(lambda v, u: v < 8., 1013 lambda v, u: (math_ops.pow(v, u), u), 1014 [x, y], 1015 return_same_structure=True) 1016 while_op = ret[0].op.inputs[0].op 1017 # Gradient pass. 1018 grad = gradients_impl.gradients(ret[0], x) 1019 # Note: There is an Identity b/w grad[0] and the While op. 1020 grad_while_op = grad[0].op.inputs[0].op 1021 1022 # Get the TensorList output of While op containing the accumulated values 1023 # of y. 1024 x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0] 1025 output = GetAccumulatorForInputAtIndex(while_op, x_input_index) 1026 _, val = list_ops.tensor_list_pop_back(output, 1027 element_dtype=dtypes.float32) 1028 MatchShape(val.shape) 1029 1030 # Take second derivative to generate intermediate grad_while_op outputs 1031 gradients_impl.gradients(grad, x) 1032 1033 # Get the TensorList output of gradient While op containing the accumulated 1034 # values of grad_x (note that grad_x is needed by the second derivative). 1035 # grad_while_op.inputs: 1036 grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0]) 1037 grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 1038 grad_output_index) 1039 _, val = list_ops.tensor_list_pop_back(grad_output, 1040 element_dtype=dtypes.float32) 1041 MatchShape(val.shape) 1042 1043 def _createWhile(self, name): 1044 """Helper function testDefaultName.""" 1045 output = while_v2.while_loop( 1046 lambda i: i < 3, 1047 lambda i: i + 1, [constant_op.constant(0)], 1048 return_same_structure=False) 1049 while_op = output.op.inputs[0].op 1050 self.assertEqual(while_op.type, "StatelessWhile") 1051 return while_op 1052 1053 def testDefaultName(self): 1054 with ops.Graph().as_default(): 1055 while_op = self._createWhile(None) 1056 self.assertEqual(while_op.name, "while") 1057 self.assertRegex(while_op.get_attr("cond").name, r"while_cond_\d*") 1058 self.assertRegex(while_op.get_attr("body").name, r"while_body_\d*") 1059 1060 with ops.Graph().as_default(): 1061 with ops.name_scope("foo"): 1062 while1_op = self._createWhile("") 1063 self.assertEqual(while1_op.name, "foo/while") 1064 self.assertRegex(while1_op.get_attr("cond").name, r"foo_while_cond_\d*") 1065 self.assertRegex(while1_op.get_attr("body").name, r"foo_while_body_\d*") 1066 1067 while2_op = self._createWhile(None) 1068 self.assertEqual(while2_op.name, "foo/while_1") 1069 self.assertRegex( 1070 while2_op.get_attr("cond").name, r"foo_while_1_cond_\d*") 1071 self.assertRegex( 1072 while2_op.get_attr("body").name, r"foo_while_1_body_\d*") 1073 1074 @test_util.enable_control_flow_v2 1075 @test_util.run_deprecated_v1 1076 def testWhileAndTensorArray(self): 1077 param = constant_op.constant(2.0) 1078 y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") 1079 # map_fn uses TensorArray internally. 1080 r = map_fn.map_fn(lambda x: math_ops.multiply(x, param), y0) 1081 grad = gradients_impl.gradients(r, param)[0] 1082 self.assertAllClose([2.0, 4.0, 6.0, 8.0, 10.0, 12.0], self.evaluate(r)) 1083 self.assertAllClose(21.0, self.evaluate(grad)) 1084 1085 @test_util.run_deprecated_v1 1086 def testNestedWhile(self): 1087 # Compute sum of geometric progression: n^0 + n^1 + ... + n^m 1088 # We compute the pow using a while loop. 1089 n = constant_op.constant(3.) 1090 m = constant_op.constant(5.) 1091 sum_of_powers = constant_op.constant(0.) 1092 1093 def Body(i, previous_sum): 1094 prod = constant_op.constant(1.) 1095 return i - 1., previous_sum + while_loop_v2( 1096 lambda c, _: c > 0, 1097 lambda c, v: (c - 1., v * n), [i, prod], 1098 return_same_structure=False)[1] 1099 1100 result = while_loop_v2( 1101 lambda i, _: i >= 0, 1102 Body, [m, sum_of_powers], 1103 return_same_structure=False)[1] 1104 grad = gradients_impl.gradients(result, [n]) 1105 self.assertEqual(self.evaluate(result), 364.) 1106 self.assertSequenceEqual(self.evaluate(grad), [547.]) 1107 1108 @test_util.run_deprecated_v1 1109 def testNestedWhileWithLegacyDefun(self): 1110 n = constant_op.constant(3.) 1111 m = constant_op.constant(5.) 1112 sum_of_powers = constant_op.constant(0.) 1113 1114 def Body(i, previous_sum): 1115 prod = constant_op.constant(1.) 1116 1117 def InnerBodyWrapper(c, v): 1118 1119 @function.Defun(dtypes.float32, dtypes.float32) 1120 def InnerBody(c, v): 1121 return c - 1., v * n 1122 1123 results = InnerBody(c, v) 1124 results[0].set_shape([]) 1125 results[1].set_shape([]) 1126 return results 1127 1128 return i - 1., previous_sum + while_loop_v2( 1129 lambda c, _: c > 0, 1130 InnerBodyWrapper, [i, prod], 1131 return_same_structure=False)[1] 1132 1133 result = while_loop_v2( 1134 lambda i, _: i >= 0, 1135 Body, [m, sum_of_powers], 1136 return_same_structure=False)[1] 1137 grad = gradients_impl.gradients(result, [n]) 1138 self.assertEqual(self.evaluate(result), 364.) 1139 self.assertSequenceEqual(self.evaluate(grad), [547.]) 1140 1141 @test_util.run_deprecated_v1 1142 def testIdentityNodeInBody(self): 1143 1144 def Body(v): 1145 v = array_ops.identity(v) 1146 v = array_ops.identity(v) 1147 return v * v 1148 1149 x = constant_op.constant(2.) 1150 ret = while_loop_v2( 1151 lambda v: v < 8., Body, [x], return_same_structure=False) 1152 grad = gradients_impl.gradients(ret, [x]) 1153 self.assertEqual(self.evaluate(ret), 16.) 1154 self.assertSequenceEqual(self.evaluate(grad), [32.]) 1155 1156 @test_util.run_deprecated_v1 1157 def testForwardPassRewrite(self): 1158 x = constant_op.constant(1.0, name="x") 1159 output = while_v2.while_loop(lambda x: x < 10.0, 1160 lambda x: x * 2.0, 1161 [x])[0] 1162 while_op = output.op.inputs[0].op 1163 self.assertEqual(while_op.type, "StatelessWhile") 1164 # outputs = [loop_counter, max_iters, x] 1165 self.assertLen(while_op.outputs, 3) 1166 1167 gradients_impl.gradients(output, x) 1168 # while_op should have been rewritten to output intermediates. 1169 # outputs = [loop_counter, max_iters, x, x_accumulator] 1170 self.assertLen(while_op.outputs, 4) 1171 1172 gradients_impl.gradients(output, x) 1173 # Computing the gradient again shouldn't rewrite while_op again. 1174 self.assertLen(while_op.outputs, 4) 1175 1176 @parameterized.named_parameters( 1177 ("RandomUniform", random_ops.random_uniform, [5, 3]), 1178 ("RandomNormal", random_ops.random_normal, [5, 3]), 1179 ("ParameterizedTruncatedNormal", 1180 random_ops.parameterized_truncated_normal, [5, 3]), 1181 ("TruncatedNormal", random_ops.truncated_normal, [5, 3]), 1182 ("RandomGamma", random_gamma, [5, 3]), 1183 ("RandomPoissonV2", random_poisson_v2, [5, 3]), 1184 ("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]), 1185 ("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]), 1186 ) 1187 @test_util.run_deprecated_v1 1188 def testRandomOpsShape(self, random_fn, expected_shape): 1189 shape = constant_op.constant([3]) 1190 1191 def Body(i, u): 1192 shape_extended = array_ops.concat([[5], shape], axis=0) 1193 u = random_fn(shape_extended) 1194 assert u.shape.as_list() == expected_shape, str(u.shape.as_list()) 1195 return i + 1, u 1196 1197 _, _ = while_loop_v2( 1198 cond=lambda i, _: i < 3, 1199 body=Body, 1200 loop_vars=[ 1201 0, 1202 array_ops.zeros(expected_shape, dtype=dtypes.float32), 1203 ]) 1204 1205 @test_util.run_deprecated_v1 1206 def testReshapeShape(self): 1207 shape = constant_op.constant([3, 4]) 1208 1209 def Body(i, u): 1210 shape_extended = array_ops.concat([[5], shape], axis=0) 1211 u = array_ops.reshape(u, [-1]) 1212 assert u.shape.as_list() == [60], str(u.shape.as_list()) 1213 u = array_ops.reshape(u, shape_extended) 1214 assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list()) 1215 return i + 1, u 1216 1217 _, _ = while_loop_v2( 1218 cond=lambda i, _: i < 3, 1219 body=Body, 1220 loop_vars=[ 1221 0, 1222 array_ops.zeros([5, 3, 4], dtype=dtypes.float32), 1223 ]) 1224 1225 @parameterized.named_parameters( 1226 ("Zeros", array_ops.zeros), 1227 ("Ones", array_ops.ones), 1228 ("Fill", fill), 1229 ) 1230 @test_util.run_deprecated_v1 1231 def testFillOpsShape(self, fill_fn): 1232 shape = constant_op.constant([3, 4]) 1233 1234 def Body(i, u): 1235 shape_extended = array_ops.concat([[5], shape], axis=0) 1236 u = fill_fn(shape_extended) 1237 assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list()) 1238 return i + 1, u 1239 1240 _, _ = while_loop_v2( 1241 cond=lambda i, _: i < 3, 1242 body=Body, 1243 loop_vars=[ 1244 0, 1245 array_ops.zeros([5, 3, 4], dtype=dtypes.float32), 1246 ]) 1247 1248 @test_util.run_deprecated_v1 1249 def testExternalColocationGrad(self): 1250 external_t = constant_op.constant(2.) 1251 v0 = constant_op.constant(2.) 1252 1253 def Body(v): 1254 with ops.colocate_with(external_t): 1255 return v * v 1256 1257 ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] 1258 grad = gradients_impl.gradients(ret, [v0])[0] 1259 self.assertAllEqual(ret, 16.) 1260 self.assertAllEqual(grad, 32.) 1261 1262 @test_util.run_deprecated_v1 1263 def testDoNotAccumulateConstNodes(self): 1264 1265 def Body(v): 1266 return v * 2.0 1267 1268 v0 = constant_op.constant(2.) 1269 ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] 1270 # Gradients computation has the side-effect of updating the forward op 1271 # which is what we want to test. 1272 unused_grad = gradients_impl.gradients(ret, [v0])[0] 1273 # ret is separated from the `While` op by an `Identity` so we skip over 1274 # that. 1275 forward_while_op = ret.op.inputs[0].op 1276 body_graph = while_v2._get_graph(forward_while_op, "body", "_body_graph") 1277 push_back_nodes = [ 1278 o for o in body_graph.get_operations() if o.type == "TensorListPushBack" 1279 ] 1280 # Gradient of `Mul` requires accumulating both its inputs. But since one 1281 # of those is a Const (2.0), we should have just one accumulator. 1282 self.assertLen(push_back_nodes, 1) 1283 1284 def testDoNotAccumulateForwardTensorsForReductionOps(self): 1285 1286 @def_function.function 1287 def Fn(): 1288 with backprop.GradientTape() as tape: 1289 x = constant_op.constant(2.) 1290 tape.watch(x) 1291 1292 def Body(i, x): 1293 forward_graph = ops.get_default_graph() 1294 1295 @custom_gradient.custom_gradient 1296 def SquaredWithZeroGrad(x): 1297 1298 def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name 1299 del variables 1300 gradient_graph = ops.get_default_graph() 1301 shape = gen_array_ops.shape(x) 1302 assert shape.graph is forward_graph 1303 rank = gen_array_ops.rank(x) 1304 assert rank.graph is forward_graph 1305 size = gen_array_ops.size(x) 1306 assert size.graph is forward_graph 1307 zeros = array_ops.zeros(shape) 1308 assert zeros.graph is gradient_graph 1309 return zeros 1310 1311 return x * 2, Grad 1312 1313 return i + 1, SquaredWithZeroGrad(x) 1314 1315 _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x]) 1316 grad = tape.gradient(result, x) 1317 return grad 1318 1319 Fn() 1320 1321 def testDoNotAccumulateForwardTensorsForTensorListReductionOps(self): 1322 1323 @def_function.function 1324 def Fn(): 1325 with backprop.GradientTape() as tape: 1326 e = constant_op.constant(2.) 1327 x = list_ops.empty_tensor_list( 1328 element_dtype=dtypes.float32, element_shape=e.shape) 1329 x = list_ops.tensor_list_push_back(x, e) 1330 tape.watch(x) 1331 1332 def Body(i, x): 1333 forward_graph = ops.get_default_graph() 1334 1335 @custom_gradient.custom_gradient 1336 def IdentityWithZeroGrad(x): 1337 1338 def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name 1339 del variables 1340 gradient_graph = ops.get_default_graph() 1341 shape = gen_list_ops.tensor_list_element_shape( 1342 x, shape_type=dtypes.int32) 1343 assert shape.graph is forward_graph 1344 size = gen_list_ops.tensor_list_length(x) 1345 assert size.graph is forward_graph 1346 zeros = gen_list_ops.tensor_list_reserve(shape, size, 1347 dtypes.float32) 1348 assert zeros.graph is gradient_graph 1349 return zeros 1350 1351 return x, Grad 1352 1353 return i + 1, IdentityWithZeroGrad(x) 1354 1355 _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x]) 1356 ones_like = list_ops.tensor_list_from_tensor( 1357 array_ops.ones_like( 1358 list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)), 1359 element_shape=tensor_shape.TensorShape([])) 1360 grad = tape.gradient(result, x, output_gradients=[ones_like]) 1361 return grad 1362 1363 Fn() 1364 1365 @test_util.run_v2_only 1366 def testInheritParentNameScope(self): 1367 1368 @def_function.function 1369 def F(): 1370 with ops.name_scope("foo"): 1371 1372 def Cond(unused_i): 1373 with ops.name_scope("cond"): 1374 actual_name_scope = ops.get_name_scope() 1375 expected_name_scope = "foo/while/cond" 1376 assert actual_name_scope == expected_name_scope, ( 1377 "%s does not match %s" % 1378 (actual_name_scope, expected_name_scope)) 1379 return False 1380 1381 def Body(i): 1382 with ops.name_scope("body"): 1383 actual_name_scope = ops.get_name_scope() 1384 expected_name_scope = "foo/while/body" 1385 assert actual_name_scope == expected_name_scope, ( 1386 "%s does not match %s" % 1387 (actual_name_scope, expected_name_scope)) 1388 return i 1389 1390 return while_v2.while_loop(Cond, Body, [0.]) 1391 1392 F() 1393 1394 @test_util.run_deprecated_v1 # Need to pass RunMetadata. 1395 def testDisableLowering(self): 1396 old = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE 1397 control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = True 1398 with self.session() as sess: 1399 x = constant_op.constant(2.) 1400 ret = while_loop_v2( 1401 lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False) 1402 1403 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 1404 run_metadata = config_pb2.RunMetadata() 1405 self.assertEqual(sess.run(ret, options=opts, run_metadata=run_metadata), 1406 16) 1407 for dev_stat in run_metadata.step_stats.dev_stats: 1408 for ns in dev_stat.node_stats: 1409 self.assertNotIn("switch", ns.node_name) 1410 control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old 1411 1412 def _runBasicWithConfig(self, config): 1413 with ops.device("/cpu:0"): 1414 x = constant_op.constant(0) 1415 ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x]) 1416 with self.cached_session(config=config): 1417 self.assertEqual(1000, self.evaluate(ret)) 1418 1419 @test_util.run_deprecated_v1 1420 def testRunKernelsInline(self): 1421 config = config_pb2.ConfigProto() 1422 config.inter_op_parallelism_threads = -1 1423 self._runBasicWithConfig(config) 1424 1425 @test_util.run_deprecated_v1 1426 def testSingleThreadedExecution(self): 1427 config = config_pb2.ConfigProto() 1428 config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR" 1429 self._runBasicWithConfig(config) 1430 1431 def testIsControlFlowGraph(self): 1432 x = constant_op.constant(0) 1433 1434 @def_function.function 1435 def F(c): 1436 1437 def Cond(i): 1438 self.assertTrue(i.graph.is_control_flow_graph) 1439 return i < 2 1440 1441 def Body(i): 1442 i = i + 1 1443 self.assertTrue(i.graph.is_control_flow_graph) 1444 return i 1445 1446 return while_loop_v2(Cond, Body, [c]) 1447 1448 ret, = F(x) 1449 self.assertEqual(2, self.evaluate(ret)) 1450 1451 def testImportFromSerializedWithFunctionInBody(self): 1452 serialized = """node { 1453 name: "Const" 1454 op: "Const" 1455 attr { 1456 key: "dtype" 1457 value { 1458 type: DT_FLOAT 1459 } 1460 } 1461 attr { 1462 key: "value" 1463 value { 1464 tensor { 1465 dtype: DT_FLOAT 1466 tensor_shape { 1467 } 1468 float_val: 1.0 1469 } 1470 } 1471 } 1472 } 1473 node { 1474 name: "while/maximum_iterations" 1475 op: "Const" 1476 attr { 1477 key: "dtype" 1478 value { 1479 type: DT_INT32 1480 } 1481 } 1482 attr { 1483 key: "value" 1484 value { 1485 tensor { 1486 dtype: DT_INT32 1487 tensor_shape { 1488 } 1489 int_val: -1 1490 } 1491 } 1492 } 1493 } 1494 node { 1495 name: "while/loop_counter" 1496 op: "Const" 1497 attr { 1498 key: "dtype" 1499 value { 1500 type: DT_INT32 1501 } 1502 } 1503 attr { 1504 key: "value" 1505 value { 1506 tensor { 1507 dtype: DT_INT32 1508 tensor_shape { 1509 } 1510 int_val: 0 1511 } 1512 } 1513 } 1514 } 1515 node { 1516 name: "while" 1517 op: "StatelessWhile" 1518 input: "while/loop_counter" 1519 input: "while/maximum_iterations" 1520 input: "Const" 1521 attr { 1522 key: "T" 1523 value { 1524 list { 1525 type: DT_INT32 1526 type: DT_INT32 1527 type: DT_FLOAT 1528 } 1529 } 1530 } 1531 attr { 1532 key: "_lower_using_switch_merge" 1533 value { 1534 b: true 1535 } 1536 } 1537 attr { 1538 key: "_num_original_outputs" 1539 value { 1540 i: 3 1541 } 1542 } 1543 attr { 1544 key: "_read_only_resource_inputs" 1545 value { 1546 list { 1547 } 1548 } 1549 } 1550 attr { 1551 key: "body" 1552 value { 1553 func { 1554 name: "while_body_822" 1555 } 1556 } 1557 } 1558 attr { 1559 key: "cond" 1560 value { 1561 func { 1562 name: "while_cond_821" 1563 } 1564 } 1565 } 1566 attr { 1567 key: "output_shapes" 1568 value { 1569 list { 1570 shape { 1571 } 1572 shape { 1573 } 1574 shape { 1575 } 1576 } 1577 } 1578 } 1579 attr { 1580 key: "parallel_iterations" 1581 value { 1582 i: 10 1583 } 1584 } 1585 } 1586 node { 1587 name: "while/Identity" 1588 op: "Identity" 1589 input: "while" 1590 attr { 1591 key: "T" 1592 value { 1593 type: DT_INT32 1594 } 1595 } 1596 } 1597 node { 1598 name: "while/Identity_1" 1599 op: "Identity" 1600 input: "while:1" 1601 attr { 1602 key: "T" 1603 value { 1604 type: DT_INT32 1605 } 1606 } 1607 } 1608 node { 1609 name: "while/Identity_2" 1610 op: "Identity" 1611 input: "while:2" 1612 attr { 1613 key: "T" 1614 value { 1615 type: DT_FLOAT 1616 } 1617 } 1618 } 1619 library { 1620 function { 1621 signature { 1622 name: "while_body_822" 1623 input_arg { 1624 name: "while_loop_counter" 1625 type: DT_INT32 1626 } 1627 input_arg { 1628 name: "while_maximum_iterations_0" 1629 type: DT_INT32 1630 } 1631 input_arg { 1632 name: "placeholder" 1633 type: DT_FLOAT 1634 } 1635 output_arg { 1636 name: "add" 1637 type: DT_INT32 1638 } 1639 output_arg { 1640 name: "while_maximum_iterations" 1641 type: DT_INT32 1642 } 1643 output_arg { 1644 name: "partitionedcall" 1645 type: DT_FLOAT 1646 } 1647 } 1648 node_def { 1649 name: "PartitionedCall" 1650 op: "PartitionedCall" 1651 input: "placeholder" 1652 attr { 1653 key: "Tin" 1654 value { 1655 list { 1656 type: DT_FLOAT 1657 } 1658 } 1659 } 1660 attr { 1661 key: "Tout" 1662 value { 1663 list { 1664 type: DT_FLOAT 1665 } 1666 } 1667 } 1668 attr { 1669 key: "_collective_manager_ids" 1670 value { 1671 list { 1672 } 1673 } 1674 } 1675 attr { 1676 key: "_read_only_resource_inputs" 1677 value { 1678 list { 1679 } 1680 } 1681 } 1682 attr { 1683 key: "config" 1684 value { 1685 s: "" 1686 } 1687 } 1688 attr { 1689 key: "config_proto" 1690 value { 1691 s: "" 1692 } 1693 } 1694 attr { 1695 key: "executor_type" 1696 value { 1697 s: "" 1698 } 1699 } 1700 attr { 1701 key: "f" 1702 value { 1703 func { 1704 name: "__inference_f_841" 1705 } 1706 } 1707 } 1708 experimental_debug_info { 1709 original_node_names: "PartitionedCall" 1710 } 1711 } 1712 node_def { 1713 name: "add/y" 1714 op: "Const" 1715 attr { 1716 key: "dtype" 1717 value { 1718 type: DT_INT32 1719 } 1720 } 1721 attr { 1722 key: "value" 1723 value { 1724 tensor { 1725 dtype: DT_INT32 1726 tensor_shape { 1727 } 1728 int_val: 1 1729 } 1730 } 1731 } 1732 experimental_debug_info { 1733 original_node_names: "add/y" 1734 } 1735 } 1736 node_def { 1737 name: "add_0" 1738 op: "AddV2" 1739 input: "while_loop_counter" 1740 input: "add/y:output:0" 1741 attr { 1742 key: "T" 1743 value { 1744 type: DT_INT32 1745 } 1746 } 1747 experimental_debug_info { 1748 original_node_names: "add" 1749 } 1750 } 1751 ret { 1752 key: "add" 1753 value: "add_0:z:0" 1754 } 1755 ret { 1756 key: "partitionedcall" 1757 value: "PartitionedCall:output:0" 1758 } 1759 ret { 1760 key: "while_maximum_iterations" 1761 value: "while_maximum_iterations_0" 1762 } 1763 arg_attr { 1764 key: 0 1765 value { 1766 attr { 1767 key: "_output_shapes" 1768 value { 1769 list { 1770 shape { 1771 } 1772 } 1773 } 1774 } 1775 } 1776 } 1777 arg_attr { 1778 key: 1 1779 value { 1780 attr { 1781 key: "_output_shapes" 1782 value { 1783 list { 1784 shape { 1785 } 1786 } 1787 } 1788 } 1789 } 1790 } 1791 arg_attr { 1792 key: 2 1793 value { 1794 attr { 1795 key: "_output_shapes" 1796 value { 1797 list { 1798 shape { 1799 } 1800 } 1801 } 1802 } 1803 } 1804 } 1805 } 1806 function { 1807 signature { 1808 name: "while_cond_821" 1809 input_arg { 1810 name: "while_loop_counter" 1811 type: DT_INT32 1812 } 1813 input_arg { 1814 name: "while_maximum_iterations" 1815 type: DT_INT32 1816 } 1817 input_arg { 1818 name: "placeholder" 1819 type: DT_FLOAT 1820 } 1821 output_arg { 1822 name: "less" 1823 type: DT_BOOL 1824 } 1825 } 1826 node_def { 1827 name: "Less/y" 1828 op: "Const" 1829 attr { 1830 key: "dtype" 1831 value { 1832 type: DT_FLOAT 1833 } 1834 } 1835 attr { 1836 key: "value" 1837 value { 1838 tensor { 1839 dtype: DT_FLOAT 1840 tensor_shape { 1841 } 1842 float_val: 5.0 1843 } 1844 } 1845 } 1846 experimental_debug_info { 1847 original_node_names: "Less/y" 1848 } 1849 } 1850 node_def { 1851 name: "Less" 1852 op: "Less" 1853 input: "placeholder" 1854 input: "Less/y:output:0" 1855 attr { 1856 key: "T" 1857 value { 1858 type: DT_FLOAT 1859 } 1860 } 1861 experimental_debug_info { 1862 original_node_names: "Less" 1863 } 1864 } 1865 ret { 1866 key: "less" 1867 value: "Less:z:0" 1868 } 1869 arg_attr { 1870 key: 0 1871 value { 1872 attr { 1873 key: "_output_shapes" 1874 value { 1875 list { 1876 shape { 1877 } 1878 } 1879 } 1880 } 1881 } 1882 } 1883 arg_attr { 1884 key: 1 1885 value { 1886 attr { 1887 key: "_output_shapes" 1888 value { 1889 list { 1890 shape { 1891 } 1892 } 1893 } 1894 } 1895 } 1896 } 1897 arg_attr { 1898 key: 2 1899 value { 1900 attr { 1901 key: "_output_shapes" 1902 value { 1903 list { 1904 shape { 1905 } 1906 } 1907 } 1908 } 1909 } 1910 } 1911 } 1912 function { 1913 signature { 1914 name: "__inference_f_841" 1915 input_arg { 1916 name: "mul_placeholder" 1917 type: DT_FLOAT 1918 } 1919 output_arg { 1920 name: "identity" 1921 type: DT_FLOAT 1922 } 1923 } 1924 node_def { 1925 name: "mul/y" 1926 op: "Const" 1927 attr { 1928 key: "dtype" 1929 value { 1930 type: DT_FLOAT 1931 } 1932 } 1933 attr { 1934 key: "value" 1935 value { 1936 tensor { 1937 dtype: DT_FLOAT 1938 tensor_shape { 1939 } 1940 float_val: 2.0 1941 } 1942 } 1943 } 1944 experimental_debug_info { 1945 original_node_names: "mul/y" 1946 } 1947 } 1948 node_def { 1949 name: "mul" 1950 op: "Mul" 1951 input: "mul_placeholder" 1952 input: "mul/y:output:0" 1953 attr { 1954 key: "T" 1955 value { 1956 type: DT_FLOAT 1957 } 1958 } 1959 experimental_debug_info { 1960 original_node_names: "mul" 1961 } 1962 } 1963 node_def { 1964 name: "Identity" 1965 op: "Identity" 1966 input: "mul:z:0" 1967 attr { 1968 key: "T" 1969 value { 1970 type: DT_FLOAT 1971 } 1972 } 1973 experimental_debug_info { 1974 original_node_names: "Identity" 1975 } 1976 } 1977 ret { 1978 key: "identity" 1979 value: "Identity:output:0" 1980 } 1981 arg_attr { 1982 key: 0 1983 value { 1984 attr { 1985 key: "_output_shapes" 1986 value { 1987 list { 1988 shape { 1989 } 1990 } 1991 } 1992 } 1993 } 1994 } 1995 } 1996 } 1997 versions { 1998 producer: 399 1999 min_consumer: 12 2000 } 2001 """ 2002 # Code for generating above graph: 2003 # 2004 # def Body(i): 2005 # @tf.function 2006 # def f(): 2007 # return i * 2 2008 # return f() 2009 # tf.while_loop(lambda i: i < 5., Body, [tf.constant(1.)]) 2010 graph_def = graph_pb2.GraphDef() 2011 text_format.Parse(serialized, graph_def) 2012 @def_function.function 2013 def F(): 2014 x, y = importer.import_graph_def( 2015 graph_def, return_elements=["Const:0", "while:2"]) 2016 grad_out, = gradients_impl.gradients(y, x) 2017 return grad_out 2018 self.assertAllEqual(F(), 8.0) 2019 2020 def testIndexedSlicesInIncomingGrads(self): 2021 @def_function.function 2022 def F(): 2023 x = constant_op.constant([2.]) 2024 # Computes x^4 2025 ret = while_loop_v2( 2026 lambda _: True, lambda v: v * v, [x], return_same_structure=False, 2027 maximum_iterations=2) 2028 v = array_ops.gather(ret, [0]) 2029 return gradients_impl.gradients(v, [x])[0] # 4*x^3 2030 self.assertAllEqual(self.evaluate(F()), [32.]) 2031 2032 2033def ScalarShape(): 2034 return ops.convert_to_tensor([], dtype=dtypes.int32) 2035 2036 2037def GetOptimizedGraph(): 2038 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 2039 config = config_pb2.ConfigProto() 2040 config.graph_options.rewrite_options.CopyFrom( 2041 rewriter_config_pb2.RewriterConfig( 2042 constant_folding=rewriter_config_pb2.RewriterConfig.OFF, 2043 memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)) 2044 return tf_optimizer.OptimizeGraph(config, mg) 2045 2046 2047if __name__ == "__main__": 2048 test.main() 2049