• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=g-long-lambda
17"""Tests for tensorflow.ops.control_flow_ops."""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import collections
24import math
25import re
26import sys
27import time
28
29from absl.testing import parameterized
30import numpy as np
31from six.moves import xrange  # pylint: disable=redefined-builtin
32
33from tensorflow.core.protobuf import config_pb2
34from tensorflow.python import tf2
35from tensorflow.python.client import device_lib
36from tensorflow.python.client import session
37from tensorflow.python.data.experimental.ops import cardinality
38from tensorflow.python.data.ops import dataset_ops
39from tensorflow.python.eager import context
40from tensorflow.python.eager import def_function
41from tensorflow.python.eager import function as eager_function
42from tensorflow.python.eager import wrap_function
43from tensorflow.python.framework import constant_op
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import errors_impl
46from tensorflow.python.framework import function
47from tensorflow.python.framework import ops
48from tensorflow.python.framework import sparse_tensor
49from tensorflow.python.framework import tensor_shape
50from tensorflow.python.framework import tensor_spec
51from tensorflow.python.framework import test_util
52from tensorflow.python.ops import array_ops
53from tensorflow.python.ops import control_flow_ops
54from tensorflow.python.ops import control_flow_util
55from tensorflow.python.ops import data_flow_ops
56from tensorflow.python.ops import functional_ops
57from tensorflow.python.ops import gen_array_ops
58from tensorflow.python.ops import gen_control_flow_ops
59from tensorflow.python.ops import gen_data_flow_ops
60from tensorflow.python.ops import gen_logging_ops
61from tensorflow.python.ops import gen_state_ops
62from tensorflow.python.ops import gradient_checker_v2
63from tensorflow.python.ops import gradients_impl
64from tensorflow.python.ops import init_ops
65from tensorflow.python.ops import linalg_ops
66from tensorflow.python.ops import logging_ops
67from tensorflow.python.ops import map_fn
68from tensorflow.python.ops import math_ops
69from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
70from tensorflow.python.ops import nn_ops
71from tensorflow.python.ops import random_ops
72from tensorflow.python.ops import resource_variable_ops
73from tensorflow.python.ops import script_ops
74from tensorflow.python.ops import sparse_ops
75from tensorflow.python.ops import state_ops
76from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
77from tensorflow.python.ops import tensor_array_ops
78from tensorflow.python.ops import variable_scope
79from tensorflow.python.ops import variables
80from tensorflow.python.ops import while_v2  # pylint: disable=unused-import
81# pylint: disable=unused-import
82from tensorflow.python.ops.ragged import ragged_factory_ops
83from tensorflow.python.ops.ragged import ragged_tensor
84import tensorflow.python.ops.tensor_array_grad
85# pylint: enable=unused-import
86from tensorflow.python.platform import test
87from tensorflow.python.training import adam
88from tensorflow.python.training import gradient_descent
89from tensorflow.python.util import nest
90
91
92def check_consumers(graph):
93  """Sanity check on the consumer list of the tensors."""
94
95  consumer_count = {}
96  for op in graph.get_operations():
97    for v in op.inputs:
98      cnt = consumer_count.get(v, 0)
99      consumer_count[v] = cnt + 1
100  for k, v in consumer_count.items():
101    if len(k.consumers()) != v:
102      return False
103  return True
104
105
106def all_fetchables():
107  tensor_names = []
108  graph = ops.get_default_graph()
109  for op in graph.get_operations():
110    for t in op.outputs:
111      if graph.is_fetchable(t):
112        tensor_names.append(t.name)
113  return tensor_names
114
115
116def all_feedables():
117  feedable_tensors = []
118  graph = ops.get_default_graph()
119  for op in graph.get_operations():
120    for t in op.inputs:
121      if graph.is_feedable(t):
122        feedable_tensors.append(t)
123  return feedable_tensors
124
125
126def opt_cfg(do_constant_folding=True):
127  return config_pb2.ConfigProto(
128      allow_soft_placement=True,
129      graph_options=config_pb2.GraphOptions(
130          optimizer_options=config_pb2.OptimizerOptions(
131              opt_level=config_pb2.OptimizerOptions.L1,
132              do_function_inlining=True,
133              do_constant_folding=do_constant_folding)))
134
135
136def isum(s, maximum_iterations=None):
137  i = constant_op.constant(0, name="i")
138  c = lambda i, s: math_ops.less(i, 10)
139  b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)]
140  _, r_s = control_flow_ops.while_loop(
141      c, b, [i, s], maximum_iterations=maximum_iterations)
142  return r_s
143
144
145def enqueue_print_op(s):
146  """Enqueues an op that prints a message to be captured in the test."""
147  return logging_ops.print_v2("ControlFlowOpsTest: " + s)
148
149
150def filter_test_messages(s):
151  """Returns a list of messages printed by enqueue_print_op."""
152  prefix = "ControlFlowOpsTest: "
153  return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)]
154
155
156def tf_function_in_tf2(f):
157  if tf2.enabled():
158    # In TF1 do not wrap with tf.function so that we can test the v1 control
159    # flow code path.
160    return def_function.function(f)
161  return f
162
163
164@test_util.with_control_flow_v2
165class ControlFlowTest(test.TestCase, parameterized.TestCase):
166
167  @test_util.run_v1_only("b/120545219")
168  def testRefIdentity(self):
169    with self.cached_session():
170      v = variables.VariableV1(7)
171
172      v = control_flow_ops._Identity(v)
173      op = state_ops.assign(v, 9)
174      v2 = control_flow_ops.with_dependencies([op], v)
175
176      self.assertTrue(isinstance(v2, ops.Tensor))
177      self.evaluate(variables.global_variables_initializer())
178      self.assertEqual(9, self.evaluate(v2))
179
180  @test_util.run_v1_only("b/120545219")
181  def testRefEnter(self):
182    with self.cached_session():
183      v = variables.VariableV1(7)
184
185      enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
186      nine = constant_op.constant(9)
187      enter_nine = gen_control_flow_ops.enter(nine, "foo_1")
188      op = state_ops.assign(enter_v, enter_nine)
189      v2 = control_flow_ops.with_dependencies([op], enter_v)
190      v3 = control_flow_ops.exit(v2)
191      self.evaluate(variables.global_variables_initializer())
192      self.assertEqual(9, self.evaluate(v3))
193
194  @test_util.run_v1_only("b/120545219")
195  def testRefSwitch(self):
196    with self.cached_session():
197      v = variables.VariableV1(7)
198
199      p = constant_op.constant(True)
200      v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p)  # pylint: disable=protected-access
201      v2 = state_ops.assign(v1[1], 9)
202      self.evaluate(variables.global_variables_initializer())
203      self.assertEqual(9, self.evaluate(v2))
204
205  def testEnterMulExit(self):
206    with self.cached_session():
207      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
208      enter_data = gen_control_flow_ops.enter(data, "foo_1", False)
209      five = constant_op.constant(5)
210      enter_five = gen_control_flow_ops.enter(five, "foo_1", False)
211      mul_op = math_ops.multiply(enter_data, enter_five)
212      exit_op = control_flow_ops.exit(mul_op)
213
214      result = self.evaluate(exit_op)
215    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
216
217  @test_util.run_deprecated_v1
218  def testEnterShapePropagation(self):
219    with self.cached_session():
220      v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
221
222      # If is_constant=True, the shape information should be propagated.
223      enter_v_constant = gen_control_flow_ops.enter(
224          v, "frame1", is_constant=True)
225      self.assertEqual(enter_v_constant.shape, [2])
226
227      # Otherwise, the shape should be unknown.
228      enter_v_non_constant = gen_control_flow_ops.enter(
229          v, "frame2", is_constant=False)
230      self.assertEqual(enter_v_non_constant.shape, None)
231
232  @test_util.run_v1_only("b/120545219")
233  def testSwitchMergeIndexedSlices(self):
234    with self.cached_session():
235      values = constant_op.constant([1, 2, 3, 4, 5, 6])
236      indices = constant_op.constant([0, 2, 4, 6, 8, 10])
237      data = ops.IndexedSlices(values, indices)
238      pred = ops.convert_to_tensor(True)
239      switch_op = control_flow_ops.switch(data, pred)
240      merge_op = control_flow_ops.merge(switch_op)[0]
241
242      val = merge_op.values
243      ind = merge_op.indices
244    self.assertAllEqual(np.arange(1, 7), val)
245    self.assertAllEqual(np.arange(0, 12, 2), ind)
246
247  @test_util.run_v1_only("b/120545219")
248  def testSwitchDeadBranch(self):
249    with self.cached_session():
250      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
251      ports = ops.convert_to_tensor(True, name="ports")
252      switch_op = control_flow_ops.switch(data, ports)
253      dead_branch = array_ops.identity(switch_op[0])
254
255      with self.assertRaisesWithPredicateMatch(
256          errors_impl.InvalidArgumentError,
257          lambda e: "Retval[0] does not have value" in str(e)):
258        self.evaluate(dead_branch)
259
260  @test_util.run_v1_only("b/120545219")
261  def testSwitchMergeLess(self):
262    with self.cached_session():
263      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
264      zero = ops.convert_to_tensor(0)
265      one = ops.convert_to_tensor(1)
266      less_op = math_ops.less(zero, one)
267      switch_op = control_flow_ops.switch(data, less_op)
268      merge_op = control_flow_ops.merge(switch_op)[0]
269
270      result = self.evaluate(merge_op)
271    self.assertAllEqual(np.arange(1, 7), result)
272
273  @test_util.run_v1_only("b/120545219")
274  def testSwitchMergeAddIdentity(self):
275    with self.cached_session():
276      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
277      ports = ops.convert_to_tensor(False, name="ports")
278      switch_op = control_flow_ops.switch(data, ports)
279      one = constant_op.constant(1)
280      add_op = math_ops.add(switch_op[0], one)
281      id_op = array_ops.identity(switch_op[1])
282      merge_op = control_flow_ops.merge([add_op, id_op])[0]
283
284      result = self.evaluate(merge_op)
285    self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
286
287  @test_util.run_v1_only("b/120545219")
288  def testSwitchMergeAddMul(self):
289    with self.cached_session():
290      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
291      ports = ops.convert_to_tensor(True, name="ports")
292      switch_op = control_flow_ops.switch(data, ports)
293      one = constant_op.constant(1)
294      add_op = math_ops.add(switch_op[0], one)
295      five = constant_op.constant(5)
296      mul_op = math_ops.multiply(switch_op[1], five)
297      merge_op = control_flow_ops.merge([add_op, mul_op])[0]
298
299      result = self.evaluate(merge_op)
300    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
301
302  @test_util.run_v1_only("b/120545219")
303  def testLoop_false(self):
304    with self.cached_session():
305      false = ops.convert_to_tensor(False)
306      n = constant_op.constant(10)
307
308      enter_false = gen_control_flow_ops.enter(false, "foo_1", False)
309      enter_n = gen_control_flow_ops.enter(n, "foo_1", False)
310
311      merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
312      switch_n = control_flow_ops.switch(merge_n, enter_false)
313      exit_n = control_flow_ops.exit(switch_n[0])
314      next_n = control_flow_ops.next_iteration(switch_n[0])
315      merge_n.op._update_input(1, next_n)
316
317      result = self.evaluate(exit_n)
318    self.assertAllEqual(10, result)
319
320  @test_util.run_deprecated_v1
321  def testLoop_1(self):
322    with self.cached_session():
323      zero = constant_op.constant(0)
324      one = constant_op.constant(1)
325      n = constant_op.constant(10)
326
327      enter_i = gen_control_flow_ops.enter(zero, "foo", False)
328      enter_one = gen_control_flow_ops.enter(one, "foo", True)
329      enter_n = gen_control_flow_ops.enter(n, "foo", True)
330
331      with ops.device(test.gpu_device_name()):
332        merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
333
334      less_op = math_ops.less(merge_i, enter_n)
335      cond_op = control_flow_ops.loop_cond(less_op)
336      switch_i = control_flow_ops.switch(merge_i, cond_op)
337
338      add_i = math_ops.add(switch_i[1], enter_one)
339
340      next_i = control_flow_ops.next_iteration(add_i)
341      merge_i.op._update_input(1, next_i)
342
343      exit_i = control_flow_ops.exit(switch_i[0])
344      result = self.evaluate(exit_i)
345    self.assertAllEqual(10, result)
346
347  @test_util.run_v1_only("b/120545219")
348  def testLoop_2(self):
349    with self.cached_session():
350      zero = constant_op.constant(0)
351      one = constant_op.constant(1)
352      n = constant_op.constant(10)
353
354      enter_i = gen_control_flow_ops.enter(zero, "foo", False)
355      enter_one = gen_control_flow_ops.enter(one, "foo", True)
356      enter_n = gen_control_flow_ops.enter(n, "foo", True)
357
358      merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
359
360      less_op = math_ops.less(merge_i, enter_n)
361      cond_op = control_flow_ops.loop_cond(less_op)
362      switch_i = control_flow_ops.switch(merge_i, cond_op)
363
364      add_i = math_ops.add(switch_i[1], enter_one)
365
366      with ops.device(test.gpu_device_name()):
367        next_i = control_flow_ops.next_iteration(add_i)
368      merge_i.op._update_input(1, next_i)
369
370      exit_i = control_flow_ops.exit(switch_i[0])
371      result = self.evaluate(exit_i)
372    self.assertAllEqual(10, result)
373
374  @test_util.run_v1_only("b/120545219")
375  def testDifferentFrame(self):
376    with self.cached_session():
377      data = array_ops.placeholder(dtypes.float32, shape=[])
378      enter_1 = gen_control_flow_ops.enter(data, "foo_1", False)
379      enter_2 = gen_control_flow_ops.enter(data, "foo_2", False)
380      res = math_ops.add(enter_1, enter_2)
381      with self.assertRaisesOpError("has inputs from different frames"):
382        res.eval(feed_dict={data: 1.0})
383
384  @test_util.run_deprecated_v1
385  def testCondBool(self):
386    values = constant_op.constant(10)
387    fn1 = lambda: math_ops.add(values, 1)
388    fn2 = lambda: math_ops.subtract(values, 1)
389    with self.assertRaisesRegex(TypeError, "must not be a Python bool"):
390      _ = control_flow_ops.cond(False, fn1, fn2)
391
392  @test_util.run_deprecated_v1
393  def testCondInt(self):
394    p = array_ops.placeholder(dtypes.bool, shape=[])
395    v = constant_op.constant(10)
396    fn1 = lambda: math_ops.add(v, 1)
397    fn2 = lambda: math_ops.subtract(v, 1)
398    y = control_flow_ops.cond(p, fn1, fn2)
399    grad = gradients_impl.gradients(y, [v])
400    self.assertAllEqual([None], grad)
401
402  def testCondOutputShape(self):
403    x = constant_op.constant(1.0)
404    b = control_flow_ops.cond(
405        constant_op.constant(True), lambda: math_ops.square(x),
406        lambda: math_ops.subtract(x, 1.))
407    self.assertEqual(b.shape, tensor_shape.TensorShape([]))
408
409  @test_util.run_v1_only("b/120545219")
410  def testFetchable(self):
411    with self.cached_session() as sess:
412      x = array_ops.placeholder(dtypes.float32)
413      control_flow_ops.cond(
414          constant_op.constant(True), lambda: x + 2, lambda: x + 0)
415      graph = ops.get_default_graph()
416      for op in graph.get_operations():
417        for t in op.inputs:
418          if graph.is_fetchable(t.op):
419            sess.run(t, feed_dict={x: 3})
420          else:
421            with self.assertRaisesRegex(ValueError,
422                                        "has been marked as not fetchable"):
423              sess.run(t, feed_dict={x: 3})
424
425  @test_util.disable_control_flow_v2("Not relevant")
426  @test_util.run_v1_only("b/120545219")
427  def testFeedable(self):
428    with self.cached_session() as sess:
429      c = constant_op.constant(2)
430      i0 = constant_op.constant(0)
431      r = control_flow_ops.while_loop(lambda i: i < 1000,
432                                      lambda i: math_ops.square(c) + i, [i0])
433      self.assertEqual(1000, r.eval(feed_dict={i0: 0}))
434      feedable_tensors = all_feedables()
435      for t in feedable_tensors:
436        sess.run(r, feed_dict={t: 3})
437      graph = ops.get_default_graph()
438      for op in graph.get_operations():
439        for t in op.inputs:
440          if t not in feedable_tensors and t.dtype is dtypes.int32:
441            with self.assertRaisesRegex(ValueError, "may not be fed"):
442              sess.run(r, feed_dict={t: 3})
443
444  @test_util.run_v1_only("b/120545219")
445  def testCondIndexedSlices(self):
446    with self.cached_session():
447      values = constant_op.constant([10])
448      indices = constant_op.constant([0])
449      x = ops.IndexedSlices(values, indices)
450      pred = math_ops.less(1, 2)
451      fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices)
452      fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), indices)
453      r = control_flow_ops.cond(pred, fn1, fn2)
454
455      val = r.values
456      ind = r.indices
457    self.assertAllEqual([11], val)
458    self.assertAllEqual([0], ind)
459
460  def testCondMismatchedIndexedSlices(self):
461    @def_function.function
462    def foo():
463      values = constant_op.constant([10])
464      indices = constant_op.constant([0])
465      x = ops.IndexedSlices(values, indices)
466      with self.assertRaisesRegex(TypeError,
467                                  "Cannot reconcile tf.cond 0-th outputs"):
468        control_flow_ops.cond(
469            constant_op.constant(True),
470            lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices),
471            lambda: math_ops.add(x.values, 1), indices)
472    foo()
473
474  def testCondSparseTensor(self):
475    values = constant_op.constant([2.0, 4.0], name="values")
476    indices = constant_op.constant([[0], [3]],
477                                   dtype=dtypes.int64,
478                                   name="indices")
479    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
480    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
481    pred = math_ops.less(1, 2)
482    fn1 = lambda: sparse_tensor.SparseTensor(
483        indices + 1, x.values + 1, dense_shape=shape)
484    fn2 = lambda: sparse_tensor.SparseTensor(
485        indices, x.values - 1, dense_shape=shape)
486    r = control_flow_ops.cond(pred, fn1, fn2)
487    self.assertAllEqual([3.0, 5.0], r.values)
488    self.assertAllEqual([[1], [4]], r.indices)
489    self.assertAllEqual(r.values.get_shape(), (2,))
490
491  def testCondRaggedTensor(self):
492    rt = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
493    pred = math_ops.less(1, 2)
494    fn1 = lambda: array_ops.concat([rt + 2, [[100]]], axis=0)
495    fn2 = lambda: rt[:2] - 2
496    result = control_flow_ops.cond(pred, fn1, fn2)
497    self.assertAllEqual([3, 4, 5, 6, 7, 8, 100], result.values)
498    self.assertAllEqual([0, 2, 3, 6, 7], result.row_splits)
499
500  @test_util.run_v1_only("b/120545219")
501  def testCondResource(self):
502
503    with self.cached_session():
504      rv = resource_variable_ops.ResourceVariable(True)
505      self.evaluate(variables.global_variables_initializer())
506      t = ops.convert_to_tensor(1.0)
507
508      def case():
509        assign = resource_variable_ops.assign_variable_op(rv.handle, False)
510        with ops.control_dependencies([assign]):
511          return array_ops.identity(t)
512
513      self.assertEqual(
514          1.0, self.evaluate(control_flow_ops.cond(rv, case, lambda: t)))
515
516  @test_util.run_deprecated_v1
517  def testCondResourceGradShape(self):
518    rv1 = resource_variable_ops.ResourceVariable([1.0, 2.0])
519    rv2 = resource_variable_ops.ResourceVariable([3.0, 4.0])
520    pred = constant_op.constant(True)
521    result = control_flow_ops.cond(pred, lambda: rv1, lambda: rv2)
522    grads = gradients_impl.gradients(result, [rv1, rv2])
523    self.assertAllEqual(grads[0].shape.as_list(), [2])
524    self.assertAllEqual(grads[1].shape.as_list(), [2])
525
526  @test_util.run_v1_only("b/120545219")
527  def testCondWithTensorArrayGrad(self):
528    with self.cached_session() as sess:
529      with ops.device(test.gpu_device_name()):
530        pred = array_ops.placeholder(dtypes.bool, [])
531        x = constant_op.constant([1.0, 2.0, 3.0])
532        y = control_flow_ops.cond(
533            pred, lambda: map_fn.map_fn(lambda z: z * 2.0, x),
534            lambda: constant_op.constant([1.0, 1.0, 1.0]))
535        g = gradients_impl.gradients(y, x)[0]
536
537      self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
538      self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
539
540  @test_util.run_v1_only("b/120545219")
541  def testCondIndexedSlicesDifferentTypes(self):
542    with self.cached_session():
543      values = constant_op.constant([10])
544      i_32 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int32)
545      i_64 = ops.convert_to_tensor([0], name="one", dtype=dtypes.int64)
546      x = ops.IndexedSlices(values, i_32)
547      pred = math_ops.less(1, 2)
548      fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), i_32)
549      fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), i_64)
550      r = control_flow_ops.cond(pred, fn1, fn2)
551
552      val = r.values
553      ind = r.indices
554    self.assertAllEqual([11], val)
555    self.assertAllEqual([0], ind)
556    self.assertTrue(ind.dtype == np.int64)
557
558  @test_util.run_v1_only("b/120545219")
559  def testCondColocation(self):
560    with self.session():
561      with ops.device("/cpu:0"):
562        v = variables.Variable(7.0)
563
564      x = constant_op.constant(10.0)
565      pred = math_ops.less(1.0, 2.0)
566      fn1 = lambda: math_ops.add(v, 1.0)
567      fn2 = lambda: math_ops.subtract(x, 1.0)
568      r = control_flow_ops.cond(pred, fn1, fn2)
569
570      for op in x.graph.get_operations():
571        if op.name == "cond/Add/Switch":
572          self.assertDeviceEqual(op.device, "/cpu:0")
573
574  def _testCond_1(self, use_gpu):
575    with self.cached_session(use_gpu=use_gpu):
576      x = constant_op.constant(10)
577      pred = math_ops.less(1, 2)
578      fn1 = lambda: math_ops.add(x, 1)
579      fn2 = lambda: math_ops.subtract(x, 1)
580      r = control_flow_ops.cond(pred, fn1, fn2)
581
582      result = self.evaluate(r)
583    self.assertAllEqual(11, result)
584
585  def testCond_1(self):
586
587    self._testCond_1(use_gpu=False)
588    # TODO(b/116526896): Enable GPU tests.
589    # self._testCond_1(use_gpu=True)
590
591  def testCond_2(self):
592
593    with self.cached_session():
594      x = constant_op.constant(10)
595      r = control_flow_ops.cond(
596          math_ops.less(1, 0), lambda: math_ops.add(x, 1),
597          lambda: math_ops.subtract(x, 1))
598      result = self.evaluate(r)
599    self.assertAllEqual(9, result)
600
601  def testCond_3(self):
602
603    with self.cached_session():
604      x = constant_op.constant(10)
605      pred = math_ops.less(1, 2)
606      fn1 = lambda: math_ops.add(x, 1)
607      fn2 = lambda: math_ops.subtract(x, 1)
608      fn3 = lambda: math_ops.add(control_flow_ops.cond(pred, fn1, fn2), 1)
609      r = control_flow_ops.cond(pred, fn3, fn2)
610
611      result = self.evaluate(r)
612    self.assertAllEqual(12, result)
613
614  @test_util.run_in_graph_and_eager_modes
615  def testCondPruning(self):
616    v1 = variables.Variable(7)
617    v2 = variables.Variable(7)
618    v3 = variables.Variable(7)
619
620    def f():
621      age = constant_op.constant(3)
622      max_age = constant_op.constant(2)
623      pred = math_ops.greater(age, max_age)
624      fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op]
625      fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op]
626      r = control_flow_ops.cond(pred, fn1, fn2)
627      self.assertEqual(len(r), 2)
628      return r[1]
629
630    f_defun = eager_function.defun(f)
631
632    if not context.executing_eagerly():
633      with self.cached_session():
634        self.evaluate(variables.global_variables_initializer())
635        result = self.evaluate(f())
636        self.assertEqual(True, result)
637        # Only second cond result was fetched, so v1 assign shouldn't run.
638        self.assertEqual(7, self.evaluate(v1))
639        self.assertEqual(2, self.evaluate(v2))
640        self.assertEqual(7, self.evaluate(v3))
641
642    result = f_defun()
643    self.assertEqual(True, self.evaluate(result))
644    # Both v1 and v2 branch assignments should be run in defun.
645    self.assertEqual(1, self.evaluate(v1))
646    self.assertEqual(2, self.evaluate(v2))
647    self.assertEqual(7, self.evaluate(v3))
648
649  def testCond_5(self):
650    with self.cached_session():
651      alive = constant_op.constant(True, name="alive")
652      count = constant_op.constant(0, name="count")
653
654      def body(i):
655        return control_flow_ops.cond(
656            alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)],
657            lambda: [alive, count])
658
659      for i in range(10):
660        alive, count = body(i)
661      self.assertAllEqual(4, self.evaluate(count))
662
663  @test_util.run_v1_only("b/120545219")
664  def testCond_6(self):
665    with self.cached_session():
666      v1 = variables.Variable([7])
667
668      age = constant_op.constant(3)
669      pred = math_ops.greater(age, 4)
670      fn1 = lambda: age
671      fn2 = lambda: v1
672      r = control_flow_ops.cond(pred, fn1, fn2)
673
674      self.evaluate(variables.global_variables_initializer())
675      result = self.evaluate(r)
676      self.assertAllEqual(np.array([7]), result)
677
678  def testCond_7(self):
679    with self.cached_session() as sess:
680      x = constant_op.constant(10)
681      y = constant_op.constant(200)
682      pred = math_ops.less(1, 2)
683      fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)]
684      fn2 = lambda: [y, y]
685      r = control_flow_ops.cond(pred, fn1, fn2)
686      self.assertAllEqual([11, 12], self.evaluate(r))
687
688  @parameterized.parameters(dtypes.float32, dtypes.float64)
689  @test_util.run_v1_only("Uses tf.gradients")
690  def testCondResourceGrad(self, dtype):
691    init = constant_op.constant([7.], dtype=dtype)
692    v1 = variables.Variable(init)
693
694    age = constant_op.constant(3., dtype=dtype)
695    pred = math_ops.greater(age, 4.)
696    fn1 = lambda: age
697    fn2 = lambda: v1
698    r = control_flow_ops.cond(pred, fn1, fn2)
699
700    grad = gradients_impl.gradients(r, v1)[0]
701    self.evaluate(variables.global_variables_initializer())
702    self.assertAllEqual(grad, [1.])
703
704  @test_util.run_gpu_only
705  @test_util.run_deprecated_v1
706  def testCond_Device(self):
707    x = constant_op.constant(-10.)
708
709    # True branch function defined outside of device scope
710    def true_fn():
711      return math_ops.exp(x)
712
713    with ops.device("CPU:0"):
714      r = control_flow_ops.cond(
715          constant_op.constant(True), true_fn, lambda: 0.)
716      self.assertIn("cpu", r.device.lower())
717
718    with session.Session() as sess:
719      options = config_pb2.RunOptions(output_partition_graphs=True)
720      run_metadata = config_pb2.RunMetadata()
721      sess.run(r, options=options, run_metadata=run_metadata)
722      # We expect that everything runs on CPU, even if GPU is available.
723      self.assertEqual(len(run_metadata.partition_graphs), 1)
724
725  def _count_matching_switch_nodes_on_device(self, run_metadata, device_str,
726                                             dtype):
727    # Returns the number of Switch nodes with type dtype placed on
728    # `device_str`.
729    device_graphs = [
730        g for g in run_metadata.partition_graphs
731        if device_str in g.node[0].device
732    ]
733    self.assertLen(device_graphs, 1)
734    switch_nodes = [
735        n for n in device_graphs[0].node
736        if n.op == "Switch" and n.attr["T"].type == dtype.as_datatype_enum
737    ]
738    return len(switch_nodes)
739
740  @test_util.run_gpu_only
741  @test_util.run_deprecated_v1
742  def testCondSwitchColocatedWithInputWhenInputExplicitlyPlacedOnCPU(self):
743    x = array_ops.placeholder(dtypes.float32)
744
745    # `arg` is used in the cond then branch so a Switch node is created for it.
746    # We test that the Switch node gets placed on the same device as `arg`.
747    # We force `arg` to be on CPU here.
748    with ops.device("CPU:0"):
749      arg = x + 10.
750
751    def true_fn():
752      with ops.device("CPU:0"):
753        return arg + 1
754
755    r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.)
756
757    with session.Session() as sess:
758      run_metadata = config_pb2.RunMetadata()
759      options = config_pb2.RunOptions(output_partition_graphs=True)
760      sess.run(
761          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
762      self.assertLen(run_metadata.partition_graphs, 2)
763      # Check that the Switch for `arg` gets placed on CPU.
764      self.assertEqual(
765          self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
766                                                      dtypes.float32), 1)
767      self.assertEqual(
768          self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
769                                                      dtypes.float32), 0)
770
771  @test_util.run_gpu_only
772  @test_util.run_deprecated_v1
773  def testCondSwitchColocatedWithInputWhenInputPlacedOnCPU(self):
774    x = array_ops.placeholder(dtypes.float32)
775
776    # `arg` is used in the cond then branch so a Switch node is created for it.
777    # We test that the Switch node gets placed on the same device as `arg`.
778    # Since arg is a dataset (and only has a CPU kernel), it gets placed on CPU
779    # by placer.
780    arg = dataset_ops.Dataset.range(8)
781
782    def true_fn():
783      return cardinality.cardinality(arg)
784
785    r = control_flow_ops.cond(
786        constant_op.constant(True), true_fn,
787        lambda: constant_op.constant(0, dtypes.int64))
788
789    with session.Session() as sess:
790      run_metadata = config_pb2.RunMetadata()
791      options = config_pb2.RunOptions(output_partition_graphs=True)
792      sess.run(
793          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
794      self.assertLen(run_metadata.partition_graphs, 2)
795      # Check that the Switch for `arg` gets placed on CPU.
796      self.assertEqual(
797          self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
798                                                      dtypes.variant), 1)
799      self.assertEqual(
800          self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
801                                                      dtypes.variant), 0)
802
803  @test_util.run_gpu_only
804  @test_util.run_deprecated_v1
805  def testCondSwitchColocatedWithInputWhenInputOnGPU(self):
806    x = array_ops.placeholder(dtypes.float32)
807
808    # `arg` is used in the cond then branch so a Switch node is created for it.
809    # We test that the Switch node gets placed on the same device as `arg`.
810    # Note: `arg` gets placed on GPU by default by the placer.
811    arg = x + 10.
812
813    def true_fn():
814      with ops.device("CPU:0"):
815        return arg + 1
816
817    r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.)
818
819    with session.Session() as sess:
820      run_metadata = config_pb2.RunMetadata()
821      options = config_pb2.RunOptions(output_partition_graphs=True)
822      sess.run(
823          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
824      self.assertEqual(len(run_metadata.partition_graphs), 2)
825      # Check that the Switch for `arg` gets placed on GPU.
826      self.assertEqual(
827          self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
828                                                      dtypes.float32), 0)
829      self.assertEqual(
830          self._count_matching_switch_nodes_on_device(run_metadata, "GPU",
831                                                      dtypes.float32), 1)
832
833  def testCondAccessTrueBranchTensorInFalseBranchRaises(self):
834
835    @def_function.function
836    def f():
837      c = constant_op.constant(1.)
838      inputs = {"c": c}
839
840      def true_fn(inputs):
841        inputs["c"] = array_ops.identity(inputs["c"], name="true_branch")
842        return inputs["c"]
843
844      def false_fn(inputs):
845        return array_ops.identity(inputs["c"])
846
847      pred = constant_op.constant(True)
848      return control_flow_ops.cond(
849          pred, lambda: true_fn(inputs), lambda: false_fn(inputs))
850
851    # This was needed for backwards compatibility with TF2 Estimators which
852    # rely on variable names.
853    prefix = "cond/" if context.executing_eagerly() else ""
854
855    with self.assertRaisesRegex(
856        ValueError,
857        "Tensor %strue_branch:0 in true_fn is accessed from false_fn." %
858        prefix):
859      f()
860
861  def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
862
863    @def_function.function
864    def f():
865      c = constant_op.constant(1.)
866      inputs = {"c": c}
867
868      def br1_fn(inputs):
869        inputs["c"] = array_ops.identity(inputs["c"], name="br1_identity")
870        return inputs["c"]
871
872      def br4_fn(inputs):
873        return array_ops.identity(inputs["c"])
874
875      def other_fn():
876        return array_ops.identity(c)
877
878      return control_flow_ops.switch_case(
879          constant_op.constant(2),
880          [other_fn, lambda: br1_fn(inputs), other_fn, other_fn,
881           lambda: br4_fn(inputs)])
882
883    # This was needed for backwards compatibility with TF2 Estimators which
884    # rely on variable names.
885    prefix = "switch_case/indexed_case/" if context.executing_eagerly() else ""
886    with self.assertRaisesRegex(
887        ValueError, "Tensor %sbr1_identity:0 in branch 1 is "
888        "accessed from branch 4." % prefix):
889      f()
890
891  def testCondListOutput(self):
892    with self.cached_session() as sess:
893      x = constant_op.constant(10)
894      y = constant_op.constant(200)
895      pred = math_ops.less(1, 2)
896      fn1 = lambda: [math_ops.add(x, y), math_ops.add(x, y)]
897      fn2 = lambda: [y, y]
898      r = control_flow_ops.cond(pred, fn1, fn2)
899      test_result = self.evaluate(r)
900      self.assertListEqual([210, 210], test_result)
901
902  def testTupleOutput(self):
903    with self.cached_session() as sess:
904      x = constant_op.constant(10)
905      y = constant_op.constant(200)
906      pred = math_ops.less(1, 2)
907      fn1 = lambda: (math_ops.add(x, y), math_ops.add(x, y))
908      fn2 = lambda: (y, y)
909      r = control_flow_ops.cond(pred, fn1, fn2)
910      test_result = self.evaluate(r)
911      self.assertTupleEqual((210, 210), test_result)
912
913  def testDictOutput(self):
914    with self.cached_session() as sess:
915      x = constant_op.constant(10)
916      y = constant_op.constant(200)
917      pred = math_ops.less(1, 2)
918      fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
919      fn2 = lambda: {"a": y, "b": y}
920      r = control_flow_ops.cond(pred, fn1, fn2)
921      test_result = self.evaluate(r)
922      self.assertDictEqual({"a": 210, "b": 210}, test_result)
923
924  def testEmbeddedListOutput(self):
925    x = constant_op.constant(10)
926    y = constant_op.constant(200)
927    pred = math_ops.less(1, 2)
928    fn1 = lambda: [[math_ops.add(x, y), math_ops.add(x, y)]]
929    fn2 = lambda: [[y, y]]
930    # Pass strict=True flag as cond_v2 allows for tensors to be
931    # in nested output structures as singletons
932    r = control_flow_ops.cond(pred, fn1, fn2, strict=True)
933    test_result = self.evaluate(r)
934    self.assertListEqual([[210, 210]], test_result)
935
936  def testEmbeddedTupleOutput(self):
937    with self.cached_session() as sess:
938      x = constant_op.constant(10)
939      y = constant_op.constant(200)
940      pred = math_ops.less(1, 2)
941      fn1 = lambda: ((math_ops.add(x, y), math_ops.add(x, y)))
942      fn2 = lambda: ((y, y))
943      r = control_flow_ops.cond(pred, fn1, fn2)
944      test_result = self.evaluate(r)
945      self.assertTupleEqual(((210, 210)), test_result)
946
947  def testEmbeddedDictOutput(self):
948    with self.cached_session() as sess:
949      x = constant_op.constant(10)
950      y = constant_op.constant(200)
951      pred = math_ops.less(1, 2)
952      fn1 = lambda: {"a": {"c": math_ops.add(x, y)},
953                     "b": {"d": math_ops.add(x, y)}}
954      fn2 = lambda: {"a": {"c": y},
955                     "b": {"d": y}}
956      r = control_flow_ops.cond(pred, fn1, fn2)
957      test_result = self.evaluate(r)
958      self.assertDictEqual({"a": {"c": 210}, "b": {"d": 210}}, test_result)
959
960  @test_util.run_v1_only("b/120545219")
961  def testCheckNestedOutputStruct(self):
962    with self.cached_session() as sess:
963      x = constant_op.constant(10)
964      y = constant_op.constant(200)
965      pred = math_ops.less(1, 2)
966      fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
967      fn2 = lambda: {"c": y, "d": y}
968      v1_msg = "The two structures don't have the same nested structure"
969      v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same "
970                "number, type, and overall structure of return values.")
971      with self.assertRaisesRegex(
972          TypeError if control_flow_util.ENABLE_CONTROL_FLOW_V2 else ValueError,
973          v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
974        control_flow_ops.cond(pred, fn1, fn2)
975
976  @test_util.run_deprecated_v1
977  def testCondRef(self):
978
979    with self.cached_session():
980      x = gen_state_ops.variable(
981          shape=[1],
982          dtype=dtypes.float32,
983          name="x",
984          container="",
985          shared_name="")
986      true_fn = lambda: x
987      false_fn = lambda: constant_op.constant([2.0])
988      r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
989      self.assertAllEqual([2.0], self.evaluate(r))
990
991  @test_util.run_v1_only("b/120545219")
992  def testCondWithControl(self):
993    with self.cached_session() as sess:
994      control_holder = array_ops.placeholder(dtypes.float32, shape=())
995      a = constant_op.constant(3)
996
997      def true_branch():
998        with ops.control_dependencies([control_holder]):
999          _ = a + 1
1000        return a + 2
1001
1002      r = control_flow_ops.cond(
1003          constant_op.constant(True), true_branch,
1004          lambda: constant_op.constant(1))
1005      result = sess.run(r, feed_dict={control_holder: 5.})
1006      self.assertEqual(5, result)
1007
1008  @test_util.run_v1_only("b/120545219")
1009  def testUninitializedRefIdentity(self):
1010    with self.cached_session() as sess:
1011      v = gen_state_ops.variable(
1012          shape=[1],
1013          dtype=dtypes.float32,
1014          name="v",
1015          container="",
1016          shared_name="")
1017      inited = state_ops.is_variable_initialized(v)
1018      v_f, v_t = control_flow_ops.ref_switch(v, inited)
1019      # Both v_f and v_t are uninitialized references. However, an actual use
1020      # of the reference in the 'true' branch in the 'tf.identity' op will
1021      # not 'fire' when v is uninitialized, so this is a valid construction.
1022      # This test tests that ref_identity allows uninitialized ref as input
1023      # so that this construction is allowed.
1024      v_f_op = gen_array_ops.ref_identity(v_f)
1025      v_t_op = gen_array_ops.ref_identity(v_t)
1026      with ops.control_dependencies([v_f_op]):
1027        assign_v = state_ops.assign(v, [1.0])
1028      with ops.control_dependencies([v_t_op]):
1029        orig_v = array_ops.identity(v)
1030      merged_op = control_flow_ops.merge([assign_v, orig_v])
1031      self.assertAllEqual([1.0], self.evaluate(merged_op.output))
1032
1033  def testCondSwitchIdentity(self):
1034    # Make sure the recv identity is not removed by optimization.
1035    with session.Session(config=opt_cfg()) as sess:
1036      pred = constant_op.constant(True)
1037
1038      def fn1():
1039        return control_flow_ops.no_op()
1040
1041      def fn2():
1042        return control_flow_ops.Assert(False, ["Wrong branch!!!"])
1043
1044      r = control_flow_ops.cond(pred, fn1, fn2)
1045      self.evaluate(r)
1046
1047  def testCondRecvIdentity(self):
1048    # Make sure the switch identity is not removed by optimization.
1049    with session.Session(config=opt_cfg()) as sess:
1050      with ops.device(test.gpu_device_name()):
1051        pred = constant_op.constant(True)
1052
1053      def fn1():
1054        return control_flow_ops.no_op()
1055
1056      def fn2():
1057        with ops.device("/cpu:0"):
1058          return control_flow_ops.Assert(False, ["Wrong branch!!!"])
1059
1060      r = control_flow_ops.cond(pred, fn1, fn2)
1061      self.evaluate(r)
1062
1063  @test_util.run_deprecated_v1
1064  @test_util.enable_control_flow_v2
1065  def testDisableLoweringSwitchMerge(self):
1066    if test_util.is_gpu_available():
1067      self.skipTest(
1068          "Single threaded executor doesn't support partitioned graphs.  "
1069          "Skipping GPU test.")
1070    # Make pred feedable to ensure we don't constant-fold it out.
1071    run_opts = config_pb2.RunOptions(
1072        trace_level=config_pb2.RunOptions.FULL_TRACE)
1073    run_metadata_no_lowering = config_pb2.RunMetadata()
1074    run_metadata_with_lowering = config_pb2.RunMetadata()
1075
1076    config = opt_cfg(do_constant_folding=False)
1077
1078    pred = array_ops.placeholder_with_default(
1079        constant_op.constant(True), shape=())
1080    r = control_flow_ops.cond(pred, lambda: True, lambda: False)
1081
1082    with session.Session(config=config) as sess:
1083      r_value = sess.run(
1084          r, options=run_opts, run_metadata=run_metadata_with_lowering)
1085      self.assertEqual(r_value, True)
1086
1087    # Use the single threaded executor, which disables control flow lowering.
1088    config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
1089    with session.Session(config=config) as sess:
1090      r_value = sess.run(
1091          r, options=run_opts, run_metadata=run_metadata_no_lowering)
1092      self.assertEqual(r_value, True)
1093
1094    self.assertTrue(  # pylint: disable=g-complex-comprehension
1095        any("switch" in ns.node_name
1096            for dev_stat in run_metadata_with_lowering.step_stats.dev_stats
1097            for ns in dev_stat.node_stats))
1098
1099    self.assertTrue(  # pylint: disable=g-complex-comprehension
1100        all("switch" not in ns.node_name
1101            for dev_stat in run_metadata_no_lowering.step_stats.dev_stats
1102            for ns in dev_stat.node_stats))
1103
1104  @test_util.run_v1_only("b/120545219")
1105  def testCondGrad_1(self):
1106    with self.cached_session():
1107      x = constant_op.constant(10.0, name="x")
1108      pred = math_ops.less(1, 2)
1109      fn1 = lambda: array_ops.identity(x)
1110      fn2 = lambda: array_ops.identity(x)
1111      r = control_flow_ops.cond(pred, fn1, fn2)
1112
1113      grad = gradients_impl.gradients(r, [x])[0]
1114      self.assertAllEqual(1.0, self.evaluate(grad))
1115
1116  @test_util.run_deprecated_v1
1117  @test_util.enable_control_flow_v2
1118  def testCondComputeGradAfterSessRunFails(self):
1119    with self.cached_session():
1120      x = constant_op.constant(10.0, name="x")
1121      pred = math_ops.less(1, 2)
1122
1123      def true_fn():
1124        a = x * x
1125        return a * a
1126
1127      def false_fn():
1128        return x * x
1129
1130      r = control_flow_ops.cond(pred, true_fn, false_fn)
1131
1132      self.assertAllEqual(r, 10000.)
1133      grad = gradients_impl.gradients(r, [x])[0]
1134      with self.assertRaisesRegex(
1135          errors_impl.InvalidArgumentError,
1136          r"Connecting to invalid output 1 of source node cond which has 1 "
1137          r"outputs. Try using "
1138          "tf.compat.v1.experimental.output_all_intermediates\(True\)."):
1139        self.evaluate(grad)
1140
1141  @test_util.run_deprecated_v1
1142  @test_util.enable_output_all_intermediates
1143  def testCondComputeGradAfterSessRun(self):
1144    with self.cached_session():
1145      x = constant_op.constant(10.0, name="x")
1146      pred = math_ops.less(1, 2)
1147
1148      def true_fn():
1149        a = x * x
1150        return a * a
1151
1152      def false_fn():
1153        return x * x
1154
1155      r = control_flow_ops.cond(pred, true_fn, false_fn)
1156
1157      self.assertAllEqual(r, 10000.)
1158      grad = gradients_impl.gradients(r, [x])[0]
1159      self.assertAllEqual(grad, 4000.)
1160
1161  @test_util.run_deprecated_v1
1162  @test_util.enable_output_all_intermediates
1163  def testNestedCondComputeGradAfterSessRun(self):
1164    with self.cached_session():
1165      x = constant_op.constant(10.0, name="x")
1166      pred = math_ops.less(1, 2)
1167
1168      def true_fn():
1169
1170        def inner_true_fn():
1171          a = x * x
1172          return a * a
1173
1174        def inner_false_fn():
1175          return x * x
1176
1177        return control_flow_ops.cond(
1178            constant_op.constant(True), inner_true_fn, inner_false_fn)
1179
1180      def false_fn():
1181        return x * x
1182
1183      r = control_flow_ops.cond(pred, true_fn, false_fn)
1184
1185      self.assertAllEqual(r, 10000.)
1186      grad = gradients_impl.gradients(r, [x])[0]
1187      self.assertAllEqual(grad, 4000.)
1188
1189  @test_util.run_deprecated_v1
1190  def testCondGrad_2(self):
1191    with self.cached_session():
1192      c = array_ops.placeholder(dtypes.int32, shape=[])
1193      x = constant_op.constant(10.0)
1194      pred = math_ops.less(c, 2)
1195      fn1 = lambda: math_ops.multiply(x, 42.0)
1196      fn2 = lambda: math_ops.multiply(x, 3.0)
1197      r = control_flow_ops.cond(pred, fn1, fn2)
1198
1199      grad = gradients_impl.gradients(r, [x])[0]
1200      self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
1201      self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
1202
1203  @test_util.disable_control_flow_v2(
1204      "b/110550782 (gradient w.r.t external variable)")
1205  @test_util.run_deprecated_v1
1206  def testCondGrad_3(self):
1207    with self.cached_session():
1208      c = array_ops.placeholder(dtypes.int32, shape=[])
1209      ox = constant_op.constant(10.0)
1210      pred = math_ops.less(c, 2)
1211
1212      def fn1(x):
1213        m = x * x
1214        return gradients_impl.gradients(m, [ox])[0]
1215
1216      fn2 = lambda: math_ops.multiply(ox, 3.0)
1217      y = math_ops.multiply(7.0, ox)
1218      r = control_flow_ops.cond(pred, lambda: fn1(y), fn2)
1219
1220      self.assertAllEqual(980.0, r.eval(feed_dict={c: 1}))
1221      self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
1222
1223  @test_util.run_deprecated_v1
1224  def testCondGradMultiDevice(self):
1225    config = config_pb2.ConfigProto(device_count={"CPU": 2},
1226                                    allow_soft_placement=True)
1227    with self.cached_session(config=config) as sess:
1228      pred = array_ops.placeholder(dtypes.bool, [])
1229      x = array_ops.placeholder(dtypes.float32)
1230      y = array_ops.placeholder(dtypes.float32)
1231
1232      with ops.device("/cpu:0"):
1233        z = control_flow_ops.cond(pred, lambda: x * y * 2.0, lambda: 2.0)
1234
1235      with ops.device("/cpu:1"):
1236        grad = gradients_impl.gradients(z, x)[0]
1237
1238      with ops.device("/cpu:0"):
1239        grad_grad = gradients_impl.gradients(grad, x)[0]
1240
1241      self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0)
1242      self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
1243
1244      # v1 control flow gets None second derivative for some reason.
1245      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
1246        self.assertIsNone(grad_grad)
1247        return
1248
1249      self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0)
1250      self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
1251
1252  @test_util.run_v1_only("b/120545219")
1253  def testNestedCond_Simple(self):
1254    with self.cached_session():
1255      x = constant_op.constant(0., name="X")
1256      y = control_flow_ops.cond(
1257          constant_op.constant(True), lambda: x,
1258          lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
1259      result = gradients_impl.gradients(y, x)[0]
1260      self.assertEqual(1.0, self.evaluate(result))
1261
1262      z = control_flow_ops.cond(
1263          constant_op.constant(False), lambda: x,
1264          lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
1265      result = gradients_impl.gradients(z, x)[0]
1266      self.assertEqual(1.0, self.evaluate(result))
1267
1268  @test_util.run_v1_only("b/120545219")
1269  def testCondGrad_Gather(self):
1270    with self.cached_session() as sess:
1271      v1 = variables.Variable([1.0, 42.0])
1272      c = array_ops.placeholder(dtypes.int32, shape=[])
1273      pred = math_ops.less(c, 2)
1274      fn1 = lambda: array_ops.identity(v1)
1275      fn2 = lambda: array_ops.gather(v1, [1, 1])
1276      r = control_flow_ops.cond(pred, fn1, fn2)
1277      # The following `grad` is a Tensor since it is the aggregation of an
1278      # IndexedSlice and a Tensor. It is an `IndexedSlices` with control flow
1279      # v2.
1280      grad = gradients_impl.gradients(r, [v1])[0]
1281      self.evaluate(variables.global_variables_initializer())
1282
1283      if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1284        self.assertIsInstance(grad, ops.IndexedSlices)
1285
1286      grad_value = sess.run(grad, feed_dict={c: 1})
1287      self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [1.0, 1.0])
1288
1289      grad_value = sess.run(grad, feed_dict={c: 3})
1290      self.assertAllEqual(gradient_checker_v2._to_numpy(grad_value), [0.0, 2.0])
1291
1292  @test_util.run_deprecated_v1
1293  def testCondGrad_ResourceVarSparseRead(self):
1294    # NOTE(skyewm): this test is interesting because the
1295    # ResourceVariable.sparse_read gradient function returns IndexedSlices.
1296    var = resource_variable_ops.ResourceVariable(
1297        np.ones((4, 2), dtype=np.float32))
1298    x = constant_op.constant(1.0)
1299    r = control_flow_ops.cond(
1300        constant_op.constant(True),
1301        lambda: x * math_ops.reduce_sum(var.sparse_read([1, 2])),
1302        lambda: constant_op.constant(np.zeros((2, 3)),
1303                                     dtype=dtypes.float32))
1304    grad = gradients_impl.gradients(r, var)[0]
1305
1306    self.evaluate(variables.global_variables_initializer())
1307    grad_val = self.evaluate(grad)
1308    self.assertIsInstance(grad_val, ops.IndexedSlicesValue)
1309    self.assertAllEqual(gradient_checker_v2._to_numpy(grad_val), [[0., 0.],
1310                                                                  [1., 1.],
1311                                                                  [1., 1.],
1312                                                                  [0., 0.]])
1313
1314  def testCondGrad_MultiGather(self):
1315    # NOTE(skyewm): this test is interesting because the array_ops.gather and
1316    # ResourceVariable.sparse_read gradient functions returns IndexedSlices.
1317    var = resource_variable_ops.ResourceVariable(
1318        np.ones((4, 2), dtype=np.float32))
1319    x1 = constant_op.constant(np.ones((3, 3), dtype=np.float32))
1320    x2 = constant_op.constant(2.0)
1321
1322    def true_fn():
1323      y1 = var.sparse_read([1, 2])
1324      y2 = array_ops.gather(x1, [2]) * x2
1325      y3 = x2 * [1., 1., 1.]
1326      return y1, y2, y3
1327
1328    def false_fn():
1329      y1 = np.zeros((2, 2), dtype=np.float32)
1330      y2 = array_ops.gather(x1, [2]) * x2
1331      y3 = array_ops.gather(x1, [2])
1332      return y1, y2, y3
1333
1334    @def_function.function
1335    def foo():
1336      r = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn)
1337      return gradients_impl.gradients(r, [var, x1, x2])
1338
1339    grad = foo()
1340    self.evaluate(variables.global_variables_initializer())
1341    var_grad, x1_grad, x2_grad = self.evaluate(grad)
1342    self.assertIsInstance(var_grad, ops.IndexedSlicesValue)
1343    self.assertAllEqual(gradient_checker_v2._to_numpy(var_grad), [[0., 0.],
1344                                                                  [1., 1.],
1345                                                                  [1., 1.],
1346                                                                  [0., 0]])
1347    self.assertIsInstance(x1_grad, ops.IndexedSlicesValue)
1348    self.assertAllEqual(gradient_checker_v2._to_numpy(x1_grad), [[0., 0., 0.],
1349                                                                 [0., 0., 0.],
1350                                                                 [2., 2., 2.]])
1351    self.assertIsInstance(x1_grad, ops.IndexedSlicesValue)
1352    self.assertEqual(gradient_checker_v2._to_numpy(x2_grad), 6.)
1353
1354  @test_util.run_v1_only("b/120545219")
1355  def testCondPredicateTensor(self):
1356    """Regression test for lowering predicate from non-first output of an op."""
1357
1358    @eager_function.defun
1359    def foo():
1360      return constant_op.constant("foo"), constant_op.constant(True)
1361
1362    r = control_flow_ops.cond(foo()[1], lambda: 1.0, lambda: 2.0)
1363    self.assertEqual(self.evaluate(r), 1.0)
1364
1365  @test_util.run_v1_only("Tests Session.run() pruning logic.")
1366  def testCondFeedConstantPredicate(self):
1367    with self.cached_session() as sess:
1368      value = constant_op.constant(37.0)
1369      predicate = constant_op.constant(True)
1370      cond_output = control_flow_ops.cond(
1371          predicate, lambda: constant_op.constant(0.0), lambda: value)
1372      result = array_ops.identity(cond_output)
1373      self.assertEqual(37.0, sess.run(result, feed_dict={predicate: False}))
1374      self.assertEqual(0.0, sess.run(result, feed_dict={predicate: True}))
1375      self.assertEqual(0.0, sess.run(result))
1376
1377  @test_util.run_v1_only("Tests Session.run() pruning logic.")
1378  def testCondFeedPlaceholderWithDefaultPredicate(self):
1379    with self.cached_session() as sess:
1380      value = constant_op.constant(37.0)
1381      predicate = array_ops.placeholder_with_default(
1382          constant_op.constant(True), [])
1383      cond_output = control_flow_ops.cond(
1384          predicate, lambda: constant_op.constant(0.0), lambda: value)
1385      result = array_ops.identity(cond_output)
1386      self.assertAllEqual(37.0, sess.run(result, feed_dict={predicate: False}))
1387      self.assertAllEqual(0.0, sess.run(result, feed_dict={predicate: True}))
1388      self.assertAllEqual(0.0, sess.run(result))
1389
1390  @test_util.run_in_graph_and_eager_modes
1391  def testCondAutoControlDeps(self):
1392    if test_util.is_gpu_available():
1393      self.skipTest("b/128676188 causes OOM on opensource gpu tests")
1394
1395    print_prefix = "testCondAutoControlDeps: "
1396
1397    def branch_fn():
1398      enqueue_print_op("A")
1399      enqueue_print_op("B")
1400      with ops.control_dependencies([enqueue_print_op("C")]):
1401        return constant_op.constant(10)
1402
1403    def build_cond():
1404      return control_flow_ops.cond(
1405          constant_op.constant(True), branch_fn, lambda: 0)
1406
1407    def build_nested_cond():
1408      return control_flow_ops.cond(
1409          constant_op.constant(True), build_cond, lambda: 0)
1410
1411    # In v1 graph mode, pruning should make only "C" print.
1412    if not context.executing_eagerly():
1413      with self.cached_session():
1414        with self.captureWritesToStream(sys.stderr) as printed:
1415          self.assertEqual(self.evaluate(build_cond()), 10)
1416        self.assertEqual(["C"], filter_test_messages(printed.contents()))
1417
1418        with self.captureWritesToStream(sys.stderr) as printed:
1419          self.assertEqual(self.evaluate(build_nested_cond()), 10)
1420        self.assertEqual(["C"], filter_test_messages(printed.contents()))
1421
1422    # In defuns, all prints should execute in program order.
1423    # This doesn't work with legacy control flow.
1424    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1425
1426      @eager_function.defun
1427      def cond():
1428        return build_cond()
1429
1430      with self.captureWritesToStream(sys.stderr) as printed:
1431        self.assertEqual(self.evaluate(cond()), 10)
1432      self.assertEqual(["A", "B", "C"],
1433                       filter_test_messages(printed.contents()))
1434
1435      @eager_function.defun
1436      def nested_cond():
1437        return build_nested_cond()
1438
1439      with self.captureWritesToStream(sys.stderr) as printed:
1440        self.assertEqual(self.evaluate(nested_cond()), 10)
1441      self.assertEqual(["A", "B", "C"],
1442                       filter_test_messages(printed.contents()))
1443
1444    # wrap_function should prune.
1445    def pruned_cond():
1446      return build_cond()
1447    pruned_cond = wrap_function.wrap_function(pruned_cond, [])
1448
1449    with self.captureWritesToStream(sys.stderr) as printed:
1450      self.assertEqual(self.evaluate(pruned_cond()), 10)
1451    self.assertEqual(["C"], filter_test_messages(printed.contents()))
1452
1453    def pruned_nested_cond():
1454      return build_nested_cond()
1455    pruned_nested_cond = wrap_function.wrap_function(pruned_nested_cond, [])
1456
1457    with self.captureWritesToStream(sys.stderr) as printed:
1458      self.assertEqual(self.evaluate(pruned_nested_cond()), 10)
1459    self.assertEqual(["C"], filter_test_messages(printed.contents()))
1460
1461
1462  @test_util.run_in_graph_and_eager_modes
1463  @test_util.disable_tfrt("b/179459136")
1464  def testWhileAutoControlDeps(self):
1465    # Legacy while_loop fails this test because it produces deprecation notices
1466    # in stderr.
1467    if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return
1468
1469    def cond(i, unused_x):
1470      enqueue_print_op("A")
1471      return i < 2
1472
1473    def body(i, x):
1474      enqueue_print_op("B")
1475      with ops.control_dependencies([enqueue_print_op("C")]):
1476        x = array_ops.identity(x)
1477      with ops.control_dependencies([enqueue_print_op("D")]):
1478        return i + 1, x
1479
1480    def build_while():
1481      return control_flow_ops.while_loop(
1482          cond, body, [constant_op.constant(0), constant_op.constant(0)])
1483
1484    def build_nested_while():
1485      return control_flow_ops.cond(
1486          constant_op.constant(True), build_while, lambda: [0, 0])
1487
1488    # In v1 graph mode, pruning should make only "D" print.
1489    if not context.executing_eagerly():
1490      with self.cached_session():
1491        with self.captureWritesToStream(sys.stderr) as printed:
1492          self.assertEqual(self.evaluate(build_while()[0]), 2)
1493        self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
1494
1495        with self.captureWritesToStream(sys.stderr) as printed:
1496          self.assertEqual(self.evaluate(build_nested_while()[0]), 2)
1497        self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
1498
1499    # In defuns, all prints should execute in program order.
1500    @eager_function.defun
1501    def while_loop():
1502      return build_while()[0]
1503
1504    with self.captureWritesToStream(sys.stderr) as printed:
1505      self.assertEqual(self.evaluate(while_loop()), 2)
1506    self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"],
1507                     filter_test_messages(printed.contents()))
1508
1509    @eager_function.defun
1510    def nested_while_loop():
1511      return build_nested_while()[0]
1512
1513    with self.captureWritesToStream(sys.stderr) as printed:
1514      self.assertEqual(self.evaluate(nested_while_loop()), 2)
1515    self.assertEqual(["A", "B", "C", "D", "A", "B", "C", "D", "A"],
1516                     filter_test_messages(printed.contents()))
1517
1518    # wrap_function should prune.
1519    def pruned_while():
1520      return build_while()[0]
1521    pruned_while = wrap_function.wrap_function(pruned_while, [])
1522
1523    with self.captureWritesToStream(sys.stderr) as printed:
1524      self.assertEqual(self.evaluate(pruned_while()), 2)
1525    self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
1526
1527    def pruned_nested_while():
1528      return build_nested_while()[0]
1529    pruned_nested_while = wrap_function.wrap_function(pruned_nested_while, [])
1530
1531    with self.captureWritesToStream(sys.stderr) as printed:
1532      self.assertEqual(self.evaluate(pruned_nested_while()), 2)
1533    self.assertEqual(["D", "D"], filter_test_messages(printed.contents()))
1534
1535  # Microbenchmark: 256,000 iterations/s.
1536  def testWhile_1(self):
1537    with self.cached_session():
1538      n = constant_op.constant(0)
1539      c = lambda x: math_ops.less(x, 10000)
1540      b = lambda x: math_ops.add(x, 1)
1541      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
1542      self.assertEqual(10000, self.evaluate(r))
1543
1544  @test_util.run_v1_only("b/120545219")
1545  def testWhileExternalControlDependencies(self):
1546    with self.cached_session():
1547      v = variables.Variable(0.0)
1548      self.evaluate(v.initializer)
1549      increment = v.assign_add(1.0).read_value()
1550
1551      def body_fn(i):
1552        with ops.control_dependencies([increment]):
1553          return i + 1
1554
1555      result = control_flow_ops.while_loop(cond=lambda i: i < 2,
1556                                           body=body_fn, loop_vars=[1])
1557      self.assertAllEqual(result, 2)
1558      self.assertAllEqual(v.read_value(), 1.0)
1559
1560  @test_util.run_v1_only("b/120545219")
1561  def testWhileExternalControlDependenciesNoInput(self):
1562    with self.cached_session():
1563      v = variables.Variable(0.0)
1564      self.evaluate(v.initializer)
1565      # TODO(apassos): figure out why the reading is necessary here.
1566      increment = v.assign_add(1.0).read_value()
1567
1568      def body_fn(unused_i):
1569        with ops.control_dependencies([increment]):
1570          return constant_op.constant(5, name="five")
1571
1572      result = control_flow_ops.while_loop(cond=lambda i: i < 5,
1573                                           body=body_fn, loop_vars=[0])
1574      self.evaluate(result)
1575      self.assertAllEqual(self.evaluate(v), 1.0)
1576
1577  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
1578  @test_util.run_v1_only("b/120545219")
1579  def testWhileWithRefs_1(self):
1580    with self.cached_session() as sess:
1581      x = variables.VariableV1(0)._ref()  # pylint: disable=protected-access
1582      i = constant_op.constant(0)
1583      c = lambda i, x: math_ops.less(i, 100)
1584
1585      self.assertEqual(x.dtype, dtypes.int32_ref)
1586
1587      def b(i, x):
1588        self.assertEqual(x.dtype, dtypes.int32_ref)
1589        return (i + 1, gen_array_ops.ref_identity(x))
1590
1591      r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5)
1592
1593      self.evaluate(variables.global_variables_initializer())
1594
1595      self.assertEqual(r[0].dtype, dtypes.int32)
1596      self.assertEqual(r[1].dtype, dtypes.int32_ref)
1597
1598      value_i, value_x = self.evaluate(r)
1599
1600    self.assertEqual(100, value_i)
1601    self.assertEqual(0, value_x)
1602
1603  def testWhile_2(self):
1604    with self.cached_session():
1605      s = constant_op.constant(0)
1606      r = isum(s)
1607      self.assertAllEqual(45, self.evaluate(r))
1608
1609  def testWhileWithMaximumIterations(self):
1610    with self.cached_session():
1611      s = constant_op.constant([1, 2, 3, 4, 5])
1612      r = isum(s, maximum_iterations=3)
1613      self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], self.evaluate(r))
1614
1615  @test_util.run_v1_only("b/120545219")
1616  def testWhileWithMaximumIterationsAndSingleArgument(self):
1617    with self.cached_session():
1618      r = control_flow_ops.while_loop(
1619          lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
1620      self.assertEqual(1, self.evaluate(r))
1621
1622  @test_util.run_v1_only("b/120545219")
1623  def testXLAGradInLoop(self):
1624    # We have an optimization that moves certain reduction ops, this test makes
1625    # sure we don't do that for XLA ops.
1626
1627    # Use dynamic inputs, which triggers the creation of "BroadcastGradientArgs"
1628    # and "Shape" op.
1629    input1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
1630    input2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
1631    def cond(i1, i2):
1632      return False
1633
1634    def body(i1, i2):
1635      return math_ops.add(i1, i2), math_ops.add(i1, i2)
1636
1637    xla_context = control_flow_ops.XLAControlFlowContext()
1638    xla_context.Enter()
1639
1640    out1, _ = control_flow_ops.while_loop(
1641        cond, body, (input1, input2), maximum_iterations=2)
1642    g = gradients_impl.gradients(out1, [input1])
1643
1644    for op in out1.graph.get_operations():
1645      # Test that the "Shape" is directly passed to BroadcastGradientArgs
1646      # instead of being pushed to the stack.
1647      if op.type == "BroadcastGradientArgs":
1648        self.assertEqual(op.inputs[0].op.type, "Shape")
1649        self.assertEqual(op.inputs[1].op.type, "Shape")
1650    xla_context.Exit()
1651
1652
1653  @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
1654  @test_util.run_v1_only("b/120545219")
1655  def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
1656    v = constant_op.constant(1.0)
1657
1658    def training_loop_with_gradient(i):
1659      out = control_flow_ops.while_loop(
1660          lambda i_, _: i_ < 3,
1661          lambda i_, j: [i_ + 1, j * v], [0, 1.0],
1662          maximum_iterations=i)
1663      g = gradients_impl.gradients(out, v)
1664      with ops.control_dependencies(g):
1665        return i + 1
1666
1667    xla_context = control_flow_ops.XLAControlFlowContext()
1668    xla_context.Enter()
1669    # Create training loop, ensure we can call gradient() of
1670    # while_loop inside the training loop.
1671    loop = control_flow_ops.while_loop(lambda i: i < 3,
1672                                       training_loop_with_gradient, [0])
1673    xla_context.Exit()
1674
1675    loop_execute = array_ops.identity(loop)  # Because loop is not fetchable.
1676
1677    # Should execute without issue.
1678    self.assertEqual(3, self.evaluate(loop_execute))
1679
1680  @test_util.run_v1_only("b/120545219")
1681  def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
1682    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1683      self.skipTest("WhileV2 does lazy evaluation of maximum_iterations")
1684    v = constant_op.constant(1.0)
1685
1686    def inner_body(i, x):
1687      out = control_flow_ops.while_loop(
1688          lambda i, _: i < 3,
1689          lambda i, j: [i + 1, j * v], [0, x],
1690          maximum_iterations=i)
1691      return out
1692
1693    def create_while_loop(maximum_iterations=None):
1694      return control_flow_ops.while_loop(
1695          lambda i, _: i < 3,
1696          inner_body, [0, 1.0],
1697          maximum_iterations=maximum_iterations)
1698
1699    loop_no_xla = create_while_loop(maximum_iterations=5)
1700    # maximum_iterations is fine outside of an XLA scope
1701    gs = gradients_impl.gradients(loop_no_xla, v)
1702    self.evaluate(gs)  # This should execute without error.
1703
1704    xla_context = control_flow_ops.XLAControlFlowContext()
1705    xla_context.Enter()
1706    loop_no_maxiter = create_while_loop()
1707    loop_with_maxiter = create_while_loop(maximum_iterations=2)
1708    xla_context.Exit()
1709
1710    with self.assertRaisesRegex(
1711        ValueError,
1712        r"Cannot create a gradient accumulator for tensor '.+' inside "
1713        r"XLA while_loop because maximum_iterations was not passed to "
1714        r"the tf.while_loop call \('.+'\)."):
1715      _ = gradients_impl.gradients(loop_no_maxiter, v)
1716
1717    with self.assertRaisesRegex(
1718        ValueError,
1719        r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
1720        r"while_loop. maximum_iterations tensor '.+' for while_loop context "
1721        r"'.+' must be statically known \(e.g. a constant value or known "
1722        r"shape dimension\), or be defined at or outside the while loop "
1723        r"context '.*' \(currently defined in '.*'\)"):
1724      _ = gradients_impl.gradients(loop_with_maxiter, v)
1725
1726  @test_util.run_v1_only("b/120545219")
1727  def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
1728    v = constant_op.constant(1.0)
1729
1730    def create_while_loop():
1731      max_iter_holder = []
1732
1733      def create_mi():
1734        max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=()))
1735        return 1.0
1736
1737      _ = control_flow_ops.cond(
1738          constant_op.constant(True), create_mi, create_mi)
1739
1740      return control_flow_ops.while_loop(
1741          lambda i, _: i < 3,
1742          lambda i, x: (i + 1, v * x), (0, 1.0),
1743          maximum_iterations=max_iter_holder[0])
1744
1745    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1746      xla_context = control_flow_ops.XLAControlFlowContext()
1747      xla_context.Enter()
1748      with self.assertRaisesRegex(ValueError, r"must be from the same graph.*"):
1749        loop = create_while_loop()
1750      xla_context.Exit()
1751    else:
1752      xla_context = control_flow_ops.XLAControlFlowContext()
1753      xla_context.Enter()
1754      loop = create_while_loop()
1755      xla_context.Exit()
1756      with self.assertRaisesRegex(
1757          ValueError,
1758          r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
1759          r"while_loop. maximum_iterations tensor '.*Placeholder:0' for "
1760          r"while_loop context '.+' must be statically known \(e.g. a constant "
1761          r"value or known shape dimension\), or be defined at or outside the "
1762          r"while loop context '' \(currently defined in 'cond/.+'\)"):
1763        _ = gradients_impl.gradients(loop, v)
1764
1765  @test_util.run_v1_only("b/120545219")
1766  def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
1767    if test_util.is_gpu_available():
1768      self.skipTest("b/128646372, b/128645947 fails in opensource build")
1769
1770    v = constant_op.constant(1.0)
1771
1772    p = array_ops.placeholder(dtype=dtypes.int32)
1773
1774    def mid_body_builder(iterations):
1775
1776      def mid_body(i, x):
1777        r = control_flow_ops.while_loop(
1778            lambda *_: True,
1779            lambda i, x: (i + 1, v * x), (0, x),
1780            maximum_iterations=iterations,
1781            name="inner")
1782        return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
1783
1784      return mid_body
1785
1786    def outer_body(i, x):
1787      iterations = array_ops.size(p, name="iterations")
1788      return (i + 1, x + control_flow_ops.while_loop(
1789          lambda *_: True,
1790          mid_body_builder(iterations), (0, x),
1791          maximum_iterations=iterations,
1792          name="mid")[1])
1793
1794    def create_while_loop():
1795      with ops.device("/cpu:0"):
1796        r = control_flow_ops.while_loop(
1797            lambda *_: True,
1798            outer_body, (0, 1.0),
1799            maximum_iterations=5,
1800            name="outer")
1801        return array_ops.identity(r[1])
1802
1803    xla_context = control_flow_ops.XLAControlFlowContext()
1804    xla_context.Enter()
1805    final_with_xla_context = create_while_loop()
1806    xla_context.Exit()
1807
1808    final_without_xla_context = create_while_loop()
1809
1810    with self.session(use_gpu=False) as sess:
1811      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1812      run_metadata_without_xla_context = config_pb2.RunMetadata()
1813      run_metadata = config_pb2.RunMetadata()
1814
1815      final_value_without_xla_context = sess.run(
1816          final_without_xla_context,
1817          feed_dict={p: [0, 0, 0]},
1818          options=opts,
1819          run_metadata=run_metadata_without_xla_context)
1820
1821      final_value_with_xla_context = sess.run(
1822          final_with_xla_context,
1823          feed_dict={p: [0, 0, 0]},
1824          options=opts,
1825          run_metadata=run_metadata)
1826
1827      if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1828        # With while_v2 on xla, run_metadata only contains the unlowered While
1829        # op so node_stats does not have statistics for the pushes. So as a
1830        # loose check we check the pushes in the lowered version.
1831        for dev in run_metadata_without_xla_context.step_stats.dev_stats:
1832          if "/device:CPU" in dev.device:
1833            node_stats = dev.node_stats
1834        stack_push_count = len([
1835            x for x in node_stats
1836            if re.match(r".*TensorListPushBack_?\d*", x.node_name)
1837        ])
1838      else:
1839        for dev in run_metadata.step_stats.dev_stats:
1840          if "/device:CPU" in dev.device:
1841            node_stats = dev.node_stats
1842        stack_push_op = "StackPushV2"
1843        stack_push_count = len(
1844            [x for x in node_stats if x.node_name.endswith("StackPushV2")])
1845      # Pushes to the stack = product of maximum_iterations values;
1846      # the last two "3"s comes from size(p), when p == [0, 0, 0].
1847      self.assertEqual(stack_push_count, 5 * 3 * 3, str(node_stats))
1848
1849      self.assertAllClose(final_value_with_xla_context,
1850                          final_value_without_xla_context)
1851
1852  # Have more than 10 parallel iterations and hence exercise k-bound
1853  # most of the time.
1854  @test_util.run_deprecated_v1
1855  def testWhile_3(self):
1856    with self.cached_session():
1857
1858      def compute(i, m, c, o):
1859        m, c = [math_ops.add(m, 1), math_ops.add(c, 1)]
1860        o = math_ops.add(o, m)
1861        o = math_ops.add(o, c)
1862        i = math_ops.add(i, 1)
1863        return [i, m, c, o]
1864
1865      i = ops.convert_to_tensor(0)
1866      m = ops.convert_to_tensor(0)
1867      c = ops.convert_to_tensor(0)
1868      o = ops.convert_to_tensor(0)
1869      d = ops.convert_to_tensor(100)
1870      r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, d),
1871                                      compute, [i, m, c, o])
1872      result = r[3]
1873    self.assertAllEqual(10100, result)
1874
1875  @test_util.run_deprecated_v1
1876  def testWhile_4(self):
1877    with self.cached_session():
1878
1879      def compute(i, m, c, o):
1880        m, c = [array_ops.gather(x, i), array_ops.gather(x, i)]
1881        o = math_ops.add(o, m)
1882        o = math_ops.add(o, c)
1883        i = math_ops.add(i, 1)
1884        return [i, m, c, o]
1885
1886      i = ops.convert_to_tensor(0)
1887      m = ops.convert_to_tensor(0)
1888      c = ops.convert_to_tensor(0)
1889      o = ops.convert_to_tensor(0)
1890      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
1891      s = array_ops.size(x)
1892      r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, s),
1893                                      compute, [i, m, c, o])
1894      result = r[3]
1895    self.assertAllEqual(42, result)
1896
1897  @test_util.run_v1_only("b/120545219")
1898  def testWhile_5(self):
1899    with self.cached_session():
1900
1901      def compute(i, c, o):
1902        c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
1903                                    [1] + array_ops.expand_dims(i, 0))
1904        o = array_ops.concat([o, c], 0)
1905        i = math_ops.add(i, 1)
1906        return [i, c, o]
1907
1908      i = ops.convert_to_tensor(0)
1909      c = ops.convert_to_tensor([0])
1910      o = ops.convert_to_tensor([0])
1911      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
1912      s = array_ops.size(x)
1913      r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s),
1914                                      compute, [i, c, o], [
1915                                          i.get_shape(),
1916                                          tensor_shape.unknown_shape(),
1917                                          tensor_shape.unknown_shape()
1918                                      ])
1919      result = r[2]
1920    self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
1921
1922  @test_util.run_gpu_only
1923  @test_util.run_deprecated_v1
1924  def testWhile_Device(self):
1925
1926    # Body function defined outside of device scope
1927    def body(x):
1928      return math_ops.exp(x)
1929
1930    with ops.device("CPU:0"):
1931      r = control_flow_ops.while_loop(
1932          lambda x: x < 10, body, [constant_op.constant(-10.)])
1933      self.assertIn("cpu", r.device.lower())
1934
1935    with session.Session() as sess:
1936      options = config_pb2.RunOptions(output_partition_graphs=True)
1937      run_metadata = config_pb2.RunMetadata()
1938      sess.run(r, options=options, run_metadata=run_metadata)
1939      # We expect that everything runs on CPU, even if GPU is available.
1940      self.assertEqual(len(run_metadata.partition_graphs), 1)
1941
1942  @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
1943  @test_util.run_v1_only("b/120545219")
1944  def testBufferForwarding(self):
1945    run_options = config_pb2.RunOptions(
1946        trace_level=config_pb2.RunOptions.FULL_TRACE)
1947    run_metadata = config_pb2.RunMetadata()
1948
1949    with self.cached_session() as sess:
1950      with ops.device("/cpu:0"):
1951        c = constant_op.constant(2)
1952        i0 = constant_op.constant(0)
1953        r = control_flow_ops.while_loop(lambda i: i < 1000,
1954                                        lambda i: math_ops.square(c) + i, [i0])
1955      r_val = sess.run(r, options=run_options, run_metadata=run_metadata)
1956      self.assertEqual(1000, r_val)
1957      self.assertTrue(run_metadata.HasField("step_stats"))
1958      unique_allocs = set()
1959      for node_stat in run_metadata.step_stats.dev_stats[0].node_stats:
1960        for output in node_stat.output:
1961          unique_allocs.add(
1962              output.tensor_description.allocation_description.ptr)
1963      # Prior to cl/147536680, the number of unique allocations was about 1005.
1964      self.assertLess(len(unique_allocs), 756)
1965
1966  def _testWhile_Gpu_1(self, use_gpu):
1967    with self.cached_session(use_gpu=use_gpu):
1968      n = constant_op.constant(1.0)
1969      c = lambda x: math_ops.less(x, 10.0)
1970      b = lambda x: math_ops.add(x, 1.0)
1971      r = control_flow_ops.while_loop(c, b, [n])
1972      self.assertAllClose(10.0, self.evaluate(r))
1973
1974  def testWhile_Gpu_1(self):
1975    self._testWhile_Gpu_1(use_gpu=False)
1976    self._testWhile_Gpu_1(use_gpu=True)
1977
1978  def _testWhile_Gpu_2(self, use_gpu):
1979    with self.cached_session(use_gpu=use_gpu):
1980      n = constant_op.constant(1.0)
1981      c = lambda x: math_ops.less(x, 10.0)
1982
1983      def b(x):
1984        with ops.device("/cpu:0"):
1985          return math_ops.add(x, 1.0)
1986
1987      r = control_flow_ops.while_loop(c, b, [n])
1988      self.assertAllClose(10.0, self.evaluate(r))
1989
1990  def testWhile_Gpu_2(self):
1991    self._testWhile_Gpu_2(use_gpu=False)
1992    self._testWhile_Gpu_2(use_gpu=True)
1993
1994  def testWhileShape(self):
1995    with self.cached_session():
1996      i = constant_op.constant(0)
1997      m = array_ops.ones([2, 2])
1998      c = lambda i, j: math_ops.less(i, 2)
1999
2000      def _b(i, j):
2001        new_i = math_ops.add(i, 1)
2002        new_j = array_ops.tile(j, [2, 2])
2003        return [new_i, new_j]
2004
2005      r = control_flow_ops.while_loop(
2006          c, _b, [i, m],
2007          [i.get_shape(), tensor_shape.unknown_shape()])
2008      r = r[1] * array_ops.ones([8, 8])
2009      self.assertAllEqual(np.ones((8, 8)), self.evaluate(r))
2010
2011  @test_util.disable_control_flow_v2("b/131265085")
2012  @test_util.run_v1_only("b/131265085")
2013  def testWhileBadShape(self):
2014    x = constant_op.constant([2.0, 4.0], name="values")
2015    i = constant_op.constant(0)
2016    c = lambda i, _: math_ops.less(i, 10)
2017    b = lambda i, x: [i + 1, x + 1]
2018    with self.assertRaisesRegex(ValueError, "is not compatible with"):
2019      # Shape of x is [2], but we specify a shape of [5].
2020      control_flow_ops.while_loop(
2021          c, b, [i, x], [i.shape, tensor_shape.TensorShape([5])])
2022
2023  @test_util.run_in_graph_and_eager_modes
2024  def testWhileBadBodyReturn(self):
2025    x = constant_op.constant([2.0, 4.0], name="values")
2026    i = constant_op.constant(0)
2027    c = lambda i, *x: math_ops.less(i, 10)
2028
2029    # body accepts N values and returns N+1 values.
2030    b = lambda i, *x: (i, i) + x
2031
2032    with self.assertRaisesRegex(
2033        ValueError, "The two structures don't have the same nested structure."):
2034      control_flow_ops.while_loop(c, b, [i, x])
2035
2036  @test_util.run_deprecated_v1
2037  def testWhileWithNonTensorInput_Scalar(self):
2038    with self.cached_session():
2039      n = 0
2040      c = lambda x: x < 10000
2041      b = lambda x: x + 1
2042      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
2043      self.assertEqual(10000, self.evaluate(r))
2044
2045  def testWhileWithNonTensorInput_Vector(self):
2046    with self.cached_session():
2047      n = np.array([0])  # Note, [0] would not work here; that is a list
2048      c = lambda x: x[0] < 10000
2049      b = lambda x: array_ops.stack([x[0] + 1])
2050      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
2051      self.assertEqual([10000], self.evaluate(r))
2052
2053  def testWhileShapeInference(self):
2054    with self.cached_session():
2055      i = constant_op.constant(0)
2056      m = array_ops.ones([2, 2])
2057      c = lambda i, j: math_ops.less(i, 2)
2058
2059      def b(i, j):
2060        new_i = math_ops.add(i, 1)
2061        new_j = array_ops.concat([j, j], 0)
2062        return [new_i, new_j]
2063
2064      r = control_flow_ops.while_loop(
2065          c, b, [i, m],
2066          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
2067      self.assertTrue(r[1].shape.is_compatible_with([8, 2]))
2068
2069  @test_util.run_v1_only("b/120545219")
2070  def testWhileShapeInferenceBadShape(self):
2071    with self.cached_session():
2072      i = constant_op.constant(0)
2073      m = array_ops.ones([2, 2])
2074      c = lambda i, j: math_ops.less(i, 2)
2075      b = lambda i, j: [i + 1, array_ops.concat([j, j], 0)]
2076      with self.assertRaisesRegex(
2077          ValueError,
2078          r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has "
2079          r"shape \(4, 2\) after one iteration. To allow the shape to vary "
2080          r"across iterations, use the `shape_invariants` argument of "
2081          r"tf.while_loop to specify a less-specific shape."):
2082        control_flow_ops.while_loop(c, b, [i, m])
2083
2084  def testWhileShapeInferenceSparseTensor(self):
2085    values = constant_op.constant([2.0, 4.0], name="values")
2086    indices = constant_op.constant([[0], [3]],
2087                                   dtype=dtypes.int64,
2088                                   name="indices")
2089    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
2090    i = constant_op.constant(0)
2091    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
2092
2093    def c(i, _):
2094      return i < 10
2095
2096    def b1(i, x):  # modifies values.  (shape of components is not changed.)
2097      return [
2098          i + 1,
2099          sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
2100      ]
2101
2102    def b2(i, x):  # adds new values.  (shape of components is changed.)
2103      return [
2104          i + 1,
2105          sparse_ops.sparse_add(
2106              x,
2107              sparse_tensor.SparseTensor(
2108                  indices=math_ops.cast(
2109                      array_ops.fill([1, 1], i), dtypes.int64),
2110                  values=array_ops.fill([1], 1.0),
2111                  dense_shape=x.dense_shape))
2112      ]
2113
2114    def b3(i, x):  # modifies rank.  (shape of all components is changed.)
2115      return [
2116          i + 1,
2117          sparse_tensor.SparseTensor(
2118              array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
2119              array_ops.concat([x.dense_shape, [10]], axis=0))
2120      ]
2121
2122    def check_shapes(r, indices, values, dense_shape):
2123      self.assertTrue(r.indices.shape.is_compatible_with(indices))
2124      self.assertTrue(r.values.shape.is_compatible_with(values))
2125      self.assertTrue(r.dense_shape.shape.is_compatible_with(dense_shape))
2126
2127    # Default shape invariant; b1 only modifies values.
2128    _, r = control_flow_ops.while_loop(c, b1, [i, x])
2129    check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])
2130
2131    # Default shape invariant; b2 adds new values
2132    _, r = control_flow_ops.while_loop(c, b2, [i, x])
2133    check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])
2134
2135    # Explicit shape invariant, allowing any rank; b1 only modifies values.
2136    _, r = control_flow_ops.while_loop(
2137        c, b1, [i, x],
2138        [i.get_shape(), tensor_shape.TensorShape([None])])
2139    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
2140
2141    # Explicit shape invariant, allowing any rank; b3 modifies rank.
2142    _, r = control_flow_ops.while_loop(
2143        c, b3, [i, x],
2144        [i.get_shape(), tensor_shape.TensorShape([None])])
2145    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
2146
2147    # Shape invariant with ndims=None.  Technically, this isn't supported
2148    # according to the docs, but we support it for backwards compatibility.
2149    _, r = control_flow_ops.while_loop(
2150        c, b1, [i, x],
2151        [i.get_shape(), tensor_shape.TensorShape(None)])
2152    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
2153    _, r = control_flow_ops.while_loop(
2154        c, b3, [i, x],
2155        [i.get_shape(), tensor_shape.TensorShape(None)])
2156    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
2157
2158  @test_util.disable_control_flow_v2("b/131265085")
2159  @test_util.run_v1_only("b/131265085")
2160  def testWhileBadShapeSparseTensor(self):
2161    values = constant_op.constant([2.0, 4.0], name="values")
2162    indices = constant_op.constant([[0], [3]],
2163                                   dtype=dtypes.int64,
2164                                   name="indices")
2165    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
2166    i = constant_op.constant(0)
2167    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
2168    c = lambda i, _: i < 10
2169    b1 = lambda i, x: [i+1, x]
2170    def b2(i, x):  # modifies rank.  (shape of all components is changed.)
2171      return [
2172          i + 1,
2173          sparse_tensor.SparseTensor(
2174              array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
2175              array_ops.concat([x.dense_shape, [10]], axis=0))
2176      ]
2177
2178    # Explicit shape invariant, with a specific (incompatible) rank.
2179    with self.assertRaisesRegex(ValueError, "is not compatible with"):
2180      control_flow_ops.while_loop(
2181          c, b1, [i, x],
2182          [i.get_shape(), tensor_shape.TensorShape([5])])
2183
2184    # Default shape invariant, but b2 modifies rank (which is not allowed).
2185    with self.assertRaises(ValueError):
2186      control_flow_ops.while_loop(c, b2, [i, x])
2187
2188  def testWhileShapeInferenceIndexedSlices(self):
2189    with self.cached_session():
2190      values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
2191      indices = constant_op.constant([0, 3], name="indices")
2192      shape = constant_op.constant([10, 2], name="dense_shape")
2193      i = constant_op.constant(0)
2194      x = ops.IndexedSlices(values, indices, dense_shape=shape)
2195
2196      def c(i, _):
2197        return i < 10
2198
2199      def b(i, x):
2200        return [
2201            i + 1,
2202            ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
2203        ]
2204
2205      _, r = control_flow_ops.while_loop(c, b, [i, x])
2206      self.assertEqual(r.dense_shape.get_shape()[0], 2)
2207      self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2]))
2208
2209      _, r = control_flow_ops.while_loop(
2210          c, b, [i, x],
2211          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
2212      self.assertEqual(r.dense_shape.get_shape()[0], 2)
2213      self.assertTrue(r.values.get_shape().is_compatible_with([None, 2]))
2214
2215  @test_util.disable_control_flow_v2("b/131265085")
2216  @test_util.run_v1_only("b/131265085")
2217  def testWhileBadShapeIndexedSlices(self):
2218    values = constant_op.constant([2.0, 4.0], name="values")
2219    indices = constant_op.constant([[0], [3]],
2220                                   dtype=dtypes.int64,
2221                                   name="indices")
2222    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
2223    i = constant_op.constant(0)
2224    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
2225    c = lambda i, _: 10
2226    b = lambda i, x: [i+1, x]
2227
2228    # Explicit shape invariant, with a specific (incompatible) rank.
2229    with self.assertRaisesRegex(ValueError, "is not compatible with"):
2230      control_flow_ops.while_loop(
2231          c, b, [i, x],
2232          [i.get_shape(), tensor_shape.TensorShape([5])])
2233
2234  def testWhileShapeInferenceRaggedTensor(self):
2235    i = constant_op.constant(0)
2236    x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
2237    c = lambda i, _: i < 10
2238
2239    def b1(i, x):  # Adds new values to rows (but doesn't create new rows)
2240      return [
2241          i + 1,
2242          array_ops.concat([x, x], axis=1)
2243      ]
2244
2245    def b2(i, x):  # Adds new rows.
2246      return [
2247          i + 1,
2248          array_ops.concat([x, x], axis=0)
2249      ]
2250
2251    def check_shapes(r, values, splits):
2252      self.assertTrue(r.values.shape.is_compatible_with(values))
2253      self.assertTrue(r.row_splits.shape.is_compatible_with(splits))
2254
2255    # Default shape invariant; b1 adds new values to rows.
2256    _, r = control_flow_ops.while_loop(c, b1, [i, x])
2257    check_shapes(r, values=[None], splits=[4])
2258
2259    # Default shape invariant; b2 adds new rows (not allowed).
2260    if not context.executing_eagerly():
2261      with self.assertRaises(ValueError):
2262        _, r = control_flow_ops.while_loop(c, b2, [i, x])
2263
2264    # Explicit shape invariant; b1 adds new values to rows.
2265    # (deprecated: use TensorShape instead of RaggedTensorSpec)
2266    _, r = control_flow_ops.while_loop(
2267        c, b1, [i, x],
2268        [i.get_shape(), tensor_shape.TensorShape([None, None])])
2269    check_shapes(r, values=[None], splits=[None])
2270
2271    # Explicit shape invariant; b1 adds new values to rows.
2272    _, r = control_flow_ops.while_loop(
2273        c, b1, [i, x],
2274        [i.get_shape(), ragged_tensor.RaggedTensorSpec([None, None],
2275                                                       dtypes.int32)])
2276    check_shapes(r, values=[None], splits=[None])
2277
2278    # Explicit shape invariant; b2 adds new rows.
2279    _, r = control_flow_ops.while_loop(
2280        c, b2, [i, x],
2281        [i.get_shape(), ragged_tensor.RaggedTensorSpec([None, None],
2282                                                       dtypes.int32)])
2283    check_shapes(r, values=[None], splits=[None])
2284
2285  def testWhileShapeInferenceRaggedTensorRaggedRank2(self):
2286    i = constant_op.constant(0)
2287    x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]],
2288                                     [[], [8, 9, 10]]])
2289    c = lambda i, _: i < 10
2290    def b(i, x):
2291      return [
2292          i + 1,
2293          array_ops.concat([x, x[..., i:i+1]], axis=-1)
2294      ]
2295    _, r = control_flow_ops.while_loop(c, b, [i, x])
2296    self.assertEqual(r.row_splits.shape.as_list(), [3])
2297    self.assertTrue(r.values.row_splits.shape.as_list() in ([6], [None]))
2298    self.assertTrue(r.values.values.shape.as_list() in ([49], [None]))
2299
2300  def testWhileShapeInvariantTensorSpec(self):
2301    i = constant_op.constant(0)
2302    x = constant_op.constant([1])
2303    c = lambda i, _: i < 10
2304    b = lambda i, x: (i + 1, array_ops.stack([x, x]))
2305    shape_invariants = [
2306        tensor_spec.TensorSpec([], dtype=dtypes.int32),
2307        tensor_spec.TensorSpec(None, dtype=dtypes.int32)]
2308    control_flow_ops.while_loop(c, b, [i, x], shape_invariants)
2309
2310  # TODO(b/131265085) Remove this decorator when bug is fixed.
2311  @test_util.build_as_function_and_v1_graph
2312  def testWhileShapeInvariantWrongTypeSpecType(self):
2313    c = lambda i, _: i < 10
2314    b = lambda i, x: (i + 1, x)
2315    i = constant_op.constant(0)
2316    x = sparse_tensor.SparseTensor([[0]], [1.0], [10])
2317    shape_invariants = [
2318        tensor_spec.TensorSpec([], dtype=dtypes.int32),
2319        sparse_tensor.SparseTensorSpec([None])]
2320    control_flow_ops.while_loop(c, b, [i, x], shape_invariants)
2321
2322    x2 = constant_op.constant([1])
2323    with self.assertRaises(TypeError):
2324      control_flow_ops.while_loop(c, b, [i, x2], shape_invariants)
2325
2326    x3 = ragged_factory_ops.constant([[1, 2], [3]])
2327    with self.assertRaises(TypeError):
2328      control_flow_ops.while_loop(c, b, [i, x3], shape_invariants)
2329
2330    i2 = constant_op.constant(0.0)
2331    with self.assertRaises(TypeError):
2332      control_flow_ops.while_loop(c, b, [i2, x], shape_invariants)
2333
2334  # TODO(b/131265085) Remove this decorator when bug is fixed.
2335  @test_util.build_as_function_and_v1_graph
2336  def testWhileShapeInvariantBadType(self):
2337    i = constant_op.constant(0)
2338    x = constant_op.constant([1])
2339    c = lambda i, _: i < 10
2340    b = lambda i, x: (i + 1, x)
2341    with self.assertRaises((ValueError, TypeError)):
2342      control_flow_ops.while_loop(c, b, [i, x], ["foo", "bar"])
2343
2344  def _testNestedWhile_1(self, use_gpu):
2345    with self.cached_session(use_gpu=use_gpu):
2346      n = constant_op.constant(0)
2347
2348      def cpu_sum(s):
2349        c = lambda i, s: math_ops.less(i, 10)
2350
2351        def b(i, s):
2352          i1 = math_ops.add(i, 1)
2353          with ops.device("/cpu:0"):
2354            s1 = math_ops.add(i, s)
2355          return i1, s1
2356
2357        _, r_s = control_flow_ops.while_loop(c, b, [n, s])
2358        return r_s
2359
2360      c = lambda x: math_ops.less(x, 200)
2361      b = lambda x: math_ops.add(x, cpu_sum(n))
2362      r = control_flow_ops.while_loop(c, b, [n])
2363      self.assertEqual(225, self.evaluate(r))
2364
2365  def testNestedWhile_1(self):
2366    self._testNestedWhile_1(use_gpu=False)
2367    self._testNestedWhile_1(use_gpu=True)
2368
2369  def _testNestedWhile_2(self, use_gpu):
2370    # Test the cases that A -> Enter and Exit -> A are partitioned.
2371    with self.cached_session(use_gpu=use_gpu):
2372      s0 = constant_op.constant(2.0)
2373
2374      def inner_loop(s):
2375        c = lambda s: math_ops.less(s, 20.0)
2376
2377        def b(s):
2378          s1 = math_ops.add(s, s)
2379          return s1
2380
2381        r_s = control_flow_ops.while_loop(c, b, [s], parallel_iterations=1)
2382        return r_s
2383
2384      outer_c = lambda x: math_ops.less(x, 3000.0)
2385
2386      def outer_b(x):
2387        x = logging_ops.Print(x, [x])  # Edge "Print -> Enter" is partitioned
2388        x = inner_loop(x)
2389        with ops.device("/cpu:0"):
2390          x = math_ops.square(x)  # Edge "Exit -> Square" is partitioned
2391        return x
2392
2393      r = control_flow_ops.while_loop(
2394          outer_c, outer_b, [s0], parallel_iterations=1)
2395      self.assertEqual(1048576.0, self.evaluate(r))
2396
2397  def testNestedWhile_2(self):
2398    self._testNestedWhile_2(use_gpu=False)
2399    self._testNestedWhile_2(use_gpu=True)
2400
2401  @test_util.run_v1_only("b/120545219")
2402  def testWhileWithControl_1(self):
2403    with self.cached_session():
2404      n = constant_op.constant(0)
2405      r = constant_op.constant(0)
2406      condition = lambda n_, r_: math_ops.less(n_, 10)
2407
2408      def body(n_, r_):
2409        n_ = math_ops.add(n_, 1)
2410        with r_.graph.control_dependencies([r_]):
2411          r_ = constant_op.constant(12)
2412        return [n_, r_]
2413
2414      res = control_flow_ops.while_loop(
2415          condition, body, [n, r], parallel_iterations=1)
2416      self.assertAllEqual(12, res[1])
2417
2418  @test_util.run_deprecated_v1
2419  def testWhileWithControl_2(self):
2420    with self.cached_session():
2421      r = constant_op.constant(0)
2422      condition = lambda r_: math_ops.less(r_, 10)
2423
2424      def body(r_):
2425        with r_.graph.control_dependencies([r_]):
2426          r_ = constant_op.constant(12)
2427        return [r_]
2428
2429      res = control_flow_ops.while_loop(
2430          condition, body, [r], parallel_iterations=1)
2431      self.assertAllEqual(12, self.evaluate(res))
2432
2433  @test_util.run_v1_only("b/120545219")
2434  def testWhileWithControl_3(self):
2435    with self.cached_session() as sess:
2436      b = array_ops.placeholder(dtypes.bool)
2437      c = constant_op.constant(1)
2438      x0 = constant_op.constant(0)
2439      with ops.control_dependencies([b]):
2440        r = control_flow_ops.while_loop(lambda x: x < 10, lambda x: x + c, [x0])
2441      self.assertEqual(10, sess.run(r, {b: True}))
2442
2443  @test_util.run_v1_only("b/120545219")
2444  def testWhileWithControl_4(self):
2445    with self.cached_session() as sess:
2446      b = array_ops.placeholder(dtypes.bool)
2447      c = constant_op.constant(1)
2448      x0 = constant_op.constant(0)
2449      with ops.control_dependencies([b]):
2450        r = control_flow_ops.while_loop(
2451            lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
2452      self.assertEqual(10, sess.run(r, {b: True}))
2453
2454  @test_util.run_v1_only("b/120545219")
2455  def testWhileWithControl_5(self):
2456    with self.cached_session() as sess:
2457      b = array_ops.placeholder(dtypes.bool)
2458      c = constant_op.constant(1)
2459      x0 = constant_op.constant(0)
2460
2461      def body(x):
2462        with ops.control_dependencies([b]):
2463          return x + c
2464
2465      r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0])
2466      self.assertEqual(10, sess.run(r, {b: True}))
2467
2468  def testWhileCondWithControl(self):
2469    # Ensure that no control edges by an outer control dependency context are
2470    # added to nodes inside cond/while contexts.
2471    with self.cached_session() as sess:
2472      const_true = lambda: constant_op.constant(True)
2473      const_false = lambda: constant_op.constant(False)
2474      cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false)
2475      body = lambda i: control_flow_ops.cond(i > 0, lambda: i - 1, lambda: i)
2476
2477      with ops.control_dependencies([control_flow_ops.no_op()]):
2478        loop = control_flow_ops.while_loop(cond, body,
2479                                           (constant_op.constant(5),))
2480      self.assertEqual(0, self.evaluate(loop))
2481
2482  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
2483  @test_util.run_v1_only("b/120545219")
2484  def testWhileCondWithControl_1(self):
2485    with self.cached_session():
2486      v = variable_scope.get_variable(
2487          "v", [], initializer=init_ops.constant_initializer(2))
2488      i0 = constant_op.constant(0)
2489      with ops.control_dependencies([i0]):
2490
2491        def loop_condition(i):
2492          return i < 4
2493
2494        def loop_body(i):
2495          some_cond = control_flow_ops.cond(
2496              constant_op.constant(True),
2497              lambda: state_ops.assign(v, math_ops.square(v)), lambda: v)
2498          with ops.control_dependencies([some_cond]):
2499            return i + 1
2500
2501      r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,))
2502      self.evaluate(variables.global_variables_initializer())
2503      self.assertEqual(4, self.evaluate(r))
2504      self.assertAllClose(65536.0, self.evaluate(v))
2505
2506  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
2507  @test_util.run_v1_only("b/120545219")
2508  def testWhileCondExitControl(self):
2509
2510    with self.cached_session():
2511      v = variables.Variable(1)
2512
2513      def false_branch():
2514        cond = lambda i: i < 100
2515
2516        def body(i):
2517          x = state_ops.assign(v, i)
2518          return x + 1
2519
2520        loop = control_flow_ops.while_loop(cond, body, [0])
2521        # Make sure to handle correctly control edge from Exit to a node.
2522        with ops.control_dependencies([loop]):
2523          return constant_op.constant(6.0)
2524
2525      r = control_flow_ops.cond(
2526          constant_op.constant(False), lambda: constant_op.constant(1.0),
2527          false_branch)
2528      self.evaluate(variables.global_variables_initializer())
2529      self.assertEqual(6.0, self.evaluate(r))
2530      self.assertEqual(99, self.evaluate(v))
2531
2532  def testCondWhile_1(self):
2533
2534    with self.cached_session():
2535      n = ops.convert_to_tensor(0, name="n")
2536      c = lambda x: math_ops.less(x, 10)
2537      b = lambda x: math_ops.add(x, 1)
2538      r = control_flow_ops.cond(
2539          math_ops.less(0, 1), lambda: control_flow_ops.while_loop(c, b, [n]),
2540          lambda: n)
2541      self.assertAllEqual(10, self.evaluate(r))
2542
2543  def testCondWhile_2(self):
2544
2545    with self.cached_session():
2546      n = ops.convert_to_tensor(0)
2547      c = lambda x: math_ops.less(x, 10)
2548      b = lambda x: math_ops.add(x, 1)
2549      r = control_flow_ops.cond(
2550          math_ops.less(1, 0), lambda: math_ops.add(n, 1),
2551          lambda: control_flow_ops.while_loop(c, b, [n]))
2552      self.assertAllEqual(10, self.evaluate(r))
2553
2554  def _testCondWhile_3(self, use_gpu):
2555    with self.cached_session(use_gpu=use_gpu) as sess:
2556      p = array_ops.placeholder(dtypes.bool)
2557      n = constant_op.constant(0.0)
2558
2559      def c(x):
2560        return math_ops.less(x, 10.0)
2561
2562      def b(x):
2563        with ops.device("/cpu:0"):
2564          x1 = math_ops.add(x, 1.0)
2565        return x1
2566
2567      r = control_flow_ops.cond(p,
2568                                lambda: control_flow_ops.while_loop(c, b, [n]),
2569                                lambda: math_ops.multiply(n, 2.0))
2570      r1 = gradients_impl.gradients(r, [n])
2571      self.assertEqual(10., sess.run(r, {p: True}))
2572      self.assertEqual([1.0], sess.run(r1, {p: True}))
2573      self.assertEqual(0.0, sess.run(r, {p: False}))
2574      self.assertEqual([2.0], sess.run(r1, {p: False}))
2575
2576  @test_util.run_deprecated_v1
2577  def testCondWhile_3(self):
2578    self._testCondWhile_3(use_gpu=False)
2579    self._testCondWhile_3(use_gpu=True)
2580
2581  def testWhileCond_1(self):
2582
2583    with self.cached_session():
2584      i = ops.convert_to_tensor(0, name="i")
2585      n = ops.convert_to_tensor(10, name="n")
2586      one = ops.convert_to_tensor(1, name="one")
2587      c = lambda x: math_ops.less(x, n)
2588      # pylint: disable=undefined-variable
2589      # for OSS build
2590      b = lambda x: control_flow_ops.cond(
2591          constant_op.constant(True),
2592          lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one))
2593      # pylint: enable=undefined-variable
2594      r = control_flow_ops.while_loop(c, b, [i])
2595      self.assertAllEqual(10, self.evaluate(r))
2596
2597  def testWhileCond_2(self):
2598
2599    with self.cached_session():
2600      n = ops.convert_to_tensor(0, name="n")
2601      c = lambda x: math_ops.less(x, 10)
2602      b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n)
2603      r = control_flow_ops.while_loop(c, b, [n])
2604      self.assertAllEqual(10, self.evaluate(r))
2605
2606  def testWhileCond_3(self):
2607
2608    with self.cached_session():
2609      n = ops.convert_to_tensor(0)
2610      c = lambda x: math_ops.less(x, 10)
2611      # pylint: disable=undefined-variable
2612      # for OSS build
2613      b = lambda x: control_flow_ops.cond(math_ops.less(0, 1),
2614                                          lambda: math_ops.add(x, 1),
2615                                          lambda: math_ops.subtract(x, 1))
2616      # pylint: enable=undefined-variable
2617      r = control_flow_ops.while_loop(c, b, [n])
2618      self.assertAllEqual(10, self.evaluate(r))
2619
2620  @test_util.run_deprecated_v1
2621  def testWhileCondGradMultiDevice(self):
2622    config = config_pb2.ConfigProto(device_count={"CPU": 2},
2623                                    allow_soft_placement=True)
2624    with self.cached_session(config=config) as sess:
2625      pred = array_ops.placeholder(dtypes.bool, [])
2626      x_init = constant_op.constant(1.0)
2627
2628      with ops.device("/cpu:0"):
2629        z = control_flow_ops.while_loop(
2630            lambda i, _: i < 3,
2631            lambda i, x: (i + 1, control_flow_ops.cond(
2632                pred, lambda: x * 2.0, lambda: 10.0)),
2633            [0, x_init])
2634
2635      with ops.device("/cpu:1"):
2636        grad = gradients_impl.gradients(z, x_init)[0]
2637
2638      with ops.device("/cpu:0"):
2639        grad_grad = gradients_impl.gradients(grad, x_init)[0]
2640
2641      self.assertEqual(sess.run(grad, {pred: True}), 8.0)
2642      self.assertEqual(sess.run(grad, {pred: False}), 0.0)
2643
2644      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
2645        return
2646
2647      self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
2648      self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0)
2649
2650  # NOTE: It is ok to have parallel_iterations > 1
2651  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2652  @test_util.run_deprecated_v1
2653  def testWhileUpdateVariable_1(self):
2654    with self.cached_session():
2655      select = variables.Variable([3.0, 4.0, 5.0])
2656      n = constant_op.constant(0)
2657
2658      def loop_iterator(j):
2659        return math_ops.less(j, 3)
2660
2661      def loop_body(j):
2662        ns = state_ops.scatter_update(select, j, 10.0)
2663        nj = math_ops.add(j, 1)
2664        op = control_flow_ops.group(ns)
2665        nj = control_flow_ops.with_dependencies([op], nj)
2666        return [nj]
2667
2668      r = control_flow_ops.while_loop(
2669          loop_iterator, loop_body, [n], parallel_iterations=1)
2670      self.evaluate(variables.global_variables_initializer())
2671      self.assertEqual(3, self.evaluate(r))
2672      result = self.evaluate(select)
2673      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
2674
2675  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2676  @test_util.run_v1_only("b/120545219")
2677  def testWhileUpdateVariable_2(self):
2678    with self.cached_session():
2679      select1 = variables.Variable([3.0, 4.0, 5.0])
2680      select2 = variables.Variable([3.0, 4.0, 5.0])
2681      n = constant_op.constant(0)
2682
2683      def loop_iterator(j):
2684        return math_ops.less(j, 3)
2685
2686      def loop_body(j):
2687        ns1 = state_ops.scatter_update(select1, j, 10.0)
2688        ns2 = state_ops.scatter_update(select2, j, 10.0)
2689        nj = math_ops.add(j, 1)
2690        op = control_flow_ops.group(ns1, ns2)
2691        nj = control_flow_ops.with_dependencies([op], nj)
2692        return [nj]
2693
2694      r = control_flow_ops.while_loop(
2695          loop_iterator, loop_body, [n], parallel_iterations=1)
2696      self.evaluate(variables.global_variables_initializer())
2697      self.assertEqual(3, self.evaluate(r))
2698      result1 = self.evaluate(select1)
2699      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1)
2700      result2 = self.evaluate(select2)
2701      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
2702
2703  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2704  @test_util.run_v1_only("b/120545219")
2705  def testWhileUpdateVariable_3(self):
2706    with self.cached_session():
2707      select = variables.Variable([3.0, 4.0, 5.0])
2708      n = constant_op.constant(0)
2709
2710      def loop_iterator(j, _):
2711        return math_ops.less(j, 3)
2712
2713      def loop_body(j, _):
2714        ns = state_ops.scatter_update(select, j, 10.0)
2715        nj = math_ops.add(j, 1)
2716        return [nj, ns]
2717
2718      r = control_flow_ops.while_loop(
2719          loop_iterator,
2720          loop_body, [n, array_ops.identity(select)],
2721          parallel_iterations=1)
2722      self.evaluate(variables.global_variables_initializer())
2723      result = r[1]
2724    self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
2725
2726  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2727  @test_util.run_v1_only("b/120545219")
2728  def testWhileUpdateVariable_4(self):
2729    with self.cached_session():
2730      var_a = variables.Variable(0, name="a")
2731      var_b = variables.Variable(0, name="b")
2732      self.evaluate(variables.global_variables_initializer())
2733
2734      c = constant_op.constant(0, name="c")
2735      asn1 = state_ops.assign_add(var_a, 1, name="a_add")
2736
2737      # Loop condition
2738      def pred(i):
2739        return math_ops.less(i, 10)
2740
2741      # Loop body
2742      def loop_body(i):
2743        asn2 = state_ops.assign_add(var_b, asn1, name="b_add")
2744        with ops.control_dependencies([asn2]):
2745          ni = math_ops.add(i, 1, name="i_add")
2746        return ni
2747
2748      lpa = control_flow_ops.while_loop(
2749          pred, loop_body, [c], parallel_iterations=1)
2750
2751      self.assertEqual(0, self.evaluate(var_b))
2752      self.evaluate(lpa)  # Run the loop
2753      self.assertEqual(10, self.evaluate(var_b))
2754
2755  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2756  @test_util.run_v1_only("b/120545219")
2757  def testWhileUpdateVariable_5(self):
2758    with self.cached_session():
2759      # Create some variables.
2760      var_a = variables.Variable(0, name="a")
2761      var_b = variables.Variable(0, name="b")
2762      self.evaluate(variables.global_variables_initializer())
2763
2764      # Change condition to check var_b
2765      def pred(_):
2766        return math_ops.less(var_b, 10)
2767
2768      # Change body to increment var_b
2769      def loop_body(i):
2770        asn1 = state_ops.assign_add(
2771            var_a, constant_op.constant(1), name="a_add")
2772        asn2 = state_ops.assign_add(
2773            var_b, constant_op.constant(1), name="b_add")
2774        with ops.control_dependencies([asn1, asn2]):
2775          inc_b = array_ops.identity(var_b)
2776        return inc_b
2777
2778      lpa = control_flow_ops.while_loop(
2779          pred, loop_body, [var_b], parallel_iterations=1, name="loop")
2780
2781      self.assertEqual(0, self.evaluate(var_b))
2782      self.evaluate(lpa)  # Run the loop
2783      self.assertEqual(10, self.evaluate(var_a))
2784      self.assertEqual(10, self.evaluate(var_b))
2785
2786  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2787  @test_util.run_v1_only("b/120545219")
2788  def testWhileUpdateVariable_6(self):
2789    with self.cached_session():
2790      # Create some variables.
2791      var_a = variables.Variable(0, name="a")
2792      var_b = variables.Variable(0, name="b")
2793      c = constant_op.constant(0)
2794      self.evaluate(variables.global_variables_initializer())
2795
2796      # Loop condition
2797      def pred(i):
2798        return math_ops.less(i, 10)
2799
2800      # Loop body
2801      def loop_body(i):
2802        asn1 = state_ops.assign_add(var_a, 1, name="a_add")
2803        with ops.control_dependencies([asn1]):
2804          asn2 = state_ops.assign_add(var_b, var_a, name="b_add")
2805        with ops.control_dependencies([asn2]):
2806          ni = math_ops.add(i, 1, name="i_add")
2807          return ni
2808
2809      lpa = control_flow_ops.while_loop(
2810          pred, loop_body, [c], parallel_iterations=1, name="loop")
2811
2812      self.assertEqual(0, self.evaluate(var_b))
2813      self.evaluate(lpa)  # Run the loop
2814      self.assertEqual(55, self.evaluate(var_b))
2815      self.assertEqual(10, self.evaluate(var_a))
2816
2817  @test_util.run_v1_only("b/120545219")
2818  def testWhileQueue_1(self):
2819    with self.cached_session():
2820      q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
2821      i = constant_op.constant(0)
2822
2823      def c(i):
2824        return math_ops.less(i, 10)
2825
2826      def b(i):
2827        ni = math_ops.add(i, 1)
2828        ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni)
2829        return ni
2830
2831      r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
2832      self.assertEqual([10], self.evaluate(r))
2833      for i in xrange(10):
2834        self.assertEqual([i], self.evaluate(q.dequeue()))
2835
2836  @test_util.run_v1_only("b/120545219")
2837  def testWhileTimeOut(self):
2838    run_options = config_pb2.RunOptions(timeout_in_ms=1)
2839    with self.cached_session() as sess:
2840      n = constant_op.constant(0)
2841      c = lambda x: True
2842      b = lambda x: math_ops.add(x, 1)
2843      r = control_flow_ops.while_loop(c, b, [n])
2844      with self.assertRaises(errors_impl.DeadlineExceededError):
2845        sess.run(r, options=run_options)
2846
2847  @test_util.disable_control_flow_v2("b/117119329 (stack)")
2848  @test_util.run_v1_only("b/120545219")
2849  def testWhileStack_1(self):
2850    with self.cached_session():
2851      s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
2852      i = constant_op.constant(0)
2853
2854      def c(i):
2855        return math_ops.less(i, 10)
2856
2857      def b(i):
2858        ni = math_ops.add(i, 1)
2859        ni = control_flow_ops.with_dependencies(
2860            [gen_data_flow_ops.stack_push_v2(s, i)], ni)
2861        return ni
2862
2863      r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
2864
2865      x = constant_op.constant(0)
2866
2867      def c1(i, _):
2868        return math_ops.greater(i, 0)
2869
2870      def b1(i, x):
2871        ni = math_ops.subtract(i, 1)
2872        nx = x + gen_data_flow_ops.stack_pop_v2(s, dtypes.int32)
2873        return [ni, nx]
2874
2875      _, rx = control_flow_ops.while_loop(
2876          c1,
2877          b1, [r, x],
2878          [r.get_shape(), tensor_shape.unknown_shape()],
2879          parallel_iterations=1)
2880      self.assertEqual(45, self.evaluate(rx))
2881
2882  def _testWhileGrad_ColocateGradients(self, colocate):
2883    gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
2884    ) else "/device:CPU:0"
2885
2886    graph = ops.Graph()
2887    with graph.as_default():
2888      v = constant_op.constant(2.0, name="v")
2889      c = lambda v: math_ops.less(v, 100.0)
2890
2891      def b(x):
2892        with ops.device(gpu_dev_name):
2893          return math_ops.square(x)
2894
2895      loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2896      r = gradients_impl.gradients(
2897          loop, v, colocate_gradients_with_ops=colocate)[0]
2898
2899    r_ops = graph.get_operations()
2900    r_devices = [(op.name, op.device) for op in r_ops]
2901
2902    self.assertTrue(any("Square" in op.name for op in r_ops))
2903
2904    for (name, dev) in r_devices:
2905      if not colocate and name.endswith("Square"):
2906        # Only forward graph contain gpu in Square device
2907        self.assertTrue(gpu_dev_name in dev)
2908      elif colocate and "Square" in name:
2909        # Forward and backward graphs contain gpu in Square/Square_grad devices
2910        self.assertTrue(gpu_dev_name in dev)
2911      else:
2912        self.assertFalse(gpu_dev_name in dev)
2913
2914    with self.session(graph=graph) as sess:
2915      self.assertAllClose(1024.0, self.evaluate(r))
2916
2917  @test_util.disable_control_flow_v2("b/116351701 (colocation)")
2918  @test_util.run_v1_only("b/120545219")
2919  def testWhileGrad_ColocateGradients(self):
2920    self._testWhileGrad_ColocateGradients(colocate=False)
2921    self._testWhileGrad_ColocateGradients(colocate=True)
2922
2923  @test_util.run_v1_only("b/120545219")
2924  def testWhileGrad_Square(self):
2925    with self.cached_session():
2926      v = constant_op.constant(2.0, name="v")
2927      c = lambda v: math_ops.less(v, 100.0)
2928      b = math_ops.square
2929      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2930      r = control_flow_ops.cond(math_ops.less(1, 2), lambda: r, lambda: v)
2931
2932      r = gradients_impl.gradients(r, v)[0]
2933      self.assertAllClose(1024.0, self.evaluate(r))
2934
2935  @test_util.run_v1_only("b/120545219")
2936  def testWhileGrad_Shape(self):
2937    with self.cached_session():
2938      x = array_ops.placeholder(dtypes.float32, shape=[None])
2939      v = constant_op.constant([2.0], name="v")
2940      n = constant_op.constant(0, name="n")
2941      c = lambda i, v: math_ops.less(i, 5)
2942      b = lambda i, v: [i + 1, math_ops.multiply(x, v)]
2943      r = control_flow_ops.while_loop(
2944          c,
2945          b, [n, v],
2946          [n.get_shape(), tensor_shape.unknown_shape()],
2947          parallel_iterations=1)
2948
2949      r = gradients_impl.gradients(r[1], x)[0]
2950      self.assertEqual([None], r.get_shape().as_list())
2951      self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
2952
2953  @test_util.run_deprecated_v1
2954  def testWhileGrad_BaseShape(self):
2955    with self.cached_session() as sess:
2956      x = array_ops.placeholder(dtypes.float32, [None])
2957      v0 = constant_op.constant([2.0, 2.0], name="v")
2958      c = lambda v: constant_op.constant(False)
2959      b = lambda v: math_ops.multiply(v, x)
2960      r = control_flow_ops.while_loop(c, b, [v0])
2961      y = math_ops.square(x)
2962
2963      r = gradients_impl.gradients([r, y], x)[0]
2964      self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
2965
2966  @test_util.run_deprecated_v1
2967  @test_util.enable_output_all_intermediates
2968  def testWhileGradAfterSessionRun(self):
2969    v0 = constant_op.constant(2.)
2970    r = control_flow_ops.while_loop(
2971        lambda _: True, lambda v: v * v, [v0], maximum_iterations=3)
2972
2973    self.assertAllEqual(r, 256.)
2974    grad = gradients_impl.gradients(r, v0)[0]
2975    self.assertAllClose(grad, 1024.)
2976
2977  @test_util.run_deprecated_v1
2978  @test_util.enable_output_all_intermediates
2979  def testNestedWhileGradAfterSessionRun(self):
2980    v0 = constant_op.constant(2.)
2981
2982    def body(v):
2983      inner_v0 = constant_op.constant(1.)
2984      return control_flow_ops.while_loop(
2985          lambda _: True, lambda x: x * v, [inner_v0], maximum_iterations=2)
2986
2987    r = control_flow_ops.while_loop(
2988        lambda _: True, body, [v0], maximum_iterations=3)
2989
2990    self.assertAllEqual(r, 256.)
2991    grad = gradients_impl.gradients(r, v0)[0]
2992    self.assertAllClose(grad, 1024.)
2993
2994  @test_util.run_v1_only("b/120545219")
2995  def testWhileGrad_MultipleUses(self):
2996    with self.cached_session():
2997      v = constant_op.constant(2.0, name="v")
2998      c = lambda v: math_ops.less(v, 100.0)
2999      b = math_ops.square
3000      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
3001      r = math_ops.multiply(r, r)
3002
3003      r = gradients_impl.gradients(r, v)[0]
3004      self.assertEqual(524288.0, self.evaluate(r))
3005
3006  @test_util.run_v1_only("b/120545219")
3007  def testWhileGrad_LoopAdd(self):
3008    with self.cached_session():
3009      v = constant_op.constant(2.0, name="v")
3010      c = lambda v: math_ops.less(v, 100.0)
3011      b = math_ops.square
3012      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
3013      r = math_ops.add(r, r)
3014
3015      r = gradients_impl.gradients(r, v)[0]
3016      self.assertAllClose(2048.0, self.evaluate(r))
3017
3018  def _testWhileGrad_Mul(self, use_gpu, p_iters):
3019    with self.cached_session(use_gpu=use_gpu) as sess:
3020      a = constant_op.constant(3.0, name="a")
3021      v = constant_op.constant(2.0, name="v")
3022      c = lambda v: math_ops.less(v, 100.0)
3023      b = lambda v: math_ops.multiply(v, a)
3024      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters)
3025
3026      grad_a, grad_v = gradients_impl.gradients(r, [a, v])
3027      grad_a_val, grad_v_val = self.evaluate([grad_a, grad_v])
3028      self.assertAllClose(216.0, grad_a_val)
3029      self.assertAllClose(81.0, grad_v_val)
3030
3031  @test_util.run_deprecated_v1
3032  def testWhileGrad_Mul(self):
3033    self._testWhileGrad_Mul(use_gpu=False, p_iters=1)
3034    self._testWhileGrad_Mul(use_gpu=False, p_iters=10)
3035    self._testWhileGrad_Mul(use_gpu=True, p_iters=1)
3036    self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
3037
3038  def testWhileGradInControlDeps(self):
3039
3040    @def_function.function
3041    def f():
3042      x_init = constant_op.constant(2.)
3043      loop_cond = lambda i, x: math_ops.less(i, 2)
3044      loop_body = lambda i, x: [i + 1, x**2]
3045      _, x = control_flow_ops.while_loop(loop_cond, loop_body, [0, x_init])
3046      with ops.control_dependencies([x]):
3047        (grad,) = gradients_impl.gradients(x, x_init)
3048        return grad
3049
3050    self.assertAllEqual(f(), 4. * 2.**3)  # 4 * x_init ^ 3
3051
3052  @test_util.run_deprecated_v1
3053  def testTfFunctionInV1WhileLoop(self):
3054
3055    # This test specifically tests that creating a Const node inside a
3056    # tf.function inside a v1 while_loop while inlining is turned on works.
3057    config = opt_cfg()
3058    assert config.graph_options.optimizer_options.do_function_inlining
3059    with session.Session(config=config):
3060
3061      @def_function.function
3062      def loop_body(i):
3063        # Here we create the const.
3064        return i + 1.
3065
3066      loop_cond = lambda i: True
3067      x = control_flow_ops.while_loop(
3068          loop_cond, loop_body, [0.], maximum_iterations=5)
3069      self.assertAllEqual(x, 5.)
3070
3071  def _testNestedWhileCondWhileGrad(self, use_gpu):
3072
3073    with self.cached_session(use_gpu=use_gpu):
3074      v = constant_op.constant(1.0)
3075
3076      def inner_loop(s):
3077        z = constant_op.constant(0)
3078        c = lambda i, x: math_ops.less(i, 4)
3079        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3080        return control_flow_ops.while_loop(c, b, [z, s])
3081
3082      c = lambda x: math_ops.less(x, 128.0)
3083
3084      def b(x):
3085        return control_flow_ops.cond(
3086            constant_op.constant(True),
3087            lambda: math_ops.square(inner_loop(x)[1]),
3088            lambda: math_ops.multiply(x, 2.0))
3089
3090      r = control_flow_ops.while_loop(c, b, [v])
3091      r = gradients_impl.gradients(r, v)[0]
3092      self.assertAllClose(512.0, self.evaluate(r))
3093
3094  @test_util.run_deprecated_v1
3095  def testNestedWhileCondWhileGrad(self):
3096    self._testNestedWhileCondWhileGrad(use_gpu=False)
3097
3098  @test_util.run_deprecated_v1
3099  def testNestedWhileCondWhileGradGpu(self):
3100    self._testNestedWhileCondWhileGrad(use_gpu=True)
3101
3102  @test_util.run_v1_only("b/120545219")
3103  def testWhileGrad_Variable(self):
3104    with self.cached_session():
3105      a = variables.Variable(3.0)
3106      v = constant_op.constant(2.0, name="v")
3107      c = lambda v: math_ops.less(v, 100.0)
3108      b = lambda v: math_ops.multiply(v, a)
3109      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
3110
3111      r = gradients_impl.gradients(r, a)
3112      self.evaluate(variables.global_variables_initializer())
3113      self.assertAllClose(216.0, r[0])
3114
3115  @test_util.run_deprecated_v1
3116  def testWhileGrad_ResourceVariable(self):
3117    with self.cached_session():
3118      a = resource_variable_ops.ResourceVariable(3.0)
3119      v = constant_op.constant(2.0, name="v")
3120      c = lambda v: math_ops.less(v, 100.0)
3121      b = lambda v: math_ops.multiply(v, a)
3122      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
3123
3124      g = gradients_impl.gradients(r, a)
3125      self.evaluate(variables.global_variables_initializer())
3126      self.assertAllClose(216.0, g[0])
3127
3128  def testWhileGrad_EagerResourceVariable(self):
3129    with context.eager_mode():
3130      a = resource_variable_ops.ResourceVariable(
3131          np.ones([2, 2], dtype=np.float32))
3132      v = constant_op.constant(1.0)
3133
3134      @eager_function.defun
3135      def fn():
3136        r = control_flow_ops.while_loop(
3137            lambda i, _: i < 2,
3138            lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v),
3139            [0, 1.0])[1]
3140        return gradients_impl.gradients(r, [v])[0]
3141
3142      self.assertEqual(self.evaluate(fn()), 32.)
3143
3144  def testWhileGrad_ResourceVarInFunctionCall(self):
3145
3146    @def_function.function
3147    def foo(x, var):
3148      return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
3149
3150    @def_function.function
3151    def bar(var):
3152      r = control_flow_ops.while_loop(
3153          lambda i, _: i < 2,
3154          lambda i, x: (i + 1, foo(x, var)),
3155          [0, 0.0])[1]
3156      return gradients_impl.gradients(r, var)[0]
3157
3158    var = resource_variable_ops.ResourceVariable([1., 2., 3., 4.])
3159    self.evaluate(variables.global_variables_initializer())
3160    grad = self.evaluate(bar(var))
3161    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
3162
3163  def testWhileGrad_ResourceVarInNestedFunctionCall(self):
3164
3165    @def_function.function
3166    def foo(x, var):
3167      return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
3168
3169    @def_function.function
3170    def foo2(x, var):
3171      return foo(x, var)
3172
3173    @def_function.function
3174    def bar(var):
3175      r = control_flow_ops.while_loop(
3176          lambda i, _: i < 2,
3177          lambda i, x: (i + 1, foo2(x, var)),
3178          [0, 0.0])[1]
3179      return gradients_impl.gradients(r, var)[0]
3180
3181    var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
3182    self.evaluate(variables.global_variables_initializer())
3183    grad = self.evaluate(bar(var))
3184    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
3185
3186  def testWhileGrad_ResourceVarInLoopInFunctionCall(self):
3187    if test.is_gpu_available():
3188      self.skipTest("b/128635252")
3189
3190    @def_function.function
3191    def foo(x, var):
3192      return control_flow_ops.while_loop(
3193          lambda j, _: j < 3,
3194          lambda j, y: (j + 1,
3195                        y + math_ops.reduce_sum(var.sparse_read([1, 2]))),
3196          [0, x])[1]
3197
3198    @def_function.function
3199    def bar(var):
3200      r = control_flow_ops.while_loop(
3201          lambda i, _: i < 2,
3202          lambda i, x: (i + 1, foo(x, var)),
3203          [0, 0.0])[1]
3204      return gradients_impl.gradients(r, var)[0]
3205
3206    var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
3207    self.evaluate(variables.global_variables_initializer())
3208    grad = self.evaluate(bar(var))
3209    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 6., 6., 0.])
3210
3211  def testWhileCondGrad_ResourceVarInFunctionCall(self):
3212
3213    @def_function.function
3214    def foo(x, var):
3215      return x + var.sparse_read([1])[0]
3216
3217    def body(i, x):
3218      return (i + 1, control_flow_ops.cond(
3219          math_ops.equal(i % 2, 0),
3220          lambda: foo(x, var1),
3221          lambda: foo(x, var2)))
3222
3223    @def_function.function
3224    def bar(var1, var2):
3225      r = control_flow_ops.while_loop(
3226          lambda i, _: i < 4, body, [0, 0.0])
3227      return gradients_impl.gradients(r, [var1, var2])
3228
3229    var1 = resource_variable_ops.ResourceVariable([1., 2., 3.])
3230    var2 = resource_variable_ops.ResourceVariable([4., 5.])
3231    self.evaluate(variables.global_variables_initializer())
3232    grads = self.evaluate(bar(var1, var2))
3233    self.assertAllEqual(gradient_checker_v2._to_numpy(grads[0]), [0., 2., 0.])
3234    self.assertAllEqual(gradient_checker_v2._to_numpy(grads[1]), [0., 2.])
3235
3236  @test_util.run_deprecated_v1
3237  def testWhileGrad_ResourceVarSparseRead(self):
3238    # NOTE(skyewm): this test is interesting because the gradient is the
3239    # aggregation result of IndexedSlices and Tensors.
3240    var = resource_variable_ops.ResourceVariable(np.ones(5),
3241                                                 dtype=dtypes.float32)
3242    r = control_flow_ops.while_loop(
3243        lambda i, _: i < 3,
3244        lambda i, x: (i + 1, x * math_ops.reduce_sum(var.sparse_read([1, 3]))),
3245        [0, constant_op.constant(1.0)])[1]
3246    grad = gradients_impl.gradients(r, var)[0]
3247
3248    self.evaluate(variables.global_variables_initializer())
3249    grad_val = self.evaluate(grad)
3250    arr = gradient_checker_v2._to_numpy(grad_val)
3251    self.assertAllEqual(arr, [0., 12., 0., 12., 0.])
3252
3253  @test_util.run_deprecated_v1
3254  def testWhileGrad_MultiResourceVarSparseRead(self):
3255    # NOTE(skyewm): this test is interesting because the gradient is the
3256    # aggregation result of IndexedSlices and Tensors.
3257    var1 = resource_variable_ops.ResourceVariable(np.ones(5),
3258                                                  dtype=dtypes.float32)
3259    var2 = resource_variable_ops.ResourceVariable(np.ones(3),
3260                                                  dtype=dtypes.float32)
3261    x1_init = constant_op.constant([0., 0.])
3262    x2_init = constant_op.constant(1.)
3263    x3_init = constant_op.constant(1.)
3264
3265    def body(i, unused_x1, x2, x3):
3266      y1 = var1.sparse_read([1, 3])
3267      y2 = x2 * 2
3268      y3 = x3 * math_ops.reduce_sum(var2.sparse_read([0]))
3269      return i + 1, y1, y2, y3
3270
3271    r = control_flow_ops.while_loop(
3272        lambda i, x1, x2, x3: i < 3, body,
3273        [0, x1_init, x2_init, x3_init])[1:]
3274    var1_grad, var2_grad = gradients_impl.gradients(r, [var1, var2])
3275
3276    self.evaluate(variables.global_variables_initializer())
3277    var1_grad_val = self.evaluate(var1_grad)
3278    var2_grad_val = self.evaluate(var2_grad)
3279    self.assertAllEqual(gradient_checker_v2._to_numpy(var1_grad_val),
3280                        [0., 1., 0., 1., 0.])
3281    self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
3282                        [3., 0., 0.])
3283
3284  def testWhileGrad_Gather(self):
3285    # NOTE(skyewm): this test is interesting because the gather gradient
3286    # function returns an IndexedSlices.
3287    @tf_function_in_tf2
3288    def fn():
3289      x = constant_op.constant([1., 1., 1., 1., 1.])
3290      y = control_flow_ops.while_loop(
3291          lambda i, _: i < 3,
3292          lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
3293          [0, x[:1]])[1]
3294      z = y * 3.0
3295      grad = gradients_impl.gradients(z, x)[0]
3296      return y, grad
3297    y, grad = fn()
3298    self.assertEqual(self.evaluate(y), 8.)
3299    self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])
3300
3301  def testWhileGrad_GatherNoFanOut(self):
3302    # NOTE(skyewm): this test is interesting because the gather gradient
3303    # function returns an IndexedSlices.
3304    @tf_function_in_tf2
3305    def fn():
3306      x = constant_op.constant([1., 1., 1., 1., 1.])
3307      y = control_flow_ops.while_loop(
3308          lambda i, _: i < 3,
3309          lambda i, x: (i + 1, array_ops.gather(x, [0])),
3310          [0, x[:1]])[1]
3311      z = y * 3.0
3312      grad = gradients_impl.gradients(z, x)[0]
3313      return y, grad
3314    y, grad = fn()
3315    self.assertEqual(self.evaluate(y), 1.)
3316    self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])
3317
3318  @test_util.run_v1_only("b/120545219")
3319  def testWhileGradInCond(self):
3320
3321    with self.cached_session():
3322      n = ops.convert_to_tensor(1.0, name="n")
3323      x = array_ops.placeholder(dtypes.float32, shape=None)
3324      c = lambda n: math_ops.less(n, 10.0)
3325      b = lambda n: math_ops.add(n, x)
3326
3327      def fn1():
3328        r = control_flow_ops.while_loop(c, b, [n],
3329                                        [tensor_shape.unknown_shape()])
3330        return gradients_impl.gradients(r, x)[0]
3331
3332      r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
3333      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
3334
3335  @test_util.disable_control_flow_v2("b/116340060")
3336  @test_util.run_v1_only("b/120545219")
3337  def testGradInWhileWrtInitialLoopVal(self):
3338    with self.cached_session():
3339      x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
3340      y = x + 1
3341
3342      def body(i, v):
3343        z = v * 2
3344        return i + 1, gradients_impl.gradients(z, x)[0]
3345
3346      with self.assertRaisesRegex(
3347          ValueError,
3348          "Cannot compute gradient inside while loop with respect to op 'x'. "
3349          "We do not support taking the gradient wrt or through the initial "
3350          "value of a loop variable. Gradients can be computed through "
3351          "loop invariants or wrt the input parameters to the loop body."):
3352        control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
3353
3354  @test_util.run_v1_only("b/120545219")
3355  def testWhileGradInWhile(self):
3356    with self.cached_session():
3357      n = ops.convert_to_tensor(1.0, name="n")
3358      x = array_ops.placeholder(dtypes.float32, shape=None)
3359      c = lambda n: math_ops.less(n, 10.0)
3360      b = lambda n: math_ops.add(n, x)
3361
3362      def b1(n):
3363        r = control_flow_ops.while_loop(c, b, [n],
3364                                        [tensor_shape.unknown_shape()])
3365        return gradients_impl.gradients(r, x)
3366
3367      r = control_flow_ops.while_loop(lambda n: n < 6.0, b1, [n],
3368                                      [tensor_shape.unknown_shape()])
3369      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
3370
3371  @test_util.run_v1_only("b/120545219")
3372  def testCondGradInNestedWhiles(self):
3373
3374    def outer_body(i, x):
3375      _, x = control_flow_ops.while_loop(
3376          lambda j, x: j < 3, inner_body, [0, 0.0])
3377      return i + 1, x
3378
3379    def inner_body(j, x):
3380      y = control_flow_ops.cond(math_ops.less(x, 1), lambda: 2 * x, lambda: x)
3381      return j + 1, gradients_impl.gradients(y, x)[0]
3382
3383    i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0])
3384
3385    with self.cached_session() as sess:
3386      i_val, x_val = self.evaluate([i, x])
3387      self.assertEqual(i_val, 3)
3388      self.assertAllClose(x_val, 1.0)
3389
3390  @test_util.run_gpu_only
3391  def testGpuResourceAccess(self):
3392    with ops.device(test.gpu_device_name()):
3393      var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))
3394
3395    @def_function.function
3396    def foo():
3397      return control_flow_ops.while_loop(
3398          lambda i, _: i < 3,
3399          lambda i, x: (i + 1, control_flow_ops.cond(
3400              constant_op.constant(True),
3401              lambda: x + var,
3402              lambda: x)),
3403          [0, 0.0])[1]
3404
3405    self.evaluate(variables.global_variables_initializer())
3406    self.assertEqual(self.evaluate(foo()), 9.0)
3407
3408  def testNestedResourceAccess(self):
3409    var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))
3410
3411    @eager_function.defun
3412    def test_fn():
3413      x = constant_op.constant(0.0)
3414      r = control_flow_ops.while_loop(
3415          # Outer loop condition
3416          lambda i, y: i < 2,
3417          # Outer loop body
3418          lambda i, y: (i + 1, y + control_flow_ops.cond(
3419              constant_op.constant(True),
3420              # True branch
3421              lambda: control_flow_ops.while_loop(
3422                  # Inner loop condition
3423                  lambda j, z: j < 3,
3424                  # Inner loop body
3425                  lambda j, z: (j + 1, z + math_ops.square(var)),
3426                  # Inner initial loop value
3427                  [0, y])[1],
3428              # False branch
3429              lambda: (0.0))),
3430          # Outer initial loop value
3431          [0, x])[1]
3432
3433      grad = gradients_impl.gradients(r, x)[0]
3434      return r, grad
3435
3436    self.evaluate(variables.global_variables_initializer())
3437    r, grad = self.evaluate(test_fn())
3438    # 2 * 3 * 3^2
3439    self.assertEqual(r, 81.0)
3440    # v1 control flow gets the wrong answer!!!
3441    # Gradient computation:
3442    #   f(x) = x + 3^2
3443    #   inner_loop(x) = f(f(f(x))) = x + 3*3^2 = x + 27
3444    #   g(x) = x + inner_loop(x) = 2x + 27
3445    #   outer_loop(x) = g(g(x)) = 4x + 81
3446    #   outer_loop'(x) = 4
3447    # Note that v1 control flow gets 4.0 as well if the cond is removed.
3448    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
3449      self.assertEqual(grad, 4.0)
3450
3451  def testWhile_NestedInput(self):
3452    with self.cached_session() as sess:
3453      named = collections.namedtuple("named", ("a", "b"))
3454      loop_vars = [
3455          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
3456          (constant_op.constant(2.0), constant_op.constant(3.0)),
3457          constant_op.constant(4.0)
3458      ]
3459      c = lambda lv0, _1, _2: lv0.a < 100.0
3460
3461      def b(lv0, lv1, lv2):
3462        lv0 = named(a=lv0.a + 1, b=lv0.b)
3463        lv1 = (lv1[0] + 1, lv1[1])
3464        lv2 += 2
3465        return [lv0, lv1, lv2]
3466
3467      r = control_flow_ops.while_loop(c, b, loop_vars)
3468
3469      self.assertTrue(isinstance(r, list))
3470      self.assertTrue(isinstance(r[0], named))
3471      self.assertTrue(isinstance(r[1], tuple))
3472      self.assertTrue(isinstance(r[2], ops.Tensor))
3473
3474      r_flattened = nest.flatten(r)
3475      self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
3476                       self.evaluate(r_flattened))
3477
3478  @test_util.run_v1_only("b/120545219")
3479  def testWhile_NestedBadArityFails(self):
3480    with self.cached_session():
3481      named = collections.namedtuple("named", ("a", "b"))
3482      loop_vars = [
3483          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
3484          (constant_op.constant(2.0), constant_op.constant(3.0)),
3485          constant_op.constant(4.0)
3486      ]
3487      c = lambda lv0, _1, _2: lv0.a < 100.0
3488
3489      def b(lv0, lv1, _):
3490        return [lv0, lv1]
3491
3492      with self.assertRaisesRegex(ValueError, "the same number of elements"):
3493        control_flow_ops.while_loop(c, b, loop_vars)
3494
3495  @test_util.run_v1_only("b/120545219")
3496  def testWhileGrad_ys_xs(self):
3497    with self.cached_session():
3498      x = constant_op.constant(3.0, name="x")
3499      y = constant_op.constant(2.0, name="y")
3500
3501      c = lambda x, y: math_ops.less(x, 100.0)
3502
3503      def b(x, y):
3504        y1 = math_ops.add(x, y)
3505        x1 = math_ops.multiply(x, y1)
3506        return x1, y1
3507
3508      rx, ry = control_flow_ops.while_loop(c, b, [x, y], parallel_iterations=1)
3509
3510      r = gradients_impl.gradients([rx, ry], x)
3511      self.assertAllClose(304.0, r[0])
3512      r = gradients_impl.gradients([rx, ry], y)
3513      self.assertAllClose(124.0, r[0])
3514      r = gradients_impl.gradients([rx], x)
3515      self.assertAllClose(295.0, r[0])
3516      r = gradients_impl.gradients([rx], y)
3517      self.assertAllClose(120.0, r[0])
3518
3519  @test_util.run_deprecated_v1
3520  def testWhileGrad_Dependency(self):
3521    with self.cached_session():
3522      i = constant_op.constant(0, name="i")
3523      x = constant_op.constant(2.0, name="x")
3524
3525      c = lambda i, x: math_ops.less(i, 10)
3526
3527      def b(i, x):
3528        x = math_ops.multiply(x, 2.0)
3529        i = math_ops.add(i, 1)
3530        return i, x
3531
3532      ri, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3533
3534      r = gradients_impl.gradients([ri, rx], x)
3535      self.assertAllClose(1024.0, r[0])
3536      r = gradients_impl.gradients([rx], x)
3537      self.assertAllClose(1024.0, r[0])
3538
3539  @test_util.run_v1_only("b/120545219")
3540  def testWhileGrad_NoGradient(self):
3541    with self.cached_session():
3542      v = constant_op.constant(2.0, name="v")
3543      c = lambda v: math_ops.less(v, 100.0)
3544      b = math_ops.square
3545      r = control_flow_ops.while_loop(c, b, [v], back_prop=False)
3546      r = math_ops.add(r, v)
3547      r = gradients_impl.gradients(r, v)
3548      self.assertAllClose(1.0, r[0])
3549
3550  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
3551  @test_util.run_v1_only("b/120545219")
3552  def testWhileGrad_NoDependency(self):
3553    with self.cached_session() as sess:
3554      variable = variables.Variable(array_ops.ones([2, 3]))
3555      duration = array_ops.zeros([], dtype=dtypes.int32)
3556
3557      def cond(duration, tensor, _):
3558        del tensor
3559        return duration < 10
3560
3561      def body(duration, tensor, _):
3562        return (duration + 1, tensor, tensor)
3563
3564      loop_vars = [duration, variable, variable]
3565      tensors = control_flow_ops.while_loop(
3566          cond=cond, body=body, loop_vars=loop_vars)
3567      cost = math_ops.reduce_sum(tensors[2])
3568      grad = gradients_impl.gradients(cost, [variable])
3569      self.evaluate(variables.global_variables_initializer())
3570      self.assertAllClose(np.ones([2, 3]), sess.run(grad[0]))
3571
3572  @test_util.run_deprecated_v1
3573  def testWhileGrad_Const(self):
3574    with self.cached_session() as sess:
3575      c0 = constant_op.constant(0.0, name="c0")
3576      c1 = constant_op.constant(1.0, name="c1")
3577      duration = constant_op.constant(0, name="t")
3578
3579      def cond(duration, _):
3580        return duration < 1
3581
3582      def body(duration, _):
3583        return duration + 1, c1
3584
3585      loop_vars = [duration, c0]
3586      tensors = control_flow_ops.while_loop(
3587          cond=cond, body=body, loop_vars=loop_vars)
3588      cost = math_ops.reduce_sum(tensors[1])
3589      grad = gradients_impl.gradients(cost, [c0])
3590      self.assertAllClose(0.0, sess.run(grad[0]))
3591
3592  @test_util.run_v1_only("b/120545219")
3593  def testWhileGrad_SerialTwoLoops(self):
3594    with self.cached_session():
3595      i = constant_op.constant(0, name="i")
3596      x = constant_op.constant(2.0, name="x")
3597
3598      c = lambda i, x: math_ops.less(i, 5)
3599
3600      def b(i, x):
3601        x = math_ops.multiply(x, 2.0)
3602        i = math_ops.add(i, 1)
3603        return i, x
3604
3605      _, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3606      _, rx = control_flow_ops.while_loop(c, b, [i, rx], parallel_iterations=1)
3607
3608      r = gradients_impl.gradients([rx], x)
3609      self.assertAllClose(1024.0, r[0])
3610
3611  @test_util.run_v1_only("b/120545219")
3612  def testWhileGrad_ParallelTwoLoops(self):
3613    with self.cached_session():
3614      i = constant_op.constant(0, name="i")
3615      x = constant_op.constant(2.0, name="x")
3616
3617      c = lambda i, x: math_ops.less(i, 5)
3618
3619      def b(i, x):
3620        x = math_ops.multiply(x, 2.0)
3621        i = math_ops.add(i, 1)
3622        return i, x
3623
3624      _, r1 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3625      _, r2 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3626      rx = math_ops.add(r1, r2)
3627
3628      r = gradients_impl.gradients([rx], x)
3629      self.assertAllClose(64.0, r[0])
3630
3631  @test_util.run_v1_only("b/120545219")
3632  def testWhileGrad_OneOutputWithControlDependencyOnSecond(self):
3633    with self.cached_session():
3634      i = constant_op.constant(0, name="i")
3635      x = constant_op.constant(1.0, name="x")
3636      y = constant_op.constant(1.0, name="y")
3637      c = lambda i, *_: math_ops.less(i, 1, name="cond_less")
3638
3639      def b(i, xi, yi):
3640        # return (i + 1, xi, xi + yi)
3641        return (math_ops.add(i, 1, name="inc"), array_ops.identity(
3642            xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi"))
3643
3644      _, x_f, y_f = control_flow_ops.while_loop(c, b, [i, x, y])
3645      with ops.control_dependencies([x_f]):
3646        y_f_d = array_ops.identity(y_f, name="y_f_d")
3647
3648      self.assertAllClose(2.0, self.evaluate(y_f_d))  # y_f_d = 1.0 + 1.0
3649      g = gradients_impl.gradients([y_f_d], [x])[0]
3650      self.assertTrue(g is not None)
3651      self.assertAllClose(1.0,
3652                          self.evaluate(g))  # y_f_d = x + 1.0, dy_f_d/dx = 1.0
3653
3654  def _testNestedWhileGrad_Simple(self, use_gpu):
3655    with self.cached_session(use_gpu=use_gpu):
3656      v = constant_op.constant(1.0)
3657
3658      def inner_loop(s):
3659        c = lambda x: math_ops.less(x, 4.0)
3660        b = lambda x: math_ops.multiply(x, 2.0)
3661        return control_flow_ops.while_loop(c, b, [s])
3662
3663      c = lambda x: math_ops.less(x, 2.0)
3664      b = lambda x: math_ops.multiply(inner_loop(x), 2.0)
3665      r = control_flow_ops.while_loop(c, b, [v])
3666
3667      r = gradients_impl.gradients(r, v)[0]
3668      self.assertAllClose(8.0, self.evaluate(r))
3669
3670  @test_util.run_deprecated_v1
3671  def testNestedWhileGrad_Simple(self):
3672    self._testNestedWhileGrad_Simple(use_gpu=False)
3673    self._testNestedWhileGrad_Simple(use_gpu=True)
3674
3675  @test_util.run_v1_only("b/120545219")
3676  def testNestedWhileGrad_SerialInner(self):
3677    with self.cached_session():
3678      v = constant_op.constant(1.0)
3679
3680      def inner_loop1(s):
3681        z = constant_op.constant(0)
3682        c = lambda i, x: math_ops.less(i, 4)
3683        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3684        return control_flow_ops.while_loop(c, b, [z, s])
3685
3686      def inner_loop2(s):
3687        z = constant_op.constant(0)
3688        c = lambda i, x: math_ops.less(i, 4)
3689        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3690        return control_flow_ops.while_loop(c, b, [z, s])
3691
3692      c = lambda x: math_ops.less(x, 128.0)
3693      b = lambda x: inner_loop2(inner_loop1(x)[1])[1]
3694      r = control_flow_ops.while_loop(c, b, [v])
3695
3696      r = gradients_impl.gradients(r, v)[0]
3697      self.assertAllClose(256.0, self.evaluate(r))
3698
3699  @test_util.run_deprecated_v1
3700  def testNestedWhileGrad_ParallelInner(self):
3701    with self.cached_session():
3702      v = constant_op.constant(1.0)
3703
3704      def inner_loop1(s):
3705        z = constant_op.constant(0)
3706        c = lambda i, x: math_ops.less(i, 4)
3707        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3708        return control_flow_ops.while_loop(c, b, [z, s])
3709
3710      def inner_loop2(s):
3711        z = constant_op.constant(0)
3712        c = lambda i, x: math_ops.less(i, 4)
3713        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3714        return control_flow_ops.while_loop(c, b, [z, s])
3715
3716      c = lambda x: math_ops.less(x, 128.0)
3717      b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1])
3718      r = control_flow_ops.while_loop(c, b, [v])
3719
3720      r = gradients_impl.gradients(r, v)[0]
3721      self.assertAllClose(512.0, self.evaluate(r))
3722
3723  @test_util.run_v1_only("b/120545219")
3724  def testNestedWhileGrad_ParallelIterations(self):
3725    # Make sure the stack pushes and pops of an inner loop are executed in
3726    # the sequential order of the iterations of its outer loop.
3727    with self.cached_session() as sess:
3728
3729      def inner_loop(t):
3730        fn = lambda n: n + math_ops.square(var)
3731        return map_fn.map_fn(fn=fn, elems=t, parallel_iterations=10)
3732
3733      def outer_loop(inp):
3734        return map_fn.map_fn(
3735            fn=inner_loop, elems=inp, parallel_iterations=10)
3736
3737      var = variables.Variable(constant_op.constant(3.0))
3738      inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
3739      res = outer_loop(inp)
3740      optimizer = adam.AdamOptimizer(learning_rate=0.001)
3741      train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res)))
3742      self.evaluate(variables.global_variables_initializer())
3743      self.evaluate(train_op)
3744      self.assertAllClose(2.999, var.read_value())
3745
3746  def _testWhileCondGrad_Simple(self, use_gpu):
3747    with self.cached_session(use_gpu=use_gpu):
3748      v = ops.convert_to_tensor(2.0, name="v")
3749      n = ops.convert_to_tensor(100.0, name="n")
3750      one = ops.convert_to_tensor(1.0, name="one")
3751      c = lambda x: math_ops.less(x, n)
3752      # pylint: disable=undefined-variable
3753      # for OSS build
3754      b = lambda x: control_flow_ops.cond(constant_op.constant(True),
3755                                          lambda: math_ops.square(x),
3756                                          lambda: math_ops.subtract(x, one))
3757      # pylint: enable=undefined-variable
3758      r = control_flow_ops.while_loop(c, b, [v])
3759      r = gradients_impl.gradients(r, v)[0]
3760      self.assertAllClose(1024.0, self.evaluate(r))
3761
3762  @test_util.run_deprecated_v1
3763  def testWhileCondGrad_Simple(self):
3764    self._testWhileCondGrad_Simple(use_gpu=False)
3765    self._testWhileCondGrad_Simple(use_gpu=True)
3766
3767  @test_util.run_deprecated_v1
3768  def testWhileCondGrad_UnknownShape(self):
3769    with self.cached_session() as sess:
3770      v = array_ops.placeholder(dtypes.float32)
3771      n = ops.convert_to_tensor(100.0, name="n")
3772      one = ops.convert_to_tensor(1.0, name="one")
3773      c = lambda x: math_ops.less(x, n)
3774      # pylint: disable=undefined-variable
3775      # for OSS build
3776      b = lambda x: control_flow_ops.cond(constant_op.constant(True),
3777                                          lambda: math_ops.square(x),
3778                                          lambda: math_ops.subtract(x, one))
3779      # pylint: enable=undefined-variable
3780      r = control_flow_ops.while_loop(c, b, [v])
3781      r = gradients_impl.gradients(r, v)[0]
3782      r = sess.run(r, feed_dict={v: 2.0})
3783      self.assertAllClose(1024.0, r)
3784
3785  @test_util.run_deprecated_v1
3786  def testWhileGrad_Concat(self):
3787    with self.cached_session() as sess:
3788      x = variable_scope.get_variable("x", initializer=[[1., 2.]])
3789      i0 = constant_op.constant(0)
3790      h0 = array_ops.zeros([0, 2])
3791
3792      def condition(i, _):
3793        return i < 2
3794
3795      def body(i, h):
3796        return i + 1, array_ops.concat([h, x], 0)
3797
3798      _, h = control_flow_ops.while_loop(
3799          condition, body, [i0, h0],
3800          [i0.get_shape(), tensor_shape.TensorShape([None, 2])])
3801      s = math_ops.reduce_sum(h)
3802
3803      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
3804      op = optimizer.minimize(s)
3805
3806      self.evaluate(variables.global_variables_initializer())
3807      self.evaluate(op)
3808      self.assertAllClose([[0.98000002, 1.98000002]], self.evaluate(x))
3809
3810  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
3811  @test_util.run_v1_only("b/120545219")
3812  def testWhileWithRefsWithGradients_1(self):
3813    with self.cached_session() as sess:
3814      x = variables.VariableV1(0.)._ref()  # pylint: disable=protected-access
3815      i = constant_op.constant(0)
3816      c = lambda i, x: math_ops.less(i, 10)
3817
3818      self.assertEqual(x.dtype, dtypes.float32_ref)
3819
3820      def body(i, x):
3821        self.assertEqual(x.dtype, dtypes.float32_ref)
3822        return [i + 1, gen_array_ops.ref_identity(x)]
3823
3824      r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
3825
3826      grad_ys = [variables.VariableV1(73)._ref()]  # pylint: disable=protected-access
3827      grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys)
3828
3829      self.evaluate(variables.global_variables_initializer())
3830
3831      self.assertEqual(r[0].dtype, dtypes.int32)
3832      self.assertEqual(r[1].dtype, dtypes.float32_ref)
3833
3834      value_i, value_x, value_x_grad = sess.run(r + grad)
3835
3836    self.assertEqual(10, value_i)
3837    self.assertEqual(0, value_x)
3838    self.assertEqual(73, value_x_grad)
3839
3840  @test_util.deprecated_graph_mode_only
3841  def testWhileGrad_IndexedSlices(self):
3842    with self.cached_session():
3843      values = constant_op.constant([2.0, 4.0], name="values")
3844      indices = constant_op.constant([0, 3], name="indices")
3845      shape = constant_op.constant([10], name="dense_shape")
3846      i = constant_op.constant(0)
3847      x = ops.IndexedSlices(values, indices, dense_shape=shape)
3848
3849      def c(i, _):
3850        return i < 10
3851
3852      def b(i, x):
3853        return [
3854            i + 1,
3855            ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
3856        ]
3857
3858      _, r = control_flow_ops.while_loop(c, b, [i, x])
3859      r = gradients_impl.gradients(r.values, values)[0]
3860      self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
3861
3862  @test_util.deprecated_graph_mode_only
3863  def testWhileGrad_SparseTensor(self):
3864    with self.cached_session():
3865      values = constant_op.constant([2.0, 4.0], name="values")
3866      indices = constant_op.constant(
3867          [[0], [3]], dtype=dtypes.int64, name="indices")
3868      shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
3869      i = constant_op.constant(0)
3870      x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
3871
3872      def c(i, _):
3873        return i < 10
3874
3875      def b(i, x):
3876        return [
3877            i + 1,
3878            sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
3879        ]
3880
3881      _, r = control_flow_ops.while_loop(c, b, [i, x])
3882      r = gradients_impl.gradients(r.values, values)[0]
3883      self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
3884
3885  @test_util.deprecated_graph_mode_only
3886  def testCallGradInLoop(self):
3887    with self.cached_session() as sess:
3888      i0 = constant_op.constant(0)
3889      params = constant_op.constant(5.0)
3890      params_1 = math_ops.square(params)
3891
3892      def c(i, _):
3893        return i < 10
3894
3895      def b(i, x):
3896        data = constant_op.constant([1.0, 2.0, 3.0])
3897        data = math_ops.multiply(data, params_1)
3898        x1 = x + gradients_impl.gradients(data, params)[0]
3899        return i + 1, x1
3900
3901      output_grad = control_flow_ops.while_loop(
3902          c, b, [i0, constant_op.constant(0.0)])
3903      self.assertAllClose(600.0, self.evaluate(output_grad)[1])
3904
3905  @test_util.run_deprecated_v1
3906  def testWhileAndTensorArray(self):
3907    with self.cached_session() as sess:
3908      param = constant_op.constant(2.0)
3909      n0 = constant_op.constant(0)
3910      y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
3911
3912      def c(i, _):
3913        return i < 10
3914
3915      def b(i, y):
3916        return [
3917            i + 1,
3918            map_fn.map_fn(lambda x: math_ops.multiply(x, param), y)
3919        ]
3920
3921      r = control_flow_ops.while_loop(c, b, [n0, y0], parallel_iterations=1)
3922      r = gradients_impl.gradients(r, param)[0]
3923      self.assertAllClose(107520.0, self.evaluate(r))
3924
3925  @test_util.run_deprecated_v1
3926  def testNestedWhileAndTensorArray(self):
3927    n = constant_op.constant(3.0)
3928
3929    def Body(row, ta):
3930
3931      def InnerBody(row, col, ta):
3932        # Note: row and col are 1-based.
3933        ta = ta.write(
3934            math_ops.cast(n * (row - 1.) + col - 1., dtypes.int32), row * col)
3935        return row, col + 1., ta
3936
3937      ta = control_flow_ops.while_loop(
3938          lambda _, col, _1: col <= n,
3939          InnerBody, [row, constant_op.constant(1.), ta],
3940          return_same_structure=False)[2]
3941      return row + 1., ta
3942
3943    ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=9)
3944    ta = control_flow_ops.while_loop(
3945        lambda row, _: row <= n,
3946        Body, [constant_op.constant(1.), ta],
3947        return_same_structure=False)[1]
3948
3949    output = array_ops.reshape(ta.stack(), [3, 3])
3950    self.assertAllEqual(
3951        self.evaluate(output), [[1., 2., 3.], [2., 4., 6.], [3., 6., 9.]])
3952    # TODO(b/117675481): This does not work with current TA. Enable with new TA.
3953    # grad = gradients_impl.gradients(output, [n])
3954    # self.assertEqual(self.evaluate(grad), 3.5)
3955
3956  @test_util.run_deprecated_v1
3957  def testWhileGrad_StopGrad(self):
3958    with self.cached_session():
3959      x = constant_op.constant(3.0, name="x")
3960      y = constant_op.constant(2.0, name="y")
3961
3962      c = lambda x, y: math_ops.less(x, 100.0)
3963
3964      def b(x, y):
3965        y1 = math_ops.square(y)
3966        x1 = math_ops.add(math_ops.square(x), y1)
3967        return x1, y1
3968
3969      rx, ry = control_flow_ops.while_loop(c, b, [x, y])
3970
3971      r = gradients_impl.gradients(rx, y)[0]
3972      self.assertEqual(136.0, self.evaluate(r))
3973      r = gradients_impl.gradients(ry, y)[0]
3974      self.assertEqual(32.0, self.evaluate(r))
3975
3976      r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0]
3977      self.assertEqual(r, None)
3978      r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0]
3979      self.assertEqual(r, None)
3980
3981      r = gradients_impl.gradients(
3982          array_ops.stop_gradient(math_ops.square(rx)), y)[0]
3983      self.assertEqual(r, None)
3984      r = gradients_impl.gradients(
3985          array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0]
3986      self.assertEqual(r, None)
3987      r = gradients_impl.gradients(
3988          array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0]
3989      self.assertEqual(r, None)
3990
3991      r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0]
3992      self.assertEqual(168.0, self.evaluate(r))
3993      r = gradients_impl.gradients(
3994          math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0]
3995      self.assertEqual(136.0, self.evaluate(r))
3996      r = gradients_impl.gradients(
3997          math_ops.add(array_ops.stop_gradient(rx), ry), y)[0]
3998      self.assertEqual(32.0, self.evaluate(r))
3999
4000  @test_util.run_deprecated_v1
4001  def testWhileGrad_StopGradInside(self):
4002    with self.cached_session():
4003      x = constant_op.constant(3.0, name="x")
4004      y = constant_op.constant(2.0, name="y")
4005
4006      c = lambda x, y: math_ops.less(x, 100.0)
4007
4008      def b(x, y):
4009        y1 = array_ops.stop_gradient(math_ops.square(y))
4010        x1 = math_ops.add(math_ops.square(x), y1)
4011        return x1, y1
4012
4013      rx, _ = control_flow_ops.while_loop(c, b, [x, y])
4014
4015      r = gradients_impl.gradients(rx, y)[0]
4016      self.assertAllClose(0.0, self.evaluate(r))
4017      r = gradients_impl.gradients(rx, x)[0]
4018      self.assertAllClose(156.0, self.evaluate(r))
4019
4020  @test_util.run_deprecated_v1
4021  def testWhileGrad_StopGradInsideNoShape(self):
4022    with self.cached_session() as sess:
4023      x = array_ops.placeholder(dtypes.float32)
4024      y = array_ops.placeholder(dtypes.float32)
4025
4026      c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0)
4027
4028      def b(x, y):
4029        y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped"))
4030        x1 = math_ops.add(math_ops.square(x), y1)
4031        return x1, y1
4032
4033      rx, _ = control_flow_ops.while_loop(c, b, [x, y])
4034
4035      grad_y = gradients_impl.gradients(rx, y)[0]
4036      grad_x = gradients_impl.gradients(rx, x)[0]
4037      feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]}
4038      self.assertAllClose([0.0, 0.0], sess.run(grad_y, feed_dict=feed_dict))
4039      self.assertAllClose([156.0, 400.0], sess.run(grad_x, feed_dict=feed_dict))
4040      name = "gradients/while/stopped_grad"
4041      all_ops = x.graph.get_operations()
4042      self.assertFalse(any(name in op.name for op in all_ops))
4043
4044  @test_util.run_deprecated_v1
4045  def testWhileGradGradFail(self):
4046    theta = variables.Variable(initial_value=1.)
4047
4048    def fn(prev, x):
4049      return prev + x * theta
4050
4051    result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
4052    grad_theta = gradients_impl.gradients(result, theta)
4053    if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
4054      with self.assertRaisesRegex(TypeError, "Second-order gradient"):
4055        gradients_impl.gradients(grad_theta, theta)
4056    grad_theta_stopped = array_ops.stop_gradient(grad_theta)
4057    gradients_impl.gradients(grad_theta_stopped, theta)
4058
4059  @test_util.run_deprecated_v1
4060  def testStopGradOnWhileGrad(self):
4061    with self.cached_session():
4062      x = constant_op.constant(2.0, name="x")
4063      y = constant_op.constant(2.0, name="y")
4064
4065      c = lambda x: math_ops.less(x, 100.0)
4066      b = lambda x: math_ops.multiply(x, y)
4067      rx = control_flow_ops.while_loop(c, b, [x])
4068
4069      rg = gradients_impl.gradients(rx, y)[0]
4070      rg = array_ops.stop_gradient(rg)
4071      r = math_ops.add(math_ops.square(y), rx)
4072      r = math_ops.add(r, rg)
4073      r = gradients_impl.gradients(r, y)[0]
4074      self.assertEqual(388.0, self.evaluate(r))
4075
4076  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
4077  @test_util.run_deprecated_v1
4078  def testWhileGradientWithNontrainablePath1(self):
4079    q = variables.Variable([7., 8.])
4080
4081    def cond(_, y):
4082      del y
4083      return False
4084
4085    def body(x, _):
4086      return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
4087
4088    _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
4089    dy_dq, = gradients_impl.gradients(y, q)
4090    self.assertIsNotNone(dy_dq)
4091    with self.cached_session() as sess:
4092      self.evaluate(q.initializer)
4093      self.assertAllClose([0., 0.], self.evaluate(dy_dq))
4094
4095  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
4096  @test_util.run_v1_only("b/120545219")
4097  def testWhileGradientWithNontrainablePath2(self):
4098    q = variables.Variable([7., 8.])
4099
4100    def cond(_, y):
4101      return math_ops.equal(y, 0.)
4102
4103    def body(x, _):
4104      zero = constant_op.constant(0, dtype=dtypes.int64)
4105      return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
4106
4107    _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
4108    dy_dq, = gradients_impl.gradients(y, q)
4109    self.assertIsNotNone(dy_dq)
4110    with self.cached_session() as sess:
4111      self.evaluate(q.initializer)
4112      self.assertAllClose([1., 1.], self.evaluate(dy_dq))
4113
4114  @test_util.run_v1_only("b/120545219")
4115  def testIssue16504(self):
4116    c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
4117    w = variables.Variable(
4118        initial_value=np.ones(100), dtype=dtypes.float32) / 100
4119    k = variables.Variable(0, dtype=dtypes.int32)
4120    chg_w = constant_op.constant(np.inf, dtype=dtypes.float32)
4121
4122    def cond(k, _, chg_w):
4123      return math_ops.logical_and(k < 10, chg_w > 1e-3)
4124
4125    def body(k, w, chg_w):
4126      grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w)
4127      w_n = w * math_ops.exp(-0.1 * grad)
4128      w_n /= math_ops.reduce_sum(w_n)
4129      chg_w = (
4130          math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum(
4131              math_ops.abs(w)))
4132      return k + 1, w_n, chg_w
4133
4134    _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w])
4135    grad, = gradients_impl.gradients(w, c)
4136    self.assertIsNotNone(grad)
4137
4138  @test_util.run_v1_only("b/120545219")
4139  def testStopGradMultiFlows(self):
4140    with self.cached_session():
4141
4142      def body(i, y, r):
4143        x = variable_scope.get_variable(
4144            "x",
4145            shape=(),
4146            dtype=dtypes.float32,
4147            initializer=init_ops.ones_initializer())
4148        y *= x
4149        return [i + 1, y, r + math_ops.reduce_sum(y)]
4150
4151      i0 = constant_op.constant(0)
4152      y0 = array_ops.ones(5)
4153      r0 = constant_op.constant(0.0)
4154      cond = lambda i, y, r: i < 1
4155      _, _, r = control_flow_ops.while_loop(
4156          cond, body, [i0, y0, r0], back_prop=True)
4157
4158      vars_ = variables.global_variables()
4159      grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0])
4160      z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads)))
4161      result = gradients_impl.gradients(z, vars_)[0]
4162      self.evaluate(variables.global_variables_initializer())
4163      self.assertEqual(5.0, self.evaluate(result))
4164
4165  @test_util.run_v1_only("b/120545219")
4166  def testOneValueCond(self):
4167
4168    with self.cached_session():
4169      c = array_ops.placeholder(dtypes.int32, shape=[])
4170      one = ops.convert_to_tensor(1, name="one")
4171      two = ops.convert_to_tensor(2, name="two")
4172      p = math_ops.greater_equal(c, 1)
4173      i = control_flow_ops.cond(p, lambda: one, lambda: two)
4174      self.assertTrue(isinstance(i, ops.Tensor))
4175
4176      # True case: c = 2 is >= 1
4177      self.assertEqual([1], i.eval(feed_dict={c: 2}))
4178
4179      # False case: c = 0 is not >= 1
4180      self.assertEqual([2], i.eval(feed_dict={c: 0}))
4181
4182  @test_util.run_deprecated_v1
4183  def testExampleCond(self):
4184
4185    with self.cached_session():
4186      x = ops.convert_to_tensor([-2.0, 2.0], name="x")
4187      d = array_ops.placeholder(dtypes.int32, shape=[])
4188
4189      def l2():
4190        return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x)))
4191
4192      def l1():
4193        return math_ops.reduce_sum(math_ops.abs(x))
4194
4195      i = control_flow_ops.cond(math_ops.equal(d, 2), l2, l1)
4196      self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
4197      self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
4198
4199  @test_util.run_v1_only("b/120545219")
4200  def testCase(self):
4201    with self.cached_session():
4202      x = constant_op.constant(1)
4203      y = constant_op.constant(2)
4204      z = constant_op.constant(3)
4205      f1 = lambda: constant_op.constant(17)
4206      f2 = lambda: constant_op.constant(23)
4207      f3 = lambda: constant_op.constant(-1)
4208
4209      r1 = control_flow_ops.case(
4210          {
4211              x < y: f1,
4212              x > z: f2
4213          }, default=f3, exclusive=True)
4214      self.assertAllEqual(r1, 17)
4215
4216      r2 = control_flow_ops.case([(y > z, f1), (y > x, f2)], default=f3)
4217      self.assertAllEqual(r2, 23)
4218
4219      # Duplicate events can happen, first one is selected
4220      r3 = control_flow_ops.case([(x < y, f1), (x < y, f2)], default=f3)
4221      self.assertAllEqual(r3, 17)
4222
4223      # Duplicate events cause an error if exclusive = True
4224      r4 = control_flow_ops.case(
4225          [(x < y, f1), (x < y, f2)], default=f3, exclusive=True)
4226      with self.assertRaisesOpError("Input error:"):
4227        self.evaluate(r4)
4228
4229      # Check that the default is called if none of the others are
4230      r5 = control_flow_ops.case({x > y: f1}, default=f3)
4231      self.assertAllEqual(r5, -1)
4232
4233      ran_once = [False, False, False]
4234
4235      def break_run_twice(ix):
4236
4237        def _break():
4238          ran_once[ix] = True
4239          return constant_op.constant(ix)
4240
4241        return _break
4242
4243      # Should not fail - each conditional gets called exactly once
4244      # except default.  Default gets called twice: once to create an
4245      # empty output and once for the actual cond switch.
4246      r6 = control_flow_ops.case(
4247          [(x < y, break_run_twice(0)), (x > y, break_run_twice(1))],
4248          default=lambda: constant_op.constant(2))
4249
4250      self.assertAllEqual(r6, 0)
4251
4252  @test_util.run_v1_only("b/120545219")
4253  def testCaseSideEffects(self):
4254    with self.cached_session() as sess:
4255      v0 = variables.Variable(-1)
4256      v1 = variables.Variable(-1)
4257      v2 = variables.Variable(-1)
4258
4259      a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0)
4260      b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1)
4261      c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2)
4262
4263      x = constant_op.constant(1)
4264      y = constant_op.constant(2)
4265
4266      r0 = control_flow_ops.case(
4267          ((x < y, a), (x > y, b)), default=c, exclusive=True)
4268      r1 = control_flow_ops.case(
4269          ((x > y, a), (x < y, b)), default=c, exclusive=True)
4270      r2 = control_flow_ops.case(
4271          ((x > y, a), (x > y, b)), default=c, exclusive=True)
4272
4273      self.evaluate(variables.global_variables_initializer())
4274      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
4275      self.assertEqual(2, self.evaluate(r2))
4276      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, -1, 2])
4277
4278      self.evaluate(variables.global_variables_initializer())
4279      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
4280      self.assertEqual(1, self.evaluate(r1))
4281      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, 1, -1])
4282
4283      self.evaluate(variables.global_variables_initializer())
4284      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
4285      self.assertEqual(0, self.evaluate(r0))
4286      self.assertAllEqual(self.evaluate([v0, v1, v2]), [0, -1, -1])
4287
4288  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
4289  @test_util.run_v1_only("b/120545219")
4290  def testOneOpCond(self):
4291    with self.cached_session():
4292      v = variables.Variable(0)
4293      c = ops.convert_to_tensor(0)
4294      one = ops.convert_to_tensor(1)
4295      two = ops.convert_to_tensor(2)
4296      p = math_ops.greater_equal(c, 1)
4297
4298      def a():
4299        return state_ops.assign(v, one)
4300
4301      def b():
4302        return state_ops.assign(v, two)
4303
4304      i = control_flow_ops.cond(p, a, b)
4305      self.assertTrue(isinstance(i, ops.Tensor))
4306      self.evaluate(variables.global_variables_initializer())
4307
4308      self.assertEqual(0, self.evaluate(v))
4309
4310      # True case: c = 2 is >= 1, v is set to 1.
4311      self.assertEqual(1, i.eval(feed_dict={c.name: 2}))
4312      self.assertEqual(1, self.evaluate(v))
4313
4314      # False case: c = 0 is not >= 1, v is set to 2.
4315      self.assertEqual(2, i.eval(feed_dict={c.name: 0}))
4316      self.assertEqual(2, self.evaluate(v))
4317
4318  @test_util.run_v1_only("b/120545219")
4319  def testWithOpsDependencies(self):
4320    with self.cached_session() as sess:
4321      v = variables.VariableV1(0.0)
4322      c = constant_op.constant(10)
4323
4324      # Fetching v directly will result in an uninitialized error
4325      with self.assertRaisesOpError("Attempting to use uninitialized value"):
4326        self.evaluate([c, v])
4327
4328      # Use a control dependency to ensure init_variable is run
4329      # while asking for c
4330      real_v = control_flow_ops.with_dependencies(
4331          name="real_tensor",
4332          output_tensor=v._ref(),  # pylint: disable=protected-access
4333          dependencies=[v.initializer])
4334      c_val, real_v_val = self.evaluate([c, real_v])
4335
4336    # Ensure the result of 'real_c' is the same as 'c'
4337    self.assertAllEqual(10, c_val)
4338
4339    # Ensure that 'v' is initialized
4340    self.assertAllClose(0.0, real_v_val)
4341
4342  @test_util.run_v1_only("b/120545219")
4343  def testWithTensorDependencies(self):
4344    with self.cached_session():
4345      v = variables.VariableV1(0.0)
4346      c1 = constant_op.constant(10)
4347      c2 = constant_op.constant(20)
4348
4349      # c1_with_init_v depends on the init op for v
4350      c1_with_init_v = control_flow_ops.with_dependencies(
4351          name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer])
4352      # c2_with_c1 depends on the value of c1_with_init_v
4353      c2_with_c1_dep = control_flow_ops.with_dependencies(
4354          name="c2_with_c1_dep",
4355          output_tensor=c2,
4356          dependencies=[c1_with_init_v])
4357
4358      # Fetching v directly will result in an uninitialized error
4359      with self.assertRaisesOpError("Attempting to use uninitialized value"):
4360        self.evaluate(v)
4361
4362      # Get the value of 'c2_with_c1_dep', which should cause 'v'
4363      # to be initialized.
4364      self.assertAllEqual(20, self.evaluate(c2_with_c1_dep))
4365
4366      # Ensure that 'v' is initialized
4367      self.assertAllClose(0.0, self.evaluate(v))
4368
4369  @test_util.run_v1_only("b/120545219")
4370  def testWithIndexedSlicesDependencies(self):
4371    with self.cached_session():
4372      v = variables.VariableV1(
4373          np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
4374      v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
4375      gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices)
4376      v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer],
4377                                                             v_at_1)
4378      gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values,
4379                                                  v_at_1_after_init.indices)
4380
4381      # Fetching gather_v_at_1 will result in an uninitialized error
4382      with self.assertRaisesOpError("Attempting to use uninitialized value"):
4383        self.evaluate(gather_v_at_1)
4384
4385      # Getting gather_v_at_1_after_init will work, and initialize v.
4386      self.assertAllEqual([[10.0, 11.0]],
4387                          self.evaluate(gather_v_at_1_after_init))
4388
4389      # Double check that 'v' is initialized
4390      self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
4391                          self.evaluate(v))
4392
4393  def testDependenciesDevice(self):
4394    with ops.Graph().as_default():
4395      # device set on tensor => same device on dep.
4396      with ops.device("/job:ps"):
4397        vd = variables.VariableV1([0.0])
4398      with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
4399      self.assertTrue("/job:ps" in with_vd_dep.device)
4400
4401      # No device set on tensor => no device on dep.
4402      vnod = variables.VariableV1([0.0])
4403      with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
4404                                                         vnod)
4405      self.assertDeviceEqual(None, with_vnod_dep.device)
4406
4407      # device set on tensor, default device on graph => default device on dep.
4408      vdef = variables.VariableV1([0.0], name="vdef")
4409      with ops.device("/job:worker/device:GPU:1"):
4410        with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
4411                                                           vdef)
4412        # The device is empty, but the colocation constraint is set.
4413        self.assertDeviceEqual("", with_vdef_dep.device)
4414        self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups())
4415
4416  @test_util.run_v1_only("b/120545219")
4417  def testGroup(self):
4418    with self.cached_session() as sess:
4419      v1 = variables.VariableV1([0.0])
4420      v2 = variables.VariableV1([1.0])
4421
4422      # Group init1 and init2 and run.
4423      init = control_flow_ops.group(v1.initializer, v2.initializer)
4424      # Fetching v1 directly will result in an uninitialized error
4425      with self.assertRaisesOpError("Attempting to use uninitialized value"):
4426        self.evaluate(v1)
4427
4428      # Runs "init" before fetching v1 and v2.
4429      init.run()
4430      v1_val, v2_val = self.evaluate([v1, v2])
4431
4432    # Ensure that v1 and v2 are initialized
4433    self.assertAllClose([0.0], v1_val)
4434    self.assertAllClose([1.0], v2_val)
4435
4436  @test_util.run_v1_only("b/120545219")
4437  def testGroupEmpty(self):
4438    op = control_flow_ops.group()
4439    self.assertEqual(op.type, "NoOp")
4440    self.assertEqual(op.control_inputs, [])
4441
4442  @test_util.run_deprecated_v1
4443  def testMergeShapes(self):
4444    # All inputs unknown.
4445    p1 = array_ops.placeholder(dtypes.float32)
4446    p2 = array_ops.placeholder(dtypes.float32)
4447    p3 = array_ops.placeholder(dtypes.float32)
4448    m, index = control_flow_ops.merge([p1, p2, p3])
4449    self.assertIs(None, m.get_shape().ndims)
4450    self.assertEqual([], index.get_shape())
4451
4452    # All inputs known with different ranks.
4453    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4454    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3])
4455    m, index = control_flow_ops.merge([p1, p2])
4456    self.assertIs(None, m.get_shape().ndims)
4457    self.assertEqual([], index.get_shape())
4458
4459    # All inputs known with some dimensions different.
4460    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4461    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1])
4462    m, index = control_flow_ops.merge([p1, p2])
4463    self.assertEqual([None, None], m.get_shape().as_list())
4464    self.assertEqual([], index.get_shape())
4465
4466    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4467    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4468    m, index = control_flow_ops.merge([p1, p2])
4469    self.assertEqual([None, 2], m.get_shape().as_list())
4470    self.assertEqual([], index.get_shape())
4471
4472    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4473    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2])
4474    m, index = control_flow_ops.merge([p1, p2])
4475    self.assertEqual([None, 2], m.get_shape().as_list())
4476    self.assertEqual([], index.get_shape())
4477
4478    # All inputs known with same dimensions.
4479    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4480    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4481    m, index = control_flow_ops.merge([p1, p2])
4482    self.assertEqual([1, 2], m.get_shape().as_list())
4483    self.assertEqual([], index.get_shape())
4484
4485    p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4486    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4487    m, index = control_flow_ops.merge([p1, p2])
4488    self.assertEqual([None, 2], m.get_shape().as_list())
4489    self.assertEqual([], index.get_shape())
4490
4491    p1 = array_ops.placeholder(dtypes.float32, shape=[None, None])
4492    p2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
4493    m, index = control_flow_ops.merge([p1, p2])
4494    self.assertEqual([None, None], m.get_shape().as_list())
4495    self.assertEqual([], index.get_shape())
4496
4497  @test_util.run_v1_only("b/120545219")
4498  def testRefSelect(self):
4499    index = array_ops.placeholder(dtypes.int32)
4500
4501    # All inputs unknown.
4502    p1 = array_ops.placeholder(dtypes.float32)
4503    p2 = array_ops.placeholder(dtypes.float32)
4504    p3 = array_ops.placeholder(dtypes.float32)
4505    v1 = variables.VariableV1(p1, validate_shape=False)
4506    v2 = variables.VariableV1(p2, validate_shape=False)
4507    v3 = variables.VariableV1(p3, validate_shape=False)
4508    self.assertIs(None, v1.get_shape().ndims)
4509    s = control_flow_ops.ref_select(index, [v1, v2, v3])
4510    self.assertIs(None, s.get_shape().ndims)
4511
4512    # All inputs known but different.
4513    v1 = variables.VariableV1([[1, 2]])
4514    v2 = variables.VariableV1([[2], [1]])
4515    s = control_flow_ops.ref_select(index, [v1, v2])
4516    self.assertIs(None, s.get_shape().ndims)
4517
4518    # All inputs known and same.
4519    v1 = variables.VariableV1([[1, 2]])
4520    v2 = variables.VariableV1([[1, 2]])
4521    s = control_flow_ops.ref_select(index, [v1, v2])
4522    self.assertEqual([1, 2], s.get_shape())
4523
4524    # Possibly the same but not guaranteed.
4525    v1 = variables.VariableV1([[1., 2.]])
4526    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4527    v2 = variables.VariableV1(p2, validate_shape=False)
4528    s = control_flow_ops.ref_select(index, [v1, v2])
4529    self.assertEqual(None, s.get_shape())
4530
4531  @test_util.run_deprecated_v1
4532  def testRunLoopTensor(self):
4533    with self.cached_session() as sess:
4534      tensor_list = []
4535
4536      def condition(t):
4537        return t < constant_op.constant(5)
4538
4539      def body(_):
4540        tensor_list.append(constant_op.constant(5))
4541        return constant_op.constant(10)
4542
4543      result = control_flow_ops.while_loop(condition, body,
4544                                           [constant_op.constant(4)])
4545      self.assertEqual(10, self.evaluate(result))
4546
4547      # Ensure that we cannot run a tensor that escapes the loop body
4548      # accidentally.
4549      with self.assertRaises(ValueError):
4550        sess.run(tensor_list[0])
4551
4552  @test_util.run_v1_only("b/120545219")
4553  def testWhilePyFuncBasic(self):
4554
4555    def func(x):
4556      return np.square(x)
4557
4558    with self.cached_session():
4559      r = control_flow_ops.while_loop(
4560          lambda i, v: i < 4,
4561          lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]],
4562          [constant_op.constant(0), constant_op.constant(2.0, dtypes.float32)],
4563          [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])
4564      self.assertEqual(self.evaluate(r[1]), 65536.0)
4565
4566  @test_util.run_v1_only("b/120545219")
4567  def testWhileFuncBasic(self):
4568
4569    @function.Defun(dtypes.float32)
4570    def func(x):
4571      return math_ops.square(math_ops.square(x))
4572
4573    with self.cached_session():
4574      x = constant_op.constant(2.0, dtypes.float32)
4575      r = control_flow_ops.while_loop(
4576          lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
4577          [constant_op.constant(0), x],
4578          [tensor_shape.unknown_shape(),
4579           tensor_shape.unknown_shape()])
4580      grad = gradients_impl.gradients(r, x)[0]
4581      self.assertEqual(self.evaluate(r[1]), 65536.0)
4582      self.assertEqual(self.evaluate(grad), 524288.0)
4583      # while_v2 does not have stacks.
4584      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
4585        self.assertEqual(
4586            len([op for op in x.graph.get_operations() if op.type == "StackV2"
4587                ]), 1)
4588
4589
4590  @test_util.run_v1_only("b/120545219")
4591  def testQIntSwitchMerge(self):
4592    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
4593      constant_qint = constant_op.constant(np.array([42]), dtypes.qint8)
4594      cond = constant_op.constant(True, dtypes.bool)
4595      v_f, v_t = control_flow_ops.switch(constant_qint, cond)
4596      result = control_flow_ops.merge([v_f, v_t])
4597      self.evaluate(result)
4598
4599  @test_util.run_v1_only("b/120545219")
4600  def testQIntRefSwitchMerge(self):
4601    with self.cached_session(use_gpu=test.is_gpu_available()) as sess:
4602      var_qint = gen_state_ops.variable(
4603          shape=[1], dtype=dtypes.qint8, name="v", container="", shared_name="")
4604      assign_op = state_ops.assign(
4605          var_qint, constant_op.constant(np.array([42]), dtypes.qint8))
4606      self.evaluate(assign_op)
4607
4608      cond = constant_op.constant(True, dtypes.bool)
4609      v_f, v_t = control_flow_ops.ref_switch(var_qint, cond)
4610      result = control_flow_ops.ref_merge([v_f, v_t])
4611      self.evaluate(result)
4612
4613  @test_util.run_v1_only("b/120545219")
4614  def testUInt64SwitchMerge(self):
4615    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
4616      constant_uint64 = constant_op.constant(np.array([42]), dtypes.uint64)
4617      cond = constant_op.constant(True, dtypes.bool)
4618      v_f, v_t = control_flow_ops.switch(constant_uint64, cond)
4619      result = control_flow_ops.merge([v_f, v_t])
4620      self.evaluate(result)
4621
4622  def testSwitchEagerMode(self):
4623    if not context.executing_eagerly():
4624      return
4625    input_data = [1, 2, 3, 4]
4626    vf, vt = control_flow_ops.switch(input_data, False)
4627    self.assertAllEqual(vf, input_data)
4628    self.assertAllEqual(vt, [])
4629
4630  @test_util.run_deprecated_v1
4631  def testQIntArgAndRet(self):
4632
4633    @function.Defun(dtypes.qint8)
4634    def func(x):
4635      return x
4636
4637    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
4638      qint = constant_op.constant(np.array([42]), dtypes.qint8)
4639      result = func(qint)
4640      self.evaluate(result)
4641
4642  def testSparseIdentity(self):
4643    st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
4644    st2 = control_flow_ops._Identity(st1)
4645    self.assertAllEqual(st1.indices, st2.indices)
4646    self.assertAllEqual(st1.values, st2.values)
4647    self.assertAllEqual(st1.dense_shape, st2.dense_shape)
4648
4649  def testSparseEnterExit(self):
4650    st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
4651    st2 = control_flow_ops._Enter(st1, "foo_1")
4652    st3 = control_flow_ops.exit(st2)
4653    self.assertAllEqual(st1.indices, st3.indices)
4654    self.assertAllEqual(st1.values, st3.values)
4655    self.assertAllEqual(st1.dense_shape, st3.dense_shape)
4656
4657  def _buildWhileWithShapeInvariants(self, shape_invariants):
4658    r = constant_op.constant([1, 2])
4659
4660    def cond(_):
4661      return False
4662
4663    def body(_):
4664      return constant_op.constant([1])
4665
4666    return control_flow_ops.while_loop(
4667        cond, body, [r], shape_invariants=shape_invariants)
4668
4669  def testWhileOutputShapeWithShapeInvariantsUnknownRank(self):
4670    @def_function.function
4671    def runTest():
4672      while_output = self._buildWhileWithShapeInvariants(
4673          [tensor_shape.TensorShape(None)])
4674      self.assertIsNone(while_output.shape.rank)
4675    runTest()
4676
4677  def testWhileOutputShapeWithShapeInvariantsPartialShape(self):
4678    @def_function.function
4679    def runTest():
4680      while_output = self._buildWhileWithShapeInvariants(
4681          [tensor_shape.TensorShape([None])])
4682      self.assertAllEqual(while_output.shape.as_list(), [None])
4683    runTest()
4684
4685  def testFunctionInWhile(self):
4686
4687    @def_function.function
4688    def body(x):
4689      return x + 1
4690
4691    r = control_flow_ops.while_loop(lambda x: x < 5, body, [0])
4692    self.assertAllEqual(r, 5.)
4693
4694
4695class ControlFlowContextCheckTest(test.TestCase):
4696
4697  def _getWhileTensor(self):
4698    """Creates and returns a tensor from a while context."""
4699    tensor = []
4700
4701    def body(i):
4702      if not tensor:
4703        tensor.append(constant_op.constant(1))
4704      return i + tensor[0]
4705
4706    control_flow_ops.while_loop(lambda i: i < 10, body, [0])
4707    return tensor[0]
4708
4709  def _getCondTensor(self):
4710    cond_tensor = []
4711
4712    def true_fn():
4713      if not cond_tensor:
4714        cond_tensor.append(constant_op.constant(1))
4715      return cond_tensor[0]
4716
4717    control_flow_ops.cond(
4718        math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
4719    return cond_tensor[0]
4720
4721  @test_util.run_v1_only("b/120545219")
4722  def testInvalidContext(self):
4723    # Accessing a while loop tensor outside of control flow is illegal.
4724    while_tensor = self._getWhileTensor()
4725    with self.assertRaisesRegex(
4726        ValueError,
4727        "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' "
4728        "is in a while loop. See info log for more details."):
4729      math_ops.add(1, while_tensor)
4730
4731  @test_util.run_v1_only("b/120545219")
4732  def testInvalidContextInCond(self):
4733    # Accessing a while loop tensor in cond is illegal.
4734    while_tensor = self._getWhileTensor()
4735    with self.assertRaisesRegex(
4736        ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because "
4737        "'while/Const_1' is in a while loop. See info log for more details."):
4738      # TODO(skyewm): this passes if we return while_tensor directly instead
4739      # of using it as input to another op.
4740      control_flow_ops.cond(
4741          math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor),
4742          lambda: constant_op.constant(0))
4743
4744  @test_util.run_v1_only("b/120545219")
4745  def testInvalidContextInWhile(self):
4746    # Accessing a while loop tensor in a different while loop is illegal.
4747    while_tensor = self._getWhileTensor()
4748    with self.assertRaisesRegex(
4749        ValueError,
4750        "Cannot use 'while/Const_1' as input to 'while_1/Add' because they are "
4751        "in different while loops. See info log for more details."):
4752      control_flow_ops.while_loop(lambda i: i < 10,
4753                                  lambda x: math_ops.add(1, while_tensor), [0])
4754
4755    with self.assertRaisesRegex(
4756        ValueError,
4757        "Cannot use 'while/Const_1' as input to 'while_2/NextIteration' "
4758        "because they are in different while loops. See info log for more "
4759        "details."):
4760      control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0])
4761
4762  def testValidCondContext(self):
4763    # Accessing a tensor from a cond context is OK (although dangerous).
4764    cond_tensor = self._getCondTensor()
4765    math_ops.add(1, cond_tensor)
4766
4767  def testValidCondContextBranches(self):
4768    # Accessing a tensor from a cond context from the other branch's cond
4769    # context is OK (although dangerous).
4770    cond_tensor = []
4771
4772    def branch_fn():
4773      if not cond_tensor:
4774        cond_tensor.append(constant_op.constant(1))
4775      return cond_tensor[0]
4776
4777    control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn)
4778
4779  @test_util.run_v1_only("b/120545219")
4780  def testValidWhileContext(self):
4781    # Accessing a tensor in a nested while is OK.
4782    def body(_):
4783      c = constant_op.constant(1)
4784      return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0])
4785
4786    control_flow_ops.while_loop(lambda i: i < 5, body, [0])
4787
4788  @test_util.run_v1_only("b/120545219")
4789  def testValidNestedContexts(self):
4790    # Accessing a tensor from a cond context in a while context, all inside an
4791    # outer while context, is OK.
4792    def body(_):
4793      cond_tensor = self._getCondTensor()
4794      # Create another cond containing the while loop for good measure
4795      return control_flow_ops.cond(
4796          math_ops.less(1, 2),
4797          lambda: control_flow_ops.while_loop(lambda i: i < 3,
4798                                              lambda i: i + cond_tensor, [0]),
4799          lambda: constant_op.constant(0))
4800
4801    control_flow_ops.while_loop(lambda i: i < 5, body, [0])
4802
4803  @test_util.run_v1_only("b/120545219")
4804  def testInvalidNestedContexts(self):
4805    # Accessing a tensor from a while context in a different while context, all
4806    # inside a cond context, is illegal.
4807    def true_fn():
4808      while_tensor = self._getWhileTensor()
4809      return control_flow_ops.while_loop(lambda i: i < 3,
4810                                         lambda i: i + while_tensor, [0])
4811
4812    with self.assertRaisesRegex(
4813        ValueError,
4814        "Cannot use 'cond/while/Const_1' as input to 'cond/while_1/add' because"
4815        " they are in different while loops. See info log for more details."):
4816      control_flow_ops.cond(
4817          math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
4818
4819
4820class TupleTest(test.TestCase):
4821
4822  @test_util.run_v1_only("b/120545219")
4823  def testTensors(self):
4824    for v1_first in [True, False]:
4825      with self.cached_session():
4826        v1 = variables.VariableV1([1.0])
4827        add1 = math_ops.add(
4828            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
4829            2.0)
4830        v2 = variables.VariableV1([10.0])
4831        add2 = math_ops.add(
4832            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
4833            20.0)
4834        t1, _, t2 = control_flow_ops.tuple([add1, None, add2])
4835
4836        # v1 is not initialized.
4837        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4838          self.evaluate(v1)
4839
4840        # v2 is not initialized.
4841        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4842          self.evaluate(v2)
4843
4844        if v1_first:
4845          # Getting t1 initializes v2.
4846          self.assertAllClose([3.0], self.evaluate(t1))
4847          self.assertAllClose([10.0], self.evaluate(v2))
4848        else:
4849          # Getting t2 initializes v1.
4850          self.assertAllClose([30.0], self.evaluate(t2))
4851          self.assertAllClose([1.0], self.evaluate(v1))
4852
4853  @test_util.run_v1_only("b/120545219")
4854  def testIndexedSlices(self):
4855    for v1_first in [True, False]:
4856      with self.cached_session():
4857        v1 = variables.VariableV1(
4858            np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
4859                np.float32))
4860        v1_at_1 = ops.IndexedSlices(
4861            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
4862            constant_op.constant([1]))
4863
4864        v2 = variables.VariableV1(
4865            np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
4866                np.float32))
4867        v2_at_1 = ops.IndexedSlices(
4868            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
4869            constant_op.constant([1]))
4870
4871        st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])
4872        g1 = array_ops.gather(st1.values, st1.indices)
4873        g2 = array_ops.gather(st2.values, st2.indices)
4874
4875        # v1 is not initialized.
4876        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4877          self.evaluate(v1)
4878
4879        # v2 is not initialized.
4880        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4881          self.evaluate(v2)
4882
4883        if v1_first:
4884          # Getting g1 initializes v2.
4885          self.assertAllClose([[10.0, 11.0]], self.evaluate(g1))
4886          self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]],
4887                              self.evaluate(v2))
4888        else:
4889          # Getting g2 initializes v1.
4890          self.assertAllClose([[10.1, 11.1]], self.evaluate(g2))
4891          self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
4892                              self.evaluate(v1))
4893
4894  def testAcceptTensorsAsControlInputs(self):
4895    with self.cached_session():
4896      var = variables.VariableV1(0)
4897      assign = state_ops.assign(var, 1)
4898      t, = control_flow_ops.tuple(
4899          [constant_op.constant(0)], control_inputs=[assign])
4900
4901      # Should trigger the assign.
4902      self.evaluate(t)
4903
4904      self.assertEqual(1, self.evaluate(var))
4905
4906
4907class AssertTest(test.TestCase):
4908
4909  @test_util.run_deprecated_v1
4910  def testGuardedAssertDoesNotCopyWhenTrue(self):
4911    if test_util.is_gpu_available():
4912      self.skipTest("b/128646478 fails in opensource")
4913
4914    with self.session() as sess:
4915      with ops.device(test.gpu_device_name()):
4916        value = constant_op.constant(1.0)
4917      with ops.device("/cpu:0"):
4918        true = constant_op.constant(True)
4919        guarded_assert = control_flow_ops.Assert(true, [value], name="guarded")
4920        unguarded_assert = gen_logging_ops._assert(
4921            true, [value], name="unguarded")
4922      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
4923      guarded_metadata = config_pb2.RunMetadata()
4924      sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata)
4925      unguarded_metadata = config_pb2.RunMetadata()
4926      sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata)
4927      guarded_nodestat_names = [
4928          n.node_name
4929          for d in guarded_metadata.step_stats.dev_stats
4930          for n in d.node_stats
4931      ]
4932      unguarded_nodestat_names = [
4933          n.node_name
4934          for d in unguarded_metadata.step_stats.dev_stats
4935          for n in d.node_stats
4936      ]
4937      guarded_memcpy_nodestat_names = [
4938          n for n in guarded_nodestat_names if "MEMCPYDtoH" in n
4939      ]
4940      unguarded_memcpy_nodestat_names = [
4941          n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n
4942      ]
4943      if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
4944        # A copy was performed for the unguarded assert
4945        self.assertLess(0, len(unguarded_memcpy_nodestat_names),
4946                        str(unguarded_nodestat_names))
4947      # No copy was performed for the guarded assert
4948      self.assertEqual([], guarded_memcpy_nodestat_names)
4949
4950
4951class WhileOpBenchmark(test.Benchmark):
4952  """Evaluate the performance of while_loop op."""
4953
4954  def _getInitVariables(self):
4955    batch_size = 10
4956    image_size = 256
4957    kernel_size = 3
4958    depth = 16
4959
4960    init_step = constant_op.constant(-1)
4961    image = variable_scope.get_variable(
4962        "image",
4963        initializer=random_ops.random_normal(
4964            [batch_size, image_size, image_size, depth],
4965            dtype=dtypes.float32,
4966            stddev=1e-1))
4967    kernel = variable_scope.get_variable(
4968        "weights",
4969        initializer=random_ops.truncated_normal(
4970            [kernel_size, kernel_size, depth, depth],
4971            dtype=dtypes.float32,
4972            stddev=1e-1))
4973    return init_step, image, kernel
4974
4975  def _runOneBenchmark(self,
4976                       default_device,
4977                       num_iters=10,
4978                       static_unroll=False,
4979                       steps=10):
4980    """Evaluate the while loop performance.
4981
4982    Args:
4983      default_device: The default device to run all ops except the loop_body.
4984        loop_body is always run on GPU.
4985      num_iters: Number of iterations to run.
4986      static_unroll: If true, run unrolled version; otherwise, run while_loop.
4987      steps: Total number of repeated steps to run the loop.
4988
4989    Returns:
4990      The duration of the run in seconds.
4991    """
4992
4993    def loop_body(i, x):
4994      with ops.device("/gpu:0"):
4995        # Always put loop body on GPU.
4996        nx = nn_ops.conv2d(
4997            input=x,
4998            filter=kernel,
4999            strides=[1, 1, 1, 1],
5000            padding="SAME",
5001            data_format="NHWC",
5002            name="conv2d")
5003        ni = math_ops.add(i, 1)
5004        return ni, nx
5005
5006    ops.reset_default_graph()
5007    with session.Session() as sess, ops.device(default_device):
5008      # Get the initial id i, input x, and kernel.
5009      i, x, kernel = self._getInitVariables()
5010      self.evaluate(variables.global_variables_initializer())
5011
5012      if static_unroll:
5013        for _ in xrange(steps):
5014          i, x = loop_body(i, x)
5015      else:
5016        i, x = control_flow_ops.while_loop(
5017            lambda i, _: i < steps,
5018            loop_body, [i, x],
5019            parallel_iterations=steps,
5020            swap_memory=True)
5021
5022      r = math_ops.reduce_sum(x)
5023      dx, dk = gradients_impl.gradients(r, [x, kernel])
5024      # Use group to avoid fetching back results.
5025      r = control_flow_ops.group(dx, dk)
5026
5027      for _ in xrange(3):
5028        # exclude warm up time
5029        self.evaluate(r)
5030
5031      start_time = time.time()
5032      for _ in xrange(num_iters):
5033        self.evaluate(r)
5034      return (time.time() - start_time) / num_iters
5035
5036  def benchmarkWhileOpCrossDevicePlacement(self):
5037    iters = 10
5038    # Run loop body on GPU, but other ops on CPU.
5039    duration = self._runOneBenchmark("cpu", iters, static_unroll=False)
5040    self.report_benchmark(
5041        name="while_op_cross_device", iters=iters, wall_time=duration)
5042
5043  def benchmarkWhileOpSameDevicePlacement(self):
5044    iters = 10
5045    # Run all ops on the same GPU device.
5046    duration = self._runOneBenchmark("gpu", iters, static_unroll=False)
5047    self.report_benchmark(
5048        name="while_op_same_device", iters=iters, wall_time=duration)
5049
5050  def benchmarkWhileOpUnrollCrossDevicePlacement(self):
5051    iters = 10
5052    # Run loop body on GPU, but other ops on CPU.
5053    duration = self._runOneBenchmark("cpu", iters, static_unroll=True)
5054    self.report_benchmark(
5055        name="unroll_cross_device_cpu", iters=iters, wall_time=duration)
5056
5057  def benchmarkWhileOpUnrollSameDevicePlacement(self):
5058    iters = 10
5059    # Run all ops on GPU.
5060    duration = self._runOneBenchmark("gpu", iters, static_unroll=True)
5061    self.report_benchmark(
5062        name="unroll_same_device", iters=iters, wall_time=duration)
5063
5064
5065@test_util.with_control_flow_v2
5066class EagerTest(test.TestCase):
5067
5068  def testCond(self):
5069    with context.eager_mode():
5070      pred = math_ops.less(1, 2)
5071      fn1 = lambda: [constant_op.constant(10)]
5072      fn2 = lambda: [constant_op.constant(20)]
5073      r = control_flow_ops.cond(pred, fn1, fn2)
5074
5075      self.assertAllEqual(r.numpy(), 10)
5076      self.assertFalse(isinstance(r, list))
5077
5078  # TODO(b/117279927): Re-enable once msan failure is fixed.
5079  def DISABLED_testCondInDefun(self):
5080    with context.eager_mode():
5081
5082      @eager_function.defun
5083      def foo(pred):
5084        # TODO(b/111124878): this only needs to output one element.
5085        fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
5086        fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
5087        return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2)
5088
5089      r = foo(True)
5090      self.assertAllEqual(r[0].numpy(), 10)
5091      self.assertNotIsInstance(r, list)
5092
5093      r = foo(False)
5094      self.assertAllEqual(r[0].numpy(), 20)
5095      self.assertFalse(isinstance(r, list))
5096
5097  def testWhileLoop(self):
5098    with context.eager_mode():
5099      tensor = constant_op.constant([1, 2, 3, 4, 5])
5100      self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50])
5101
5102  def testWhileLoopWithMaxIterations(self):
5103    with context.eager_mode():
5104      tensor = constant_op.constant([1, 2, 3, 4, 5])
5105      self.assertAllEqual(
5106          isum(tensor, maximum_iterations=3).numpy(),
5107          [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3])
5108
5109  @test_util.run_v1_only("b/120545219")
5110  def testWhileWithMaximumIterationsAndSingleArgument(self):
5111    with context.eager_mode():
5112      tensor = constant_op.constant(0)
5113      r = control_flow_ops.while_loop(
5114          lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1)
5115      self.assertEqual(1, r.numpy())
5116
5117  def testWithDependencies(self):
5118    with context.eager_mode():
5119      t1 = constant_op.constant(1)
5120      t2 = constant_op.constant(2)
5121      t3 = control_flow_ops.with_dependencies(t1, t2)
5122      self.assertAllEqual(t2.numpy(), t3.numpy())
5123
5124  def testTuple(self):
5125    with context.eager_mode():
5126      t1 = constant_op.constant(1)
5127      t2 = constant_op.constant(2)
5128      tup1, tup2 = control_flow_ops.tuple([t1, t2])
5129      self.assertAllEqual(t1.numpy(), tup1.numpy())
5130      self.assertAllEqual(t2.numpy(), tup2.numpy())
5131
5132  @test_util.run_v1_only("b/120545219")
5133  def testCase(self):
5134    with context.eager_mode():
5135      x = constant_op.constant(1)
5136      y = constant_op.constant(2)
5137      z = constant_op.constant(3)
5138      f1 = lambda: constant_op.constant(17)
5139      f2 = lambda: constant_op.constant(23)
5140      f3 = lambda: constant_op.constant(-1)
5141
5142      r1 = control_flow_ops.case(
5143          [(x < y, f1), (x > z, f2)], default=f3, exclusive=True)
5144      self.assertAllEqual(r1.numpy(), 17)
5145
5146
5147if __name__ == "__main__":
5148  test.main()
5149