• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for while_v2."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.core.protobuf import rewriter_config_pb2
25from tensorflow.python.eager import backprop
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import function
31from tensorflow.python.framework import meta_graph
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.grappler import tf_optimizer
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import control_flow_util
38from tensorflow.python.ops import control_flow_util_v2
39from tensorflow.python.ops import control_flow_v2_toggles
40from tensorflow.python.ops import custom_gradient
41from tensorflow.python.ops import gen_array_ops
42from tensorflow.python.ops import gradients_impl
43from tensorflow.python.ops import list_ops
44from tensorflow.python.ops import map_fn
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import random_ops
47from tensorflow.python.ops import variables
48from tensorflow.python.ops import while_v2
49from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
50from tensorflow.python.platform import test
51
52def random_gamma(shape):  # pylint: disable=invalid-name
53  return random_ops.random_gamma(shape, 1.0)
54
55
56def random_gamma_with_alpha_beta(shape):  # pylint: disable=invalid-name
57  return random_ops.random_gamma(
58      shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]])
59
60
61def random_poisson_v2(shape):  # pylint: disable=invalid-name
62  return random_ops.random_poisson_v2(shape, 1.0)
63
64
65def random_poisson_v2_with_lam(shape):  # pylint: disable=invalid-name
66  return random_ops.random_poisson_v2(shape, [12.2, 3.3])
67
68
69def fill(shape):  # pylint: disable=invalid-name
70  return array_ops.fill(shape, 1.0)
71
72
73class WhileV2Test(test.TestCase, parameterized.TestCase):
74
75  @test_util.run_deprecated_v1
76  def testSingleLoopVar(self):
77    x = constant_op.constant(2.)
78    ret = while_loop_v2(
79        lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
80    grad = gradients_impl.gradients(ret, [x])
81    with self.cached_session():
82      self.assertEqual(self.evaluate(ret), 16.)
83      self.assertSequenceEqual(self.evaluate(grad), [32.])
84
85  @test_util.run_deprecated_v1
86  def testSingleLoopVarBackPropFalse(self):
87    x = constant_op.constant(2.)
88    ret = while_loop_v2(
89        lambda v: v < 8.,
90        lambda v: v * v, [x],
91        return_same_structure=False,
92        back_prop=False)
93    grad = gradients_impl.gradients(ret, [x])
94    self.assertEqual(grad, [None])
95    with self.cached_session():
96      self.assertEqual(self.evaluate(ret), 16.)
97
98  @test_util.run_deprecated_v1
99  def testCustomGradient(self):
100    x = constant_op.constant(2.)
101    n = constant_op.constant(1., name="const-n")
102    m = variables.Variable(1.0)
103    self.evaluate(variables.global_variables_initializer())
104
105    def body_fn(v):  # pylint: disable=invalid-name
106
107      @custom_gradient.custom_gradient
108      def inner_fn(v):  # pylint: disable=invalid-name
109
110        def grad_fn(dy, variables=None):  # pylint: disable=invalid-name, unused-argument, redefined-outer-name
111          return dy * 2 * v * n * m, [v * v]
112
113        return v * v * m, grad_fn
114
115      return inner_fn(v)
116
117    ret = while_loop_v2(
118        lambda v: v < 8., body_fn, [x], return_same_structure=False)
119    grad = gradients_impl.gradients(ret, [x])
120    with self.cached_session():
121      self.assertEqual(self.evaluate(ret), 16.)
122      self.assertSequenceEqual(self.evaluate(grad), [32.])
123
124  @test_util.run_v1_only("b/120545219")
125  def testReturnSameStructureTrue(self):
126    x = constant_op.constant(2.)
127    ret = while_loop_v2(
128        lambda v: v < 8., lambda v: v * v, [x], return_same_structure=True)
129    grad = gradients_impl.gradients(ret, [x])
130    with self.cached_session() as sess:
131      eval_result = sess.run(ret)
132      self.assertIsInstance(eval_result, list)
133      self.assertLen(eval_result, 1)
134      self.assertEqual(16., eval_result[0])
135      self.assertSequenceEqual(sess.run(grad), [32.])
136
137  def testVerifyInputOutputTypesMatch(self):
138
139    @def_function.function
140    def BuildWhile():
141      x = constant_op.constant(1., dtypes.float32)
142
143      def Body(x):
144        return math_ops.cast(x, dtypes.float16) + 1
145
146      while_loop_v2(lambda x: x < 10, Body, [x])
147
148    with self.assertRaisesRegexp(
149        TypeError,
150        r"Loop var Const:0 enters the loop with type <dtype: 'float32'> "
151        r"but has type <dtype: 'float16'> after 1 iteration."):
152      BuildWhile()
153
154  @parameterized.parameters(dtypes.float32, dtypes.float64)
155  def testGradientTapeResourceVariable(self, dtype):
156    with context.eager_mode():
157      v = variables.Variable(1., dtype=dtype)
158
159      @def_function.function
160      def fnWithLoop():  # pylint: disable=invalid-name
161        with backprop.GradientTape() as tape:
162          _, x = while_loop_v2(
163              lambda i, _: i < 2,
164              lambda i, x: (i + 1, x * v),
165              [0, constant_op.constant(2., dtype=dtype)])
166        return tape.gradient(x, v)
167
168      self.assertAllEqual(fnWithLoop(), 4.0)
169
170  def testExternalControlDependencies(self):
171    with ops.Graph().as_default(), self.test_session():
172      v = variables.Variable(1.)
173      v.initializer.run()
174      op = v.assign_add(1.)
175
176      def body_fn(i):  # pylint: disable=invalid-name
177        with ops.control_dependencies([op]):
178          return i + 1
179
180      loop = while_loop_v2(lambda i: i < 1, body_fn, [0])
181      loop[0].op.run()
182      self.assertAllEqual(self.evaluate(v), 2.0)
183
184  @test_util.run_deprecated_v1
185  def testMultipleLoopVarsBasic(self):
186    x = constant_op.constant(5.)
187    y = constant_op.constant(3.)
188
189    # x = 5.
190    # y = 3.
191    # while x < 45.:
192    #   x = x * y
193    ret = while_loop_v2(
194        lambda v, _: v < 45.,
195        lambda v, w: (v * w, w), [x, y],
196        return_same_structure=False)
197    # ret = [x*y^2, y]
198
199    # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
200    grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
201    with self.cached_session():
202      self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
203      self.assertSequenceEqual(self.evaluate(grad), [9.])
204
205  @test_util.run_deprecated_v1
206  def testMultipleLoopNonscalarCond(self):
207    x = constant_op.constant([[5.]])
208    y = constant_op.constant(3.)
209
210    # x = 5.
211    # y = 3.
212    # while x < 45.:
213    #   x = x * y
214    ret = while_loop_v2(
215        lambda v, _: v < 45.,
216        lambda v, w: (v * w, w), [x, y],
217        return_same_structure=False)
218    # ret == [x*y^2, y]
219
220    # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
221    grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
222    with self.cached_session():
223      self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
224      self.assertSequenceEqual(self.evaluate(grad), [9.])
225
226  @test_util.run_deprecated_v1
227  def testMultipleLoopVars(self):
228    x = constant_op.constant(5.)
229    y = constant_op.constant(3.)
230
231    # x = 5.
232    # y = 3.
233    # while x < 45.:
234    #   x = x * y
235    #   y = x + y
236    ret = while_loop_v2(
237        lambda v, _: v < 45.,
238        lambda v, w: (v * w, v + w), [x, y],
239        return_same_structure=False)
240    # ret = [y*x**2 + x*y**2, x*y + x + y]
241
242    gradx_0 = gradients_impl.gradients(ret[0], [x])  # [2*x*y + y**2]
243    gradx_1 = gradients_impl.gradients(ret[1], [x])  # [y + 1]
244    gradx_2 = gradients_impl.gradients(ret, [x])  # [2*x*y + y**2 + 2*y + 1]
245    grady_0 = gradients_impl.gradients(ret[0], [y])  # [2*x*y + x**2]
246    grady_1 = gradients_impl.gradients(ret[1], [y])  # [x + 1]
247    grady_2 = gradients_impl.gradients(ret, [y])  # [2*x*y + x**2 + x + 1]
248    with self.cached_session():
249      self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
250      self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
251      self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
252      self.assertSequenceEqual(self.evaluate(gradx_2), [43.])
253      self.assertSequenceEqual(self.evaluate(grady_0), [55.])
254      self.assertSequenceEqual(self.evaluate(grady_1), [6.])
255      self.assertSequenceEqual(self.evaluate(grady_2), [61.])
256
257  @test_util.run_deprecated_v1
258  def testGradientTape(self):
259    with backprop.GradientTape() as t:
260      x = constant_op.constant(2.)
261      t.watch(x)
262      ret = while_loop_v2(
263          lambda v: v < 4., lambda v: v * v, [x],
264          return_same_structure=False)  # x**2
265    grad = t.gradient(ret, x)
266    with self.cached_session() as sess:
267      self.assertAllEqual(sess.run(grad), 4.0)
268
269  @test_util.run_deprecated_v1
270  def testMultipleWhileLoops(self):
271    x = constant_op.constant(2.)
272    ret1 = while_loop_v2(
273        lambda v: v < 4., lambda v: v * v, [x],
274        return_same_structure=False)  # x**2
275    ret2 = while_loop_v2(
276        lambda v: v < 16., lambda v: v * v, [ret1],
277        return_same_structure=False)  # x**4
278    grad = gradients_impl.gradients(ret2, [x])  # 4x**3
279    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
280    with self.cached_session():
281      self.assertSequenceEqual(self.evaluate(grad), [32.])
282      self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
283
284  def testMultipleWhileLoopsWithFunc(self):
285    x = constant_op.constant(2.)
286
287    @def_function.function
288    def Fn():
289      ret1 = while_loop_v2(
290          lambda v: v < 4.,
291          lambda v: v * v, [x],
292          return_same_structure=False,
293          name="while_1")  # x**2
294      ret2 = while_loop_v2(
295          lambda v: v < 16.,
296          lambda v: v * v, [x],
297          return_same_structure=False,
298          name="while_2")  # x**4
299      return ret1, ret2
300
301    concrete_fn = Fn.get_concrete_function()
302    while_1 = concrete_fn.graph.get_operation_by_name("while_1")
303    while_2 = concrete_fn.graph.get_operation_by_name("while_2")
304    self.assertEqual(while_1.type, "StatelessWhile")
305    self.assertEqual(while_2.type, "StatelessWhile")
306    self.assertEmpty(while_1.control_inputs)
307    self.assertEmpty(while_2.control_inputs)
308
309  def testMultipleWhileLoopsGradStateless(self):
310
311    @def_function.function
312    def Fn():
313      x = constant_op.constant(2.)
314      with backprop.GradientTape() as tape:
315        tape.watch(x)
316        ret1 = while_loop_v2(
317            lambda v: v < 4.,
318            lambda v: v * v, [x],
319            return_same_structure=False,
320            name="while_1")  # x**2
321        ret2 = while_loop_v2(
322            lambda v: v < 16.,
323            lambda v: v * v, [x],
324            return_same_structure=False,
325            name="while_2")  # x**4
326        loss = ret1 + ret2
327      return tape.gradient(loss, x)
328
329    graph = Fn.get_concrete_function().graph
330    while_ops = [op for op in graph.get_operations() if "While" in op.type]
331    self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4,
332                        "Must have exactly 4 StatelessWhile ops.")
333    for op in while_ops:
334      self.assertEmpty(op.control_inputs,
335                       "{} should not have any control inputs".format(op.name))
336
337  def testMultipleWhileLoopsWithDeps(self):
338    x = variables.Variable(2.)
339    c = constant_op.constant(2.)
340
341    @def_function.function
342    def Fn():
343      ret1 = while_loop_v2(
344          lambda v: v < 4.,
345          lambda v: v * x, [c],
346          return_same_structure=False,
347          name="while_1")  # 2x
348      ret2 = while_loop_v2(
349          lambda v: v < 16.,
350          lambda v: v * x * x, [c],
351          return_same_structure=False,
352          name="while_2")  # 4x
353      return ret1, ret2
354
355    concrete_fn = Fn.get_concrete_function()
356    while_1 = concrete_fn.graph.get_operation_by_name("while_1")
357    while_2 = concrete_fn.graph.get_operation_by_name("while_2")
358    self.assertEqual(while_1.type, "While")
359    self.assertEqual(while_2.type, "While")
360    self.assertEmpty(while_1.control_inputs)
361    self.assertLen(while_2.control_inputs, 1)
362    self.assertIs(while_2.control_inputs[0], while_1)
363
364  def testMultipleWhileLoopsWithVarsDeps(self):
365    x1 = variables.Variable(2.)
366    x2 = variables.Variable(3.)
367    c = constant_op.constant(2.)
368
369    @def_function.function
370    def Fn():
371      ret1 = while_loop_v2(
372          lambda v: v < 4.,
373          lambda v: v * x1, [c],
374          return_same_structure=False,
375          name="while_1")  # 2x
376      ret2 = while_loop_v2(
377          lambda v: v < 16.,
378          lambda v: v * x1 * x1, [c],
379          return_same_structure=False,
380          name="while_2")  # 4x
381      ret3 = while_loop_v2(
382          lambda v: v < 4.,
383          lambda v: v * x2, [c],
384          return_same_structure=False,
385          name="while_3")  # 3x
386      ret4 = while_loop_v2(
387          lambda v: v < 16.,
388          lambda v: v * x2 * x2, [c],
389          return_same_structure=False,
390          name="while_4")  # 9x
391      ret5 = while_loop_v2(
392          lambda v: v < 16.,
393          lambda v: v * v, [c],
394          return_same_structure=False,
395          name="while_stateless")  # x**2
396      return ret1, ret2, ret3, ret4, ret5
397
398    concrete_fn = Fn.get_concrete_function()
399    while_1 = concrete_fn.graph.get_operation_by_name("while_1")
400    while_2 = concrete_fn.graph.get_operation_by_name("while_2")
401    while_3 = concrete_fn.graph.get_operation_by_name("while_3")
402    while_4 = concrete_fn.graph.get_operation_by_name("while_4")
403    while_stateless = concrete_fn.graph.get_operation_by_name(
404        "while_stateless")
405    self.assertEqual(while_1.type, "While")
406    self.assertEqual(while_2.type, "While")
407    self.assertEqual(while_3.type, "While")
408    self.assertEqual(while_4.type, "While")
409    self.assertEqual(while_stateless.type, "StatelessWhile")
410    self.assertEmpty(while_1.control_inputs)
411    self.assertLen(while_2.control_inputs, 1)
412    self.assertIs(while_2.control_inputs[0], while_1)
413    self.assertEmpty(while_3.control_inputs)
414    self.assertLen(while_4.control_inputs, 1)
415    self.assertIs(while_4.control_inputs[0], while_3)
416    self.assertEmpty(while_stateless.control_inputs)
417
418  @test_util.run_deprecated_v1
419  def testDoubleDerivative(self):
420    x = constant_op.constant(2.)
421    ret = while_loop_v2(
422        lambda v: v < 8., lambda v: v**2, [x],
423        return_same_structure=False)  # x**4
424    grad = gradients_impl.gradients(ret, [x])  # 4x**3
425    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
426    with self.cached_session():
427      self.assertEqual(self.evaluate(ret), 16.)
428      self.assertSequenceEqual(self.evaluate(grad), [32.])
429      self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
430
431  @test_util.run_v2_only
432  def testMultipleWhileLoopsEager(self):
433
434    @def_function.function
435    def Func():
436      x = constant_op.constant(2.)
437      ret1 = while_loop_v2(
438          lambda v: v < 4., lambda v: v * v, [x],
439          return_same_structure=False)  # x**2
440      ret2 = while_loop_v2(
441          lambda v: v < 16.,
442          lambda v: v * v, [ret1],
443          return_same_structure=False)  # x**4
444      grad = gradients_impl.gradients(ret2, [x])[0]  # 4x**3
445      grad_grad = gradients_impl.gradients(grad, [x])[0]  # 12x**2
446      return grad, grad_grad
447
448    grad, grad_grad = Func()
449    self.assertEqual(grad.numpy(), 32.)
450    self.assertEqual(grad_grad.numpy(), 48.)
451
452  @test_util.run_v2_only
453  def testDoubleDerivativeEager(self):
454
455    @def_function.function
456    def Func():
457      x = constant_op.constant(2.)
458      ret = while_loop_v2(
459          lambda v: v < 8., lambda v: v**2, [x],
460          return_same_structure=False)  # x**4
461      grad = gradients_impl.gradients(ret, [x])[0]  # 4x**3
462      grad_grad = gradients_impl.gradients(grad, [x])[0]  # 12x**2
463      return ret, grad, grad_grad
464
465    ret, grad, grad_grad = Func()
466    self.assertEqual(ret.numpy(), 16.)
467    self.assertEqual(grad.numpy(), 32.)
468    self.assertEqual(grad_grad.numpy(), 48.)
469
470  def _testPruning(self):
471    x = constant_op.constant(1)
472
473    tensor_list = list_ops.empty_tensor_list(
474        element_dtype=x.dtype, element_shape=x.shape)
475
476    def Cond(x, tl):
477      del tl  # Unused for Cond.
478      return x < 5
479
480    def Body(x, tl):
481      return x + 1, list_ops.tensor_list_push_back(tl, x)
482
483    outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
484
485    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
486    train_op.append(outputs[0])
487
488    g = GetOptimizedGraph()
489    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
490    # away, causing an extra Enter node.
491    enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
492    self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
493    # Test that the TensorList is pruned out.
494    self.assertEmpty([
495        n for n in g.node if n.op == "Enter" and
496        n.attr["T"].type == dtypes.variant.as_datatype_enum
497    ])
498    self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
499
500    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
501    train_op.append(stack)
502    g = GetOptimizedGraph()
503    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
504    # away, causing an extra Enter node.
505    enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
506    self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
507    # Test that the TensorList is not pruned out.
508    self.assertNotEmpty([
509        n for n in g.node if n.op == "Enter" and
510        n.attr["T"].type == dtypes.variant.as_datatype_enum
511    ])
512    self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
513
514  @test_util.run_deprecated_v1
515  def testPruningV1(self):
516    self._testPruning()
517
518  @test_util.enable_control_flow_v2
519  @test_util.run_deprecated_v1
520  def testPruningV2(self):
521    self._testPruning()
522
523  def _testDoNotAccumulateInvariants(self):
524    push_op = ("TensorListPushBack"
525               if control_flow_v2_toggles.control_flow_v2_enabled() else
526               "StackPushV2")
527
528    # Tests that loop invariants, i.e., tensors that are "captured" by the
529    # while loop and not passed as loop variables are not accumulated in
530    # gradient computation.
531    v = constant_op.constant(5.0, name="v")
532
533    r = control_flow_ops.while_loop(
534        lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
535
536    output = gradients_impl.gradients(r, v)[0]
537    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
538    train_op.append(output)
539
540    g = GetOptimizedGraph()
541    # The gradient for v * x requires the value of both v and x. Since v is a
542    # loop invariant it is not accumulated so we have just one accumulator for
543    # x.
544    self.assertLen([n for n in g.node if n.op == push_op], 1)
545
546  @test_util.run_deprecated_v1
547  def testDoNotAccumulateInvariantsV1(self):
548    self._testDoNotAccumulateInvariants()
549
550  @test_util.run_deprecated_v1
551  @test_util.enable_control_flow_v2
552  def testDoNotAccumulateInvariantsV2(self):
553    self._testDoNotAccumulateInvariants()
554
555  @test_util.enable_control_flow_v2
556  @test_util.run_deprecated_v1
557  @test_util.enable_output_all_intermediates
558  def testPruningNested(self):
559    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
560    x = constant_op.constant(0)
561
562    tensor_list = list_ops.empty_tensor_list(
563        element_dtype=x.dtype, element_shape=x.shape)
564
565    def Cond(x, tl):
566      del tl  # Unused for Cond.
567      return x < 25
568
569    def Body(x, tl):
570
571      def InnerCond(inner_x, unused_outer_x, unused_tl):
572        return inner_x < 5
573
574      def InnerBody(inner_x, outer_x, tl):
575        return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x)
576
577      inner_x = constant_op.constant(0)
578      return control_flow_ops.while_loop(InnerCond, InnerBody,
579                                         [inner_x, x, tl])[1:]
580
581    outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
582
583    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
584    train_op.append(outputs[0])
585
586    g = GetOptimizedGraph()
587    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
588    # away, causing an extra Enter node.
589    # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
590    # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
591    # Test that the TensorList is pruned out.
592    self.assertEmpty([
593        n for n in g.node if n.op == "Enter" and
594        n.attr["T"].type == dtypes.variant.as_datatype_enum
595    ])
596    self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
597    self.assertEmpty([n for n in g.node if n.op == "_While"])
598
599    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
600    train_op.append(stack)
601    g = GetOptimizedGraph()
602    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
603    # away, causing an extra Enter node.
604    # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
605    # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
606    # Test that the TensorList is not pruned out.
607    self.assertNotEmpty([
608        n for n in g.node if n.op == "Enter" and
609        n.attr["T"].type == dtypes.variant.as_datatype_enum
610    ])
611    self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
612
613  @test_util.enable_control_flow_v2
614  @test_util.run_deprecated_v1
615  @test_util.enable_output_all_intermediates
616  def testPruningNested2(self):
617    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
618    v = constant_op.constant(5.0, name="v")
619
620    p = array_ops.placeholder(dtype=dtypes.int32)
621
622    def MidBodyBuilder(iterations):
623
624      def MidBody(i, x):
625        r = control_flow_ops.while_loop(
626            lambda *_: True,
627            lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")),
628            (0, x),
629            maximum_iterations=iterations,
630            name="inner")
631        return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
632
633      return MidBody
634
635    def OuterBody(i, x):
636      iterations = array_ops.size(p, name="iterations")
637      return (i + 1, x + control_flow_ops.while_loop(
638          lambda *_: True,
639          MidBodyBuilder(iterations), (0, x),
640          maximum_iterations=iterations,
641          name="mid")[1])
642
643    def CreateWhileLoop():
644      with ops.device("/cpu:0"):
645        r = control_flow_ops.while_loop(
646            lambda *_: True,
647            OuterBody, (0, 1.0),
648            maximum_iterations=5,
649            name="outer")
650        return array_ops.identity(r[1])
651
652    output = CreateWhileLoop()
653    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
654    train_op.append(output)
655
656    g = GetOptimizedGraph()
657    self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
658
659  @test_util.enable_control_flow_v2
660  @test_util.run_deprecated_v1
661  @test_util.enable_output_all_intermediates
662  def testPruningNested3(self):
663    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
664    v = constant_op.constant(5.0, name="v")
665
666    def CreateWhileLoop():
667      r = control_flow_ops.while_loop(
668          lambda _: True,
669          lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0],
670          maximum_iterations=5,
671          name="outer")
672      return array_ops.identity(r)
673
674    r = CreateWhileLoop()
675    output = gradients_impl.gradients(r, v)[0]
676    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
677    train_op.append(output)
678
679    g = GetOptimizedGraph()
680    self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
681
682  def _assertNotAccumulated(self, while_op, index):
683    """Asserts that `while_op` input at `index` is not accumulated."""
684    body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
685    placeholder = body_graph.inputs[index]
686    self.assertNotIn("TensorListPushBack",
687                     [op.type for op in placeholder.consumers()])
688
689  @test_util.enable_control_flow_v2
690  @test_util.run_deprecated_v1
691  @test_util.enable_output_all_intermediates
692  def testDoNotOutputLoopCounterAsIntermediate(self):
693    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
694    v = constant_op.constant(5.0, name="v")
695    r = control_flow_ops.while_loop(
696        lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
697    # Skip over Identity.
698    while_op = r.op.inputs[0].op
699    self._assertNotAccumulated(while_op, 0)
700
701  @test_util.enable_control_flow_v2
702  @test_util.run_deprecated_v1
703  @test_util.enable_output_all_intermediates
704  def testDoNotOutputLoopInvariantAsIntermediate(self):
705    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
706
707    def GetInputIndex(op, tensor):
708      for index, inp in enumerate(op.inputs):
709        if inp is tensor:
710          return index
711
712    v = constant_op.constant(5.0, name="v")
713    r = control_flow_ops.while_loop(
714        lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
715    # Skip over Identity.
716    while_op = r.op.inputs[0].op
717    # We can't directly use while_op.inputs.index() because Tensors are not
718    # hashshable.
719    index = GetInputIndex(while_op, v)
720    self._assertNotAccumulated(while_op, index)
721
722  @test_util.run_deprecated_v1
723  def testCaptureExternalTensorInCond(self):
724    x = constant_op.constant(2.)
725    y = constant_op.constant(1.)
726    ret = while_loop_v2(
727        lambda v: v + y < 9.,
728        lambda v: v * 3., [x],
729        return_same_structure=False)
730    grad = gradients_impl.gradients(ret, [x])
731    with self.cached_session():
732      self.assertEqual(self.evaluate(ret), 18.)
733      self.assertSequenceEqual(self.evaluate(grad), [9.])
734
735  @test_util.run_deprecated_v1
736  def testCaptureExternalTensorInBody(self):
737    x = constant_op.constant(2.)
738    y = constant_op.constant(3.)
739    ret = while_loop_v2(
740        lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False)
741    grad = gradients_impl.gradients(ret, [x])
742    with self.cached_session():
743      self.assertEqual(self.evaluate(ret), 18.)
744      self.assertSequenceEqual(self.evaluate(grad), [9.])
745
746  @test_util.run_deprecated_v1
747  def testLoopWithTensorListPushBack(self):
748    x = constant_op.constant(2.)
749
750    tensor_list = list_ops.empty_tensor_list(
751        element_dtype=dtypes.float32, element_shape=ScalarShape())
752
753    def Cond(x, tl):
754      del tl  # Unused for Cond.
755      return x < 5.
756
757    def Body(x, tl):
758      tl = list_ops.tensor_list_push_back(tl, x)
759      tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
760      return x**2., tl
761
762    ret = while_loop_v2(
763        Cond, Body, [x, tensor_list], return_same_structure=False)
764    grad = gradients_impl.gradients(ret[0], x)
765    with self.cached_session() as sess:
766      self.assertEqual(sess.run(ret[0]), 16.)
767      self.assertSequenceEqual(self.evaluate(grad), [32.])
768
769  @test_util.run_deprecated_v1
770  def testDuplicateAccumulator(self):
771    x = constant_op.constant(2.)
772
773    tensor_list = list_ops.empty_tensor_list(
774        element_dtype=dtypes.float32, element_shape=ScalarShape())
775
776    def Cond(x, tl):
777      del tl  # Unused for Cond.
778      return x < 5.
779
780    def Body(x, tl):
781      # There is an accumulator in the loop already so we should not add
782      # another.
783      tl = list_ops.tensor_list_push_back(tl, x)
784      return x**2., tl
785
786    ret = while_loop_v2(
787        Cond, Body, [x, tensor_list], return_same_structure=False)
788
789    for op in ops.get_default_graph().get_operations():
790      if op.type == "While" or op.type == "StatelessWhile":
791        while_op = op
792
793    body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
794    x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0]
795    x_input_t = body_graph.inputs[x_input_index]
796    accumulator_count = len(
797        [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
798    self.assertEqual(accumulator_count, 1)
799
800    grad = gradients_impl.gradients(ret[0], x)
801    with self.cached_session() as sess:
802      self.assertEqual(sess.run(ret[0]), 16.)
803      self.assertSequenceEqual(self.evaluate(grad), [32.])
804
805  @parameterized.named_parameters(
806      ("UnknownShape", None),
807      ("PartiallyDefinedShape", [None, 2]),
808      ("FullyDefinedShape", [1, 2]),
809  )
810  @test_util.run_deprecated_v1
811  def testAccumulatorElementShape(self, shape):
812
813    def MatchShape(actual_tensor_shape):
814      # Compare the shapes, treating None dimensions as equal. We do not
815      # directly check actual_tensor_shape and tf.TensorShape(shape) for
816      # equality because tf.Dimension.__eq__ returns None if either dimension is
817      # None.
818      if shape is None:
819        self.assertIsNone(actual_tensor_shape.dims)
820      else:
821        self.assertListEqual(actual_tensor_shape.as_list(), shape)
822
823    def GetAccumulatorForInputAtIndex(while_op, idx):
824      body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
825      y_input_t = body_graph.inputs[idx]
826      push_back_node = [c for c in y_input_t.consumers()
827                        if c.type == "TensorListPushBack"][0]
828      output_idx = body_graph.outputs.index(push_back_node.outputs[0])
829      return while_op.outputs[output_idx]
830
831    x = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
832    y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
833
834    # Forward pass.
835    ret = while_loop_v2(lambda v, u: v < 8.,
836                        lambda v, u: (math_ops.pow(v, u), u),
837                        [x, y],
838                        return_same_structure=True)
839    while_op = ret[0].op.inputs[0].op
840    # Gradient pass.
841    grad = gradients_impl.gradients(ret[0], x)
842    # Note: There is an Identity b/w grad[0] and the While op.
843    grad_while_op = grad[0].op.inputs[0].op
844
845    # Get the TensorList output of While op containing the accumulated values
846    # of y.
847    x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0]
848    output = GetAccumulatorForInputAtIndex(while_op, x_input_index)
849    _, val = list_ops.tensor_list_pop_back(output,
850                                           element_dtype=dtypes.float32)
851    MatchShape(val.shape)
852
853    # Take second derivative to generate intermediate grad_while_op outputs
854    gradients_impl.gradients(grad, x)
855
856    # Get the TensorList output of gradient While op containing the accumulated
857    # values of grad_x (note that grad_x is needed by the second derivative).
858    # grad_while_op.inputs:
859    grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0])
860    grad_output = GetAccumulatorForInputAtIndex(grad_while_op,
861                                                grad_output_index)
862    _, val = list_ops.tensor_list_pop_back(grad_output,
863                                           element_dtype=dtypes.float32)
864    MatchShape(val.shape)
865
866  def _createWhile(self, name):
867    """Helper function testDefaultName."""
868    output = while_v2.while_loop(
869        lambda i: i < 3,
870        lambda i: i + 1, [constant_op.constant(0)],
871        return_same_structure=False)
872    while_op = output.op.inputs[0].op
873    self.assertEqual(while_op.type, "StatelessWhile")
874    return while_op
875
876  def testDefaultName(self):
877    with ops.Graph().as_default():
878      while_op = self._createWhile(None)
879      self.assertEqual(while_op.name, "while")
880      self.assertRegexpMatches(
881          while_op.get_attr("cond").name, r"while_cond_\d*")
882      self.assertRegexpMatches(
883          while_op.get_attr("body").name, r"while_body_\d*")
884
885    with ops.Graph().as_default():
886      with ops.name_scope("foo"):
887        while1_op = self._createWhile("")
888        self.assertEqual(while1_op.name, "foo/while")
889        self.assertRegexpMatches(
890            while1_op.get_attr("cond").name, r"foo_while_cond_\d*")
891        self.assertRegexpMatches(
892            while1_op.get_attr("body").name, r"foo_while_body_\d*")
893
894        while2_op = self._createWhile(None)
895        self.assertEqual(while2_op.name, "foo/while_1")
896        self.assertRegexpMatches(
897            while2_op.get_attr("cond").name, r"foo_while_1_cond_\d*")
898        self.assertRegexpMatches(
899            while2_op.get_attr("body").name, r"foo_while_1_body_\d*")
900
901  @test_util.enable_control_flow_v2
902  @test_util.run_deprecated_v1
903  def testWhileAndTensorArray(self):
904    param = constant_op.constant(2.0)
905    y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
906    # map_fn uses TensorArray internally.
907    r = map_fn.map_fn(lambda x: math_ops.multiply(x, param), y0)
908    grad = gradients_impl.gradients(r, param)[0]
909    self.assertAllClose([2.0, 4.0, 6.0, 8.0, 10.0, 12.0], self.evaluate(r))
910    self.assertAllClose(21.0, self.evaluate(grad))
911
912  @test_util.run_deprecated_v1
913  def testNestedWhile(self):
914    # Compute sum of geometric progression: n^0 + n^1 + ... + n^m
915    # We compute the pow using a while loop.
916    n = constant_op.constant(3.)
917    m = constant_op.constant(5.)
918    sum_of_powers = constant_op.constant(0.)
919
920    def Body(i, previous_sum):
921      prod = constant_op.constant(1.)
922      return i - 1., previous_sum + while_loop_v2(
923          lambda c, _: c > 0,
924          lambda c, v: (c - 1., v * n), [i, prod],
925          return_same_structure=False)[1]
926
927    result = while_loop_v2(
928        lambda i, _: i >= 0,
929        Body, [m, sum_of_powers],
930        return_same_structure=False)[1]
931    grad = gradients_impl.gradients(result, [n])
932    self.assertEqual(self.evaluate(result), 364.)
933    self.assertSequenceEqual(self.evaluate(grad), [547.])
934
935  @test_util.run_deprecated_v1
936  def testNestedWhileWithLegacyDefun(self):
937    n = constant_op.constant(3.)
938    m = constant_op.constant(5.)
939    sum_of_powers = constant_op.constant(0.)
940
941    def Body(i, previous_sum):
942      prod = constant_op.constant(1.)
943
944      def InnerBodyWrapper(c, v):
945
946        @function.Defun(dtypes.float32, dtypes.float32)
947        def InnerBody(c, v):
948          return c - 1., v * n
949
950        results = InnerBody(c, v)
951        results[0].set_shape([])
952        results[1].set_shape([])
953        return results
954
955      return i - 1., previous_sum + while_loop_v2(
956          lambda c, _: c > 0,
957          InnerBodyWrapper, [i, prod],
958          return_same_structure=False)[1]
959
960    result = while_loop_v2(
961        lambda i, _: i >= 0,
962        Body, [m, sum_of_powers],
963        return_same_structure=False)[1]
964    grad = gradients_impl.gradients(result, [n])
965    self.assertEqual(self.evaluate(result), 364.)
966    self.assertSequenceEqual(self.evaluate(grad), [547.])
967
968  @test_util.run_deprecated_v1
969  def testIdentityNodeInBody(self):
970
971    def Body(v):
972      v = array_ops.identity(v)
973      v = array_ops.identity(v)
974      return v * v
975
976    x = constant_op.constant(2.)
977    ret = while_loop_v2(
978        lambda v: v < 8., Body, [x], return_same_structure=False)
979    grad = gradients_impl.gradients(ret, [x])
980    self.assertEqual(self.evaluate(ret), 16.)
981    self.assertSequenceEqual(self.evaluate(grad), [32.])
982
983  @test_util.run_deprecated_v1
984  def testForwardPassRewrite(self):
985    x = constant_op.constant(1.0, name="x")
986    output = while_v2.while_loop(lambda x: x < 10.0,
987                                 lambda x: x * 2.0,
988                                 [x])[0]
989    while_op = output.op.inputs[0].op
990    self.assertEqual(while_op.type, "StatelessWhile")
991    # outputs = [loop_counter, max_iters, x]
992    self.assertLen(while_op.outputs, 3)
993
994    gradients_impl.gradients(output, x)
995    # while_op should have been rewritten to output intermediates.
996    # outputs = [loop_counter, max_iters, x, x_accumulator]
997    self.assertLen(while_op.outputs, 4)
998
999    gradients_impl.gradients(output, x)
1000    # Computing the gradient again shouldn't rewrite while_op again.
1001    self.assertLen(while_op.outputs, 4)
1002
1003  @parameterized.named_parameters(
1004      ("RandomUniform", random_ops.random_uniform, [5, 3]),
1005      ("RandomNormal", random_ops.random_normal, [5, 3]),
1006      ("ParameterizedTruncatedNormal",
1007       random_ops.parameterized_truncated_normal, [5, 3]),
1008      ("TruncatedNormal", random_ops.truncated_normal, [5, 3]),
1009      ("RandomGamma", random_gamma, [5, 3]),
1010      ("RandomPoissonV2", random_poisson_v2, [5, 3]),
1011      ("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]),
1012      ("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]),
1013  )
1014  @test_util.run_deprecated_v1
1015  def testRandomOpsShape(self, random_fn, expected_shape):
1016    shape = constant_op.constant([3])
1017
1018    def Body(i, u):
1019      shape_extended = array_ops.concat([[5], shape], axis=0)
1020      u = random_fn(shape_extended)
1021      assert u.shape.as_list() == expected_shape, str(u.shape.as_list())
1022      return i + 1, u
1023
1024    _, _ = while_loop_v2(
1025        cond=lambda i, _: i < 3,
1026        body=Body,
1027        loop_vars=[
1028            0,
1029            array_ops.zeros(expected_shape, dtype=dtypes.float32),
1030        ])
1031
1032  @test_util.run_deprecated_v1
1033  def testReshapeShape(self):
1034    shape = constant_op.constant([3, 4])
1035
1036    def Body(i, u):
1037      shape_extended = array_ops.concat([[5], shape], axis=0)
1038      u = array_ops.reshape(u, [-1])
1039      assert u.shape.as_list() == [60], str(u.shape.as_list())
1040      u = array_ops.reshape(u, shape_extended)
1041      assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
1042      return i + 1, u
1043
1044    _, _ = while_loop_v2(
1045        cond=lambda i, _: i < 3,
1046        body=Body,
1047        loop_vars=[
1048            0,
1049            array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
1050        ])
1051
1052  @parameterized.named_parameters(
1053      ("Zeros", array_ops.zeros),
1054      ("Ones", array_ops.ones),
1055      ("Fill", fill),
1056  )
1057  @test_util.run_deprecated_v1
1058  def testFillOpsShape(self, fill_fn):
1059    shape = constant_op.constant([3, 4])
1060
1061    def Body(i, u):
1062      shape_extended = array_ops.concat([[5], shape], axis=0)
1063      u = fill_fn(shape_extended)
1064      assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
1065      return i + 1, u
1066
1067    _, _ = while_loop_v2(
1068        cond=lambda i, _: i < 3,
1069        body=Body,
1070        loop_vars=[
1071            0,
1072            array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
1073        ])
1074
1075  @test_util.run_deprecated_v1
1076  def testExternalColocationGrad(self):
1077    external_t = constant_op.constant(2.)
1078    v0 = constant_op.constant(2.)
1079
1080    def Body(v):
1081      with ops.colocate_with(external_t):
1082        return v * v
1083
1084    ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
1085    grad = gradients_impl.gradients(ret, [v0])[0]
1086    self.assertAllEqual(ret, 16.)
1087    self.assertAllEqual(grad, 32.)
1088
1089  @test_util.run_deprecated_v1
1090  def testDoNotAccumulateConstNodes(self):
1091
1092    def Body(v):
1093      return v * 2.0
1094
1095    v0 = constant_op.constant(2.)
1096    ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
1097    # Gradients computation has the side-effect of updating the forward op
1098    # which is what we want to test.
1099    unused_grad = gradients_impl.gradients(ret, [v0])[0]
1100    # ret is separated from the `While` op by an `Identity` so we skip over
1101    # that.
1102    forward_while_op = ret.op.inputs[0].op
1103    body_graph = while_v2._get_graph(forward_while_op, "body", "_body_graph")
1104    push_back_nodes = [
1105        o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
1106    ]
1107    # Gradient of `Mul` requires accumulating both its inputs. But since one
1108    # of those is a Const (2.0), we should have just one accumulator.
1109    self.assertLen(push_back_nodes, 1)
1110
1111  def testDoNotAccumulateForwardTensorsForReductionOps(self):
1112
1113    @def_function.function
1114    def Fn():
1115      with backprop.GradientTape() as tape:
1116        x = constant_op.constant(2.)
1117        tape.watch(x)
1118
1119        def Body(i, x):
1120          forward_graph = ops.get_default_graph()
1121
1122          @custom_gradient.custom_gradient
1123          def SquaredWithZeroGrad(x):
1124
1125            def Grad(unused_g, variables=None):  # pylint: disable=redefined-outer-name
1126              del variables
1127              gradient_graph = ops.get_default_graph()
1128              shape = gen_array_ops.shape(x)
1129              assert shape.graph is forward_graph
1130              rank = gen_array_ops.rank(x)
1131              assert rank.graph is forward_graph
1132              size = gen_array_ops.size(x)
1133              assert size.graph is forward_graph
1134              zeros = array_ops.zeros(shape)
1135              assert zeros.graph is gradient_graph
1136              return zeros
1137
1138            return x * 2, Grad
1139
1140          return i + 1, SquaredWithZeroGrad(x)
1141
1142        _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
1143      grad = tape.gradient(result, x)
1144      return grad
1145
1146    Fn()
1147
1148
1149def ScalarShape():
1150  return ops.convert_to_tensor([], dtype=dtypes.int32)
1151
1152
1153def GetOptimizedGraph():
1154  mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
1155  config = config_pb2.ConfigProto()
1156  config.graph_options.rewrite_options.CopyFrom(
1157      rewriter_config_pb2.RewriterConfig(
1158          constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
1159          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
1160  return tf_optimizer.OptimizeGraph(config, mg)
1161
1162
1163if __name__ == "__main__":
1164  test.main()
1165