1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=g-long-lambda 17"""Tests for tensorflow.ops.control_flow_ops.""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import collections 24import math 25import re 26import sys 27import time 28 29from absl.testing import parameterized 30import numpy as np 31from six.moves import xrange # pylint: disable=redefined-builtin 32 33from tensorflow.core.protobuf import config_pb2 34from tensorflow.python import tf2 35from tensorflow.python.client import device_lib 36from tensorflow.python.client import session 37from tensorflow.python.data.experimental.ops import cardinality 38from tensorflow.python.data.ops import dataset_ops 39from tensorflow.python.eager import context 40from tensorflow.python.eager import def_function 41from tensorflow.python.eager import function as eager_function 42from tensorflow.python.eager import wrap_function 43from tensorflow.python.framework import constant_op 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import errors_impl 46from tensorflow.python.framework import function 47from tensorflow.python.framework import ops 48from tensorflow.python.framework import sparse_tensor 49from tensorflow.python.framework import tensor_shape 50from tensorflow.python.framework import tensor_spec 51from tensorflow.python.framework import test_util 52from tensorflow.python.ops import array_ops 53from tensorflow.python.ops import control_flow_ops 54from tensorflow.python.ops import control_flow_util 55from tensorflow.python.ops import data_flow_ops 56from tensorflow.python.ops import functional_ops 57from tensorflow.python.ops import gen_array_ops 58from tensorflow.python.ops import gen_control_flow_ops 59from tensorflow.python.ops import gen_data_flow_ops 60from tensorflow.python.ops import gen_logging_ops 61from tensorflow.python.ops import gen_state_ops 62from tensorflow.python.ops import gradient_checker_v2 63from tensorflow.python.ops import gradients_impl 64from tensorflow.python.ops import init_ops 65from tensorflow.python.ops import linalg_ops 66from tensorflow.python.ops import logging_ops 67from tensorflow.python.ops import map_fn 68from tensorflow.python.ops import math_ops 69from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 70from tensorflow.python.ops import nn_ops 71from tensorflow.python.ops import random_ops 72from tensorflow.python.ops import resource_variable_ops 73from tensorflow.python.ops import script_ops 74from tensorflow.python.ops import sparse_ops 75from tensorflow.python.ops import state_ops 76from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 77from tensorflow.python.ops import tensor_array_ops 78from tensorflow.python.ops import variable_scope 79from tensorflow.python.ops import variables 80from tensorflow.python.ops import while_v2 # pylint: disable=unused-import 81# pylint: disable=unused-import 82from tensorflow.python.ops.ragged import ragged_factory_ops 83from tensorflow.python.ops.ragged import ragged_tensor 84import tensorflow.python.ops.tensor_array_grad 85# pylint: enable=unused-import 86from tensorflow.python.platform import test 87from tensorflow.python.training import adam 88from tensorflow.python.training import gradient_descent 89from tensorflow.python.util import nest 90 91 92def check_consumers(graph): 93 """Sanity check on the consumer list of the tensors.""" 94 95 consumer_count = {} 96 for op in graph.get_operations(): 97 for v in op.inputs: 98 cnt = consumer_count.get(v, 0) 99 consumer_count[v] = cnt + 1 100 for k, v in consumer_count.items(): 101 if len(k.consumers()) != v: 102 return False 103 return True 104 105 106def all_fetchables(): 107 tensor_names = [] 108 graph = ops.get_default_graph() 109 for op in graph.get_operations(): 110 for t in op.outputs: 111 if graph.is_fetchable(t): 112 tensor_names.append(t.name) 113 return tensor_names 114 115 116def all_feedables(): 117 feedable_tensors = [] 118 graph = ops.get_default_graph() 119 for op in graph.get_operations(): 120 for t in op.inputs: 121 if graph.is_feedable(t): 122 feedable_tensors.append(t) 123 return feedable_tensors 124 125 126def opt_cfg(do_constant_folding=True): 127 return config_pb2.ConfigProto( 128 allow_soft_placement=True, 129 graph_options=config_pb2.GraphOptions( 130 optimizer_options=config_pb2.OptimizerOptions( 131 opt_level=config_pb2.OptimizerOptions.L1, 132 do_function_inlining=True, 133 do_constant_folding=do_constant_folding))) 134 135 136def isum(s, maximum_iterations=None): 137 i = constant_op.constant(0, name="i") 138 c = lambda i, s: math_ops.less(i, 10) 139 b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)] 140 _, r_s = control_flow_ops.while_loop( 141 c, b, [i, s], maximum_iterations=maximum_iterations) 142 return r_s 143 144 145def enqueue_print_op(s): 146 """Enqueues an op that prints a message to be captured in the test.""" 147 return logging_ops.print_v2("ControlFlowOpsTest: " + s) 148 149 150def filter_test_messages(s): 151 """Returns a list of messages printed by enqueue_print_op.""" 152 prefix = "ControlFlowOpsTest: " 153 return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)] 154 155 156def tf_function_in_tf2(f): 157 if tf2.enabled(): 158 # In TF1 do not wrap with tf.function so that we can test the v1 control 159 # flow code path. 160 return def_function.function(f) 161 return f 162 163 164@test_util.with_control_flow_v2 165class ControlFlowTest(test.TestCase, parameterized.TestCase): 166 167 @test_util.run_v1_only("b/120545219") 168 def testRefIdentity(self): 169 with self.cached_session(): 170 v = variables.VariableV1(7) 171 172 v = control_flow_ops._Identity(v) 173 op = state_ops.assign(v, 9) 174 v2 = control_flow_ops.with_dependencies([op], v) 175 176 self.assertTrue(isinstance(v2, ops.Tensor)) 177 self.evaluate(variables.global_variables_initializer()) 178 self.assertEqual(9, self.evaluate(v2)) 179 180 @test_util.run_v1_only("b/120545219") 181 def testRefEnter(self): 182 with self.cached_session(): 183 v = variables.VariableV1(7) 184 185 enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True) 186 nine = constant_op.constant(9) 187 enter_nine = gen_control_flow_ops.enter(nine, "foo_1") 188 op = state_ops.assign(enter_v, enter_nine) 189 v2 = control_flow_ops.with_dependencies([op], enter_v) 190 v3 = control_flow_ops.exit(v2) 191 self.evaluate(variables.global_variables_initializer()) 192 self.assertEqual(9, self.evaluate(v3)) 193 194 @test_util.run_v1_only("b/120545219") 195 def testRefSwitch(self): 196 with self.cached_session(): 197 v = variables.VariableV1(7) 198 199 p = constant_op.constant(True) 200 v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p) # pylint: disable=protected-access 201 v2 = state_ops.assign(v1[1], 9) 202 self.evaluate(variables.global_variables_initializer()) 203 self.assertEqual(9, self.evaluate(v2)) 204 205 def testEnterMulExit(self): 206 with self.cached_session(): 207 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 208 enter_data = gen_control_flow_ops.enter(data, "foo_1", False) 209 five = constant_op.constant(5) 210 enter_five = gen_control_flow_ops.enter(five, "foo_1", False) 211 mul_op = math_ops.multiply(enter_data, enter_five) 212 exit_op = control_flow_ops.exit(mul_op) 213 214 result = self.evaluate(exit_op) 215 self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result) 216 217 @test_util.run_deprecated_v1 218 def testEnterShapePropagation(self): 219 with self.cached_session(): 220 v = variables.Variable([0.0, 0.0], dtype=dtypes.float32) 221 222 # If is_constant=True, the shape information should be propagated. 223 enter_v_constant = gen_control_flow_ops.enter( 224 v, "frame1", is_constant=True) 225 self.assertEqual(enter_v_constant.shape, [2]) 226 227 # Otherwise, the shape should be unknown. 228 enter_v_non_constant = gen_control_flow_ops.enter( 229 v, "frame2", is_constant=False) 230 self.assertEqual(enter_v_non_constant.shape, None) 231 232 @test_util.run_v1_only("b/120545219") 233 def testSwitchMergeIndexedSlices(self): 234 with self.cached_session(): 235 values = constant_op.constant([1, 2, 3, 4, 5, 6]) 236 indices = constant_op.constant([0, 2, 4, 6, 8, 10]) 237 data = ops.IndexedSlices(values, indices) 238 pred = ops.convert_to_tensor(True) 239 switch_op = control_flow_ops.switch(data, pred) 240 merge_op = control_flow_ops.merge(switch_op)[0] 241 242 val = merge_op.values 243 ind = merge_op.indices 244 self.assertAllEqual(np.arange(1, 7), val) 245 self.assertAllEqual(np.arange(0, 12, 2), ind) 246 247 @test_util.run_v1_only("b/120545219") 248 def testSwitchDeadBranch(self): 249 with self.cached_session(): 250 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 251 ports = ops.convert_to_tensor(True, name="ports") 252 switch_op = control_flow_ops.switch(data, ports) 253 dead_branch = array_ops.identity(switch_op[0]) 254 255 with self.assertRaisesWithPredicateMatch( 256 errors_impl.InvalidArgumentError, 257 lambda e: "Retval[0] does not have value" in str(e)): 258 self.evaluate(dead_branch) 259 260 @test_util.run_v1_only("b/120545219") 261 def testSwitchMergeLess(self): 262 with self.cached_session(): 263 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 264 zero = ops.convert_to_tensor(0) 265 one = ops.convert_to_tensor(1) 266 less_op = math_ops.less(zero, one) 267 switch_op = control_flow_ops.switch(data, less_op) 268 merge_op = control_flow_ops.merge(switch_op)[0] 269 270 result = self.evaluate(merge_op) 271 self.assertAllEqual(np.arange(1, 7), result) 272 273 @test_util.run_v1_only("b/120545219") 274 def testSwitchMergeAddIdentity(self): 275 with self.cached_session(): 276 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 277 ports = ops.convert_to_tensor(False, name="ports") 278 switch_op = control_flow_ops.switch(data, ports) 279 one = constant_op.constant(1) 280 add_op = math_ops.add(switch_op[0], one) 281 id_op = array_ops.identity(switch_op[1]) 282 merge_op = control_flow_ops.merge([add_op, id_op])[0] 283 284 result = self.evaluate(merge_op) 285 self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result) 286 287 @test_util.run_v1_only("b/120545219") 288 def testSwitchMergeAddMul(self): 289 with self.cached_session(): 290 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 291 ports = ops.convert_to_tensor(True, name="ports") 292 switch_op = control_flow_ops.switch(data, ports) 293 one = constant_op.constant(1) 294 add_op = math_ops.add(switch_op[0], one) 295 five = constant_op.constant(5) 296 mul_op = math_ops.multiply(switch_op[1], five) 297 merge_op = control_flow_ops.merge([add_op, mul_op])[0] 298 299 result = self.evaluate(merge_op) 300 self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result) 301 302 @test_util.run_v1_only("b/120545219") 303 def testLoop_false(self): 304 with self.cached_session(): 305 false = ops.convert_to_tensor(False) 306 n = constant_op.constant(10) 307 308 enter_false = gen_control_flow_ops.enter(false, "foo_1", False) 309 enter_n = gen_control_flow_ops.enter(n, "foo_1", False) 310 311 merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0] 312 switch_n = control_flow_ops.switch(merge_n, enter_false) 313 exit_n = control_flow_ops.exit(switch_n[0]) 314 next_n = control_flow_ops.next_iteration(switch_n[0]) 315 merge_n.op._update_input(1, next_n) 316 317 result = self.evaluate(exit_n) 318 self.assertAllEqual(10, result) 319 320 @test_util.run_deprecated_v1 321 def testLoop_1(self): 322 with self.cached_session(): 323 zero = constant_op.constant(0) 324 one = constant_op.constant(1) 325 n = constant_op.constant(10) 326 327 enter_i = gen_control_flow_ops.enter(zero, "foo", False) 328 enter_one = gen_control_flow_ops.enter(one, "foo", True) 329 enter_n = gen_control_flow_ops.enter(n, "foo", True) 330 331 with ops.device(test.gpu_device_name()): 332 merge_i = control_flow_ops.merge([enter_i, enter_i])[0] 333 334 less_op = math_ops.less(merge_i, enter_n) 335 cond_op = control_flow_ops.loop_cond(less_op) 336 switch_i = control_flow_ops.switch(merge_i, cond_op) 337 338 add_i = math_ops.add(switch_i[1], enter_one) 339 340 next_i = control_flow_ops.next_iteration(add_i) 341 merge_i.op._update_input(1, next_i) 342 343 exit_i = control_flow_ops.exit(switch_i[0]) 344 result = self.evaluate(exit_i) 345 self.assertAllEqual(10, result) 346 347 @test_util.run_v1_only("b/120545219") 348 def testLoop_2(self): 349 with self.cached_session(): 350 zero = constant_op.constant(0) 351 one = constant_op.constant(1) 352 n = constant_op.constant(10) 353 354 enter_i = gen_control_flow_ops.enter(zero, "foo", False) 355 enter_one = gen_control_flow_ops.enter(one, "foo", True) 356 enter_n = gen_control_flow_ops.enter(n, "foo", True) 357 358 merge_i = control_flow_ops.merge([enter_i, enter_i])[0] 359 360 less_op = math_ops.less(merge_i, enter_n) 361 cond_op = control_flow_ops.loop_cond(less_op) 362 switch_i = control_flow_ops.switch(merge_i, cond_op) 363 364 add_i = math_ops.add(switch_i[1], enter_one) 365 366 with ops.device(test.gpu_device_name()): 367 next_i = control_flow_ops.next_iteration(add_i) 368 merge_i.op._update_input(1, next_i) 369 370 exit_i = control_flow_ops.exit(switch_i[0]) 371 result = self.evaluate(exit_i) 372 self.assertAllEqual(10, result) 373 374 @test_util.run_v1_only("b/120545219") 375 def testDifferentFrame(self): 376 with self.cached_session(): 377 data = array_ops.placeholder(dtypes.float32, shape=[]) 378 enter_1 = gen_control_flow_ops.enter(data, "foo_1", False) 379 enter_2 = gen_control_flow_ops.enter(data, "foo_2", False) 380 res = math_ops.add(enter_1, enter_2) 381 with self.assertRaisesOpError("has inputs from different frames"): 382 res.eval(feed_dict={data: 1.0}) 383 384 @test_util.run_deprecated_v1 385 def testCondBool(self): 386 values = constant_op.constant(10) 387 fn1 = lambda: math_ops.add(values, 1) 388 fn2 = lambda: math_ops.subtract(values, 1) 389 with self.assertRaisesRegex(TypeError, "must not be a Python bool"): 390 _ = control_flow_ops.cond(False, fn1, fn2) 391 392 @test_util.run_deprecated_v1 393 def testCondInt(self): 394 p = array_ops.placeholder(dtypes.bool, shape=[]) 395 v = constant_op.constant(10) 396 fn1 = lambda: math_ops.add(v, 1) 397 fn2 = lambda: math_ops.subtract(v, 1) 398 y = control_flow_ops.cond(p, fn1, fn2) 399 grad = gradients_impl.gradients(y, [v]) 400 self.assertAllEqual([None], grad) 401 402 def testCondOutputShape(self): 403 x = constant_op.constant(1.0) 404 b = control_flow_ops.cond( 405 constant_op.constant(True), lambda: math_ops.square(x), 406 lambda: math_ops.subtract(x, 1.)) 407 self.assertEqual(b.shape, tensor_shape.TensorShape([])) 408 409 @test_util.run_v1_only("b/120545219") 410 def testFetchable(self): 411 with self.cached_session() as sess: 412 x = array_ops.placeholder(dtypes.float32) 413 control_flow_ops.cond( 414 constant_op.constant(True), lambda: x + 2, lambda: x + 0) 415 graph = ops.get_default_graph() 416 for op in graph.get_operations(): 417 for t in op.inputs: 418 if graph.is_fetchable(t.op): 419 sess.run(t, feed_dict={x: 3}) 420 else: 421 with self.assertRaisesRegex(ValueError, 422 "has been marked as not fetchable"): 423 sess.run(t, feed_dict={x: 3}) 424 425 @test_util.disable_control_flow_v2("Not relevant") 426 @test_util.run_v1_only("b/120545219") 427 def testFeedable(self): 428 with self.cached_session() as sess: 429 c = constant_op.constant(2) 430 i0 = constant_op.constant(0) 431 r = control_flow_ops.while_loop(lambda i: i < 1000, 432 lambda i: math_ops.square(c) + i, [i0]) 433 self.assertEqual(1000, r.eval(feed_dict={i0: 0})) 434 feedable_tensors = all_feedables() 435 for t in feedable_tensors: 436 sess.run(r, feed_dict={t: 3}) 437 graph = ops.get_default_graph() 438 for op in graph.get_operations(): 439 for t in op.inputs: 440 if t not in feedable_tensors and t.dtype is dtypes.int32: 441 with self.assertRaisesRegex(ValueError, "may not be fed"): 442 sess.run(r, feed_dict={t: 3}) 443 444 @test_util.run_v1_only("b/120545219") 445 def testCondIndexedSlices(self): 446 with self.cached_session(): 447 values = constant_op.constant([10]) 448 indices = constant_op.constant([0]) 449 x = ops.IndexedSlices(values, indices) 450 pred = math_ops.less(1, 2) 451 fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices) 452 fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), indices) 453 r = control_flow_ops.cond(pred, fn1, fn2) 454 455 val = r.values 456 ind = r.indices 457 self.assertAllEqual([11], val) 458 self.assertAllEqual([0], ind) 459 460 def testCondMismatchedIndexedSlices(self): 461 @def_function.function 462 def foo(): 463 values = constant_op.constant([10]) 464 indices = constant_op.constant([0]) 465 x = ops.IndexedSlices(values, indices) 466 with self.assertRaisesRegex(TypeError, 467 "Cannot reconcile tf.cond 0-th outputs"): 468 control_flow_ops.cond( 469 constant_op.constant(True), 470 lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices), 471 lambda: math_ops.add(x.values, 1), indices) 472 foo() 473 474 def testCondSparseTensor(self): 475 values = constant_op.constant([2.0, 4.0], name="values") 476 indices = constant_op.constant([[0], [3]], 477 dtype=dtypes.int64, 478 name="indices") 479 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 480 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 481 pred = math_ops.less(1, 2) 482 fn1 = lambda: sparse_tensor.SparseTensor( 483 indices + 1, x.values + 1, dense_shape=shape) 484 fn2 = lambda: sparse_tensor.SparseTensor( 485 indices, x.values - 1, dense_shape=shape) 486 r = control_flow_ops.cond(pred, fn1, fn2) 487 self.assertAllEqual([3.0, 5.0], r.values) 488 self.assertAllEqual([[1], [4]], r.indices) 489 self.assertAllEqual(r.values.get_shape(), (2,)) 490 491 def testCondRaggedTensor(self): 492 rt = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) 493 pred = math_ops.less(1, 2) 494 fn1 = lambda: array_ops.concat([rt + 2, [[100]]], axis=0) 495 fn2 = lambda: rt[:2] - 2 496 result = control_flow_ops.cond(pred, fn1, fn2) 497 self.assertAllEqual([3, 4, 5, 6, 7, 8, 100], result.values) 498 self.assertAllEqual([0, 2, 3, 6, 7], result.row_splits) 499 500 @test_util.run_v1_only("b/120545219") 501 def testCondResource(self): 502 503 with self.cached_session(): 504 rv = resource_variable_ops.ResourceVariable(True) 505 self.evaluate(variables.global_variables_initializer()) 506 t = ops.convert_to_tensor(1.0) 507 508 def case(): 509 assign = resource_variable_ops.assign_variable_op(rv.handle, False) 510 with ops.control_dependencies([assign]): 511 return array_ops.identity(t) 512 513 self.assertEqual( 514 1.0, self.evaluate(control_flow_ops.cond(rv, case, lambda: t))) 515 516 @test_util.run_deprecated_v1 517 def testCondResourceGradShape(self): 518 rv1 = resource_variable_ops.ResourceVariable([1.0, 2.0]) 519 rv2 = resource_variable_ops.ResourceVariable([3.0, 4.0]) 520 pred = constant_op.constant(True) 521 result = control_flow_ops.cond(pred, lambda: rv1, lambda: rv2) 522 grads = gradients_impl.gradients(result, [rv1, rv2]) 523 self.assertAllEqual(grads[0].shape.as_list(), [2]) 524 self.assertAllEqual(grads[1].shape.as_list(), [2]) 525 526 @test_util.run_v1_only("b/120545219") 527 def testCondWithTensorArrayGrad(self): 528 with self.cached_session() as sess: 529 with ops.device(test.gpu_device_name()): 530 pred = array_ops.placeholder(dtypes.bool, []) 531 x = constant_op.constant([1.0, 2.0, 3.0]) 532 y = control_flow_ops.cond( 533 pred, lambda: map_fn.map_fn(lambda z: z * 2.0, x), 534 lambda: constant_op.constant([1.0, 1.0, 1.0])) 535 g = gradients_impl.gradients(y, x)[0] 536 537 self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0]) 538 self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0]) 539 540 @test_util.run_v1_only("b/120545219") 541 def testCondIndexedSlicesDifferentTypes(self): 542 with self.cached_session(): 543 values = constant_op.constant([10]) 544 i_32 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int32) 545 i_64 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int64) 546 x = ops.IndexedSlices(values, i_32) 547 pred = math_ops.less(1, 2) 548 fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), i_32) 549 fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), i_64) 550 r = control_flow_ops.cond(pred, fn1, fn2) 551 552 val = r.values 553 ind = r.indices 554 self.assertAllEqual([11], val) 555 self.assertAllEqual([0], ind) 556 self.assertTrue(ind.dtype == np.int64) 557 558 @test_util.run_v1_only("b/120545219") 559 def testCondColocation(self): 560 with self.session(): 561 with ops.device("/cpu:0"): 562 v = variables.Variable(7.0) 563 564 x = constant_op.constant(10.0) 565 pred = math_ops.less(1.0, 2.0) 566 fn1 = lambda: math_ops.add(v, 1.0) 567 fn2 = lambda: math_ops.subtract(x, 1.0) 568 r = control_flow_ops.cond(pred, fn1, fn2) 569 570 for op in x.graph.get_operations(): 571 if op.name == "cond/Add/Switch": 572 self.assertDeviceEqual(op.device, "/cpu:0") 573 574 def _testCond_1(self, use_gpu): 575 with self.cached_session(use_gpu=use_gpu): 576 x = constant_op.constant(10) 577 pred = math_ops.less(1, 2) 578 fn1 = lambda: math_ops.add(x, 1) 579 fn2 = lambda: math_ops.subtract(x, 1) 580 r = control_flow_ops.cond(pred, fn1, fn2) 581 582 result = self.evaluate(r) 583 self.assertAllEqual(11, result) 584 585 def testCond_1(self): 586 587 self._testCond_1(use_gpu=False) 588 # TODO(b/116526896): Enable GPU tests. 589 # self._testCond_1(use_gpu=True) 590 591 def testCond_2(self): 592 593 with self.cached_session(): 594 x = constant_op.constant(10) 595 r = control_flow_ops.cond( 596 math_ops.less(1, 0), lambda: math_ops.add(x, 1), 597 lambda: math_ops.subtract(x, 1)) 598 result = self.evaluate(r) 599 self.assertAllEqual(9, result) 600 601 def testCond_3(self): 602 603 with self.cached_session(): 604 x = constant_op.constant(10) 605 pred = math_ops.less(1, 2) 606 fn1 = lambda: math_ops.add(x, 1) 607 fn2 = lambda: math_ops.subtract(x, 1) 608 fn3 = lambda: math_ops.add(control_flow_ops.cond(pred, fn1, fn2), 1) 609 r = control_flow_ops.cond(pred, fn3, fn2) 610 611 result = self.evaluate(r) 612 self.assertAllEqual(12, result) 613 614 @test_util.run_in_graph_and_eager_modes 615 def testCondPruning(self): 616 v1 = variables.Variable(7) 617 v2 = variables.Variable(7) 618 v3 = variables.Variable(7) 619 620 def f(): 621 age = constant_op.constant(3) 622 max_age = constant_op.constant(2) 623 pred = math_ops.greater(age, max_age) 624 fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op] 625 fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op] 626 r = control_flow_ops.cond(pred, fn1, fn2) 627 self.assertEqual(len(r), 2) 628 return r[1] 629 630 f_defun = eager_function.defun(f) 631 632 if not context.executing_eagerly(): 633 with self.cached_session(): 634 self.evaluate(variables.global_variables_initializer()) 635 result = self.evaluate(f()) 636 self.assertEqual(True, result) 637 # Only second cond result was fetched, so v1 assign shouldn't run. 638 self.assertEqual(7, self.evaluate(v1)) 639 self.assertEqual(2, self.evaluate(v2)) 640 self.assertEqual(7, self.evaluate(v3)) 641 642 result = f_defun() 643 self.assertEqual(True, self.evaluate(result)) 644 # Both v1 and v2 branch assignments should be run in defun. 645 self.assertEqual(1, self.evaluate(v1)) 646 self.assertEqual(2, self.evaluate(v2)) 647 self.assertEqual(7, self.evaluate(v3)) 648 649 def testCond_5(self): 650 with self.cached_session(): 651 alive = constant_op.constant(True, name="alive") 652 count = constant_op.constant(0, name="count") 653 654 def body(i): 655 return control_flow_ops.cond( 656 alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)], 657 lambda: [alive, count]) 658 659 for i in range(10): 660 alive, count = body(i) 661 self.assertAllEqual(4, self.evaluate(count)) 662 663 @test_util.run_v1_only("b/120545219") 664 def testCond_6(self): 665 with self.cached_session(): 666 v1 = variables.Variable([7]) 667 668 age = constant_op.constant(3) 669 pred = math_ops.greater(age, 4) 670 fn1 = lambda: age 671 fn2 = lambda: v1 672 r = control_flow_ops.cond(pred, fn1, fn2) 673 674 self.evaluate(variables.global_variables_initializer()) 675 result = self.evaluate(r) 676 self.assertAllEqual(np.array([7]), result) 677 678 def testCond_7(self): 679 with self.cached_session() as sess: 680 x = constant_op.constant(10) 681 y = constant_op.constant(200) 682 pred = math_ops.less(1, 2) 683 fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)] 684 fn2 = lambda: [y, y] 685 r = control_flow_ops.cond(pred, fn1, fn2) 686 self.assertAllEqual([11, 12], self.evaluate(r)) 687 688 @parameterized.parameters(dtypes.float32, dtypes.float64) 689 @test_util.run_v1_only("Uses tf.gradients") 690 def testCondResourceGrad(self, dtype): 691 init = constant_op.constant([7.], dtype=dtype) 692 v1 = variables.Variable(init) 693 694 age = constant_op.constant(3., dtype=dtype) 695 pred = math_ops.greater(age, 4.) 696 fn1 = lambda: age 697 fn2 = lambda: v1 698 r = control_flow_ops.cond(pred, fn1, fn2) 699 700 grad = gradients_impl.gradients(r, v1)[0] 701 self.evaluate(variables.global_variables_initializer()) 702 self.assertAllEqual(grad, [1.]) 703 704 @test_util.run_gpu_only 705 @test_util.run_deprecated_v1 706 def testCond_Device(self): 707 x = constant_op.constant(-10.) 708 709 # True branch function defined outside of device scope 710 def true_fn(): 711 return math_ops.exp(x) 712 713 with ops.device("CPU:0"): 714 r = control_flow_ops.cond( 715 constant_op.constant(True), true_fn, lambda: 0.) 716 self.assertIn("cpu", r.device.lower()) 717 718 with session.Session() as sess: 719 options = config_pb2.RunOptions(output_partition_graphs=True) 720 run_metadata = config_pb2.RunMetadata() 721 sess.run(r, options=options, run_metadata=run_metadata) 722 # We expect that everything runs on CPU, even if GPU is available. 723 self.assertEqual(len(run_metadata.partition_graphs), 1) 724 725 def _count_matching_switch_nodes_on_device(self, run_metadata, device_str, 726 dtype): 727 # Returns the number of Switch nodes with type dtype placed on 728 # `device_str`. 729 device_graphs = [ 730 g for g in run_metadata.partition_graphs 731 if device_str in g.node[0].device 732 ] 733 self.assertLen(device_graphs, 1) 734 switch_nodes = [ 735 n for n in device_graphs[0].node 736 if n.op == "Switch" and n.attr["T"].type == dtype.as_datatype_enum 737 ] 738 return len(switch_nodes) 739 740 @test_util.run_gpu_only 741 @test_util.run_deprecated_v1 742 def testCondSwitchColocatedWithInputWhenInputExplicitlyPlacedOnCPU(self): 743 x = array_ops.placeholder(dtypes.float32) 744 745 # `arg` is used in the cond then branch so a Switch node is created for it. 746 # We test that the Switch node gets placed on the same device as `arg`. 747 # We force `arg` to be on CPU here. 748 with ops.device("CPU:0"): 749 arg = x + 10. 750 751 def true_fn(): 752 with ops.device("CPU:0"): 753 return arg + 1 754 755 r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.) 756 757 with session.Session() as sess: 758 run_metadata = config_pb2.RunMetadata() 759 options = config_pb2.RunOptions(output_partition_graphs=True) 760 sess.run( 761 r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) 762 self.assertLen(run_metadata.partition_graphs, 2) 763 # Check that the Switch for `arg` gets placed on CPU. 764 self.assertEqual( 765 self._count_matching_switch_nodes_on_device(run_metadata, "CPU", 766 dtypes.float32), 1) 767 self.assertEqual( 768 self._count_matching_switch_nodes_on_device(run_metadata, "GPU", 769 dtypes.float32), 0) 770 771 @test_util.run_gpu_only 772 @test_util.run_deprecated_v1 773 def testCondSwitchColocatedWithInputWhenInputPlacedOnCPU(self): 774 x = array_ops.placeholder(dtypes.float32) 775 776 # `arg` is used in the cond then branch so a Switch node is created for it. 777 # We test that the Switch node gets placed on the same device as `arg`. 778 # Since arg is a dataset (and only has a CPU kernel), it gets placed on CPU 779 # by placer. 780 arg = dataset_ops.Dataset.range(8) 781 782 def true_fn(): 783 return cardinality.cardinality(arg) 784 785 r = control_flow_ops.cond( 786 constant_op.constant(True), true_fn, 787 lambda: constant_op.constant(0, dtypes.int64)) 788 789 with session.Session() as sess: 790 run_metadata = config_pb2.RunMetadata() 791 options = config_pb2.RunOptions(output_partition_graphs=True) 792 sess.run( 793 r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) 794 self.assertLen(run_metadata.partition_graphs, 2) 795 # Check that the Switch for `arg` gets placed on CPU. 796 self.assertEqual( 797 self._count_matching_switch_nodes_on_device(run_metadata, "CPU", 798 dtypes.variant), 1) 799 self.assertEqual( 800 self._count_matching_switch_nodes_on_device(run_metadata, "GPU", 801 dtypes.variant), 0) 802 803 @test_util.run_gpu_only 804 @test_util.run_deprecated_v1 805 def testCondSwitchColocatedWithInputWhenInputOnGPU(self): 806 x = array_ops.placeholder(dtypes.float32) 807 808 # `arg` is used in the cond then branch so a Switch node is created for it. 809 # We test that the Switch node gets placed on the same device as `arg`. 810 # Note: `arg` gets placed on GPU by default by the placer. 811 arg = x + 10. 812 813 def true_fn(): 814 with ops.device("CPU:0"): 815 return arg + 1 816 817 r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.) 818 819 with session.Session() as sess: 820 run_metadata = config_pb2.RunMetadata() 821 options = config_pb2.RunOptions(output_partition_graphs=True) 822 sess.run( 823 r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) 824 self.assertEqual(len(run_metadata.partition_graphs), 2) 825 # Check that the Switch for `arg` gets placed on GPU. 826 self.assertEqual( 827 self._count_matching_switch_nodes_on_device(run_metadata, "CPU", 828 dtypes.float32), 0) 829 self.assertEqual( 830 self._count_matching_switch_nodes_on_device(run_metadata, "GPU", 831 dtypes.float32), 1) 832 833 def testCondAccessTrueBranchTensorInFalseBranchRaises(self): 834 835 @def_function.function 836 def f(): 837 c = constant_op.constant(1.) 838 inputs = {"c": c} 839 840 def true_fn(inputs): 841 inputs["c"] = array_ops.identity(inputs["c"], name="true_branch") 842 return inputs["c"] 843 844 def false_fn(inputs): 845 return array_ops.identity(inputs["c"]) 846 847 pred = constant_op.constant(True) 848 return control_flow_ops.cond( 849 pred, lambda: true_fn(inputs), lambda: false_fn(inputs)) 850 851 # This was needed for backwards compatibility with TF2 Estimators which 852 # rely on variable names. 853 prefix = "cond/" if context.executing_eagerly() else "" 854 855 with self.assertRaisesRegex( 856 ValueError, 857 "Tensor %strue_branch:0 in true_fn is accessed from false_fn." % 858 prefix): 859 f() 860 861 def testSwitchCaseAccessBranch1TensorInBranch4Raises(self): 862 863 @def_function.function 864 def f(): 865 c = constant_op.constant(1.) 866 inputs = {"c": c} 867 868 def br1_fn(inputs): 869 inputs["c"] = array_ops.identity(inputs["c"], name="br1_identity") 870 return inputs["c"] 871 872 def br4_fn(inputs): 873 return array_ops.identity(inputs["c"]) 874 875 def other_fn(): 876 return array_ops.identity(c) 877 878 return control_flow_ops.switch_case( 879 constant_op.constant(2), 880 [other_fn, lambda: br1_fn(inputs), other_fn, other_fn, 881 lambda: br4_fn(inputs)]) 882 883 # This was needed for backwards compatibility with TF2 Estimators which 884 # rely on variable names. 885 prefix = "switch_case/indexed_case/" if context.executing_eagerly() else "" 886 with self.assertRaisesRegex( 887 ValueError, "Tensor %sbr1_identity:0 in branch 1 is " 888 "accessed from branch 4." % prefix): 889 f() 890 891 def testCondListOutput(self): 892 with self.cached_session() as sess: 893 x = constant_op.constant(10) 894 y = constant_op.constant(200) 895 pred = math_ops.less(1, 2) 896 fn1 = lambda: [math_ops.add(x, y), math_ops.add(x, y)] 897 fn2 = lambda: [y, y] 898 r = control_flow_ops.cond(pred, fn1, fn2) 899 test_result = self.evaluate(r) 900 self.assertListEqual([210, 210], test_result) 901 902 def testTupleOutput(self): 903 with self.cached_session() as sess: 904 x = constant_op.constant(10) 905 y = constant_op.constant(200) 906 pred = math_ops.less(1, 2) 907 fn1 = lambda: (math_ops.add(x, y), math_ops.add(x, y)) 908 fn2 = lambda: (y, y) 909 r = control_flow_ops.cond(pred, fn1, fn2) 910 test_result = self.evaluate(r) 911 self.assertTupleEqual((210, 210), test_result) 912 913 def testDictOutput(self): 914 with self.cached_session() as sess: 915 x = constant_op.constant(10) 916 y = constant_op.constant(200) 917 pred = math_ops.less(1, 2) 918 fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)} 919 fn2 = lambda: {"a": y, "b": y} 920 r = control_flow_ops.cond(pred, fn1, fn2) 921 test_result = self.evaluate(r) 922 self.assertDictEqual({"a": 210, "b": 210}, test_result) 923 924 def testEmbeddedListOutput(self): 925 x = constant_op.constant(10) 926 y = constant_op.constant(200) 927 pred = math_ops.less(1, 2) 928 fn1 = lambda: [[math_ops.add(x, y), math_ops.add(x, y)]] 929 fn2 = lambda: [[y, y]] 930 # Pass strict=True flag as cond_v2 allows for tensors to be 931 # in nested output structures as singletons 932 r = control_flow_ops.cond(pred, fn1, fn2, strict=True) 933 test_result = self.evaluate(r) 934 self.assertListEqual([[210, 210]], test_result) 935 936 def testEmbeddedTupleOutput(self): 937 with self.cached_session() as sess: 938 x = constant_op.constant(10) 939 y = constant_op.constant(200) 940 pred = math_ops.less(1, 2) 941 fn1 = lambda: ((math_ops.add(x, y), math_ops.add(x, y))) 942 fn2 = lambda: ((y, y)) 943 r = control_flow_ops.cond(pred, fn1, fn2) 944 test_result = self.evaluate(r) 945 self.assertTupleEqual(((210, 210)), test_result) 946 947 def testEmbeddedDictOutput(self): 948 with self.cached_session() as sess: 949 x = constant_op.constant(10) 950 y = constant_op.constant(200) 951 pred = math_ops.less(1, 2) 952 fn1 = lambda: {"a": {"c": math_ops.add(x, y)}, 953 "b": {"d": math_ops.add(x, y)}} 954 fn2 = lambda: {"a": {"c": y}, 955 "b": {"d": y}} 956 r = control_flow_ops.cond(pred, fn1, fn2) 957 test_result = self.evaluate(r) 958 self.assertDictEqual({"a": {"c": 210}, "b": {"d": 210}}, test_result) 959 960 @test_util.run_v1_only("b/120545219") 961 def testCheckNestedOutputStruct(self): 962 with self.cached_session() as sess: 963 x = constant_op.constant(10) 964 y = constant_op.constant(200) 965 pred = math_ops.less(1, 2) 966 fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)} 967 fn2 = lambda: {"c": y, "d": y} 968 v1_msg = "The two structures don't have the same nested structure" 969 v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same " 970 "number, type, and overall structure of return values.") 971 with self.assertRaisesRegex( 972 TypeError if control_flow_util.ENABLE_CONTROL_FLOW_V2 else ValueError, 973 v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg): 974 control_flow_ops.cond(pred, fn1, fn2) 975 976 @test_util.run_deprecated_v1 977 def testCondRef(self): 978 979 with self.cached_session(): 980 x = gen_state_ops.variable( 981 shape=[1], 982 dtype=dtypes.float32, 983 name="x", 984 container="", 985 shared_name="") 986 true_fn = lambda: x 987 false_fn = lambda: constant_op.constant([2.0]) 988 r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn) 989 self.assertAllEqual([2.0], self.evaluate(r)) 990 991 @test_util.run_v1_only("b/120545219") 992 def testCondWithControl(self): 993 with self.cached_session() as sess: 994 control_holder = array_ops.placeholder(dtypes.float32, shape=()) 995 a = constant_op.constant(3) 996 997 def true_branch(): 998 with ops.control_dependencies([control_holder]): 999 _ = a + 1 1000 return a + 2 1001 1002 r = control_flow_ops.cond( 1003 constant_op.constant(True), true_branch, 1004 lambda: constant_op.constant(1)) 1005 result = sess.run(r, feed_dict={control_holder: 5.}) 1006 self.assertEqual(5, result) 1007 1008 @test_util.run_v1_only("b/120545219") 1009 def testUninitializedRefIdentity(self): 1010 with self.cached_session() as sess: 1011 v = gen_state_ops.variable( 1012 shape=[1], 1013 dtype=dtypes.float32, 1014 name="v", 1015 container="", 1016 shared_name="") 1017 inited = state_ops.is_variable_initialized(v) 1018 v_f, v_t = control_flow_ops.ref_switch(v, inited) 1019 # Both v_f and v_t are uninitialized references. However, an actual use 1020 # of the reference in the 'true' branch in the 'tf.identity' op will 1021 # not 'fire' when v is uninitialized, so this is a valid construction. 1022 # This test tests that ref_identity allows uninitialized ref as input 1023 # so that this construction is allowed. 1024 v_f_op = gen_array_ops.ref_identity(v_f) 1025 v_t_op = gen_array_ops.ref_identity(v_t) 1026 with ops.control_dependencies([v_f_op]): 1027 assign_v = state_ops.assign(v, [1.0]) 1028 with ops.control_dependencies([v_t_op]): 1029 orig_v = array_ops.identity(v) 1030 merged_op = control_flow_ops.merge([assign_v, orig_v]) 1031 self.assertAllEqual([1.0], self.evaluate(merged_op.output)) 1032 1033 def testCondSwitchIdentity(self): 1034 # Make sure the recv identity is not removed by optimization. 1035 with session.Session(config=opt_cfg()) as sess: 1036 pred = constant_op.constant(True) 1037 1038 def fn1(): 1039 return control_flow_ops.no_op() 1040 1041 def fn2(): 1042 return control_flow_ops.Assert(False, ["Wrong branch!!!"]) 1043 1044 r = control_flow_ops.cond(pred, fn1, fn2) 1045 self.evaluate(r) 1046 1047 def testCondRecvIdentity(self): 1048 # Make sure the switch identity is not removed by optimization. 1049 with session.Session(config=opt_cfg()) as sess: 1050 with ops.device(test.gpu_device_name()): 1051 pred = constant_op.constant(True) 1052 1053 def fn1(): 1054 return control_flow_ops.no_op() 1055 1056 def fn2(): 1057 with ops.device("/cpu:0"): 1058 return control_flow_ops.Assert(False, ["Wrong branch!!!"]) 1059 1060 r = control_flow_ops.cond(pred, fn1, fn2) 1061 self.evaluate(r) 1062 1063 @test_util.run_deprecated_v1 1064 @test_util.enable_control_flow_v2 1065 def testDisableLoweringSwitchMerge(self): 1066 if test_util.is_gpu_available(): 1067 self.skipTest( 1068 "Single threaded executor doesn't support partitioned graphs. " 1069 "Skipping GPU test.") 1070 # Make pred feedable to ensure we don't constant-fold it out. 1071 run_opts = config_pb2.RunOptions( 1072 trace_level=config_pb2.RunOptions.FULL_TRACE) 1073 run_metadata_no_lowering = config_pb2.RunMetadata() 1074 run_metadata_with_lowering = config_pb2.RunMetadata() 1075 1076 config = opt_cfg(do_constant_folding=False) 1077 1078 pred = array_ops.placeholder_with_default( 1079 constant_op.constant(True), shape=()) 1080 r = control_flow_ops.cond(pred, lambda: True, lambda: False) 1081 1082 with session.Session(config=config) as sess: 1083 r_value = sess.run( 1084 r, options=run_opts, run_metadata=run_metadata_with_lowering) 1085 self.assertEqual(r_value, True) 1086 1087 # Use the single threaded executor, which disables control flow lowering. 1088 config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR" 1089 with session.Session(config=config) as sess: 1090 r_value = sess.run( 1091 r, options=run_opts, run_metadata=run_metadata_no_lowering) 1092 self.assertEqual(r_value, True) 1093 1094 self.assertTrue( # pylint: disable=g-complex-comprehension 1095 any("switch" in ns.node_name 1096 for dev_stat in run_metadata_with_lowering.step_stats.dev_stats 1097 for ns in dev_stat.node_stats)) 1098 1099 self.assertTrue( # pylint: disable=g-complex-comprehension 1100 all("switch" not in ns.node_name 1101 for dev_stat in run_metadata_no_lowering.step_stats.dev_stats 1102 for ns in dev_stat.node_stats)) 1103 1104 @test_util.run_v1_only("b/120545219") 1105 def testCondGrad_1(self): 1106 with self.cached_session(): 1107 x = constant_op.constant(10.0, name="x") 1108 pred = math_ops.less(1, 2) 1109 fn1 = lambda: array_ops.identity(x) 1110 fn2 = lambda: array_ops.identity(x) 1111 r = control_flow_ops.cond(pred, fn1, fn2) 1112 1113 grad = gradients_impl.gradients(r, [x])[0] 1114 self.assertAllEqual(1.0, self.evaluate(grad)) 1115 1116 @test_util.run_deprecated_v1 1117 @test_util.enable_control_flow_v2 1118 def testCondComputeGradAfterSessRunFails(self): 1119 with self.cached_session(): 1120 x = constant_op.constant(10.0, name="x") 1121 pred = math_ops.less(1, 2) 1122 1123 def true_fn(): 1124 a = x * x 1125 return a * a 1126 1127 def false_fn(): 1128 return x * x 1129 1130 r = control_flow_ops.cond(pred, true_fn, false_fn) 1131 1132 self.assertAllEqual(r, 10000.) 1133 grad = gradients_impl.gradients(r, [x])[0] 1134 with self.assertRaisesRegex( 1135 errors_impl.InvalidArgumentError, 1136 r"Connecting to invalid output 1 of source node cond which has 1 " 1137 r"outputs. Try using " 1138 "tf.compat.v1.experimental.output_all_intermediates\(True\)."): 1139 self.evaluate(grad) 1140 1141 @test_util.run_deprecated_v1 1142 @test_util.enable_output_all_intermediates 1143 def testCondComputeGradAfterSessRun(self): 1144 with self.cached_session(): 1145 x = constant_op.constant(10.0, name="x") 1146 pred = math_ops.less(1, 2) 1147 1148 def true_fn(): 1149 a = x * x 1150 return a * a 1151 1152 def false_fn(): 1153 return x * x 1154 1155 r = control_flow_ops.cond(pred, true_fn, false_fn) 1156 1157 self.assertAllEqual(r, 10000.) 1158 grad = gradients_impl.gradients(r, [x])[0] 1159 self.assertAllEqual(grad, 4000.) 1160 1161 @test_util.run_deprecated_v1 1162 @test_util.enable_output_all_intermediates 1163 def testNestedCondComputeGradAfterSessRun(self): 1164 with self.cached_session(): 1165 x = constant_op.constant(10.0, name="x") 1166 pred = math_ops.less(1, 2) 1167 1168 def true_fn(): 1169 1170 def inner_true_fn(): 1171 a = x * x 1172 return a * a 1173 1174 def inner_false_fn(): 1175 return x * x 1176 1177 return control_flow_ops.cond( 1178 constant_op.constant(True), inner_true_fn, inner_false_fn) 1179 1180 def false_fn(): 1181 return x * x 1182 1183 r = control_flow_ops.cond(pred, true_fn, false_fn) 1184 1185 self.assertAllEqual(r, 10000.) 1186 grad = gradients_impl.gradients(r, [x])[0] 1187 self.assertAllEqual(grad, 4000.) 1188 1189 @test_util.run_deprecated_v1 1190 def testCondGrad_2(self): 1191 with self.cached_session(): 1192 c = array_ops.placeholder(dtypes.int32, shape=[]) 1193 x = constant_op.constant(10.0) 1194 pred = math_ops.less(c, 2) 1195 fn1 = lambda: math_ops.multiply(x, 42.0) 1196 fn2 = lambda: math_ops.multiply(x, 3.0) 1197 r = control_flow_ops.cond(pred, fn1, fn2) 1198 1199 grad = gradients_impl.gradients(r, [x])[0] 1200 self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1})) 1201 self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3})) 1202 1203 @test_util.disable_control_flow_v2( 1204 "b/110550782 (gradient w.r.t external variable)") 1205 @test_util.run_deprecated_v1 1206 def testCondGrad_3(self): 1207 with self.cached_session(): 1208 c = array_ops.placeholder(dtypes.int32, shape=[]) 1209 ox = constant_op.constant(10.0) 1210 pred = math_ops.less(c, 2) 1211 1212 def fn1(x): 1213 m = x * x 1214 return gradients_impl.gradients(m, [ox])[0] 1215 1216 fn2 = lambda: math_ops.multiply(ox, 3.0) 1217 y = math_ops.multiply(7.0, ox) 1218 r = control_flow_ops.cond(pred, lambda: fn1(y), fn2) 1219 1220 self.assertAllEqual(980.0, r.eval(feed_dict={c: 1})) 1221 self.assertAllEqual(30.0, r.eval(feed_dict={c: 3})) 1222 1223 @test_util.run_deprecated_v1 1224 def testCondGradMultiDevice(self): 1225 config = config_pb2.ConfigProto(device_count={"CPU": 2}, 1226 allow_soft_placement=True) 1227 with self.cached_session(config=config) as sess: 1228 pred = array_ops.placeholder(dtypes.bool, []) 1229 x = array_ops.placeholder(dtypes.float32) 1230 y = array_ops.placeholder(dtypes.float32) 1231 1232 with ops.device("/cpu:0"): 1233 z = control_flow_ops.cond(pred, lambda: x * y * 2.0, lambda: 2.0) 1234 1235 with ops.device("/cpu:1"): 1236 grad = gradients_impl.gradients(z, x)[0] 1237 1238 with ops.device("/cpu:0"): 1239 grad_grad = gradients_impl.gradients(grad, x)[0] 1240 1241 self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0) 1242 self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0) 1243 1244 # v1 control flow gets None second derivative for some reason. 1245 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 1246 self.assertIsNone(grad_grad) 1247 return 1248 1249 self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0) 1250 self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0) 1251 1252 @test_util.run_v1_only("b/120545219") 1253 def testNestedCond_Simple(self): 1254 with self.cached_session(): 1255 x = constant_op.constant(0., name="X") 1256 y = control_flow_ops.cond( 1257 constant_op.constant(True), lambda: x, 1258 lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x)) 1259 result = gradients_impl.gradients(y, x)[0] 1260 self.assertEqual(1.0, self.evaluate(result)) 1261 1262 z = control_flow_ops.cond( 1263 constant_op.constant(False), lambda: x, 1264 lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x)) 1265 result = gradients_impl.gradients(z, x)[0] 1266 self.assertEqual(1.0, self.evaluate(result)) 1267 1268 @test_util.run_v1_only("b/120545219") 1269 def testCondGrad_Gather(self): 1270 with self.cached_session() as sess: 1271 v1 = variables.Variable([1.0, 42.0]) 1272 c = array_ops.placeholder(dtypes.int32, shape=[]) 1273 pred = math_ops.less(c, 2) 1274 fn1 = lambda: array_ops.identity(v1) 1275 fn2 = lambda: array_ops.gather(v1, [1, 1]) 1276 r = control_flow_ops.cond(pred, fn1, fn2) 1277 # The following `grad` is a Tensor since it is the aggregation of an 1278 # IndexedSlice and a Tensor. It is an `IndexedSlices` with control flow 1279 # v2. 1280 grad = gradients_impl.gradients(r, [v1])[0] 1281 self.evaluate(variables.global_variables_initializer()) 1282 1283 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1284 self.assertIsInstance(grad, ops.IndexedSlices) 1285 1286 grad_value = sess.run(grad, feed_dict={c: 1}) 1287 self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [1.0, 1.0]) 1288 1289 grad_value = sess.run(grad, feed_dict={c: 3}) 1290 self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [0.0, 2.0]) 1291 1292 @test_util.run_deprecated_v1 1293 def testCondGrad_ResourceVarSparseRead(self): 1294 # NOTE(skyewm): this test is interesting because the 1295 # ResourceVariable.sparse_read gradient function returns IndexedSlices. 1296 var = resource_variable_ops.ResourceVariable( 1297 np.ones((4, 2), dtype=np.float32)) 1298 x = constant_op.constant(1.0) 1299 r = control_flow_ops.cond( 1300 constant_op.constant(True), 1301 lambda: x * math_ops.reduce_sum(var.sparse_read([1, 2])), 1302 lambda: constant_op.constant(np.zeros((2, 3)), 1303 dtype=dtypes.float32)) 1304 grad = gradients_impl.gradients(r, var)[0] 1305 1306 self.evaluate(variables.global_variables_initializer()) 1307 grad_val = self.evaluate(grad) 1308 self.assertIsInstance(grad_val, ops.IndexedSlicesValue) 1309 self.assertAllEqual(gradient_checker_v2._to_numpy(grad_val), [[0., 0.], 1310 [1., 1.], 1311 [1., 1.], 1312 [0., 0.]]) 1313 1314 def testCondGrad_MultiGather(self): 1315 # NOTE(skyewm): this test is interesting because the array_ops.gather and 1316 # ResourceVariable.sparse_read gradient functions returns IndexedSlices. 1317 var = resource_variable_ops.ResourceVariable( 1318 np.ones((4, 2), dtype=np.float32)) 1319 x1 = constant_op.constant(np.ones((3, 3), dtype=np.float32)) 1320 x2 = constant_op.constant(2.0) 1321 1322 def true_fn(): 1323 y1 = var.sparse_read([1, 2]) 1324 y2 = array_ops.gather(x1, [2]) * x2 1325 y3 = x2 * [1., 1., 1.] 1326 return y1, y2, y3 1327 1328 def false_fn(): 1329 y1 = np.zeros((2, 2), dtype=np.float32) 1330 y2 = array_ops.gather(x1, [2]) * x2 1331 y3 = array_ops.gather(x1, [2]) 1332 return y1, y2, y3 1333 1334 @def_function.function 1335 def foo(): 1336 r = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn) 1337 return gradients_impl.gradients(r, [var, x1, x2]) 1338 1339 grad = foo() 1340 self.evaluate(variables.global_variables_initializer()) 1341 var_grad, x1_grad, x2_grad = self.evaluate(grad) 1342 self.assertIsInstance(var_grad, ops.IndexedSlicesValue) 1343 self.assertAllEqual(gradient_checker_v2._to_numpy(var_grad), [[0., 0.], 1344 [1., 1.], 1345 [1., 1.], 1346 [0., 0]]) 1347 self.assertIsInstance(x1_grad, ops.IndexedSlicesValue) 1348 self.assertAllEqual(gradient_checker_v2._to_numpy(x1_grad), [[0., 0., 0.], 1349 [0., 0., 0.], 1350 [2., 2., 2.]]) 1351 self.assertIsInstance(x1_grad, ops.IndexedSlicesValue) 1352 self.assertEqual(gradient_checker_v2._to_numpy(x2_grad), 6.) 1353 1354 @test_util.run_v1_only("b/120545219") 1355 def testCondPredicateTensor(self): 1356 """Regression test for lowering predicate from non-first output of an op.""" 1357 1358 @eager_function.defun 1359 def foo(): 1360 return constant_op.constant("foo"), constant_op.constant(True) 1361 1362 r = control_flow_ops.cond(foo()[1], lambda: 1.0, lambda: 2.0) 1363 self.assertEqual(self.evaluate(r), 1.0) 1364 1365 @test_util.run_v1_only("Tests Session.run() pruning logic.") 1366 def testCondFeedConstantPredicate(self): 1367 with self.cached_session() as sess: 1368 value = constant_op.constant(37.0) 1369 predicate = constant_op.constant(True) 1370 cond_output = control_flow_ops.cond( 1371 predicate, lambda: constant_op.constant(0.0), lambda: value) 1372 result = array_ops.identity(cond_output) 1373 self.assertEqual(37.0, sess.run(result, feed_dict={predicate: False})) 1374 self.assertEqual(0.0, sess.run(result, feed_dict={predicate: True})) 1375 self.assertEqual(0.0, sess.run(result)) 1376 1377 @test_util.run_v1_only("Tests Session.run() pruning logic.") 1378 def testCondFeedPlaceholderWithDefaultPredicate(self): 1379 with self.cached_session() as sess: 1380 value = constant_op.constant(37.0) 1381 predicate = array_ops.placeholder_with_default( 1382 constant_op.constant(True), []) 1383 cond_output = control_flow_ops.cond( 1384 predicate, lambda: constant_op.constant(0.0), lambda: value) 1385 result = array_ops.identity(cond_output) 1386 self.assertAllEqual(37.0, sess.run(result, feed_dict={predicate: False})) 1387 self.assertAllEqual(0.0, sess.run(result, feed_dict={predicate: True})) 1388 self.assertAllEqual(0.0, sess.run(result)) 1389 1390 @test_util.run_in_graph_and_eager_modes 1391 def testCondAutoControlDeps(self): 1392 if test_util.is_gpu_available(): 1393 self.skipTest("b/128676188 causes OOM on opensource gpu tests") 1394 1395 print_prefix = "testCondAutoControlDeps: " 1396 1397 def branch_fn(): 1398 enqueue_print_op("A") 1399 enqueue_print_op("B") 1400 with ops.control_dependencies([enqueue_print_op("C")]): 1401 return constant_op.constant(10) 1402 1403 def build_cond(): 1404 return control_flow_ops.cond( 1405 constant_op.constant(True), branch_fn, lambda: 0) 1406 1407 def build_nested_cond(): 1408 return control_flow_ops.cond( 1409 constant_op.constant(True), build_cond, lambda: 0) 1410 1411 # In v1 graph mode, pruning should make only "C" print. 1412 if not context.executing_eagerly(): 1413 with self.cached_session(): 1414 with self.captureWritesToStream(sys.stderr) as printed: 1415 self.assertEqual(self.evaluate(build_cond()), 10) 1416 self.assertEqual(["C"], filter_test_messages(printed.contents())) 1417 1418 with self.captureWritesToStream(sys.stderr) as printed: 1419 self.assertEqual(self.evaluate(build_nested_cond()), 10) 1420 self.assertEqual(["C"], filter_test_messages(printed.contents())) 1421 1422 # In defuns, all prints should execute in program order. 1423 # This doesn't work with legacy control flow. 1424 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1425 1426 @eager_function.defun 1427 def cond(): 1428 return build_cond() 1429 1430 with self.captureWritesToStream(sys.stderr) as printed: 1431 self.assertEqual(self.evaluate(cond()), 10) 1432 self.assertEqual(["A", "B", "C"], 1433 filter_test_messages(printed.contents())) 1434 1435 @eager_function.defun 1436 def nested_cond(): 1437 return build_nested_cond() 1438 1439 with self.captureWritesToStream(sys.stderr) as printed: 1440 self.assertEqual(self.evaluate(nested_cond()), 10) 1441 self.assertEqual(["A", "B", "C"], 1442 filter_test_messages(printed.contents())) 1443 1444 # wrap_function should prune. 1445 def pruned_cond(): 1446 return build_cond() 1447 pruned_cond = wrap_function.wrap_function(pruned_cond, []) 1448 1449 with self.captureWritesToStream(sys.stderr) as printed: 1450 self.assertEqual(self.evaluate(pruned_cond()), 10) 1451 self.assertEqual(["C"], filter_test_messages(printed.contents())) 1452 1453 def pruned_nested_cond(): 1454 return build_nested_cond() 1455 pruned_nested_cond = wrap_function.wrap_function(pruned_nested_cond, []) 1456 1457 with self.captureWritesToStream(sys.stderr) as printed: 1458 self.assertEqual(self.evaluate(pruned_nested_cond()), 10) 1459 self.assertEqual(["C"], filter_test_messages(printed.contents())) 1460 1461 1462 @test_util.run_in_graph_and_eager_modes 1463 @test_util.disable_tfrt("b/179459136") 1464 def testWhileAutoControlDeps(self): 1465 # Legacy while_loop fails this test because it produces deprecation notices 1466 # in stderr. 1467 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return 1468 1469 def cond(i, unused_x): 1470 enqueue_print_op("A") 1471 return i < 2 1472 1473 def body(i, x): 1474 enqueue_print_op("B") 1475 with ops.control_dependencies([enqueue_print_op("C")]): 1476 x = array_ops.identity(x) 1477 with ops.control_dependencies([enqueue_print_op("D")]): 1478 return i + 1, x 1479 1480 def build_while(): 1481 return control_flow_ops.while_loop( 1482 cond, body, [constant_op.constant(0), constant_op.constant(0)]) 1483 1484 def build_nested_while(): 1485 return control_flow_ops.cond( 1486 constant_op.constant(True), build_while, lambda: [0, 0]) 1487 1488 # In v1 graph mode, pruning should make only "D" print. 1489 if not context.executing_eagerly(): 1490 with self.cached_session(): 1491 with self.captureWritesToStream(sys.stderr) as printed: 1492 self.assertEqual(self.evaluate(build_while()[0]), 2) 1493 self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) 1494 1495 with self.captureWritesToStream(sys.stderr) as printed: 1496 self.assertEqual(self.evaluate(build_nested_while()[0]), 2) 1497 self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) 1498 1499 # In defuns, all prints should execute in program order. 1500 @eager_function.defun 1501 def while_loop(): 1502 return build_while()[0] 1503 1504 with self.captureWritesToStream(sys.stderr) as printed: 1505 self.assertEqual(self.evaluate(while_loop()), 2) 1506 self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"], 1507 filter_test_messages(printed.contents())) 1508 1509 @eager_function.defun 1510 def nested_while_loop(): 1511 return build_nested_while()[0] 1512 1513 with self.captureWritesToStream(sys.stderr) as printed: 1514 self.assertEqual(self.evaluate(nested_while_loop()), 2) 1515 self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"], 1516 filter_test_messages(printed.contents())) 1517 1518 # wrap_function should prune. 1519 def pruned_while(): 1520 return build_while()[0] 1521 pruned_while = wrap_function.wrap_function(pruned_while, []) 1522 1523 with self.captureWritesToStream(sys.stderr) as printed: 1524 self.assertEqual(self.evaluate(pruned_while()), 2) 1525 self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) 1526 1527 def pruned_nested_while(): 1528 return build_nested_while()[0] 1529 pruned_nested_while = wrap_function.wrap_function(pruned_nested_while, []) 1530 1531 with self.captureWritesToStream(sys.stderr) as printed: 1532 self.assertEqual(self.evaluate(pruned_nested_while()), 2) 1533 self.assertEqual(["D", "D"], filter_test_messages(printed.contents())) 1534 1535 # Microbenchmark: 256,000 iterations/s. 1536 def testWhile_1(self): 1537 with self.cached_session(): 1538 n = constant_op.constant(0) 1539 c = lambda x: math_ops.less(x, 10000) 1540 b = lambda x: math_ops.add(x, 1) 1541 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 1542 self.assertEqual(10000, self.evaluate(r)) 1543 1544 @test_util.run_v1_only("b/120545219") 1545 def testWhileExternalControlDependencies(self): 1546 with self.cached_session(): 1547 v = variables.Variable(0.0) 1548 self.evaluate(v.initializer) 1549 increment = v.assign_add(1.0).read_value() 1550 1551 def body_fn(i): 1552 with ops.control_dependencies([increment]): 1553 return i + 1 1554 1555 result = control_flow_ops.while_loop(cond=lambda i: i < 2, 1556 body=body_fn, loop_vars=[1]) 1557 self.assertAllEqual(result, 2) 1558 self.assertAllEqual(v.read_value(), 1.0) 1559 1560 @test_util.run_v1_only("b/120545219") 1561 def testWhileExternalControlDependenciesNoInput(self): 1562 with self.cached_session(): 1563 v = variables.Variable(0.0) 1564 self.evaluate(v.initializer) 1565 # TODO(apassos): figure out why the reading is necessary here. 1566 increment = v.assign_add(1.0).read_value() 1567 1568 def body_fn(unused_i): 1569 with ops.control_dependencies([increment]): 1570 return constant_op.constant(5, name="five") 1571 1572 result = control_flow_ops.while_loop(cond=lambda i: i < 5, 1573 body=body_fn, loop_vars=[0]) 1574 self.evaluate(result) 1575 self.assertAllEqual(self.evaluate(v), 1.0) 1576 1577 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 1578 @test_util.run_v1_only("b/120545219") 1579 def testWhileWithRefs_1(self): 1580 with self.cached_session() as sess: 1581 x = variables.VariableV1(0)._ref() # pylint: disable=protected-access 1582 i = constant_op.constant(0) 1583 c = lambda i, x: math_ops.less(i, 100) 1584 1585 self.assertEqual(x.dtype, dtypes.int32_ref) 1586 1587 def b(i, x): 1588 self.assertEqual(x.dtype, dtypes.int32_ref) 1589 return (i + 1, gen_array_ops.ref_identity(x)) 1590 1591 r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5) 1592 1593 self.evaluate(variables.global_variables_initializer()) 1594 1595 self.assertEqual(r[0].dtype, dtypes.int32) 1596 self.assertEqual(r[1].dtype, dtypes.int32_ref) 1597 1598 value_i, value_x = self.evaluate(r) 1599 1600 self.assertEqual(100, value_i) 1601 self.assertEqual(0, value_x) 1602 1603 def testWhile_2(self): 1604 with self.cached_session(): 1605 s = constant_op.constant(0) 1606 r = isum(s) 1607 self.assertAllEqual(45, self.evaluate(r)) 1608 1609 def testWhileWithMaximumIterations(self): 1610 with self.cached_session(): 1611 s = constant_op.constant([1, 2, 3, 4, 5]) 1612 r = isum(s, maximum_iterations=3) 1613 self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], self.evaluate(r)) 1614 1615 @test_util.run_v1_only("b/120545219") 1616 def testWhileWithMaximumIterationsAndSingleArgument(self): 1617 with self.cached_session(): 1618 r = control_flow_ops.while_loop( 1619 lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1) 1620 self.assertEqual(1, self.evaluate(r)) 1621 1622 @test_util.run_v1_only("b/120545219") 1623 def testXLAGradInLoop(self): 1624 # We have an optimization that moves certain reduction ops, this test makes 1625 # sure we don't do that for XLA ops. 1626 1627 # Use dynamic inputs, which triggers the creation of "BroadcastGradientArgs" 1628 # and "Shape" op. 1629 input1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) 1630 input2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) 1631 def cond(i1, i2): 1632 return False 1633 1634 def body(i1, i2): 1635 return math_ops.add(i1, i2), math_ops.add(i1, i2) 1636 1637 xla_context = control_flow_ops.XLAControlFlowContext() 1638 xla_context.Enter() 1639 1640 out1, _ = control_flow_ops.while_loop( 1641 cond, body, (input1, input2), maximum_iterations=2) 1642 g = gradients_impl.gradients(out1, [input1]) 1643 1644 for op in out1.graph.get_operations(): 1645 # Test that the "Shape" is directly passed to BroadcastGradientArgs 1646 # instead of being pushed to the stack. 1647 if op.type == "BroadcastGradientArgs": 1648 self.assertEqual(op.inputs[0].op.type, "Shape") 1649 self.assertEqual(op.inputs[1].op.type, "Shape") 1650 xla_context.Exit() 1651 1652 1653 @test_util.disable_control_flow_v2("b/115776323 (max_iters)") 1654 @test_util.run_v1_only("b/120545219") 1655 def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self): 1656 v = constant_op.constant(1.0) 1657 1658 def training_loop_with_gradient(i): 1659 out = control_flow_ops.while_loop( 1660 lambda i_, _: i_ < 3, 1661 lambda i_, j: [i_ + 1, j * v], [0, 1.0], 1662 maximum_iterations=i) 1663 g = gradients_impl.gradients(out, v) 1664 with ops.control_dependencies(g): 1665 return i + 1 1666 1667 xla_context = control_flow_ops.XLAControlFlowContext() 1668 xla_context.Enter() 1669 # Create training loop, ensure we can call gradient() of 1670 # while_loop inside the training loop. 1671 loop = control_flow_ops.while_loop(lambda i: i < 3, 1672 training_loop_with_gradient, [0]) 1673 xla_context.Exit() 1674 1675 loop_execute = array_ops.identity(loop) # Because loop is not fetchable. 1676 1677 # Should execute without issue. 1678 self.assertEqual(3, self.evaluate(loop_execute)) 1679 1680 @test_util.run_v1_only("b/120545219") 1681 def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self): 1682 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1683 self.skipTest("WhileV2 does lazy evaluation of maximum_iterations") 1684 v = constant_op.constant(1.0) 1685 1686 def inner_body(i, x): 1687 out = control_flow_ops.while_loop( 1688 lambda i, _: i < 3, 1689 lambda i, j: [i + 1, j * v], [0, x], 1690 maximum_iterations=i) 1691 return out 1692 1693 def create_while_loop(maximum_iterations=None): 1694 return control_flow_ops.while_loop( 1695 lambda i, _: i < 3, 1696 inner_body, [0, 1.0], 1697 maximum_iterations=maximum_iterations) 1698 1699 loop_no_xla = create_while_loop(maximum_iterations=5) 1700 # maximum_iterations is fine outside of an XLA scope 1701 gs = gradients_impl.gradients(loop_no_xla, v) 1702 self.evaluate(gs) # This should execute without error. 1703 1704 xla_context = control_flow_ops.XLAControlFlowContext() 1705 xla_context.Enter() 1706 loop_no_maxiter = create_while_loop() 1707 loop_with_maxiter = create_while_loop(maximum_iterations=2) 1708 xla_context.Exit() 1709 1710 with self.assertRaisesRegex( 1711 ValueError, 1712 r"Cannot create a gradient accumulator for tensor '.+' inside " 1713 r"XLA while_loop because maximum_iterations was not passed to " 1714 r"the tf.while_loop call \('.+'\)."): 1715 _ = gradients_impl.gradients(loop_no_maxiter, v) 1716 1717 with self.assertRaisesRegex( 1718 ValueError, 1719 r"Cannot create a gradient accumulator for tensor '.+' inside XLA " 1720 r"while_loop. maximum_iterations tensor '.+' for while_loop context " 1721 r"'.+' must be statically known \(e.g. a constant value or known " 1722 r"shape dimension\), or be defined at or outside the while loop " 1723 r"context '.*' \(currently defined in '.*'\)"): 1724 _ = gradients_impl.gradients(loop_with_maxiter, v) 1725 1726 @test_util.run_v1_only("b/120545219") 1727 def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): 1728 v = constant_op.constant(1.0) 1729 1730 def create_while_loop(): 1731 max_iter_holder = [] 1732 1733 def create_mi(): 1734 max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=())) 1735 return 1.0 1736 1737 _ = control_flow_ops.cond( 1738 constant_op.constant(True), create_mi, create_mi) 1739 1740 return control_flow_ops.while_loop( 1741 lambda i, _: i < 3, 1742 lambda i, x: (i + 1, v * x), (0, 1.0), 1743 maximum_iterations=max_iter_holder[0]) 1744 1745 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1746 xla_context = control_flow_ops.XLAControlFlowContext() 1747 xla_context.Enter() 1748 with self.assertRaisesRegex(ValueError, r"must be from the same graph.*"): 1749 loop = create_while_loop() 1750 xla_context.Exit() 1751 else: 1752 xla_context = control_flow_ops.XLAControlFlowContext() 1753 xla_context.Enter() 1754 loop = create_while_loop() 1755 xla_context.Exit() 1756 with self.assertRaisesRegex( 1757 ValueError, 1758 r"Cannot create a gradient accumulator for tensor '.+' inside XLA " 1759 r"while_loop. maximum_iterations tensor '.*Placeholder:0' for " 1760 r"while_loop context '.+' must be statically known \(e.g. a constant " 1761 r"value or known shape dimension\), or be defined at or outside the " 1762 r"while loop context '' \(currently defined in 'cond/.+'\)"): 1763 _ = gradients_impl.gradients(loop, v) 1764 1765 @test_util.run_v1_only("b/120545219") 1766 def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self): 1767 if test_util.is_gpu_available(): 1768 self.skipTest("b/128646372, b/128645947 fails in opensource build") 1769 1770 v = constant_op.constant(1.0) 1771 1772 p = array_ops.placeholder(dtype=dtypes.int32) 1773 1774 def mid_body_builder(iterations): 1775 1776 def mid_body(i, x): 1777 r = control_flow_ops.while_loop( 1778 lambda *_: True, 1779 lambda i, x: (i + 1, v * x), (0, x), 1780 maximum_iterations=iterations, 1781 name="inner") 1782 return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) 1783 1784 return mid_body 1785 1786 def outer_body(i, x): 1787 iterations = array_ops.size(p, name="iterations") 1788 return (i + 1, x + control_flow_ops.while_loop( 1789 lambda *_: True, 1790 mid_body_builder(iterations), (0, x), 1791 maximum_iterations=iterations, 1792 name="mid")[1]) 1793 1794 def create_while_loop(): 1795 with ops.device("/cpu:0"): 1796 r = control_flow_ops.while_loop( 1797 lambda *_: True, 1798 outer_body, (0, 1.0), 1799 maximum_iterations=5, 1800 name="outer") 1801 return array_ops.identity(r[1]) 1802 1803 xla_context = control_flow_ops.XLAControlFlowContext() 1804 xla_context.Enter() 1805 final_with_xla_context = create_while_loop() 1806 xla_context.Exit() 1807 1808 final_without_xla_context = create_while_loop() 1809 1810 with self.session(use_gpu=False) as sess: 1811 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 1812 run_metadata_without_xla_context = config_pb2.RunMetadata() 1813 run_metadata = config_pb2.RunMetadata() 1814 1815 final_value_without_xla_context = sess.run( 1816 final_without_xla_context, 1817 feed_dict={p: [0, 0, 0]}, 1818 options=opts, 1819 run_metadata=run_metadata_without_xla_context) 1820 1821 final_value_with_xla_context = sess.run( 1822 final_with_xla_context, 1823 feed_dict={p: [0, 0, 0]}, 1824 options=opts, 1825 run_metadata=run_metadata) 1826 1827 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1828 # With while_v2 on xla, run_metadata only contains the unlowered While 1829 # op so node_stats does not have statistics for the pushes. So as a 1830 # loose check we check the pushes in the lowered version. 1831 for dev in run_metadata_without_xla_context.step_stats.dev_stats: 1832 if "/device:CPU" in dev.device: 1833 node_stats = dev.node_stats 1834 stack_push_count = len([ 1835 x for x in node_stats 1836 if re.match(r".*TensorListPushBack_?\d*", x.node_name) 1837 ]) 1838 else: 1839 for dev in run_metadata.step_stats.dev_stats: 1840 if "/device:CPU" in dev.device: 1841 node_stats = dev.node_stats 1842 stack_push_op = "StackPushV2" 1843 stack_push_count = len( 1844 [x for x in node_stats if x.node_name.endswith("StackPushV2")]) 1845 # Pushes to the stack = product of maximum_iterations values; 1846 # the last two "3"s comes from size(p), when p == [0, 0, 0]. 1847 self.assertEqual(stack_push_count, 5 * 3 * 3, str(node_stats)) 1848 1849 self.assertAllClose(final_value_with_xla_context, 1850 final_value_without_xla_context) 1851 1852 # Have more than 10 parallel iterations and hence exercise k-bound 1853 # most of the time. 1854 @test_util.run_deprecated_v1 1855 def testWhile_3(self): 1856 with self.cached_session(): 1857 1858 def compute(i, m, c, o): 1859 m, c = [math_ops.add(m, 1), math_ops.add(c, 1)] 1860 o = math_ops.add(o, m) 1861 o = math_ops.add(o, c) 1862 i = math_ops.add(i, 1) 1863 return [i, m, c, o] 1864 1865 i = ops.convert_to_tensor(0) 1866 m = ops.convert_to_tensor(0) 1867 c = ops.convert_to_tensor(0) 1868 o = ops.convert_to_tensor(0) 1869 d = ops.convert_to_tensor(100) 1870 r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, d), 1871 compute, [i, m, c, o]) 1872 result = r[3] 1873 self.assertAllEqual(10100, result) 1874 1875 @test_util.run_deprecated_v1 1876 def testWhile_4(self): 1877 with self.cached_session(): 1878 1879 def compute(i, m, c, o): 1880 m, c = [array_ops.gather(x, i), array_ops.gather(x, i)] 1881 o = math_ops.add(o, m) 1882 o = math_ops.add(o, c) 1883 i = math_ops.add(i, 1) 1884 return [i, m, c, o] 1885 1886 i = ops.convert_to_tensor(0) 1887 m = ops.convert_to_tensor(0) 1888 c = ops.convert_to_tensor(0) 1889 o = ops.convert_to_tensor(0) 1890 x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) 1891 s = array_ops.size(x) 1892 r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, s), 1893 compute, [i, m, c, o]) 1894 result = r[3] 1895 self.assertAllEqual(42, result) 1896 1897 @test_util.run_v1_only("b/120545219") 1898 def testWhile_5(self): 1899 with self.cached_session(): 1900 1901 def compute(i, c, o): 1902 c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0), 1903 [1] + array_ops.expand_dims(i, 0)) 1904 o = array_ops.concat([o, c], 0) 1905 i = math_ops.add(i, 1) 1906 return [i, c, o] 1907 1908 i = ops.convert_to_tensor(0) 1909 c = ops.convert_to_tensor([0]) 1910 o = ops.convert_to_tensor([0]) 1911 x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) 1912 s = array_ops.size(x) 1913 r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s), 1914 compute, [i, c, o], [ 1915 i.get_shape(), 1916 tensor_shape.unknown_shape(), 1917 tensor_shape.unknown_shape() 1918 ]) 1919 result = r[2] 1920 self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result) 1921 1922 @test_util.run_gpu_only 1923 @test_util.run_deprecated_v1 1924 def testWhile_Device(self): 1925 1926 # Body function defined outside of device scope 1927 def body(x): 1928 return math_ops.exp(x) 1929 1930 with ops.device("CPU:0"): 1931 r = control_flow_ops.while_loop( 1932 lambda x: x < 10, body, [constant_op.constant(-10.)]) 1933 self.assertIn("cpu", r.device.lower()) 1934 1935 with session.Session() as sess: 1936 options = config_pb2.RunOptions(output_partition_graphs=True) 1937 run_metadata = config_pb2.RunMetadata() 1938 sess.run(r, options=options, run_metadata=run_metadata) 1939 # We expect that everything runs on CPU, even if GPU is available. 1940 self.assertEqual(len(run_metadata.partition_graphs), 1) 1941 1942 @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)") 1943 @test_util.run_v1_only("b/120545219") 1944 def testBufferForwarding(self): 1945 run_options = config_pb2.RunOptions( 1946 trace_level=config_pb2.RunOptions.FULL_TRACE) 1947 run_metadata = config_pb2.RunMetadata() 1948 1949 with self.cached_session() as sess: 1950 with ops.device("/cpu:0"): 1951 c = constant_op.constant(2) 1952 i0 = constant_op.constant(0) 1953 r = control_flow_ops.while_loop(lambda i: i < 1000, 1954 lambda i: math_ops.square(c) + i, [i0]) 1955 r_val = sess.run(r, options=run_options, run_metadata=run_metadata) 1956 self.assertEqual(1000, r_val) 1957 self.assertTrue(run_metadata.HasField("step_stats")) 1958 unique_allocs = set() 1959 for node_stat in run_metadata.step_stats.dev_stats[0].node_stats: 1960 for output in node_stat.output: 1961 unique_allocs.add( 1962 output.tensor_description.allocation_description.ptr) 1963 # Prior to cl/147536680, the number of unique allocations was about 1005. 1964 self.assertLess(len(unique_allocs), 756) 1965 1966 def _testWhile_Gpu_1(self, use_gpu): 1967 with self.cached_session(use_gpu=use_gpu): 1968 n = constant_op.constant(1.0) 1969 c = lambda x: math_ops.less(x, 10.0) 1970 b = lambda x: math_ops.add(x, 1.0) 1971 r = control_flow_ops.while_loop(c, b, [n]) 1972 self.assertAllClose(10.0, self.evaluate(r)) 1973 1974 def testWhile_Gpu_1(self): 1975 self._testWhile_Gpu_1(use_gpu=False) 1976 self._testWhile_Gpu_1(use_gpu=True) 1977 1978 def _testWhile_Gpu_2(self, use_gpu): 1979 with self.cached_session(use_gpu=use_gpu): 1980 n = constant_op.constant(1.0) 1981 c = lambda x: math_ops.less(x, 10.0) 1982 1983 def b(x): 1984 with ops.device("/cpu:0"): 1985 return math_ops.add(x, 1.0) 1986 1987 r = control_flow_ops.while_loop(c, b, [n]) 1988 self.assertAllClose(10.0, self.evaluate(r)) 1989 1990 def testWhile_Gpu_2(self): 1991 self._testWhile_Gpu_2(use_gpu=False) 1992 self._testWhile_Gpu_2(use_gpu=True) 1993 1994 def testWhileShape(self): 1995 with self.cached_session(): 1996 i = constant_op.constant(0) 1997 m = array_ops.ones([2, 2]) 1998 c = lambda i, j: math_ops.less(i, 2) 1999 2000 def _b(i, j): 2001 new_i = math_ops.add(i, 1) 2002 new_j = array_ops.tile(j, [2, 2]) 2003 return [new_i, new_j] 2004 2005 r = control_flow_ops.while_loop( 2006 c, _b, [i, m], 2007 [i.get_shape(), tensor_shape.unknown_shape()]) 2008 r = r[1] * array_ops.ones([8, 8]) 2009 self.assertAllEqual(np.ones((8, 8)), self.evaluate(r)) 2010 2011 @test_util.disable_control_flow_v2("b/131265085") 2012 @test_util.run_v1_only("b/131265085") 2013 def testWhileBadShape(self): 2014 x = constant_op.constant([2.0, 4.0], name="values") 2015 i = constant_op.constant(0) 2016 c = lambda i, _: math_ops.less(i, 10) 2017 b = lambda i, x: [i + 1, x + 1] 2018 with self.assertRaisesRegex(ValueError, "is not compatible with"): 2019 # Shape of x is [2], but we specify a shape of [5]. 2020 control_flow_ops.while_loop( 2021 c, b, [i, x], [i.shape, tensor_shape.TensorShape([5])]) 2022 2023 @test_util.run_in_graph_and_eager_modes 2024 def testWhileBadBodyReturn(self): 2025 x = constant_op.constant([2.0, 4.0], name="values") 2026 i = constant_op.constant(0) 2027 c = lambda i, *x: math_ops.less(i, 10) 2028 2029 # body accepts N values and returns N+1 values. 2030 b = lambda i, *x: (i, i) + x 2031 2032 with self.assertRaisesRegex( 2033 ValueError, "The two structures don't have the same nested structure."): 2034 control_flow_ops.while_loop(c, b, [i, x]) 2035 2036 @test_util.run_deprecated_v1 2037 def testWhileWithNonTensorInput_Scalar(self): 2038 with self.cached_session(): 2039 n = 0 2040 c = lambda x: x < 10000 2041 b = lambda x: x + 1 2042 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 2043 self.assertEqual(10000, self.evaluate(r)) 2044 2045 def testWhileWithNonTensorInput_Vector(self): 2046 with self.cached_session(): 2047 n = np.array([0]) # Note, [0] would not work here; that is a list 2048 c = lambda x: x[0] < 10000 2049 b = lambda x: array_ops.stack([x[0] + 1]) 2050 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 2051 self.assertEqual([10000], self.evaluate(r)) 2052 2053 def testWhileShapeInference(self): 2054 with self.cached_session(): 2055 i = constant_op.constant(0) 2056 m = array_ops.ones([2, 2]) 2057 c = lambda i, j: math_ops.less(i, 2) 2058 2059 def b(i, j): 2060 new_i = math_ops.add(i, 1) 2061 new_j = array_ops.concat([j, j], 0) 2062 return [new_i, new_j] 2063 2064 r = control_flow_ops.while_loop( 2065 c, b, [i, m], 2066 [i.get_shape(), tensor_shape.TensorShape([None, 2])]) 2067 self.assertTrue(r[1].shape.is_compatible_with([8, 2])) 2068 2069 @test_util.run_v1_only("b/120545219") 2070 def testWhileShapeInferenceBadShape(self): 2071 with self.cached_session(): 2072 i = constant_op.constant(0) 2073 m = array_ops.ones([2, 2]) 2074 c = lambda i, j: math_ops.less(i, 2) 2075 b = lambda i, j: [i + 1, array_ops.concat([j, j], 0)] 2076 with self.assertRaisesRegex( 2077 ValueError, 2078 r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has " 2079 r"shape \(4, 2\) after one iteration. To allow the shape to vary " 2080 r"across iterations, use the `shape_invariants` argument of " 2081 r"tf.while_loop to specify a less-specific shape."): 2082 control_flow_ops.while_loop(c, b, [i, m]) 2083 2084 def testWhileShapeInferenceSparseTensor(self): 2085 values = constant_op.constant([2.0, 4.0], name="values") 2086 indices = constant_op.constant([[0], [3]], 2087 dtype=dtypes.int64, 2088 name="indices") 2089 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 2090 i = constant_op.constant(0) 2091 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 2092 2093 def c(i, _): 2094 return i < 10 2095 2096 def b1(i, x): # modifies values. (shape of components is not changed.) 2097 return [ 2098 i + 1, 2099 sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) 2100 ] 2101 2102 def b2(i, x): # adds new values. (shape of components is changed.) 2103 return [ 2104 i + 1, 2105 sparse_ops.sparse_add( 2106 x, 2107 sparse_tensor.SparseTensor( 2108 indices=math_ops.cast( 2109 array_ops.fill([1, 1], i), dtypes.int64), 2110 values=array_ops.fill([1], 1.0), 2111 dense_shape=x.dense_shape)) 2112 ] 2113 2114 def b3(i, x): # modifies rank. (shape of all components is changed.) 2115 return [ 2116 i + 1, 2117 sparse_tensor.SparseTensor( 2118 array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0, 2119 array_ops.concat([x.dense_shape, [10]], axis=0)) 2120 ] 2121 2122 def check_shapes(r, indices, values, dense_shape): 2123 self.assertTrue(r.indices.shape.is_compatible_with(indices)) 2124 self.assertTrue(r.values.shape.is_compatible_with(values)) 2125 self.assertTrue(r.dense_shape.shape.is_compatible_with(dense_shape)) 2126 2127 # Default shape invariant; b1 only modifies values. 2128 _, r = control_flow_ops.while_loop(c, b1, [i, x]) 2129 check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1]) 2130 2131 # Default shape invariant; b2 adds new values 2132 _, r = control_flow_ops.while_loop(c, b2, [i, x]) 2133 check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1]) 2134 2135 # Explicit shape invariant, allowing any rank; b1 only modifies values. 2136 _, r = control_flow_ops.while_loop( 2137 c, b1, [i, x], 2138 [i.get_shape(), tensor_shape.TensorShape([None])]) 2139 check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) 2140 2141 # Explicit shape invariant, allowing any rank; b3 modifies rank. 2142 _, r = control_flow_ops.while_loop( 2143 c, b3, [i, x], 2144 [i.get_shape(), tensor_shape.TensorShape([None])]) 2145 check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) 2146 2147 # Shape invariant with ndims=None. Technically, this isn't supported 2148 # according to the docs, but we support it for backwards compatibility. 2149 _, r = control_flow_ops.while_loop( 2150 c, b1, [i, x], 2151 [i.get_shape(), tensor_shape.TensorShape(None)]) 2152 check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) 2153 _, r = control_flow_ops.while_loop( 2154 c, b3, [i, x], 2155 [i.get_shape(), tensor_shape.TensorShape(None)]) 2156 check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) 2157 2158 @test_util.disable_control_flow_v2("b/131265085") 2159 @test_util.run_v1_only("b/131265085") 2160 def testWhileBadShapeSparseTensor(self): 2161 values = constant_op.constant([2.0, 4.0], name="values") 2162 indices = constant_op.constant([[0], [3]], 2163 dtype=dtypes.int64, 2164 name="indices") 2165 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 2166 i = constant_op.constant(0) 2167 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 2168 c = lambda i, _: i < 10 2169 b1 = lambda i, x: [i+1, x] 2170 def b2(i, x): # modifies rank. (shape of all components is changed.) 2171 return [ 2172 i + 1, 2173 sparse_tensor.SparseTensor( 2174 array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0, 2175 array_ops.concat([x.dense_shape, [10]], axis=0)) 2176 ] 2177 2178 # Explicit shape invariant, with a specific (incompatible) rank. 2179 with self.assertRaisesRegex(ValueError, "is not compatible with"): 2180 control_flow_ops.while_loop( 2181 c, b1, [i, x], 2182 [i.get_shape(), tensor_shape.TensorShape([5])]) 2183 2184 # Default shape invariant, but b2 modifies rank (which is not allowed). 2185 with self.assertRaises(ValueError): 2186 control_flow_ops.while_loop(c, b2, [i, x]) 2187 2188 def testWhileShapeInferenceIndexedSlices(self): 2189 with self.cached_session(): 2190 values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values") 2191 indices = constant_op.constant([0, 3], name="indices") 2192 shape = constant_op.constant([10, 2], name="dense_shape") 2193 i = constant_op.constant(0) 2194 x = ops.IndexedSlices(values, indices, dense_shape=shape) 2195 2196 def c(i, _): 2197 return i < 10 2198 2199 def b(i, x): 2200 return [ 2201 i + 1, 2202 ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) 2203 ] 2204 2205 _, r = control_flow_ops.while_loop(c, b, [i, x]) 2206 self.assertEqual(r.dense_shape.get_shape()[0], 2) 2207 self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2])) 2208 2209 _, r = control_flow_ops.while_loop( 2210 c, b, [i, x], 2211 [i.get_shape(), tensor_shape.TensorShape([None, 2])]) 2212 self.assertEqual(r.dense_shape.get_shape()[0], 2) 2213 self.assertTrue(r.values.get_shape().is_compatible_with([None, 2])) 2214 2215 @test_util.disable_control_flow_v2("b/131265085") 2216 @test_util.run_v1_only("b/131265085") 2217 def testWhileBadShapeIndexedSlices(self): 2218 values = constant_op.constant([2.0, 4.0], name="values") 2219 indices = constant_op.constant([[0], [3]], 2220 dtype=dtypes.int64, 2221 name="indices") 2222 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 2223 i = constant_op.constant(0) 2224 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 2225 c = lambda i, _: 10 2226 b = lambda i, x: [i+1, x] 2227 2228 # Explicit shape invariant, with a specific (incompatible) rank. 2229 with self.assertRaisesRegex(ValueError, "is not compatible with"): 2230 control_flow_ops.while_loop( 2231 c, b, [i, x], 2232 [i.get_shape(), tensor_shape.TensorShape([5])]) 2233 2234 def testWhileShapeInferenceRaggedTensor(self): 2235 i = constant_op.constant(0) 2236 x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) 2237 c = lambda i, _: i < 10 2238 2239 def b1(i, x): # Adds new values to rows (but doesn't create new rows) 2240 return [ 2241 i + 1, 2242 array_ops.concat([x, x], axis=1) 2243 ] 2244 2245 def b2(i, x): # Adds new rows. 2246 return [ 2247 i + 1, 2248 array_ops.concat([x, x], axis=0) 2249 ] 2250 2251 def check_shapes(r, values, splits): 2252 self.assertTrue(r.values.shape.is_compatible_with(values)) 2253 self.assertTrue(r.row_splits.shape.is_compatible_with(splits)) 2254 2255 # Default shape invariant; b1 adds new values to rows. 2256 _, r = control_flow_ops.while_loop(c, b1, [i, x]) 2257 check_shapes(r, values=[None], splits=[4]) 2258 2259 # Default shape invariant; b2 adds new rows (not allowed). 2260 if not context.executing_eagerly(): 2261 with self.assertRaises(ValueError): 2262 _, r = control_flow_ops.while_loop(c, b2, [i, x]) 2263 2264 # Explicit shape invariant; b1 adds new values to rows. 2265 # (deprecated: use TensorShape instead of RaggedTensorSpec) 2266 _, r = control_flow_ops.while_loop( 2267 c, b1, [i, x], 2268 [i.get_shape(), tensor_shape.TensorShape([None, None])]) 2269 check_shapes(r, values=[None], splits=[None]) 2270 2271 # Explicit shape invariant; b1 adds new values to rows. 2272 _, r = control_flow_ops.while_loop( 2273 c, b1, [i, x], 2274 [i.get_shape(), ragged_tensor.RaggedTensorSpec([None, None], 2275 dtypes.int32)]) 2276 check_shapes(r, values=[None], splits=[None]) 2277 2278 # Explicit shape invariant; b2 adds new rows. 2279 _, r = control_flow_ops.while_loop( 2280 c, b2, [i, x], 2281 [i.get_shape(), ragged_tensor.RaggedTensorSpec([None, None], 2282 dtypes.int32)]) 2283 check_shapes(r, values=[None], splits=[None]) 2284 2285 def testWhileShapeInferenceRaggedTensorRaggedRank2(self): 2286 i = constant_op.constant(0) 2287 x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]], 2288 [[], [8, 9, 10]]]) 2289 c = lambda i, _: i < 10 2290 def b(i, x): 2291 return [ 2292 i + 1, 2293 array_ops.concat([x, x[..., i:i+1]], axis=-1) 2294 ] 2295 _, r = control_flow_ops.while_loop(c, b, [i, x]) 2296 self.assertEqual(r.row_splits.shape.as_list(), [3]) 2297 self.assertTrue(r.values.row_splits.shape.as_list() in ([6], [None])) 2298 self.assertTrue(r.values.values.shape.as_list() in ([49], [None])) 2299 2300 def testWhileShapeInvariantTensorSpec(self): 2301 i = constant_op.constant(0) 2302 x = constant_op.constant([1]) 2303 c = lambda i, _: i < 10 2304 b = lambda i, x: (i + 1, array_ops.stack([x, x])) 2305 shape_invariants = [ 2306 tensor_spec.TensorSpec([], dtype=dtypes.int32), 2307 tensor_spec.TensorSpec(None, dtype=dtypes.int32)] 2308 control_flow_ops.while_loop(c, b, [i, x], shape_invariants) 2309 2310 # TODO(b/131265085) Remove this decorator when bug is fixed. 2311 @test_util.build_as_function_and_v1_graph 2312 def testWhileShapeInvariantWrongTypeSpecType(self): 2313 c = lambda i, _: i < 10 2314 b = lambda i, x: (i + 1, x) 2315 i = constant_op.constant(0) 2316 x = sparse_tensor.SparseTensor([[0]], [1.0], [10]) 2317 shape_invariants = [ 2318 tensor_spec.TensorSpec([], dtype=dtypes.int32), 2319 sparse_tensor.SparseTensorSpec([None])] 2320 control_flow_ops.while_loop(c, b, [i, x], shape_invariants) 2321 2322 x2 = constant_op.constant([1]) 2323 with self.assertRaises(TypeError): 2324 control_flow_ops.while_loop(c, b, [i, x2], shape_invariants) 2325 2326 x3 = ragged_factory_ops.constant([[1, 2], [3]]) 2327 with self.assertRaises(TypeError): 2328 control_flow_ops.while_loop(c, b, [i, x3], shape_invariants) 2329 2330 i2 = constant_op.constant(0.0) 2331 with self.assertRaises(TypeError): 2332 control_flow_ops.while_loop(c, b, [i2, x], shape_invariants) 2333 2334 # TODO(b/131265085) Remove this decorator when bug is fixed. 2335 @test_util.build_as_function_and_v1_graph 2336 def testWhileShapeInvariantBadType(self): 2337 i = constant_op.constant(0) 2338 x = constant_op.constant([1]) 2339 c = lambda i, _: i < 10 2340 b = lambda i, x: (i + 1, x) 2341 with self.assertRaises((ValueError, TypeError)): 2342 control_flow_ops.while_loop(c, b, [i, x], ["foo", "bar"]) 2343 2344 def _testNestedWhile_1(self, use_gpu): 2345 with self.cached_session(use_gpu=use_gpu): 2346 n = constant_op.constant(0) 2347 2348 def cpu_sum(s): 2349 c = lambda i, s: math_ops.less(i, 10) 2350 2351 def b(i, s): 2352 i1 = math_ops.add(i, 1) 2353 with ops.device("/cpu:0"): 2354 s1 = math_ops.add(i, s) 2355 return i1, s1 2356 2357 _, r_s = control_flow_ops.while_loop(c, b, [n, s]) 2358 return r_s 2359 2360 c = lambda x: math_ops.less(x, 200) 2361 b = lambda x: math_ops.add(x, cpu_sum(n)) 2362 r = control_flow_ops.while_loop(c, b, [n]) 2363 self.assertEqual(225, self.evaluate(r)) 2364 2365 def testNestedWhile_1(self): 2366 self._testNestedWhile_1(use_gpu=False) 2367 self._testNestedWhile_1(use_gpu=True) 2368 2369 def _testNestedWhile_2(self, use_gpu): 2370 # Test the cases that A -> Enter and Exit -> A are partitioned. 2371 with self.cached_session(use_gpu=use_gpu): 2372 s0 = constant_op.constant(2.0) 2373 2374 def inner_loop(s): 2375 c = lambda s: math_ops.less(s, 20.0) 2376 2377 def b(s): 2378 s1 = math_ops.add(s, s) 2379 return s1 2380 2381 r_s = control_flow_ops.while_loop(c, b, [s], parallel_iterations=1) 2382 return r_s 2383 2384 outer_c = lambda x: math_ops.less(x, 3000.0) 2385 2386 def outer_b(x): 2387 x = logging_ops.Print(x, [x]) # Edge "Print -> Enter" is partitioned 2388 x = inner_loop(x) 2389 with ops.device("/cpu:0"): 2390 x = math_ops.square(x) # Edge "Exit -> Square" is partitioned 2391 return x 2392 2393 r = control_flow_ops.while_loop( 2394 outer_c, outer_b, [s0], parallel_iterations=1) 2395 self.assertEqual(1048576.0, self.evaluate(r)) 2396 2397 def testNestedWhile_2(self): 2398 self._testNestedWhile_2(use_gpu=False) 2399 self._testNestedWhile_2(use_gpu=True) 2400 2401 @test_util.run_v1_only("b/120545219") 2402 def testWhileWithControl_1(self): 2403 with self.cached_session(): 2404 n = constant_op.constant(0) 2405 r = constant_op.constant(0) 2406 condition = lambda n_, r_: math_ops.less(n_, 10) 2407 2408 def body(n_, r_): 2409 n_ = math_ops.add(n_, 1) 2410 with r_.graph.control_dependencies([r_]): 2411 r_ = constant_op.constant(12) 2412 return [n_, r_] 2413 2414 res = control_flow_ops.while_loop( 2415 condition, body, [n, r], parallel_iterations=1) 2416 self.assertAllEqual(12, res[1]) 2417 2418 @test_util.run_deprecated_v1 2419 def testWhileWithControl_2(self): 2420 with self.cached_session(): 2421 r = constant_op.constant(0) 2422 condition = lambda r_: math_ops.less(r_, 10) 2423 2424 def body(r_): 2425 with r_.graph.control_dependencies([r_]): 2426 r_ = constant_op.constant(12) 2427 return [r_] 2428 2429 res = control_flow_ops.while_loop( 2430 condition, body, [r], parallel_iterations=1) 2431 self.assertAllEqual(12, self.evaluate(res)) 2432 2433 @test_util.run_v1_only("b/120545219") 2434 def testWhileWithControl_3(self): 2435 with self.cached_session() as sess: 2436 b = array_ops.placeholder(dtypes.bool) 2437 c = constant_op.constant(1) 2438 x0 = constant_op.constant(0) 2439 with ops.control_dependencies([b]): 2440 r = control_flow_ops.while_loop(lambda x: x < 10, lambda x: x + c, [x0]) 2441 self.assertEqual(10, sess.run(r, {b: True})) 2442 2443 @test_util.run_v1_only("b/120545219") 2444 def testWhileWithControl_4(self): 2445 with self.cached_session() as sess: 2446 b = array_ops.placeholder(dtypes.bool) 2447 c = constant_op.constant(1) 2448 x0 = constant_op.constant(0) 2449 with ops.control_dependencies([b]): 2450 r = control_flow_ops.while_loop( 2451 lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0]) 2452 self.assertEqual(10, sess.run(r, {b: True})) 2453 2454 @test_util.run_v1_only("b/120545219") 2455 def testWhileWithControl_5(self): 2456 with self.cached_session() as sess: 2457 b = array_ops.placeholder(dtypes.bool) 2458 c = constant_op.constant(1) 2459 x0 = constant_op.constant(0) 2460 2461 def body(x): 2462 with ops.control_dependencies([b]): 2463 return x + c 2464 2465 r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0]) 2466 self.assertEqual(10, sess.run(r, {b: True})) 2467 2468 def testWhileCondWithControl(self): 2469 # Ensure that no control edges by an outer control dependency context are 2470 # added to nodes inside cond/while contexts. 2471 with self.cached_session() as sess: 2472 const_true = lambda: constant_op.constant(True) 2473 const_false = lambda: constant_op.constant(False) 2474 cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false) 2475 body = lambda i: control_flow_ops.cond(i > 0, lambda: i - 1, lambda: i) 2476 2477 with ops.control_dependencies([control_flow_ops.no_op()]): 2478 loop = control_flow_ops.while_loop(cond, body, 2479 (constant_op.constant(5),)) 2480 self.assertEqual(0, self.evaluate(loop)) 2481 2482 @test_util.disable_control_flow_v2("b/113324949 (ref vars)") 2483 @test_util.run_v1_only("b/120545219") 2484 def testWhileCondWithControl_1(self): 2485 with self.cached_session(): 2486 v = variable_scope.get_variable( 2487 "v", [], initializer=init_ops.constant_initializer(2)) 2488 i0 = constant_op.constant(0) 2489 with ops.control_dependencies([i0]): 2490 2491 def loop_condition(i): 2492 return i < 4 2493 2494 def loop_body(i): 2495 some_cond = control_flow_ops.cond( 2496 constant_op.constant(True), 2497 lambda: state_ops.assign(v, math_ops.square(v)), lambda: v) 2498 with ops.control_dependencies([some_cond]): 2499 return i + 1 2500 2501 r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,)) 2502 self.evaluate(variables.global_variables_initializer()) 2503 self.assertEqual(4, self.evaluate(r)) 2504 self.assertAllClose(65536.0, self.evaluate(v)) 2505 2506 @test_util.disable_control_flow_v2("b/113324949 (ref vars)") 2507 @test_util.run_v1_only("b/120545219") 2508 def testWhileCondExitControl(self): 2509 2510 with self.cached_session(): 2511 v = variables.Variable(1) 2512 2513 def false_branch(): 2514 cond = lambda i: i < 100 2515 2516 def body(i): 2517 x = state_ops.assign(v, i) 2518 return x + 1 2519 2520 loop = control_flow_ops.while_loop(cond, body, [0]) 2521 # Make sure to handle correctly control edge from Exit to a node. 2522 with ops.control_dependencies([loop]): 2523 return constant_op.constant(6.0) 2524 2525 r = control_flow_ops.cond( 2526 constant_op.constant(False), lambda: constant_op.constant(1.0), 2527 false_branch) 2528 self.evaluate(variables.global_variables_initializer()) 2529 self.assertEqual(6.0, self.evaluate(r)) 2530 self.assertEqual(99, self.evaluate(v)) 2531 2532 def testCondWhile_1(self): 2533 2534 with self.cached_session(): 2535 n = ops.convert_to_tensor(0, name="n") 2536 c = lambda x: math_ops.less(x, 10) 2537 b = lambda x: math_ops.add(x, 1) 2538 r = control_flow_ops.cond( 2539 math_ops.less(0, 1), lambda: control_flow_ops.while_loop(c, b, [n]), 2540 lambda: n) 2541 self.assertAllEqual(10, self.evaluate(r)) 2542 2543 def testCondWhile_2(self): 2544 2545 with self.cached_session(): 2546 n = ops.convert_to_tensor(0) 2547 c = lambda x: math_ops.less(x, 10) 2548 b = lambda x: math_ops.add(x, 1) 2549 r = control_flow_ops.cond( 2550 math_ops.less(1, 0), lambda: math_ops.add(n, 1), 2551 lambda: control_flow_ops.while_loop(c, b, [n])) 2552 self.assertAllEqual(10, self.evaluate(r)) 2553 2554 def _testCondWhile_3(self, use_gpu): 2555 with self.cached_session(use_gpu=use_gpu) as sess: 2556 p = array_ops.placeholder(dtypes.bool) 2557 n = constant_op.constant(0.0) 2558 2559 def c(x): 2560 return math_ops.less(x, 10.0) 2561 2562 def b(x): 2563 with ops.device("/cpu:0"): 2564 x1 = math_ops.add(x, 1.0) 2565 return x1 2566 2567 r = control_flow_ops.cond(p, 2568 lambda: control_flow_ops.while_loop(c, b, [n]), 2569 lambda: math_ops.multiply(n, 2.0)) 2570 r1 = gradients_impl.gradients(r, [n]) 2571 self.assertEqual(10., sess.run(r, {p: True})) 2572 self.assertEqual([1.0], sess.run(r1, {p: True})) 2573 self.assertEqual(0.0, sess.run(r, {p: False})) 2574 self.assertEqual([2.0], sess.run(r1, {p: False})) 2575 2576 @test_util.run_deprecated_v1 2577 def testCondWhile_3(self): 2578 self._testCondWhile_3(use_gpu=False) 2579 self._testCondWhile_3(use_gpu=True) 2580 2581 def testWhileCond_1(self): 2582 2583 with self.cached_session(): 2584 i = ops.convert_to_tensor(0, name="i") 2585 n = ops.convert_to_tensor(10, name="n") 2586 one = ops.convert_to_tensor(1, name="one") 2587 c = lambda x: math_ops.less(x, n) 2588 # pylint: disable=undefined-variable 2589 # for OSS build 2590 b = lambda x: control_flow_ops.cond( 2591 constant_op.constant(True), 2592 lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one)) 2593 # pylint: enable=undefined-variable 2594 r = control_flow_ops.while_loop(c, b, [i]) 2595 self.assertAllEqual(10, self.evaluate(r)) 2596 2597 def testWhileCond_2(self): 2598 2599 with self.cached_session(): 2600 n = ops.convert_to_tensor(0, name="n") 2601 c = lambda x: math_ops.less(x, 10) 2602 b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n) 2603 r = control_flow_ops.while_loop(c, b, [n]) 2604 self.assertAllEqual(10, self.evaluate(r)) 2605 2606 def testWhileCond_3(self): 2607 2608 with self.cached_session(): 2609 n = ops.convert_to_tensor(0) 2610 c = lambda x: math_ops.less(x, 10) 2611 # pylint: disable=undefined-variable 2612 # for OSS build 2613 b = lambda x: control_flow_ops.cond(math_ops.less(0, 1), 2614 lambda: math_ops.add(x, 1), 2615 lambda: math_ops.subtract(x, 1)) 2616 # pylint: enable=undefined-variable 2617 r = control_flow_ops.while_loop(c, b, [n]) 2618 self.assertAllEqual(10, self.evaluate(r)) 2619 2620 @test_util.run_deprecated_v1 2621 def testWhileCondGradMultiDevice(self): 2622 config = config_pb2.ConfigProto(device_count={"CPU": 2}, 2623 allow_soft_placement=True) 2624 with self.cached_session(config=config) as sess: 2625 pred = array_ops.placeholder(dtypes.bool, []) 2626 x_init = constant_op.constant(1.0) 2627 2628 with ops.device("/cpu:0"): 2629 z = control_flow_ops.while_loop( 2630 lambda i, _: i < 3, 2631 lambda i, x: (i + 1, control_flow_ops.cond( 2632 pred, lambda: x * 2.0, lambda: 10.0)), 2633 [0, x_init]) 2634 2635 with ops.device("/cpu:1"): 2636 grad = gradients_impl.gradients(z, x_init)[0] 2637 2638 with ops.device("/cpu:0"): 2639 grad_grad = gradients_impl.gradients(grad, x_init)[0] 2640 2641 self.assertEqual(sess.run(grad, {pred: True}), 8.0) 2642 self.assertEqual(sess.run(grad, {pred: False}), 0.0) 2643 2644 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 2645 return 2646 2647 self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0) 2648 self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0) 2649 2650 # NOTE: It is ok to have parallel_iterations > 1 2651 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2652 @test_util.run_deprecated_v1 2653 def testWhileUpdateVariable_1(self): 2654 with self.cached_session(): 2655 select = variables.Variable([3.0, 4.0, 5.0]) 2656 n = constant_op.constant(0) 2657 2658 def loop_iterator(j): 2659 return math_ops.less(j, 3) 2660 2661 def loop_body(j): 2662 ns = state_ops.scatter_update(select, j, 10.0) 2663 nj = math_ops.add(j, 1) 2664 op = control_flow_ops.group(ns) 2665 nj = control_flow_ops.with_dependencies([op], nj) 2666 return [nj] 2667 2668 r = control_flow_ops.while_loop( 2669 loop_iterator, loop_body, [n], parallel_iterations=1) 2670 self.evaluate(variables.global_variables_initializer()) 2671 self.assertEqual(3, self.evaluate(r)) 2672 result = self.evaluate(select) 2673 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 2674 2675 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2676 @test_util.run_v1_only("b/120545219") 2677 def testWhileUpdateVariable_2(self): 2678 with self.cached_session(): 2679 select1 = variables.Variable([3.0, 4.0, 5.0]) 2680 select2 = variables.Variable([3.0, 4.0, 5.0]) 2681 n = constant_op.constant(0) 2682 2683 def loop_iterator(j): 2684 return math_ops.less(j, 3) 2685 2686 def loop_body(j): 2687 ns1 = state_ops.scatter_update(select1, j, 10.0) 2688 ns2 = state_ops.scatter_update(select2, j, 10.0) 2689 nj = math_ops.add(j, 1) 2690 op = control_flow_ops.group(ns1, ns2) 2691 nj = control_flow_ops.with_dependencies([op], nj) 2692 return [nj] 2693 2694 r = control_flow_ops.while_loop( 2695 loop_iterator, loop_body, [n], parallel_iterations=1) 2696 self.evaluate(variables.global_variables_initializer()) 2697 self.assertEqual(3, self.evaluate(r)) 2698 result1 = self.evaluate(select1) 2699 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1) 2700 result2 = self.evaluate(select2) 2701 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2) 2702 2703 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2704 @test_util.run_v1_only("b/120545219") 2705 def testWhileUpdateVariable_3(self): 2706 with self.cached_session(): 2707 select = variables.Variable([3.0, 4.0, 5.0]) 2708 n = constant_op.constant(0) 2709 2710 def loop_iterator(j, _): 2711 return math_ops.less(j, 3) 2712 2713 def loop_body(j, _): 2714 ns = state_ops.scatter_update(select, j, 10.0) 2715 nj = math_ops.add(j, 1) 2716 return [nj, ns] 2717 2718 r = control_flow_ops.while_loop( 2719 loop_iterator, 2720 loop_body, [n, array_ops.identity(select)], 2721 parallel_iterations=1) 2722 self.evaluate(variables.global_variables_initializer()) 2723 result = r[1] 2724 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 2725 2726 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2727 @test_util.run_v1_only("b/120545219") 2728 def testWhileUpdateVariable_4(self): 2729 with self.cached_session(): 2730 var_a = variables.Variable(0, name="a") 2731 var_b = variables.Variable(0, name="b") 2732 self.evaluate(variables.global_variables_initializer()) 2733 2734 c = constant_op.constant(0, name="c") 2735 asn1 = state_ops.assign_add(var_a, 1, name="a_add") 2736 2737 # Loop condition 2738 def pred(i): 2739 return math_ops.less(i, 10) 2740 2741 # Loop body 2742 def loop_body(i): 2743 asn2 = state_ops.assign_add(var_b, asn1, name="b_add") 2744 with ops.control_dependencies([asn2]): 2745 ni = math_ops.add(i, 1, name="i_add") 2746 return ni 2747 2748 lpa = control_flow_ops.while_loop( 2749 pred, loop_body, [c], parallel_iterations=1) 2750 2751 self.assertEqual(0, self.evaluate(var_b)) 2752 self.evaluate(lpa) # Run the loop 2753 self.assertEqual(10, self.evaluate(var_b)) 2754 2755 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2756 @test_util.run_v1_only("b/120545219") 2757 def testWhileUpdateVariable_5(self): 2758 with self.cached_session(): 2759 # Create some variables. 2760 var_a = variables.Variable(0, name="a") 2761 var_b = variables.Variable(0, name="b") 2762 self.evaluate(variables.global_variables_initializer()) 2763 2764 # Change condition to check var_b 2765 def pred(_): 2766 return math_ops.less(var_b, 10) 2767 2768 # Change body to increment var_b 2769 def loop_body(i): 2770 asn1 = state_ops.assign_add( 2771 var_a, constant_op.constant(1), name="a_add") 2772 asn2 = state_ops.assign_add( 2773 var_b, constant_op.constant(1), name="b_add") 2774 with ops.control_dependencies([asn1, asn2]): 2775 inc_b = array_ops.identity(var_b) 2776 return inc_b 2777 2778 lpa = control_flow_ops.while_loop( 2779 pred, loop_body, [var_b], parallel_iterations=1, name="loop") 2780 2781 self.assertEqual(0, self.evaluate(var_b)) 2782 self.evaluate(lpa) # Run the loop 2783 self.assertEqual(10, self.evaluate(var_a)) 2784 self.assertEqual(10, self.evaluate(var_b)) 2785 2786 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2787 @test_util.run_v1_only("b/120545219") 2788 def testWhileUpdateVariable_6(self): 2789 with self.cached_session(): 2790 # Create some variables. 2791 var_a = variables.Variable(0, name="a") 2792 var_b = variables.Variable(0, name="b") 2793 c = constant_op.constant(0) 2794 self.evaluate(variables.global_variables_initializer()) 2795 2796 # Loop condition 2797 def pred(i): 2798 return math_ops.less(i, 10) 2799 2800 # Loop body 2801 def loop_body(i): 2802 asn1 = state_ops.assign_add(var_a, 1, name="a_add") 2803 with ops.control_dependencies([asn1]): 2804 asn2 = state_ops.assign_add(var_b, var_a, name="b_add") 2805 with ops.control_dependencies([asn2]): 2806 ni = math_ops.add(i, 1, name="i_add") 2807 return ni 2808 2809 lpa = control_flow_ops.while_loop( 2810 pred, loop_body, [c], parallel_iterations=1, name="loop") 2811 2812 self.assertEqual(0, self.evaluate(var_b)) 2813 self.evaluate(lpa) # Run the loop 2814 self.assertEqual(55, self.evaluate(var_b)) 2815 self.assertEqual(10, self.evaluate(var_a)) 2816 2817 @test_util.run_v1_only("b/120545219") 2818 def testWhileQueue_1(self): 2819 with self.cached_session(): 2820 q = data_flow_ops.FIFOQueue(-1, dtypes.int32) 2821 i = constant_op.constant(0) 2822 2823 def c(i): 2824 return math_ops.less(i, 10) 2825 2826 def b(i): 2827 ni = math_ops.add(i, 1) 2828 ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni) 2829 return ni 2830 2831 r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) 2832 self.assertEqual([10], self.evaluate(r)) 2833 for i in xrange(10): 2834 self.assertEqual([i], self.evaluate(q.dequeue())) 2835 2836 @test_util.run_v1_only("b/120545219") 2837 def testWhileTimeOut(self): 2838 run_options = config_pb2.RunOptions(timeout_in_ms=1) 2839 with self.cached_session() as sess: 2840 n = constant_op.constant(0) 2841 c = lambda x: True 2842 b = lambda x: math_ops.add(x, 1) 2843 r = control_flow_ops.while_loop(c, b, [n]) 2844 with self.assertRaises(errors_impl.DeadlineExceededError): 2845 sess.run(r, options=run_options) 2846 2847 @test_util.disable_control_flow_v2("b/117119329 (stack)") 2848 @test_util.run_v1_only("b/120545219") 2849 def testWhileStack_1(self): 2850 with self.cached_session(): 2851 s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo") 2852 i = constant_op.constant(0) 2853 2854 def c(i): 2855 return math_ops.less(i, 10) 2856 2857 def b(i): 2858 ni = math_ops.add(i, 1) 2859 ni = control_flow_ops.with_dependencies( 2860 [gen_data_flow_ops.stack_push_v2(s, i)], ni) 2861 return ni 2862 2863 r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) 2864 2865 x = constant_op.constant(0) 2866 2867 def c1(i, _): 2868 return math_ops.greater(i, 0) 2869 2870 def b1(i, x): 2871 ni = math_ops.subtract(i, 1) 2872 nx = x + gen_data_flow_ops.stack_pop_v2(s, dtypes.int32) 2873 return [ni, nx] 2874 2875 _, rx = control_flow_ops.while_loop( 2876 c1, 2877 b1, [r, x], 2878 [r.get_shape(), tensor_shape.unknown_shape()], 2879 parallel_iterations=1) 2880 self.assertEqual(45, self.evaluate(rx)) 2881 2882 def _testWhileGrad_ColocateGradients(self, colocate): 2883 gpu_dev_name = test.gpu_device_name() if test.is_gpu_available( 2884 ) else "/device:CPU:0" 2885 2886 graph = ops.Graph() 2887 with graph.as_default(): 2888 v = constant_op.constant(2.0, name="v") 2889 c = lambda v: math_ops.less(v, 100.0) 2890 2891 def b(x): 2892 with ops.device(gpu_dev_name): 2893 return math_ops.square(x) 2894 2895 loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2896 r = gradients_impl.gradients( 2897 loop, v, colocate_gradients_with_ops=colocate)[0] 2898 2899 r_ops = graph.get_operations() 2900 r_devices = [(op.name, op.device) for op in r_ops] 2901 2902 self.assertTrue(any("Square" in op.name for op in r_ops)) 2903 2904 for (name, dev) in r_devices: 2905 if not colocate and name.endswith("Square"): 2906 # Only forward graph contain gpu in Square device 2907 self.assertTrue(gpu_dev_name in dev) 2908 elif colocate and "Square" in name: 2909 # Forward and backward graphs contain gpu in Square/Square_grad devices 2910 self.assertTrue(gpu_dev_name in dev) 2911 else: 2912 self.assertFalse(gpu_dev_name in dev) 2913 2914 with self.session(graph=graph) as sess: 2915 self.assertAllClose(1024.0, self.evaluate(r)) 2916 2917 @test_util.disable_control_flow_v2("b/116351701 (colocation)") 2918 @test_util.run_v1_only("b/120545219") 2919 def testWhileGrad_ColocateGradients(self): 2920 self._testWhileGrad_ColocateGradients(colocate=False) 2921 self._testWhileGrad_ColocateGradients(colocate=True) 2922 2923 @test_util.run_v1_only("b/120545219") 2924 def testWhileGrad_Square(self): 2925 with self.cached_session(): 2926 v = constant_op.constant(2.0, name="v") 2927 c = lambda v: math_ops.less(v, 100.0) 2928 b = math_ops.square 2929 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2930 r = control_flow_ops.cond(math_ops.less(1, 2), lambda: r, lambda: v) 2931 2932 r = gradients_impl.gradients(r, v)[0] 2933 self.assertAllClose(1024.0, self.evaluate(r)) 2934 2935 @test_util.run_v1_only("b/120545219") 2936 def testWhileGrad_Shape(self): 2937 with self.cached_session(): 2938 x = array_ops.placeholder(dtypes.float32, shape=[None]) 2939 v = constant_op.constant([2.0], name="v") 2940 n = constant_op.constant(0, name="n") 2941 c = lambda i, v: math_ops.less(i, 5) 2942 b = lambda i, v: [i + 1, math_ops.multiply(x, v)] 2943 r = control_flow_ops.while_loop( 2944 c, 2945 b, [n, v], 2946 [n.get_shape(), tensor_shape.unknown_shape()], 2947 parallel_iterations=1) 2948 2949 r = gradients_impl.gradients(r[1], x)[0] 2950 self.assertEqual([None], r.get_shape().as_list()) 2951 self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]})) 2952 2953 @test_util.run_deprecated_v1 2954 def testWhileGrad_BaseShape(self): 2955 with self.cached_session() as sess: 2956 x = array_ops.placeholder(dtypes.float32, [None]) 2957 v0 = constant_op.constant([2.0, 2.0], name="v") 2958 c = lambda v: constant_op.constant(False) 2959 b = lambda v: math_ops.multiply(v, x) 2960 r = control_flow_ops.while_loop(c, b, [v0]) 2961 y = math_ops.square(x) 2962 2963 r = gradients_impl.gradients([r, y], x)[0] 2964 self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]})) 2965 2966 @test_util.run_deprecated_v1 2967 @test_util.enable_output_all_intermediates 2968 def testWhileGradAfterSessionRun(self): 2969 v0 = constant_op.constant(2.) 2970 r = control_flow_ops.while_loop( 2971 lambda _: True, lambda v: v * v, [v0], maximum_iterations=3) 2972 2973 self.assertAllEqual(r, 256.) 2974 grad = gradients_impl.gradients(r, v0)[0] 2975 self.assertAllClose(grad, 1024.) 2976 2977 @test_util.run_deprecated_v1 2978 @test_util.enable_output_all_intermediates 2979 def testNestedWhileGradAfterSessionRun(self): 2980 v0 = constant_op.constant(2.) 2981 2982 def body(v): 2983 inner_v0 = constant_op.constant(1.) 2984 return control_flow_ops.while_loop( 2985 lambda _: True, lambda x: x * v, [inner_v0], maximum_iterations=2) 2986 2987 r = control_flow_ops.while_loop( 2988 lambda _: True, body, [v0], maximum_iterations=3) 2989 2990 self.assertAllEqual(r, 256.) 2991 grad = gradients_impl.gradients(r, v0)[0] 2992 self.assertAllClose(grad, 1024.) 2993 2994 @test_util.run_v1_only("b/120545219") 2995 def testWhileGrad_MultipleUses(self): 2996 with self.cached_session(): 2997 v = constant_op.constant(2.0, name="v") 2998 c = lambda v: math_ops.less(v, 100.0) 2999 b = math_ops.square 3000 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 3001 r = math_ops.multiply(r, r) 3002 3003 r = gradients_impl.gradients(r, v)[0] 3004 self.assertEqual(524288.0, self.evaluate(r)) 3005 3006 @test_util.run_v1_only("b/120545219") 3007 def testWhileGrad_LoopAdd(self): 3008 with self.cached_session(): 3009 v = constant_op.constant(2.0, name="v") 3010 c = lambda v: math_ops.less(v, 100.0) 3011 b = math_ops.square 3012 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 3013 r = math_ops.add(r, r) 3014 3015 r = gradients_impl.gradients(r, v)[0] 3016 self.assertAllClose(2048.0, self.evaluate(r)) 3017 3018 def _testWhileGrad_Mul(self, use_gpu, p_iters): 3019 with self.cached_session(use_gpu=use_gpu) as sess: 3020 a = constant_op.constant(3.0, name="a") 3021 v = constant_op.constant(2.0, name="v") 3022 c = lambda v: math_ops.less(v, 100.0) 3023 b = lambda v: math_ops.multiply(v, a) 3024 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters) 3025 3026 grad_a, grad_v = gradients_impl.gradients(r, [a, v]) 3027 grad_a_val, grad_v_val = self.evaluate([grad_a, grad_v]) 3028 self.assertAllClose(216.0, grad_a_val) 3029 self.assertAllClose(81.0, grad_v_val) 3030 3031 @test_util.run_deprecated_v1 3032 def testWhileGrad_Mul(self): 3033 self._testWhileGrad_Mul(use_gpu=False, p_iters=1) 3034 self._testWhileGrad_Mul(use_gpu=False, p_iters=10) 3035 self._testWhileGrad_Mul(use_gpu=True, p_iters=1) 3036 self._testWhileGrad_Mul(use_gpu=True, p_iters=10) 3037 3038 def testWhileGradInControlDeps(self): 3039 3040 @def_function.function 3041 def f(): 3042 x_init = constant_op.constant(2.) 3043 loop_cond = lambda i, x: math_ops.less(i, 2) 3044 loop_body = lambda i, x: [i + 1, x**2] 3045 _, x = control_flow_ops.while_loop(loop_cond, loop_body, [0, x_init]) 3046 with ops.control_dependencies([x]): 3047 (grad,) = gradients_impl.gradients(x, x_init) 3048 return grad 3049 3050 self.assertAllEqual(f(), 4. * 2.**3) # 4 * x_init ^ 3 3051 3052 @test_util.run_deprecated_v1 3053 def testTfFunctionInV1WhileLoop(self): 3054 3055 # This test specifically tests that creating a Const node inside a 3056 # tf.function inside a v1 while_loop while inlining is turned on works. 3057 config = opt_cfg() 3058 assert config.graph_options.optimizer_options.do_function_inlining 3059 with session.Session(config=config): 3060 3061 @def_function.function 3062 def loop_body(i): 3063 # Here we create the const. 3064 return i + 1. 3065 3066 loop_cond = lambda i: True 3067 x = control_flow_ops.while_loop( 3068 loop_cond, loop_body, [0.], maximum_iterations=5) 3069 self.assertAllEqual(x, 5.) 3070 3071 def _testNestedWhileCondWhileGrad(self, use_gpu): 3072 3073 with self.cached_session(use_gpu=use_gpu): 3074 v = constant_op.constant(1.0) 3075 3076 def inner_loop(s): 3077 z = constant_op.constant(0) 3078 c = lambda i, x: math_ops.less(i, 4) 3079 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3080 return control_flow_ops.while_loop(c, b, [z, s]) 3081 3082 c = lambda x: math_ops.less(x, 128.0) 3083 3084 def b(x): 3085 return control_flow_ops.cond( 3086 constant_op.constant(True), 3087 lambda: math_ops.square(inner_loop(x)[1]), 3088 lambda: math_ops.multiply(x, 2.0)) 3089 3090 r = control_flow_ops.while_loop(c, b, [v]) 3091 r = gradients_impl.gradients(r, v)[0] 3092 self.assertAllClose(512.0, self.evaluate(r)) 3093 3094 @test_util.run_deprecated_v1 3095 def testNestedWhileCondWhileGrad(self): 3096 self._testNestedWhileCondWhileGrad(use_gpu=False) 3097 3098 @test_util.run_deprecated_v1 3099 def testNestedWhileCondWhileGradGpu(self): 3100 self._testNestedWhileCondWhileGrad(use_gpu=True) 3101 3102 @test_util.run_v1_only("b/120545219") 3103 def testWhileGrad_Variable(self): 3104 with self.cached_session(): 3105 a = variables.Variable(3.0) 3106 v = constant_op.constant(2.0, name="v") 3107 c = lambda v: math_ops.less(v, 100.0) 3108 b = lambda v: math_ops.multiply(v, a) 3109 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 3110 3111 r = gradients_impl.gradients(r, a) 3112 self.evaluate(variables.global_variables_initializer()) 3113 self.assertAllClose(216.0, r[0]) 3114 3115 @test_util.run_deprecated_v1 3116 def testWhileGrad_ResourceVariable(self): 3117 with self.cached_session(): 3118 a = resource_variable_ops.ResourceVariable(3.0) 3119 v = constant_op.constant(2.0, name="v") 3120 c = lambda v: math_ops.less(v, 100.0) 3121 b = lambda v: math_ops.multiply(v, a) 3122 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 3123 3124 g = gradients_impl.gradients(r, a) 3125 self.evaluate(variables.global_variables_initializer()) 3126 self.assertAllClose(216.0, g[0]) 3127 3128 def testWhileGrad_EagerResourceVariable(self): 3129 with context.eager_mode(): 3130 a = resource_variable_ops.ResourceVariable( 3131 np.ones([2, 2], dtype=np.float32)) 3132 v = constant_op.constant(1.0) 3133 3134 @eager_function.defun 3135 def fn(): 3136 r = control_flow_ops.while_loop( 3137 lambda i, _: i < 2, 3138 lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v), 3139 [0, 1.0])[1] 3140 return gradients_impl.gradients(r, [v])[0] 3141 3142 self.assertEqual(self.evaluate(fn()), 32.) 3143 3144 def testWhileGrad_ResourceVarInFunctionCall(self): 3145 3146 @def_function.function 3147 def foo(x, var): 3148 return x + math_ops.reduce_sum(var.sparse_read([1, 3])) 3149 3150 @def_function.function 3151 def bar(var): 3152 r = control_flow_ops.while_loop( 3153 lambda i, _: i < 2, 3154 lambda i, x: (i + 1, foo(x, var)), 3155 [0, 0.0])[1] 3156 return gradients_impl.gradients(r, var)[0] 3157 3158 var = resource_variable_ops.ResourceVariable([1., 2., 3., 4.]) 3159 self.evaluate(variables.global_variables_initializer()) 3160 grad = self.evaluate(bar(var)) 3161 self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.]) 3162 3163 def testWhileGrad_ResourceVarInNestedFunctionCall(self): 3164 3165 @def_function.function 3166 def foo(x, var): 3167 return x + math_ops.reduce_sum(var.sparse_read([1, 3])) 3168 3169 @def_function.function 3170 def foo2(x, var): 3171 return foo(x, var) 3172 3173 @def_function.function 3174 def bar(var): 3175 r = control_flow_ops.while_loop( 3176 lambda i, _: i < 2, 3177 lambda i, x: (i + 1, foo2(x, var)), 3178 [0, 0.0])[1] 3179 return gradients_impl.gradients(r, var)[0] 3180 3181 var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.]) 3182 self.evaluate(variables.global_variables_initializer()) 3183 grad = self.evaluate(bar(var)) 3184 self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.]) 3185 3186 def testWhileGrad_ResourceVarInLoopInFunctionCall(self): 3187 if test.is_gpu_available(): 3188 self.skipTest("b/128635252") 3189 3190 @def_function.function 3191 def foo(x, var): 3192 return control_flow_ops.while_loop( 3193 lambda j, _: j < 3, 3194 lambda j, y: (j + 1, 3195 y + math_ops.reduce_sum(var.sparse_read([1, 2]))), 3196 [0, x])[1] 3197 3198 @def_function.function 3199 def bar(var): 3200 r = control_flow_ops.while_loop( 3201 lambda i, _: i < 2, 3202 lambda i, x: (i + 1, foo(x, var)), 3203 [0, 0.0])[1] 3204 return gradients_impl.gradients(r, var)[0] 3205 3206 var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.]) 3207 self.evaluate(variables.global_variables_initializer()) 3208 grad = self.evaluate(bar(var)) 3209 self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 6., 6., 0.]) 3210 3211 def testWhileCondGrad_ResourceVarInFunctionCall(self): 3212 3213 @def_function.function 3214 def foo(x, var): 3215 return x + var.sparse_read([1])[0] 3216 3217 def body(i, x): 3218 return (i + 1, control_flow_ops.cond( 3219 math_ops.equal(i % 2, 0), 3220 lambda: foo(x, var1), 3221 lambda: foo(x, var2))) 3222 3223 @def_function.function 3224 def bar(var1, var2): 3225 r = control_flow_ops.while_loop( 3226 lambda i, _: i < 4, body, [0, 0.0]) 3227 return gradients_impl.gradients(r, [var1, var2]) 3228 3229 var1 = resource_variable_ops.ResourceVariable([1., 2., 3.]) 3230 var2 = resource_variable_ops.ResourceVariable([4., 5.]) 3231 self.evaluate(variables.global_variables_initializer()) 3232 grads = self.evaluate(bar(var1, var2)) 3233 self.assertAllEqual(gradient_checker_v2._to_numpy(grads[0]), [0., 2., 0.]) 3234 self.assertAllEqual(gradient_checker_v2._to_numpy(grads[1]), [0., 2.]) 3235 3236 @test_util.run_deprecated_v1 3237 def testWhileGrad_ResourceVarSparseRead(self): 3238 # NOTE(skyewm): this test is interesting because the gradient is the 3239 # aggregation result of IndexedSlices and Tensors. 3240 var = resource_variable_ops.ResourceVariable(np.ones(5), 3241 dtype=dtypes.float32) 3242 r = control_flow_ops.while_loop( 3243 lambda i, _: i < 3, 3244 lambda i, x: (i + 1, x * math_ops.reduce_sum(var.sparse_read([1, 3]))), 3245 [0, constant_op.constant(1.0)])[1] 3246 grad = gradients_impl.gradients(r, var)[0] 3247 3248 self.evaluate(variables.global_variables_initializer()) 3249 grad_val = self.evaluate(grad) 3250 arr = gradient_checker_v2._to_numpy(grad_val) 3251 self.assertAllEqual(arr, [0., 12., 0., 12., 0.]) 3252 3253 @test_util.run_deprecated_v1 3254 def testWhileGrad_MultiResourceVarSparseRead(self): 3255 # NOTE(skyewm): this test is interesting because the gradient is the 3256 # aggregation result of IndexedSlices and Tensors. 3257 var1 = resource_variable_ops.ResourceVariable(np.ones(5), 3258 dtype=dtypes.float32) 3259 var2 = resource_variable_ops.ResourceVariable(np.ones(3), 3260 dtype=dtypes.float32) 3261 x1_init = constant_op.constant([0., 0.]) 3262 x2_init = constant_op.constant(1.) 3263 x3_init = constant_op.constant(1.) 3264 3265 def body(i, unused_x1, x2, x3): 3266 y1 = var1.sparse_read([1, 3]) 3267 y2 = x2 * 2 3268 y3 = x3 * math_ops.reduce_sum(var2.sparse_read([0])) 3269 return i + 1, y1, y2, y3 3270 3271 r = control_flow_ops.while_loop( 3272 lambda i, x1, x2, x3: i < 3, body, 3273 [0, x1_init, x2_init, x3_init])[1:] 3274 var1_grad, var2_grad = gradients_impl.gradients(r, [var1, var2]) 3275 3276 self.evaluate(variables.global_variables_initializer()) 3277 var1_grad_val = self.evaluate(var1_grad) 3278 var2_grad_val = self.evaluate(var2_grad) 3279 self.assertAllEqual(gradient_checker_v2._to_numpy(var1_grad_val), 3280 [0., 1., 0., 1., 0.]) 3281 self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val), 3282 [3., 0., 0.]) 3283 3284 def testWhileGrad_Gather(self): 3285 # NOTE(skyewm): this test is interesting because the gather gradient 3286 # function returns an IndexedSlices. 3287 @tf_function_in_tf2 3288 def fn(): 3289 x = constant_op.constant([1., 1., 1., 1., 1.]) 3290 y = control_flow_ops.while_loop( 3291 lambda i, _: i < 3, 3292 lambda i, x: (i + 1, x + array_ops.gather(x, [0])), 3293 [0, x[:1]])[1] 3294 z = y * 3.0 3295 grad = gradients_impl.gradients(z, x)[0] 3296 return y, grad 3297 y, grad = fn() 3298 self.assertEqual(self.evaluate(y), 8.) 3299 self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.]) 3300 3301 def testWhileGrad_GatherNoFanOut(self): 3302 # NOTE(skyewm): this test is interesting because the gather gradient 3303 # function returns an IndexedSlices. 3304 @tf_function_in_tf2 3305 def fn(): 3306 x = constant_op.constant([1., 1., 1., 1., 1.]) 3307 y = control_flow_ops.while_loop( 3308 lambda i, _: i < 3, 3309 lambda i, x: (i + 1, array_ops.gather(x, [0])), 3310 [0, x[:1]])[1] 3311 z = y * 3.0 3312 grad = gradients_impl.gradients(z, x)[0] 3313 return y, grad 3314 y, grad = fn() 3315 self.assertEqual(self.evaluate(y), 1.) 3316 self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.]) 3317 3318 @test_util.run_v1_only("b/120545219") 3319 def testWhileGradInCond(self): 3320 3321 with self.cached_session(): 3322 n = ops.convert_to_tensor(1.0, name="n") 3323 x = array_ops.placeholder(dtypes.float32, shape=None) 3324 c = lambda n: math_ops.less(n, 10.0) 3325 b = lambda n: math_ops.add(n, x) 3326 3327 def fn1(): 3328 r = control_flow_ops.while_loop(c, b, [n], 3329 [tensor_shape.unknown_shape()]) 3330 return gradients_impl.gradients(r, x)[0] 3331 3332 r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x) 3333 self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) 3334 3335 @test_util.disable_control_flow_v2("b/116340060") 3336 @test_util.run_v1_only("b/120545219") 3337 def testGradInWhileWrtInitialLoopVal(self): 3338 with self.cached_session(): 3339 x = array_ops.placeholder(dtypes.float32, shape=(), name="x") 3340 y = x + 1 3341 3342 def body(i, v): 3343 z = v * 2 3344 return i + 1, gradients_impl.gradients(z, x)[0] 3345 3346 with self.assertRaisesRegex( 3347 ValueError, 3348 "Cannot compute gradient inside while loop with respect to op 'x'. " 3349 "We do not support taking the gradient wrt or through the initial " 3350 "value of a loop variable. Gradients can be computed through " 3351 "loop invariants or wrt the input parameters to the loop body."): 3352 control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y]) 3353 3354 @test_util.run_v1_only("b/120545219") 3355 def testWhileGradInWhile(self): 3356 with self.cached_session(): 3357 n = ops.convert_to_tensor(1.0, name="n") 3358 x = array_ops.placeholder(dtypes.float32, shape=None) 3359 c = lambda n: math_ops.less(n, 10.0) 3360 b = lambda n: math_ops.add(n, x) 3361 3362 def b1(n): 3363 r = control_flow_ops.while_loop(c, b, [n], 3364 [tensor_shape.unknown_shape()]) 3365 return gradients_impl.gradients(r, x) 3366 3367 r = control_flow_ops.while_loop(lambda n: n < 6.0, b1, [n], 3368 [tensor_shape.unknown_shape()]) 3369 self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) 3370 3371 @test_util.run_v1_only("b/120545219") 3372 def testCondGradInNestedWhiles(self): 3373 3374 def outer_body(i, x): 3375 _, x = control_flow_ops.while_loop( 3376 lambda j, x: j < 3, inner_body, [0, 0.0]) 3377 return i + 1, x 3378 3379 def inner_body(j, x): 3380 y = control_flow_ops.cond(math_ops.less(x, 1), lambda: 2 * x, lambda: x) 3381 return j + 1, gradients_impl.gradients(y, x)[0] 3382 3383 i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0]) 3384 3385 with self.cached_session() as sess: 3386 i_val, x_val = self.evaluate([i, x]) 3387 self.assertEqual(i_val, 3) 3388 self.assertAllClose(x_val, 1.0) 3389 3390 @test_util.run_gpu_only 3391 def testGpuResourceAccess(self): 3392 with ops.device(test.gpu_device_name()): 3393 var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0)) 3394 3395 @def_function.function 3396 def foo(): 3397 return control_flow_ops.while_loop( 3398 lambda i, _: i < 3, 3399 lambda i, x: (i + 1, control_flow_ops.cond( 3400 constant_op.constant(True), 3401 lambda: x + var, 3402 lambda: x)), 3403 [0, 0.0])[1] 3404 3405 self.evaluate(variables.global_variables_initializer()) 3406 self.assertEqual(self.evaluate(foo()), 9.0) 3407 3408 def testNestedResourceAccess(self): 3409 var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0)) 3410 3411 @eager_function.defun 3412 def test_fn(): 3413 x = constant_op.constant(0.0) 3414 r = control_flow_ops.while_loop( 3415 # Outer loop condition 3416 lambda i, y: i < 2, 3417 # Outer loop body 3418 lambda i, y: (i + 1, y + control_flow_ops.cond( 3419 constant_op.constant(True), 3420 # True branch 3421 lambda: control_flow_ops.while_loop( 3422 # Inner loop condition 3423 lambda j, z: j < 3, 3424 # Inner loop body 3425 lambda j, z: (j + 1, z + math_ops.square(var)), 3426 # Inner initial loop value 3427 [0, y])[1], 3428 # False branch 3429 lambda: (0.0))), 3430 # Outer initial loop value 3431 [0, x])[1] 3432 3433 grad = gradients_impl.gradients(r, x)[0] 3434 return r, grad 3435 3436 self.evaluate(variables.global_variables_initializer()) 3437 r, grad = self.evaluate(test_fn()) 3438 # 2 * 3 * 3^2 3439 self.assertEqual(r, 81.0) 3440 # v1 control flow gets the wrong answer!!! 3441 # Gradient computation: 3442 # f(x) = x + 3^2 3443 # inner_loop(x) = f(f(f(x))) = x + 3*3^2 = x + 27 3444 # g(x) = x + inner_loop(x) = 2x + 27 3445 # outer_loop(x) = g(g(x)) = 4x + 81 3446 # outer_loop'(x) = 4 3447 # Note that v1 control flow gets 4.0 as well if the cond is removed. 3448 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 3449 self.assertEqual(grad, 4.0) 3450 3451 def testWhile_NestedInput(self): 3452 with self.cached_session() as sess: 3453 named = collections.namedtuple("named", ("a", "b")) 3454 loop_vars = [ 3455 named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), 3456 (constant_op.constant(2.0), constant_op.constant(3.0)), 3457 constant_op.constant(4.0) 3458 ] 3459 c = lambda lv0, _1, _2: lv0.a < 100.0 3460 3461 def b(lv0, lv1, lv2): 3462 lv0 = named(a=lv0.a + 1, b=lv0.b) 3463 lv1 = (lv1[0] + 1, lv1[1]) 3464 lv2 += 2 3465 return [lv0, lv1, lv2] 3466 3467 r = control_flow_ops.while_loop(c, b, loop_vars) 3468 3469 self.assertTrue(isinstance(r, list)) 3470 self.assertTrue(isinstance(r[0], named)) 3471 self.assertTrue(isinstance(r[1], tuple)) 3472 self.assertTrue(isinstance(r[2], ops.Tensor)) 3473 3474 r_flattened = nest.flatten(r) 3475 self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0], 3476 self.evaluate(r_flattened)) 3477 3478 @test_util.run_v1_only("b/120545219") 3479 def testWhile_NestedBadArityFails(self): 3480 with self.cached_session(): 3481 named = collections.namedtuple("named", ("a", "b")) 3482 loop_vars = [ 3483 named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), 3484 (constant_op.constant(2.0), constant_op.constant(3.0)), 3485 constant_op.constant(4.0) 3486 ] 3487 c = lambda lv0, _1, _2: lv0.a < 100.0 3488 3489 def b(lv0, lv1, _): 3490 return [lv0, lv1] 3491 3492 with self.assertRaisesRegex(ValueError, "the same number of elements"): 3493 control_flow_ops.while_loop(c, b, loop_vars) 3494 3495 @test_util.run_v1_only("b/120545219") 3496 def testWhileGrad_ys_xs(self): 3497 with self.cached_session(): 3498 x = constant_op.constant(3.0, name="x") 3499 y = constant_op.constant(2.0, name="y") 3500 3501 c = lambda x, y: math_ops.less(x, 100.0) 3502 3503 def b(x, y): 3504 y1 = math_ops.add(x, y) 3505 x1 = math_ops.multiply(x, y1) 3506 return x1, y1 3507 3508 rx, ry = control_flow_ops.while_loop(c, b, [x, y], parallel_iterations=1) 3509 3510 r = gradients_impl.gradients([rx, ry], x) 3511 self.assertAllClose(304.0, r[0]) 3512 r = gradients_impl.gradients([rx, ry], y) 3513 self.assertAllClose(124.0, r[0]) 3514 r = gradients_impl.gradients([rx], x) 3515 self.assertAllClose(295.0, r[0]) 3516 r = gradients_impl.gradients([rx], y) 3517 self.assertAllClose(120.0, r[0]) 3518 3519 @test_util.run_deprecated_v1 3520 def testWhileGrad_Dependency(self): 3521 with self.cached_session(): 3522 i = constant_op.constant(0, name="i") 3523 x = constant_op.constant(2.0, name="x") 3524 3525 c = lambda i, x: math_ops.less(i, 10) 3526 3527 def b(i, x): 3528 x = math_ops.multiply(x, 2.0) 3529 i = math_ops.add(i, 1) 3530 return i, x 3531 3532 ri, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3533 3534 r = gradients_impl.gradients([ri, rx], x) 3535 self.assertAllClose(1024.0, r[0]) 3536 r = gradients_impl.gradients([rx], x) 3537 self.assertAllClose(1024.0, r[0]) 3538 3539 @test_util.run_v1_only("b/120545219") 3540 def testWhileGrad_NoGradient(self): 3541 with self.cached_session(): 3542 v = constant_op.constant(2.0, name="v") 3543 c = lambda v: math_ops.less(v, 100.0) 3544 b = math_ops.square 3545 r = control_flow_ops.while_loop(c, b, [v], back_prop=False) 3546 r = math_ops.add(r, v) 3547 r = gradients_impl.gradients(r, v) 3548 self.assertAllClose(1.0, r[0]) 3549 3550 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 3551 @test_util.run_v1_only("b/120545219") 3552 def testWhileGrad_NoDependency(self): 3553 with self.cached_session() as sess: 3554 variable = variables.Variable(array_ops.ones([2, 3])) 3555 duration = array_ops.zeros([], dtype=dtypes.int32) 3556 3557 def cond(duration, tensor, _): 3558 del tensor 3559 return duration < 10 3560 3561 def body(duration, tensor, _): 3562 return (duration + 1, tensor, tensor) 3563 3564 loop_vars = [duration, variable, variable] 3565 tensors = control_flow_ops.while_loop( 3566 cond=cond, body=body, loop_vars=loop_vars) 3567 cost = math_ops.reduce_sum(tensors[2]) 3568 grad = gradients_impl.gradients(cost, [variable]) 3569 self.evaluate(variables.global_variables_initializer()) 3570 self.assertAllClose(np.ones([2, 3]), sess.run(grad[0])) 3571 3572 @test_util.run_deprecated_v1 3573 def testWhileGrad_Const(self): 3574 with self.cached_session() as sess: 3575 c0 = constant_op.constant(0.0, name="c0") 3576 c1 = constant_op.constant(1.0, name="c1") 3577 duration = constant_op.constant(0, name="t") 3578 3579 def cond(duration, _): 3580 return duration < 1 3581 3582 def body(duration, _): 3583 return duration + 1, c1 3584 3585 loop_vars = [duration, c0] 3586 tensors = control_flow_ops.while_loop( 3587 cond=cond, body=body, loop_vars=loop_vars) 3588 cost = math_ops.reduce_sum(tensors[1]) 3589 grad = gradients_impl.gradients(cost, [c0]) 3590 self.assertAllClose(0.0, sess.run(grad[0])) 3591 3592 @test_util.run_v1_only("b/120545219") 3593 def testWhileGrad_SerialTwoLoops(self): 3594 with self.cached_session(): 3595 i = constant_op.constant(0, name="i") 3596 x = constant_op.constant(2.0, name="x") 3597 3598 c = lambda i, x: math_ops.less(i, 5) 3599 3600 def b(i, x): 3601 x = math_ops.multiply(x, 2.0) 3602 i = math_ops.add(i, 1) 3603 return i, x 3604 3605 _, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3606 _, rx = control_flow_ops.while_loop(c, b, [i, rx], parallel_iterations=1) 3607 3608 r = gradients_impl.gradients([rx], x) 3609 self.assertAllClose(1024.0, r[0]) 3610 3611 @test_util.run_v1_only("b/120545219") 3612 def testWhileGrad_ParallelTwoLoops(self): 3613 with self.cached_session(): 3614 i = constant_op.constant(0, name="i") 3615 x = constant_op.constant(2.0, name="x") 3616 3617 c = lambda i, x: math_ops.less(i, 5) 3618 3619 def b(i, x): 3620 x = math_ops.multiply(x, 2.0) 3621 i = math_ops.add(i, 1) 3622 return i, x 3623 3624 _, r1 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3625 _, r2 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3626 rx = math_ops.add(r1, r2) 3627 3628 r = gradients_impl.gradients([rx], x) 3629 self.assertAllClose(64.0, r[0]) 3630 3631 @test_util.run_v1_only("b/120545219") 3632 def testWhileGrad_OneOutputWithControlDependencyOnSecond(self): 3633 with self.cached_session(): 3634 i = constant_op.constant(0, name="i") 3635 x = constant_op.constant(1.0, name="x") 3636 y = constant_op.constant(1.0, name="y") 3637 c = lambda i, *_: math_ops.less(i, 1, name="cond_less") 3638 3639 def b(i, xi, yi): 3640 # return (i + 1, xi, xi + yi) 3641 return (math_ops.add(i, 1, name="inc"), array_ops.identity( 3642 xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi")) 3643 3644 _, x_f, y_f = control_flow_ops.while_loop(c, b, [i, x, y]) 3645 with ops.control_dependencies([x_f]): 3646 y_f_d = array_ops.identity(y_f, name="y_f_d") 3647 3648 self.assertAllClose(2.0, self.evaluate(y_f_d)) # y_f_d = 1.0 + 1.0 3649 g = gradients_impl.gradients([y_f_d], [x])[0] 3650 self.assertTrue(g is not None) 3651 self.assertAllClose(1.0, 3652 self.evaluate(g)) # y_f_d = x + 1.0, dy_f_d/dx = 1.0 3653 3654 def _testNestedWhileGrad_Simple(self, use_gpu): 3655 with self.cached_session(use_gpu=use_gpu): 3656 v = constant_op.constant(1.0) 3657 3658 def inner_loop(s): 3659 c = lambda x: math_ops.less(x, 4.0) 3660 b = lambda x: math_ops.multiply(x, 2.0) 3661 return control_flow_ops.while_loop(c, b, [s]) 3662 3663 c = lambda x: math_ops.less(x, 2.0) 3664 b = lambda x: math_ops.multiply(inner_loop(x), 2.0) 3665 r = control_flow_ops.while_loop(c, b, [v]) 3666 3667 r = gradients_impl.gradients(r, v)[0] 3668 self.assertAllClose(8.0, self.evaluate(r)) 3669 3670 @test_util.run_deprecated_v1 3671 def testNestedWhileGrad_Simple(self): 3672 self._testNestedWhileGrad_Simple(use_gpu=False) 3673 self._testNestedWhileGrad_Simple(use_gpu=True) 3674 3675 @test_util.run_v1_only("b/120545219") 3676 def testNestedWhileGrad_SerialInner(self): 3677 with self.cached_session(): 3678 v = constant_op.constant(1.0) 3679 3680 def inner_loop1(s): 3681 z = constant_op.constant(0) 3682 c = lambda i, x: math_ops.less(i, 4) 3683 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3684 return control_flow_ops.while_loop(c, b, [z, s]) 3685 3686 def inner_loop2(s): 3687 z = constant_op.constant(0) 3688 c = lambda i, x: math_ops.less(i, 4) 3689 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3690 return control_flow_ops.while_loop(c, b, [z, s]) 3691 3692 c = lambda x: math_ops.less(x, 128.0) 3693 b = lambda x: inner_loop2(inner_loop1(x)[1])[1] 3694 r = control_flow_ops.while_loop(c, b, [v]) 3695 3696 r = gradients_impl.gradients(r, v)[0] 3697 self.assertAllClose(256.0, self.evaluate(r)) 3698 3699 @test_util.run_deprecated_v1 3700 def testNestedWhileGrad_ParallelInner(self): 3701 with self.cached_session(): 3702 v = constant_op.constant(1.0) 3703 3704 def inner_loop1(s): 3705 z = constant_op.constant(0) 3706 c = lambda i, x: math_ops.less(i, 4) 3707 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3708 return control_flow_ops.while_loop(c, b, [z, s]) 3709 3710 def inner_loop2(s): 3711 z = constant_op.constant(0) 3712 c = lambda i, x: math_ops.less(i, 4) 3713 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3714 return control_flow_ops.while_loop(c, b, [z, s]) 3715 3716 c = lambda x: math_ops.less(x, 128.0) 3717 b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1]) 3718 r = control_flow_ops.while_loop(c, b, [v]) 3719 3720 r = gradients_impl.gradients(r, v)[0] 3721 self.assertAllClose(512.0, self.evaluate(r)) 3722 3723 @test_util.run_v1_only("b/120545219") 3724 def testNestedWhileGrad_ParallelIterations(self): 3725 # Make sure the stack pushes and pops of an inner loop are executed in 3726 # the sequential order of the iterations of its outer loop. 3727 with self.cached_session() as sess: 3728 3729 def inner_loop(t): 3730 fn = lambda n: n + math_ops.square(var) 3731 return map_fn.map_fn(fn=fn, elems=t, parallel_iterations=10) 3732 3733 def outer_loop(inp): 3734 return map_fn.map_fn( 3735 fn=inner_loop, elems=inp, parallel_iterations=10) 3736 3737 var = variables.Variable(constant_op.constant(3.0)) 3738 inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 3739 res = outer_loop(inp) 3740 optimizer = adam.AdamOptimizer(learning_rate=0.001) 3741 train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res))) 3742 self.evaluate(variables.global_variables_initializer()) 3743 self.evaluate(train_op) 3744 self.assertAllClose(2.999, var.read_value()) 3745 3746 def _testWhileCondGrad_Simple(self, use_gpu): 3747 with self.cached_session(use_gpu=use_gpu): 3748 v = ops.convert_to_tensor(2.0, name="v") 3749 n = ops.convert_to_tensor(100.0, name="n") 3750 one = ops.convert_to_tensor(1.0, name="one") 3751 c = lambda x: math_ops.less(x, n) 3752 # pylint: disable=undefined-variable 3753 # for OSS build 3754 b = lambda x: control_flow_ops.cond(constant_op.constant(True), 3755 lambda: math_ops.square(x), 3756 lambda: math_ops.subtract(x, one)) 3757 # pylint: enable=undefined-variable 3758 r = control_flow_ops.while_loop(c, b, [v]) 3759 r = gradients_impl.gradients(r, v)[0] 3760 self.assertAllClose(1024.0, self.evaluate(r)) 3761 3762 @test_util.run_deprecated_v1 3763 def testWhileCondGrad_Simple(self): 3764 self._testWhileCondGrad_Simple(use_gpu=False) 3765 self._testWhileCondGrad_Simple(use_gpu=True) 3766 3767 @test_util.run_deprecated_v1 3768 def testWhileCondGrad_UnknownShape(self): 3769 with self.cached_session() as sess: 3770 v = array_ops.placeholder(dtypes.float32) 3771 n = ops.convert_to_tensor(100.0, name="n") 3772 one = ops.convert_to_tensor(1.0, name="one") 3773 c = lambda x: math_ops.less(x, n) 3774 # pylint: disable=undefined-variable 3775 # for OSS build 3776 b = lambda x: control_flow_ops.cond(constant_op.constant(True), 3777 lambda: math_ops.square(x), 3778 lambda: math_ops.subtract(x, one)) 3779 # pylint: enable=undefined-variable 3780 r = control_flow_ops.while_loop(c, b, [v]) 3781 r = gradients_impl.gradients(r, v)[0] 3782 r = sess.run(r, feed_dict={v: 2.0}) 3783 self.assertAllClose(1024.0, r) 3784 3785 @test_util.run_deprecated_v1 3786 def testWhileGrad_Concat(self): 3787 with self.cached_session() as sess: 3788 x = variable_scope.get_variable("x", initializer=[[1., 2.]]) 3789 i0 = constant_op.constant(0) 3790 h0 = array_ops.zeros([0, 2]) 3791 3792 def condition(i, _): 3793 return i < 2 3794 3795 def body(i, h): 3796 return i + 1, array_ops.concat([h, x], 0) 3797 3798 _, h = control_flow_ops.while_loop( 3799 condition, body, [i0, h0], 3800 [i0.get_shape(), tensor_shape.TensorShape([None, 2])]) 3801 s = math_ops.reduce_sum(h) 3802 3803 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 3804 op = optimizer.minimize(s) 3805 3806 self.evaluate(variables.global_variables_initializer()) 3807 self.evaluate(op) 3808 self.assertAllClose([[0.98000002, 1.98000002]], self.evaluate(x)) 3809 3810 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 3811 @test_util.run_v1_only("b/120545219") 3812 def testWhileWithRefsWithGradients_1(self): 3813 with self.cached_session() as sess: 3814 x = variables.VariableV1(0.)._ref() # pylint: disable=protected-access 3815 i = constant_op.constant(0) 3816 c = lambda i, x: math_ops.less(i, 10) 3817 3818 self.assertEqual(x.dtype, dtypes.float32_ref) 3819 3820 def body(i, x): 3821 self.assertEqual(x.dtype, dtypes.float32_ref) 3822 return [i + 1, gen_array_ops.ref_identity(x)] 3823 3824 r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5) 3825 3826 grad_ys = [variables.VariableV1(73)._ref()] # pylint: disable=protected-access 3827 grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys) 3828 3829 self.evaluate(variables.global_variables_initializer()) 3830 3831 self.assertEqual(r[0].dtype, dtypes.int32) 3832 self.assertEqual(r[1].dtype, dtypes.float32_ref) 3833 3834 value_i, value_x, value_x_grad = sess.run(r + grad) 3835 3836 self.assertEqual(10, value_i) 3837 self.assertEqual(0, value_x) 3838 self.assertEqual(73, value_x_grad) 3839 3840 @test_util.deprecated_graph_mode_only 3841 def testWhileGrad_IndexedSlices(self): 3842 with self.cached_session(): 3843 values = constant_op.constant([2.0, 4.0], name="values") 3844 indices = constant_op.constant([0, 3], name="indices") 3845 shape = constant_op.constant([10], name="dense_shape") 3846 i = constant_op.constant(0) 3847 x = ops.IndexedSlices(values, indices, dense_shape=shape) 3848 3849 def c(i, _): 3850 return i < 10 3851 3852 def b(i, x): 3853 return [ 3854 i + 1, 3855 ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) 3856 ] 3857 3858 _, r = control_flow_ops.while_loop(c, b, [i, x]) 3859 r = gradients_impl.gradients(r.values, values)[0] 3860 self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r)) 3861 3862 @test_util.deprecated_graph_mode_only 3863 def testWhileGrad_SparseTensor(self): 3864 with self.cached_session(): 3865 values = constant_op.constant([2.0, 4.0], name="values") 3866 indices = constant_op.constant( 3867 [[0], [3]], dtype=dtypes.int64, name="indices") 3868 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 3869 i = constant_op.constant(0) 3870 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 3871 3872 def c(i, _): 3873 return i < 10 3874 3875 def b(i, x): 3876 return [ 3877 i + 1, 3878 sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) 3879 ] 3880 3881 _, r = control_flow_ops.while_loop(c, b, [i, x]) 3882 r = gradients_impl.gradients(r.values, values)[0] 3883 self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r)) 3884 3885 @test_util.deprecated_graph_mode_only 3886 def testCallGradInLoop(self): 3887 with self.cached_session() as sess: 3888 i0 = constant_op.constant(0) 3889 params = constant_op.constant(5.0) 3890 params_1 = math_ops.square(params) 3891 3892 def c(i, _): 3893 return i < 10 3894 3895 def b(i, x): 3896 data = constant_op.constant([1.0, 2.0, 3.0]) 3897 data = math_ops.multiply(data, params_1) 3898 x1 = x + gradients_impl.gradients(data, params)[0] 3899 return i + 1, x1 3900 3901 output_grad = control_flow_ops.while_loop( 3902 c, b, [i0, constant_op.constant(0.0)]) 3903 self.assertAllClose(600.0, self.evaluate(output_grad)[1]) 3904 3905 @test_util.run_deprecated_v1 3906 def testWhileAndTensorArray(self): 3907 with self.cached_session() as sess: 3908 param = constant_op.constant(2.0) 3909 n0 = constant_op.constant(0) 3910 y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") 3911 3912 def c(i, _): 3913 return i < 10 3914 3915 def b(i, y): 3916 return [ 3917 i + 1, 3918 map_fn.map_fn(lambda x: math_ops.multiply(x, param), y) 3919 ] 3920 3921 r = control_flow_ops.while_loop(c, b, [n0, y0], parallel_iterations=1) 3922 r = gradients_impl.gradients(r, param)[0] 3923 self.assertAllClose(107520.0, self.evaluate(r)) 3924 3925 @test_util.run_deprecated_v1 3926 def testNestedWhileAndTensorArray(self): 3927 n = constant_op.constant(3.0) 3928 3929 def Body(row, ta): 3930 3931 def InnerBody(row, col, ta): 3932 # Note: row and col are 1-based. 3933 ta = ta.write( 3934 math_ops.cast(n * (row - 1.) + col - 1., dtypes.int32), row * col) 3935 return row, col + 1., ta 3936 3937 ta = control_flow_ops.while_loop( 3938 lambda _, col, _1: col <= n, 3939 InnerBody, [row, constant_op.constant(1.), ta], 3940 return_same_structure=False)[2] 3941 return row + 1., ta 3942 3943 ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=9) 3944 ta = control_flow_ops.while_loop( 3945 lambda row, _: row <= n, 3946 Body, [constant_op.constant(1.), ta], 3947 return_same_structure=False)[1] 3948 3949 output = array_ops.reshape(ta.stack(), [3, 3]) 3950 self.assertAllEqual( 3951 self.evaluate(output), [[1., 2., 3.], [2., 4., 6.], [3., 6., 9.]]) 3952 # TODO(b/117675481): This does not work with current TA. Enable with new TA. 3953 # grad = gradients_impl.gradients(output, [n]) 3954 # self.assertEqual(self.evaluate(grad), 3.5) 3955 3956 @test_util.run_deprecated_v1 3957 def testWhileGrad_StopGrad(self): 3958 with self.cached_session(): 3959 x = constant_op.constant(3.0, name="x") 3960 y = constant_op.constant(2.0, name="y") 3961 3962 c = lambda x, y: math_ops.less(x, 100.0) 3963 3964 def b(x, y): 3965 y1 = math_ops.square(y) 3966 x1 = math_ops.add(math_ops.square(x), y1) 3967 return x1, y1 3968 3969 rx, ry = control_flow_ops.while_loop(c, b, [x, y]) 3970 3971 r = gradients_impl.gradients(rx, y)[0] 3972 self.assertEqual(136.0, self.evaluate(r)) 3973 r = gradients_impl.gradients(ry, y)[0] 3974 self.assertEqual(32.0, self.evaluate(r)) 3975 3976 r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0] 3977 self.assertEqual(r, None) 3978 r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0] 3979 self.assertEqual(r, None) 3980 3981 r = gradients_impl.gradients( 3982 array_ops.stop_gradient(math_ops.square(rx)), y)[0] 3983 self.assertEqual(r, None) 3984 r = gradients_impl.gradients( 3985 array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0] 3986 self.assertEqual(r, None) 3987 r = gradients_impl.gradients( 3988 array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0] 3989 self.assertEqual(r, None) 3990 3991 r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0] 3992 self.assertEqual(168.0, self.evaluate(r)) 3993 r = gradients_impl.gradients( 3994 math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0] 3995 self.assertEqual(136.0, self.evaluate(r)) 3996 r = gradients_impl.gradients( 3997 math_ops.add(array_ops.stop_gradient(rx), ry), y)[0] 3998 self.assertEqual(32.0, self.evaluate(r)) 3999 4000 @test_util.run_deprecated_v1 4001 def testWhileGrad_StopGradInside(self): 4002 with self.cached_session(): 4003 x = constant_op.constant(3.0, name="x") 4004 y = constant_op.constant(2.0, name="y") 4005 4006 c = lambda x, y: math_ops.less(x, 100.0) 4007 4008 def b(x, y): 4009 y1 = array_ops.stop_gradient(math_ops.square(y)) 4010 x1 = math_ops.add(math_ops.square(x), y1) 4011 return x1, y1 4012 4013 rx, _ = control_flow_ops.while_loop(c, b, [x, y]) 4014 4015 r = gradients_impl.gradients(rx, y)[0] 4016 self.assertAllClose(0.0, self.evaluate(r)) 4017 r = gradients_impl.gradients(rx, x)[0] 4018 self.assertAllClose(156.0, self.evaluate(r)) 4019 4020 @test_util.run_deprecated_v1 4021 def testWhileGrad_StopGradInsideNoShape(self): 4022 with self.cached_session() as sess: 4023 x = array_ops.placeholder(dtypes.float32) 4024 y = array_ops.placeholder(dtypes.float32) 4025 4026 c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0) 4027 4028 def b(x, y): 4029 y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped")) 4030 x1 = math_ops.add(math_ops.square(x), y1) 4031 return x1, y1 4032 4033 rx, _ = control_flow_ops.while_loop(c, b, [x, y]) 4034 4035 grad_y = gradients_impl.gradients(rx, y)[0] 4036 grad_x = gradients_impl.gradients(rx, x)[0] 4037 feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]} 4038 self.assertAllClose([0.0, 0.0], sess.run(grad_y, feed_dict=feed_dict)) 4039 self.assertAllClose([156.0, 400.0], sess.run(grad_x, feed_dict=feed_dict)) 4040 name = "gradients/while/stopped_grad" 4041 all_ops = x.graph.get_operations() 4042 self.assertFalse(any(name in op.name for op in all_ops)) 4043 4044 @test_util.run_deprecated_v1 4045 def testWhileGradGradFail(self): 4046 theta = variables.Variable(initial_value=1.) 4047 4048 def fn(prev, x): 4049 return prev + x * theta 4050 4051 result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32)) 4052 grad_theta = gradients_impl.gradients(result, theta) 4053 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 4054 with self.assertRaisesRegex(TypeError, "Second-order gradient"): 4055 gradients_impl.gradients(grad_theta, theta) 4056 grad_theta_stopped = array_ops.stop_gradient(grad_theta) 4057 gradients_impl.gradients(grad_theta_stopped, theta) 4058 4059 @test_util.run_deprecated_v1 4060 def testStopGradOnWhileGrad(self): 4061 with self.cached_session(): 4062 x = constant_op.constant(2.0, name="x") 4063 y = constant_op.constant(2.0, name="y") 4064 4065 c = lambda x: math_ops.less(x, 100.0) 4066 b = lambda x: math_ops.multiply(x, y) 4067 rx = control_flow_ops.while_loop(c, b, [x]) 4068 4069 rg = gradients_impl.gradients(rx, y)[0] 4070 rg = array_ops.stop_gradient(rg) 4071 r = math_ops.add(math_ops.square(y), rx) 4072 r = math_ops.add(r, rg) 4073 r = gradients_impl.gradients(r, y)[0] 4074 self.assertEqual(388.0, self.evaluate(r)) 4075 4076 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 4077 @test_util.run_deprecated_v1 4078 def testWhileGradientWithNontrainablePath1(self): 4079 q = variables.Variable([7., 8.]) 4080 4081 def cond(_, y): 4082 del y 4083 return False 4084 4085 def body(x, _): 4086 return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) 4087 4088 _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) 4089 dy_dq, = gradients_impl.gradients(y, q) 4090 self.assertIsNotNone(dy_dq) 4091 with self.cached_session() as sess: 4092 self.evaluate(q.initializer) 4093 self.assertAllClose([0., 0.], self.evaluate(dy_dq)) 4094 4095 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 4096 @test_util.run_v1_only("b/120545219") 4097 def testWhileGradientWithNontrainablePath2(self): 4098 q = variables.Variable([7., 8.]) 4099 4100 def cond(_, y): 4101 return math_ops.equal(y, 0.) 4102 4103 def body(x, _): 4104 zero = constant_op.constant(0, dtype=dtypes.int64) 4105 return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) 4106 4107 _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) 4108 dy_dq, = gradients_impl.gradients(y, q) 4109 self.assertIsNotNone(dy_dq) 4110 with self.cached_session() as sess: 4111 self.evaluate(q.initializer) 4112 self.assertAllClose([1., 1.], self.evaluate(dy_dq)) 4113 4114 @test_util.run_v1_only("b/120545219") 4115 def testIssue16504(self): 4116 c = constant_op.constant(np.arange(100), dtype=dtypes.float32) 4117 w = variables.Variable( 4118 initial_value=np.ones(100), dtype=dtypes.float32) / 100 4119 k = variables.Variable(0, dtype=dtypes.int32) 4120 chg_w = constant_op.constant(np.inf, dtype=dtypes.float32) 4121 4122 def cond(k, _, chg_w): 4123 return math_ops.logical_and(k < 10, chg_w > 1e-3) 4124 4125 def body(k, w, chg_w): 4126 grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w) 4127 w_n = w * math_ops.exp(-0.1 * grad) 4128 w_n /= math_ops.reduce_sum(w_n) 4129 chg_w = ( 4130 math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum( 4131 math_ops.abs(w))) 4132 return k + 1, w_n, chg_w 4133 4134 _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w]) 4135 grad, = gradients_impl.gradients(w, c) 4136 self.assertIsNotNone(grad) 4137 4138 @test_util.run_v1_only("b/120545219") 4139 def testStopGradMultiFlows(self): 4140 with self.cached_session(): 4141 4142 def body(i, y, r): 4143 x = variable_scope.get_variable( 4144 "x", 4145 shape=(), 4146 dtype=dtypes.float32, 4147 initializer=init_ops.ones_initializer()) 4148 y *= x 4149 return [i + 1, y, r + math_ops.reduce_sum(y)] 4150 4151 i0 = constant_op.constant(0) 4152 y0 = array_ops.ones(5) 4153 r0 = constant_op.constant(0.0) 4154 cond = lambda i, y, r: i < 1 4155 _, _, r = control_flow_ops.while_loop( 4156 cond, body, [i0, y0, r0], back_prop=True) 4157 4158 vars_ = variables.global_variables() 4159 grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0]) 4160 z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads))) 4161 result = gradients_impl.gradients(z, vars_)[0] 4162 self.evaluate(variables.global_variables_initializer()) 4163 self.assertEqual(5.0, self.evaluate(result)) 4164 4165 @test_util.run_v1_only("b/120545219") 4166 def testOneValueCond(self): 4167 4168 with self.cached_session(): 4169 c = array_ops.placeholder(dtypes.int32, shape=[]) 4170 one = ops.convert_to_tensor(1, name="one") 4171 two = ops.convert_to_tensor(2, name="two") 4172 p = math_ops.greater_equal(c, 1) 4173 i = control_flow_ops.cond(p, lambda: one, lambda: two) 4174 self.assertTrue(isinstance(i, ops.Tensor)) 4175 4176 # True case: c = 2 is >= 1 4177 self.assertEqual([1], i.eval(feed_dict={c: 2})) 4178 4179 # False case: c = 0 is not >= 1 4180 self.assertEqual([2], i.eval(feed_dict={c: 0})) 4181 4182 @test_util.run_deprecated_v1 4183 def testExampleCond(self): 4184 4185 with self.cached_session(): 4186 x = ops.convert_to_tensor([-2.0, 2.0], name="x") 4187 d = array_ops.placeholder(dtypes.int32, shape=[]) 4188 4189 def l2(): 4190 return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x))) 4191 4192 def l1(): 4193 return math_ops.reduce_sum(math_ops.abs(x)) 4194 4195 i = control_flow_ops.cond(math_ops.equal(d, 2), l2, l1) 4196 self.assertAllClose(4.0, i.eval(feed_dict={d: 1})) 4197 self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2})) 4198 4199 @test_util.run_v1_only("b/120545219") 4200 def testCase(self): 4201 with self.cached_session(): 4202 x = constant_op.constant(1) 4203 y = constant_op.constant(2) 4204 z = constant_op.constant(3) 4205 f1 = lambda: constant_op.constant(17) 4206 f2 = lambda: constant_op.constant(23) 4207 f3 = lambda: constant_op.constant(-1) 4208 4209 r1 = control_flow_ops.case( 4210 { 4211 x < y: f1, 4212 x > z: f2 4213 }, default=f3, exclusive=True) 4214 self.assertAllEqual(r1, 17) 4215 4216 r2 = control_flow_ops.case([(y > z, f1), (y > x, f2)], default=f3) 4217 self.assertAllEqual(r2, 23) 4218 4219 # Duplicate events can happen, first one is selected 4220 r3 = control_flow_ops.case([(x < y, f1), (x < y, f2)], default=f3) 4221 self.assertAllEqual(r3, 17) 4222 4223 # Duplicate events cause an error if exclusive = True 4224 r4 = control_flow_ops.case( 4225 [(x < y, f1), (x < y, f2)], default=f3, exclusive=True) 4226 with self.assertRaisesOpError("Input error:"): 4227 self.evaluate(r4) 4228 4229 # Check that the default is called if none of the others are 4230 r5 = control_flow_ops.case({x > y: f1}, default=f3) 4231 self.assertAllEqual(r5, -1) 4232 4233 ran_once = [False, False, False] 4234 4235 def break_run_twice(ix): 4236 4237 def _break(): 4238 ran_once[ix] = True 4239 return constant_op.constant(ix) 4240 4241 return _break 4242 4243 # Should not fail - each conditional gets called exactly once 4244 # except default. Default gets called twice: once to create an 4245 # empty output and once for the actual cond switch. 4246 r6 = control_flow_ops.case( 4247 [(x < y, break_run_twice(0)), (x > y, break_run_twice(1))], 4248 default=lambda: constant_op.constant(2)) 4249 4250 self.assertAllEqual(r6, 0) 4251 4252 @test_util.run_v1_only("b/120545219") 4253 def testCaseSideEffects(self): 4254 with self.cached_session() as sess: 4255 v0 = variables.Variable(-1) 4256 v1 = variables.Variable(-1) 4257 v2 = variables.Variable(-1) 4258 4259 a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0) 4260 b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1) 4261 c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2) 4262 4263 x = constant_op.constant(1) 4264 y = constant_op.constant(2) 4265 4266 r0 = control_flow_ops.case( 4267 ((x < y, a), (x > y, b)), default=c, exclusive=True) 4268 r1 = control_flow_ops.case( 4269 ((x > y, a), (x < y, b)), default=c, exclusive=True) 4270 r2 = control_flow_ops.case( 4271 ((x > y, a), (x > y, b)), default=c, exclusive=True) 4272 4273 self.evaluate(variables.global_variables_initializer()) 4274 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3) 4275 self.assertEqual(2, self.evaluate(r2)) 4276 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, -1, 2]) 4277 4278 self.evaluate(variables.global_variables_initializer()) 4279 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3) 4280 self.assertEqual(1, self.evaluate(r1)) 4281 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, 1, -1]) 4282 4283 self.evaluate(variables.global_variables_initializer()) 4284 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3) 4285 self.assertEqual(0, self.evaluate(r0)) 4286 self.assertAllEqual(self.evaluate([v0, v1, v2]), [0, -1, -1]) 4287 4288 @test_util.disable_control_flow_v2("b/113324949 (ref vars)") 4289 @test_util.run_v1_only("b/120545219") 4290 def testOneOpCond(self): 4291 with self.cached_session(): 4292 v = variables.Variable(0) 4293 c = ops.convert_to_tensor(0) 4294 one = ops.convert_to_tensor(1) 4295 two = ops.convert_to_tensor(2) 4296 p = math_ops.greater_equal(c, 1) 4297 4298 def a(): 4299 return state_ops.assign(v, one) 4300 4301 def b(): 4302 return state_ops.assign(v, two) 4303 4304 i = control_flow_ops.cond(p, a, b) 4305 self.assertTrue(isinstance(i, ops.Tensor)) 4306 self.evaluate(variables.global_variables_initializer()) 4307 4308 self.assertEqual(0, self.evaluate(v)) 4309 4310 # True case: c = 2 is >= 1, v is set to 1. 4311 self.assertEqual(1, i.eval(feed_dict={c.name: 2})) 4312 self.assertEqual(1, self.evaluate(v)) 4313 4314 # False case: c = 0 is not >= 1, v is set to 2. 4315 self.assertEqual(2, i.eval(feed_dict={c.name: 0})) 4316 self.assertEqual(2, self.evaluate(v)) 4317 4318 @test_util.run_v1_only("b/120545219") 4319 def testWithOpsDependencies(self): 4320 with self.cached_session() as sess: 4321 v = variables.VariableV1(0.0) 4322 c = constant_op.constant(10) 4323 4324 # Fetching v directly will result in an uninitialized error 4325 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4326 self.evaluate([c, v]) 4327 4328 # Use a control dependency to ensure init_variable is run 4329 # while asking for c 4330 real_v = control_flow_ops.with_dependencies( 4331 name="real_tensor", 4332 output_tensor=v._ref(), # pylint: disable=protected-access 4333 dependencies=[v.initializer]) 4334 c_val, real_v_val = self.evaluate([c, real_v]) 4335 4336 # Ensure the result of 'real_c' is the same as 'c' 4337 self.assertAllEqual(10, c_val) 4338 4339 # Ensure that 'v' is initialized 4340 self.assertAllClose(0.0, real_v_val) 4341 4342 @test_util.run_v1_only("b/120545219") 4343 def testWithTensorDependencies(self): 4344 with self.cached_session(): 4345 v = variables.VariableV1(0.0) 4346 c1 = constant_op.constant(10) 4347 c2 = constant_op.constant(20) 4348 4349 # c1_with_init_v depends on the init op for v 4350 c1_with_init_v = control_flow_ops.with_dependencies( 4351 name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer]) 4352 # c2_with_c1 depends on the value of c1_with_init_v 4353 c2_with_c1_dep = control_flow_ops.with_dependencies( 4354 name="c2_with_c1_dep", 4355 output_tensor=c2, 4356 dependencies=[c1_with_init_v]) 4357 4358 # Fetching v directly will result in an uninitialized error 4359 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4360 self.evaluate(v) 4361 4362 # Get the value of 'c2_with_c1_dep', which should cause 'v' 4363 # to be initialized. 4364 self.assertAllEqual(20, self.evaluate(c2_with_c1_dep)) 4365 4366 # Ensure that 'v' is initialized 4367 self.assertAllClose(0.0, self.evaluate(v)) 4368 4369 @test_util.run_v1_only("b/120545219") 4370 def testWithIndexedSlicesDependencies(self): 4371 with self.cached_session(): 4372 v = variables.VariableV1( 4373 np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32)) 4374 v_at_1 = ops.IndexedSlices(v, constant_op.constant([1])) 4375 gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices) 4376 v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer], 4377 v_at_1) 4378 gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values, 4379 v_at_1_after_init.indices) 4380 4381 # Fetching gather_v_at_1 will result in an uninitialized error 4382 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4383 self.evaluate(gather_v_at_1) 4384 4385 # Getting gather_v_at_1_after_init will work, and initialize v. 4386 self.assertAllEqual([[10.0, 11.0]], 4387 self.evaluate(gather_v_at_1_after_init)) 4388 4389 # Double check that 'v' is initialized 4390 self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], 4391 self.evaluate(v)) 4392 4393 def testDependenciesDevice(self): 4394 with ops.Graph().as_default(): 4395 # device set on tensor => same device on dep. 4396 with ops.device("/job:ps"): 4397 vd = variables.VariableV1([0.0]) 4398 with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd) 4399 self.assertTrue("/job:ps" in with_vd_dep.device) 4400 4401 # No device set on tensor => no device on dep. 4402 vnod = variables.VariableV1([0.0]) 4403 with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer], 4404 vnod) 4405 self.assertDeviceEqual(None, with_vnod_dep.device) 4406 4407 # device set on tensor, default device on graph => default device on dep. 4408 vdef = variables.VariableV1([0.0], name="vdef") 4409 with ops.device("/job:worker/device:GPU:1"): 4410 with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer], 4411 vdef) 4412 # The device is empty, but the colocation constraint is set. 4413 self.assertDeviceEqual("", with_vdef_dep.device) 4414 self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups()) 4415 4416 @test_util.run_v1_only("b/120545219") 4417 def testGroup(self): 4418 with self.cached_session() as sess: 4419 v1 = variables.VariableV1([0.0]) 4420 v2 = variables.VariableV1([1.0]) 4421 4422 # Group init1 and init2 and run. 4423 init = control_flow_ops.group(v1.initializer, v2.initializer) 4424 # Fetching v1 directly will result in an uninitialized error 4425 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4426 self.evaluate(v1) 4427 4428 # Runs "init" before fetching v1 and v2. 4429 init.run() 4430 v1_val, v2_val = self.evaluate([v1, v2]) 4431 4432 # Ensure that v1 and v2 are initialized 4433 self.assertAllClose([0.0], v1_val) 4434 self.assertAllClose([1.0], v2_val) 4435 4436 @test_util.run_v1_only("b/120545219") 4437 def testGroupEmpty(self): 4438 op = control_flow_ops.group() 4439 self.assertEqual(op.type, "NoOp") 4440 self.assertEqual(op.control_inputs, []) 4441 4442 @test_util.run_deprecated_v1 4443 def testMergeShapes(self): 4444 # All inputs unknown. 4445 p1 = array_ops.placeholder(dtypes.float32) 4446 p2 = array_ops.placeholder(dtypes.float32) 4447 p3 = array_ops.placeholder(dtypes.float32) 4448 m, index = control_flow_ops.merge([p1, p2, p3]) 4449 self.assertIs(None, m.get_shape().ndims) 4450 self.assertEqual([], index.get_shape()) 4451 4452 # All inputs known with different ranks. 4453 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4454 p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3]) 4455 m, index = control_flow_ops.merge([p1, p2]) 4456 self.assertIs(None, m.get_shape().ndims) 4457 self.assertEqual([], index.get_shape()) 4458 4459 # All inputs known with some dimensions different. 4460 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4461 p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1]) 4462 m, index = control_flow_ops.merge([p1, p2]) 4463 self.assertEqual([None, None], m.get_shape().as_list()) 4464 self.assertEqual([], index.get_shape()) 4465 4466 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4467 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4468 m, index = control_flow_ops.merge([p1, p2]) 4469 self.assertEqual([None, 2], m.get_shape().as_list()) 4470 self.assertEqual([], index.get_shape()) 4471 4472 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4473 p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2]) 4474 m, index = control_flow_ops.merge([p1, p2]) 4475 self.assertEqual([None, 2], m.get_shape().as_list()) 4476 self.assertEqual([], index.get_shape()) 4477 4478 # All inputs known with same dimensions. 4479 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4480 p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4481 m, index = control_flow_ops.merge([p1, p2]) 4482 self.assertEqual([1, 2], m.get_shape().as_list()) 4483 self.assertEqual([], index.get_shape()) 4484 4485 p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4486 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4487 m, index = control_flow_ops.merge([p1, p2]) 4488 self.assertEqual([None, 2], m.get_shape().as_list()) 4489 self.assertEqual([], index.get_shape()) 4490 4491 p1 = array_ops.placeholder(dtypes.float32, shape=[None, None]) 4492 p2 = array_ops.placeholder(dtypes.float32, shape=[None, None]) 4493 m, index = control_flow_ops.merge([p1, p2]) 4494 self.assertEqual([None, None], m.get_shape().as_list()) 4495 self.assertEqual([], index.get_shape()) 4496 4497 @test_util.run_v1_only("b/120545219") 4498 def testRefSelect(self): 4499 index = array_ops.placeholder(dtypes.int32) 4500 4501 # All inputs unknown. 4502 p1 = array_ops.placeholder(dtypes.float32) 4503 p2 = array_ops.placeholder(dtypes.float32) 4504 p3 = array_ops.placeholder(dtypes.float32) 4505 v1 = variables.VariableV1(p1, validate_shape=False) 4506 v2 = variables.VariableV1(p2, validate_shape=False) 4507 v3 = variables.VariableV1(p3, validate_shape=False) 4508 self.assertIs(None, v1.get_shape().ndims) 4509 s = control_flow_ops.ref_select(index, [v1, v2, v3]) 4510 self.assertIs(None, s.get_shape().ndims) 4511 4512 # All inputs known but different. 4513 v1 = variables.VariableV1([[1, 2]]) 4514 v2 = variables.VariableV1([[2], [1]]) 4515 s = control_flow_ops.ref_select(index, [v1, v2]) 4516 self.assertIs(None, s.get_shape().ndims) 4517 4518 # All inputs known and same. 4519 v1 = variables.VariableV1([[1, 2]]) 4520 v2 = variables.VariableV1([[1, 2]]) 4521 s = control_flow_ops.ref_select(index, [v1, v2]) 4522 self.assertEqual([1, 2], s.get_shape()) 4523 4524 # Possibly the same but not guaranteed. 4525 v1 = variables.VariableV1([[1., 2.]]) 4526 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4527 v2 = variables.VariableV1(p2, validate_shape=False) 4528 s = control_flow_ops.ref_select(index, [v1, v2]) 4529 self.assertEqual(None, s.get_shape()) 4530 4531 @test_util.run_deprecated_v1 4532 def testRunLoopTensor(self): 4533 with self.cached_session() as sess: 4534 tensor_list = [] 4535 4536 def condition(t): 4537 return t < constant_op.constant(5) 4538 4539 def body(_): 4540 tensor_list.append(constant_op.constant(5)) 4541 return constant_op.constant(10) 4542 4543 result = control_flow_ops.while_loop(condition, body, 4544 [constant_op.constant(4)]) 4545 self.assertEqual(10, self.evaluate(result)) 4546 4547 # Ensure that we cannot run a tensor that escapes the loop body 4548 # accidentally. 4549 with self.assertRaises(ValueError): 4550 sess.run(tensor_list[0]) 4551 4552 @test_util.run_v1_only("b/120545219") 4553 def testWhilePyFuncBasic(self): 4554 4555 def func(x): 4556 return np.square(x) 4557 4558 with self.cached_session(): 4559 r = control_flow_ops.while_loop( 4560 lambda i, v: i < 4, 4561 lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]], 4562 [constant_op.constant(0), constant_op.constant(2.0, dtypes.float32)], 4563 [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()]) 4564 self.assertEqual(self.evaluate(r[1]), 65536.0) 4565 4566 @test_util.run_v1_only("b/120545219") 4567 def testWhileFuncBasic(self): 4568 4569 @function.Defun(dtypes.float32) 4570 def func(x): 4571 return math_ops.square(math_ops.square(x)) 4572 4573 with self.cached_session(): 4574 x = constant_op.constant(2.0, dtypes.float32) 4575 r = control_flow_ops.while_loop( 4576 lambda i, v: i < 2, lambda i, v: [i + 1, func(v)], 4577 [constant_op.constant(0), x], 4578 [tensor_shape.unknown_shape(), 4579 tensor_shape.unknown_shape()]) 4580 grad = gradients_impl.gradients(r, x)[0] 4581 self.assertEqual(self.evaluate(r[1]), 65536.0) 4582 self.assertEqual(self.evaluate(grad), 524288.0) 4583 # while_v2 does not have stacks. 4584 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 4585 self.assertEqual( 4586 len([op for op in x.graph.get_operations() if op.type == "StackV2" 4587 ]), 1) 4588 4589 4590 @test_util.run_v1_only("b/120545219") 4591 def testQIntSwitchMerge(self): 4592 with self.cached_session(force_gpu=test.is_gpu_available()) as sess: 4593 constant_qint = constant_op.constant(np.array([42]), dtypes.qint8) 4594 cond = constant_op.constant(True, dtypes.bool) 4595 v_f, v_t = control_flow_ops.switch(constant_qint, cond) 4596 result = control_flow_ops.merge([v_f, v_t]) 4597 self.evaluate(result) 4598 4599 @test_util.run_v1_only("b/120545219") 4600 def testQIntRefSwitchMerge(self): 4601 with self.cached_session(use_gpu=test.is_gpu_available()) as sess: 4602 var_qint = gen_state_ops.variable( 4603 shape=[1], dtype=dtypes.qint8, name="v", container="", shared_name="") 4604 assign_op = state_ops.assign( 4605 var_qint, constant_op.constant(np.array([42]), dtypes.qint8)) 4606 self.evaluate(assign_op) 4607 4608 cond = constant_op.constant(True, dtypes.bool) 4609 v_f, v_t = control_flow_ops.ref_switch(var_qint, cond) 4610 result = control_flow_ops.ref_merge([v_f, v_t]) 4611 self.evaluate(result) 4612 4613 @test_util.run_v1_only("b/120545219") 4614 def testUInt64SwitchMerge(self): 4615 with self.cached_session(force_gpu=test.is_gpu_available()) as sess: 4616 constant_uint64 = constant_op.constant(np.array([42]), dtypes.uint64) 4617 cond = constant_op.constant(True, dtypes.bool) 4618 v_f, v_t = control_flow_ops.switch(constant_uint64, cond) 4619 result = control_flow_ops.merge([v_f, v_t]) 4620 self.evaluate(result) 4621 4622 def testSwitchEagerMode(self): 4623 if not context.executing_eagerly(): 4624 return 4625 input_data = [1, 2, 3, 4] 4626 vf, vt = control_flow_ops.switch(input_data, False) 4627 self.assertAllEqual(vf, input_data) 4628 self.assertAllEqual(vt, []) 4629 4630 @test_util.run_deprecated_v1 4631 def testQIntArgAndRet(self): 4632 4633 @function.Defun(dtypes.qint8) 4634 def func(x): 4635 return x 4636 4637 with self.cached_session(force_gpu=test.is_gpu_available()) as sess: 4638 qint = constant_op.constant(np.array([42]), dtypes.qint8) 4639 result = func(qint) 4640 self.evaluate(result) 4641 4642 def testSparseIdentity(self): 4643 st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10]) 4644 st2 = control_flow_ops._Identity(st1) 4645 self.assertAllEqual(st1.indices, st2.indices) 4646 self.assertAllEqual(st1.values, st2.values) 4647 self.assertAllEqual(st1.dense_shape, st2.dense_shape) 4648 4649 def testSparseEnterExit(self): 4650 st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10]) 4651 st2 = control_flow_ops._Enter(st1, "foo_1") 4652 st3 = control_flow_ops.exit(st2) 4653 self.assertAllEqual(st1.indices, st3.indices) 4654 self.assertAllEqual(st1.values, st3.values) 4655 self.assertAllEqual(st1.dense_shape, st3.dense_shape) 4656 4657 def _buildWhileWithShapeInvariants(self, shape_invariants): 4658 r = constant_op.constant([1, 2]) 4659 4660 def cond(_): 4661 return False 4662 4663 def body(_): 4664 return constant_op.constant([1]) 4665 4666 return control_flow_ops.while_loop( 4667 cond, body, [r], shape_invariants=shape_invariants) 4668 4669 def testWhileOutputShapeWithShapeInvariantsUnknownRank(self): 4670 @def_function.function 4671 def runTest(): 4672 while_output = self._buildWhileWithShapeInvariants( 4673 [tensor_shape.TensorShape(None)]) 4674 self.assertIsNone(while_output.shape.rank) 4675 runTest() 4676 4677 def testWhileOutputShapeWithShapeInvariantsPartialShape(self): 4678 @def_function.function 4679 def runTest(): 4680 while_output = self._buildWhileWithShapeInvariants( 4681 [tensor_shape.TensorShape([None])]) 4682 self.assertAllEqual(while_output.shape.as_list(), [None]) 4683 runTest() 4684 4685 def testFunctionInWhile(self): 4686 4687 @def_function.function 4688 def body(x): 4689 return x + 1 4690 4691 r = control_flow_ops.while_loop(lambda x: x < 5, body, [0]) 4692 self.assertAllEqual(r, 5.) 4693 4694 4695class ControlFlowContextCheckTest(test.TestCase): 4696 4697 def _getWhileTensor(self): 4698 """Creates and returns a tensor from a while context.""" 4699 tensor = [] 4700 4701 def body(i): 4702 if not tensor: 4703 tensor.append(constant_op.constant(1)) 4704 return i + tensor[0] 4705 4706 control_flow_ops.while_loop(lambda i: i < 10, body, [0]) 4707 return tensor[0] 4708 4709 def _getCondTensor(self): 4710 cond_tensor = [] 4711 4712 def true_fn(): 4713 if not cond_tensor: 4714 cond_tensor.append(constant_op.constant(1)) 4715 return cond_tensor[0] 4716 4717 control_flow_ops.cond( 4718 math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) 4719 return cond_tensor[0] 4720 4721 @test_util.run_v1_only("b/120545219") 4722 def testInvalidContext(self): 4723 # Accessing a while loop tensor outside of control flow is illegal. 4724 while_tensor = self._getWhileTensor() 4725 with self.assertRaisesRegex( 4726 ValueError, 4727 "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' " 4728 "is in a while loop. See info log for more details."): 4729 math_ops.add(1, while_tensor) 4730 4731 @test_util.run_v1_only("b/120545219") 4732 def testInvalidContextInCond(self): 4733 # Accessing a while loop tensor in cond is illegal. 4734 while_tensor = self._getWhileTensor() 4735 with self.assertRaisesRegex( 4736 ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because " 4737 "'while/Const_1' is in a while loop. See info log for more details."): 4738 # TODO(skyewm): this passes if we return while_tensor directly instead 4739 # of using it as input to another op. 4740 control_flow_ops.cond( 4741 math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor), 4742 lambda: constant_op.constant(0)) 4743 4744 @test_util.run_v1_only("b/120545219") 4745 def testInvalidContextInWhile(self): 4746 # Accessing a while loop tensor in a different while loop is illegal. 4747 while_tensor = self._getWhileTensor() 4748 with self.assertRaisesRegex( 4749 ValueError, 4750 "Cannot use 'while/Const_1' as input to 'while_1/Add' because they are " 4751 "in different while loops. See info log for more details."): 4752 control_flow_ops.while_loop(lambda i: i < 10, 4753 lambda x: math_ops.add(1, while_tensor), [0]) 4754 4755 with self.assertRaisesRegex( 4756 ValueError, 4757 "Cannot use 'while/Const_1' as input to 'while_2/NextIteration' " 4758 "because they are in different while loops. See info log for more " 4759 "details."): 4760 control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0]) 4761 4762 def testValidCondContext(self): 4763 # Accessing a tensor from a cond context is OK (although dangerous). 4764 cond_tensor = self._getCondTensor() 4765 math_ops.add(1, cond_tensor) 4766 4767 def testValidCondContextBranches(self): 4768 # Accessing a tensor from a cond context from the other branch's cond 4769 # context is OK (although dangerous). 4770 cond_tensor = [] 4771 4772 def branch_fn(): 4773 if not cond_tensor: 4774 cond_tensor.append(constant_op.constant(1)) 4775 return cond_tensor[0] 4776 4777 control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn) 4778 4779 @test_util.run_v1_only("b/120545219") 4780 def testValidWhileContext(self): 4781 # Accessing a tensor in a nested while is OK. 4782 def body(_): 4783 c = constant_op.constant(1) 4784 return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0]) 4785 4786 control_flow_ops.while_loop(lambda i: i < 5, body, [0]) 4787 4788 @test_util.run_v1_only("b/120545219") 4789 def testValidNestedContexts(self): 4790 # Accessing a tensor from a cond context in a while context, all inside an 4791 # outer while context, is OK. 4792 def body(_): 4793 cond_tensor = self._getCondTensor() 4794 # Create another cond containing the while loop for good measure 4795 return control_flow_ops.cond( 4796 math_ops.less(1, 2), 4797 lambda: control_flow_ops.while_loop(lambda i: i < 3, 4798 lambda i: i + cond_tensor, [0]), 4799 lambda: constant_op.constant(0)) 4800 4801 control_flow_ops.while_loop(lambda i: i < 5, body, [0]) 4802 4803 @test_util.run_v1_only("b/120545219") 4804 def testInvalidNestedContexts(self): 4805 # Accessing a tensor from a while context in a different while context, all 4806 # inside a cond context, is illegal. 4807 def true_fn(): 4808 while_tensor = self._getWhileTensor() 4809 return control_flow_ops.while_loop(lambda i: i < 3, 4810 lambda i: i + while_tensor, [0]) 4811 4812 with self.assertRaisesRegex( 4813 ValueError, 4814 "Cannot use 'cond/while/Const_1' as input to 'cond/while_1/add' because" 4815 " they are in different while loops. See info log for more details."): 4816 control_flow_ops.cond( 4817 math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) 4818 4819 4820class TupleTest(test.TestCase): 4821 4822 @test_util.run_v1_only("b/120545219") 4823 def testTensors(self): 4824 for v1_first in [True, False]: 4825 with self.cached_session(): 4826 v1 = variables.VariableV1([1.0]) 4827 add1 = math_ops.add( 4828 control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access 4829 2.0) 4830 v2 = variables.VariableV1([10.0]) 4831 add2 = math_ops.add( 4832 control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access 4833 20.0) 4834 t1, _, t2 = control_flow_ops.tuple([add1, None, add2]) 4835 4836 # v1 is not initialized. 4837 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4838 self.evaluate(v1) 4839 4840 # v2 is not initialized. 4841 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4842 self.evaluate(v2) 4843 4844 if v1_first: 4845 # Getting t1 initializes v2. 4846 self.assertAllClose([3.0], self.evaluate(t1)) 4847 self.assertAllClose([10.0], self.evaluate(v2)) 4848 else: 4849 # Getting t2 initializes v1. 4850 self.assertAllClose([30.0], self.evaluate(t2)) 4851 self.assertAllClose([1.0], self.evaluate(v1)) 4852 4853 @test_util.run_v1_only("b/120545219") 4854 def testIndexedSlices(self): 4855 for v1_first in [True, False]: 4856 with self.cached_session(): 4857 v1 = variables.VariableV1( 4858 np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype( 4859 np.float32)) 4860 v1_at_1 = ops.IndexedSlices( 4861 control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access 4862 constant_op.constant([1])) 4863 4864 v2 = variables.VariableV1( 4865 np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype( 4866 np.float32)) 4867 v2_at_1 = ops.IndexedSlices( 4868 control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access 4869 constant_op.constant([1])) 4870 4871 st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1]) 4872 g1 = array_ops.gather(st1.values, st1.indices) 4873 g2 = array_ops.gather(st2.values, st2.indices) 4874 4875 # v1 is not initialized. 4876 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4877 self.evaluate(v1) 4878 4879 # v2 is not initialized. 4880 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4881 self.evaluate(v2) 4882 4883 if v1_first: 4884 # Getting g1 initializes v2. 4885 self.assertAllClose([[10.0, 11.0]], self.evaluate(g1)) 4886 self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]], 4887 self.evaluate(v2)) 4888 else: 4889 # Getting g2 initializes v1. 4890 self.assertAllClose([[10.1, 11.1]], self.evaluate(g2)) 4891 self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], 4892 self.evaluate(v1)) 4893 4894 def testAcceptTensorsAsControlInputs(self): 4895 with self.cached_session(): 4896 var = variables.VariableV1(0) 4897 assign = state_ops.assign(var, 1) 4898 t, = control_flow_ops.tuple( 4899 [constant_op.constant(0)], control_inputs=[assign]) 4900 4901 # Should trigger the assign. 4902 self.evaluate(t) 4903 4904 self.assertEqual(1, self.evaluate(var)) 4905 4906 4907class AssertTest(test.TestCase): 4908 4909 @test_util.run_deprecated_v1 4910 def testGuardedAssertDoesNotCopyWhenTrue(self): 4911 if test_util.is_gpu_available(): 4912 self.skipTest("b/128646478 fails in opensource") 4913 4914 with self.session() as sess: 4915 with ops.device(test.gpu_device_name()): 4916 value = constant_op.constant(1.0) 4917 with ops.device("/cpu:0"): 4918 true = constant_op.constant(True) 4919 guarded_assert = control_flow_ops.Assert(true, [value], name="guarded") 4920 unguarded_assert = gen_logging_ops._assert( 4921 true, [value], name="unguarded") 4922 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 4923 guarded_metadata = config_pb2.RunMetadata() 4924 sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata) 4925 unguarded_metadata = config_pb2.RunMetadata() 4926 sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata) 4927 guarded_nodestat_names = [ 4928 n.node_name 4929 for d in guarded_metadata.step_stats.dev_stats 4930 for n in d.node_stats 4931 ] 4932 unguarded_nodestat_names = [ 4933 n.node_name 4934 for d in unguarded_metadata.step_stats.dev_stats 4935 for n in d.node_stats 4936 ] 4937 guarded_memcpy_nodestat_names = [ 4938 n for n in guarded_nodestat_names if "MEMCPYDtoH" in n 4939 ] 4940 unguarded_memcpy_nodestat_names = [ 4941 n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n 4942 ] 4943 if "GPU" in [d.device_type for d in device_lib.list_local_devices()]: 4944 # A copy was performed for the unguarded assert 4945 self.assertLess(0, len(unguarded_memcpy_nodestat_names), 4946 str(unguarded_nodestat_names)) 4947 # No copy was performed for the guarded assert 4948 self.assertEqual([], guarded_memcpy_nodestat_names) 4949 4950 4951class WhileOpBenchmark(test.Benchmark): 4952 """Evaluate the performance of while_loop op.""" 4953 4954 def _getInitVariables(self): 4955 batch_size = 10 4956 image_size = 256 4957 kernel_size = 3 4958 depth = 16 4959 4960 init_step = constant_op.constant(-1) 4961 image = variable_scope.get_variable( 4962 "image", 4963 initializer=random_ops.random_normal( 4964 [batch_size, image_size, image_size, depth], 4965 dtype=dtypes.float32, 4966 stddev=1e-1)) 4967 kernel = variable_scope.get_variable( 4968 "weights", 4969 initializer=random_ops.truncated_normal( 4970 [kernel_size, kernel_size, depth, depth], 4971 dtype=dtypes.float32, 4972 stddev=1e-1)) 4973 return init_step, image, kernel 4974 4975 def _runOneBenchmark(self, 4976 default_device, 4977 num_iters=10, 4978 static_unroll=False, 4979 steps=10): 4980 """Evaluate the while loop performance. 4981 4982 Args: 4983 default_device: The default device to run all ops except the loop_body. 4984 loop_body is always run on GPU. 4985 num_iters: Number of iterations to run. 4986 static_unroll: If true, run unrolled version; otherwise, run while_loop. 4987 steps: Total number of repeated steps to run the loop. 4988 4989 Returns: 4990 The duration of the run in seconds. 4991 """ 4992 4993 def loop_body(i, x): 4994 with ops.device("/gpu:0"): 4995 # Always put loop body on GPU. 4996 nx = nn_ops.conv2d( 4997 input=x, 4998 filter=kernel, 4999 strides=[1, 1, 1, 1], 5000 padding="SAME", 5001 data_format="NHWC", 5002 name="conv2d") 5003 ni = math_ops.add(i, 1) 5004 return ni, nx 5005 5006 ops.reset_default_graph() 5007 with session.Session() as sess, ops.device(default_device): 5008 # Get the initial id i, input x, and kernel. 5009 i, x, kernel = self._getInitVariables() 5010 self.evaluate(variables.global_variables_initializer()) 5011 5012 if static_unroll: 5013 for _ in xrange(steps): 5014 i, x = loop_body(i, x) 5015 else: 5016 i, x = control_flow_ops.while_loop( 5017 lambda i, _: i < steps, 5018 loop_body, [i, x], 5019 parallel_iterations=steps, 5020 swap_memory=True) 5021 5022 r = math_ops.reduce_sum(x) 5023 dx, dk = gradients_impl.gradients(r, [x, kernel]) 5024 # Use group to avoid fetching back results. 5025 r = control_flow_ops.group(dx, dk) 5026 5027 for _ in xrange(3): 5028 # exclude warm up time 5029 self.evaluate(r) 5030 5031 start_time = time.time() 5032 for _ in xrange(num_iters): 5033 self.evaluate(r) 5034 return (time.time() - start_time) / num_iters 5035 5036 def benchmarkWhileOpCrossDevicePlacement(self): 5037 iters = 10 5038 # Run loop body on GPU, but other ops on CPU. 5039 duration = self._runOneBenchmark("cpu", iters, static_unroll=False) 5040 self.report_benchmark( 5041 name="while_op_cross_device", iters=iters, wall_time=duration) 5042 5043 def benchmarkWhileOpSameDevicePlacement(self): 5044 iters = 10 5045 # Run all ops on the same GPU device. 5046 duration = self._runOneBenchmark("gpu", iters, static_unroll=False) 5047 self.report_benchmark( 5048 name="while_op_same_device", iters=iters, wall_time=duration) 5049 5050 def benchmarkWhileOpUnrollCrossDevicePlacement(self): 5051 iters = 10 5052 # Run loop body on GPU, but other ops on CPU. 5053 duration = self._runOneBenchmark("cpu", iters, static_unroll=True) 5054 self.report_benchmark( 5055 name="unroll_cross_device_cpu", iters=iters, wall_time=duration) 5056 5057 def benchmarkWhileOpUnrollSameDevicePlacement(self): 5058 iters = 10 5059 # Run all ops on GPU. 5060 duration = self._runOneBenchmark("gpu", iters, static_unroll=True) 5061 self.report_benchmark( 5062 name="unroll_same_device", iters=iters, wall_time=duration) 5063 5064 5065@test_util.with_control_flow_v2 5066class EagerTest(test.TestCase): 5067 5068 def testCond(self): 5069 with context.eager_mode(): 5070 pred = math_ops.less(1, 2) 5071 fn1 = lambda: [constant_op.constant(10)] 5072 fn2 = lambda: [constant_op.constant(20)] 5073 r = control_flow_ops.cond(pred, fn1, fn2) 5074 5075 self.assertAllEqual(r.numpy(), 10) 5076 self.assertFalse(isinstance(r, list)) 5077 5078 # TODO(b/117279927): Re-enable once msan failure is fixed. 5079 def DISABLED_testCondInDefun(self): 5080 with context.eager_mode(): 5081 5082 @eager_function.defun 5083 def foo(pred): 5084 # TODO(b/111124878): this only needs to output one element. 5085 fn1 = lambda: (constant_op.constant(10), constant_op.constant(100)) 5086 fn2 = lambda: (constant_op.constant(20), constant_op.constant(200)) 5087 return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2) 5088 5089 r = foo(True) 5090 self.assertAllEqual(r[0].numpy(), 10) 5091 self.assertNotIsInstance(r, list) 5092 5093 r = foo(False) 5094 self.assertAllEqual(r[0].numpy(), 20) 5095 self.assertFalse(isinstance(r, list)) 5096 5097 def testWhileLoop(self): 5098 with context.eager_mode(): 5099 tensor = constant_op.constant([1, 2, 3, 4, 5]) 5100 self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50]) 5101 5102 def testWhileLoopWithMaxIterations(self): 5103 with context.eager_mode(): 5104 tensor = constant_op.constant([1, 2, 3, 4, 5]) 5105 self.assertAllEqual( 5106 isum(tensor, maximum_iterations=3).numpy(), 5107 [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3]) 5108 5109 @test_util.run_v1_only("b/120545219") 5110 def testWhileWithMaximumIterationsAndSingleArgument(self): 5111 with context.eager_mode(): 5112 tensor = constant_op.constant(0) 5113 r = control_flow_ops.while_loop( 5114 lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1) 5115 self.assertEqual(1, r.numpy()) 5116 5117 def testWithDependencies(self): 5118 with context.eager_mode(): 5119 t1 = constant_op.constant(1) 5120 t2 = constant_op.constant(2) 5121 t3 = control_flow_ops.with_dependencies(t1, t2) 5122 self.assertAllEqual(t2.numpy(), t3.numpy()) 5123 5124 def testTuple(self): 5125 with context.eager_mode(): 5126 t1 = constant_op.constant(1) 5127 t2 = constant_op.constant(2) 5128 tup1, tup2 = control_flow_ops.tuple([t1, t2]) 5129 self.assertAllEqual(t1.numpy(), tup1.numpy()) 5130 self.assertAllEqual(t2.numpy(), tup2.numpy()) 5131 5132 @test_util.run_v1_only("b/120545219") 5133 def testCase(self): 5134 with context.eager_mode(): 5135 x = constant_op.constant(1) 5136 y = constant_op.constant(2) 5137 z = constant_op.constant(3) 5138 f1 = lambda: constant_op.constant(17) 5139 f2 = lambda: constant_op.constant(23) 5140 f3 = lambda: constant_op.constant(-1) 5141 5142 r1 = control_flow_ops.case( 5143 [(x < y, f1), (x > z, f2)], default=f3, exclusive=True) 5144 self.assertAllEqual(r1.numpy(), 17) 5145 5146 5147if __name__ == "__main__": 5148 test.main() 5149