• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for while_v2."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22
23from google.protobuf import text_format
24from tensorflow.core.framework import graph_pb2
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.core.protobuf import rewriter_config_pb2
27from tensorflow.python.eager import backprop
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import function
33from tensorflow.python.framework import importer
34from tensorflow.python.framework import meta_graph
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_shape
37from tensorflow.python.framework import test_util
38from tensorflow.python.grappler import tf_optimizer
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import control_flow_util
42from tensorflow.python.ops import control_flow_util_v2
43from tensorflow.python.ops import control_flow_v2_toggles
44from tensorflow.python.ops import custom_gradient
45from tensorflow.python.ops import gen_array_ops
46from tensorflow.python.ops import gen_list_ops
47from tensorflow.python.ops import gradient_checker_v2
48from tensorflow.python.ops import gradients_impl
49from tensorflow.python.ops import list_ops
50from tensorflow.python.ops import map_fn
51from tensorflow.python.ops import math_ops
52from tensorflow.python.ops import random_ops
53from tensorflow.python.ops import variables
54from tensorflow.python.ops import while_v2
55from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
56from tensorflow.python.platform import test
57
58
59def random_gamma(shape):  # pylint: disable=invalid-name
60  return random_ops.random_gamma(shape, 1.0)
61
62
63def random_gamma_with_alpha_beta(shape):  # pylint: disable=invalid-name
64  return random_ops.random_gamma(
65      shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]])
66
67
68def random_poisson_v2(shape):  # pylint: disable=invalid-name
69  return random_ops.random_poisson_v2(shape, 1.0)
70
71
72def random_poisson_v2_with_lam(shape):  # pylint: disable=invalid-name
73  return random_ops.random_poisson_v2(shape, [12.2, 3.3])
74
75
76def fill(shape):  # pylint: disable=invalid-name
77  return array_ops.fill(shape, 1.0)
78
79
80class WhileV2Test(test.TestCase, parameterized.TestCase):
81
82  @test_util.run_deprecated_v1
83  def testSingleLoopVar(self):
84    x = constant_op.constant(2.)
85    ret = while_loop_v2(
86        lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
87    grad = gradients_impl.gradients(ret, [x])
88    with self.cached_session():
89      self.assertEqual(self.evaluate(ret), 16.)
90      self.assertSequenceEqual(self.evaluate(grad), [32.])
91
92  @test_util.run_deprecated_v1
93  def testSingleLoopVarBackPropFalse(self):
94    x = constant_op.constant(2.)
95    ret = while_loop_v2(
96        lambda v: v < 8.,
97        lambda v: v * v, [x],
98        return_same_structure=False,
99        back_prop=False)
100    grad = gradients_impl.gradients(ret, [x])
101    self.assertEqual(grad, [None])
102    with self.cached_session():
103      self.assertEqual(self.evaluate(ret), 16.)
104
105  @test_util.run_deprecated_v1
106  def testCustomGradient(self):
107    x = constant_op.constant(2.)
108    n = constant_op.constant(1., name="const-n")
109    m = variables.Variable(1.0)
110    self.evaluate(variables.global_variables_initializer())
111
112    def body_fn(v):  # pylint: disable=invalid-name
113
114      @custom_gradient.custom_gradient
115      def inner_fn(v):  # pylint: disable=invalid-name
116
117        def grad_fn(dy, variables=None):  # pylint: disable=invalid-name, unused-argument, redefined-outer-name
118          return dy * 2 * v * n * m, [v * v]
119
120        return v * v * m, grad_fn
121
122      return inner_fn(v)
123
124    ret = while_loop_v2(
125        lambda v: v < 8., body_fn, [x], return_same_structure=False)
126    grad = gradients_impl.gradients(ret, [x])
127    with self.cached_session():
128      self.assertEqual(self.evaluate(ret), 16.)
129      self.assertSequenceEqual(self.evaluate(grad), [32.])
130
131  @test_util.run_v1_only("b/120545219")
132  def testReturnSameStructureTrue(self):
133    x = constant_op.constant(2.)
134    ret = while_loop_v2(
135        lambda v: v < 8., lambda v: v * v, [x], return_same_structure=True)
136    grad = gradients_impl.gradients(ret, [x])
137    with self.cached_session() as sess:
138      eval_result = sess.run(ret)
139      self.assertIsInstance(eval_result, list)
140      self.assertLen(eval_result, 1)
141      self.assertEqual(16., eval_result[0])
142      self.assertSequenceEqual(sess.run(grad), [32.])
143
144  def testVerifyInputOutputTypesMatch(self):
145
146    @def_function.function
147    def BuildWhile():
148      x = constant_op.constant(1., dtypes.float32)
149
150      def Body(x):
151        return math_ops.cast(x, dtypes.float16) + 1
152
153      while_loop_v2(lambda x: x < 10, Body, [x])
154
155    with self.assertRaisesRegex(
156        TypeError,
157        r"Loop var Const:0 enters the loop with type <dtype: 'float32'> "
158        r"but has type <dtype: 'float16'> after 1 iteration."):
159      BuildWhile()
160
161  @parameterized.parameters(dtypes.float32, dtypes.float64)
162  def testGradientTapeResourceVariable(self, dtype):
163    with context.eager_mode():
164      v = variables.Variable(1., dtype=dtype)
165
166      @def_function.function
167      def fnWithLoop():  # pylint: disable=invalid-name
168        with backprop.GradientTape() as tape:
169          _, x = while_loop_v2(
170              lambda i, _: i < 2,
171              lambda i, x: (i + 1, x * v),
172              [0, constant_op.constant(2., dtype=dtype)])
173        return tape.gradient(x, v)
174
175      self.assertAllEqual(fnWithLoop(), 4.0)
176
177  def checkIteratedGradients(self, func):
178    with context.eager_mode():
179
180      def _Grad(f):
181        def _GradFunction(primal):
182          with backprop.GradientTape() as tape:
183            tape.watch(primal)
184            primal_out = f(primal)
185          return tape.gradient(primal_out, primal)
186        return _GradFunction
187
188      f = func
189      one = constant_op.constant(1.)
190
191      for _ in range(3):
192        theoretical, numerical = gradient_checker_v2.compute_gradient(
193            def_function.function(f), [one])
194        self.assertAllClose(theoretical, numerical, rtol=1e-3)
195        f = _Grad(f)
196        self.assertAllClose(array_ops.reshape(numerical, []),
197                            def_function.function(f)(one),
198                            rtol=1e-3)
199
200  def testIteratedGradients(self):
201
202    def _Func(x):
203      _, z = while_loop_v2(
204          lambda i, _: i < 2,
205          lambda i, y: (i + 1, math_ops.cos(y)),
206          [0, x])
207      return z
208
209    self.checkIteratedGradients(_Func)
210
211  def testIteratedGradientsWithList(self):
212
213    def _Func(x):
214      results = list_ops.empty_tensor_list(
215          element_shape=[], element_dtype=dtypes.float32)
216
217      def _LoopBody(i, y, handle):
218        return (i + 1, math_ops.cos(y),
219                list_ops.tensor_list_push_back(handle, y))
220
221      _, z, results = while_loop_v2(
222          lambda i, _, h: i < 2, _LoopBody, [0, x, results])
223      return z + math_ops.reduce_sum(list_ops.tensor_list_stack(
224          results, dtypes.float32))
225
226    self.checkIteratedGradients(_Func)
227
228  def testGradWhileGradWhileWithVariable(self):
229    with context.eager_mode():
230      v = variables.Variable(1.)
231
232      @def_function.function
233      def _Func(x):
234
235        def _Inner(a):
236          with backprop.GradientTape() as tape:
237            tape.watch(a)
238            _, b = while_loop_v2(
239                lambda i, _: i < 2,
240                lambda i, y: (i + 1, math_ops.cos(v + y)),
241                [0, a])
242          return tape.gradient(b, a)
243
244        _, z = while_loop_v2(
245            lambda i, _: i < 2,
246            lambda i, y: (i + 1, _Inner(y)),
247            [0, x])
248        return z
249
250      with backprop.GradientTape(persistent=True) as tape:
251        x = constant_op.constant(1.)
252        tape.watch(x)
253        y = _Func(x)
254      dx, _ = tape.gradient(y, [x, v])
255      theoretical, numerical = gradient_checker_v2.compute_gradient(
256          _Func, [x])
257      self.assertAllClose(numerical, theoretical, rtol=1e-3)
258      self.assertAllClose(array_ops.reshape(numerical, []),
259                          dx, rtol=1e-3)
260
261  def testThreeNestWithLists(self):
262    with context.eager_mode():
263      def _WrapInWhile(f):
264        def _Wrapped(x):
265          results = list_ops.empty_tensor_list(
266              element_shape=[], element_dtype=dtypes.float32)
267
268          def _LoopBody(i, y, handle):
269            return (i + 1, f(math_ops.cos(y)),
270                    list_ops.tensor_list_push_back(handle, y))
271
272          _, z, results = control_flow_ops.while_loop(
273              lambda i, _, h: i < 2, _LoopBody, [0, x, results])
274          return z + math_ops.reduce_sum(list_ops.tensor_list_stack(
275              results, dtypes.float32))
276        return _Wrapped
277
278      f = math_ops.sin
279
280      target_function = _WrapInWhile(_WrapInWhile(_WrapInWhile(f)))
281
282      @def_function.function
283      def _TapeFromGraphMode(x):
284        with backprop.GradientTape(persistent=True) as tape:
285          tape.watch(x)
286          y = target_function(x)
287        return tape.gradient(y, x)
288
289      x = constant_op.constant(1.)
290      dx = _TapeFromGraphMode(x)
291      theoretical, numerical = gradient_checker_v2.compute_gradient(
292          target_function, [x])
293      self.assertAllClose(numerical, theoretical, rtol=3e-3)
294      self.assertAllClose(array_ops.reshape(numerical, []), dx, rtol=3e-3)
295
296  def testDeviceLabelsInherited(self):
297    def _LoopBody(i, y):
298      result = math_ops.cos(y)
299      self.assertIn("CPU:10", result.device)
300      with ops.device("CPU:11"):
301        result = array_ops.identity(result)
302      self.assertIn("CPU:11", result.device)
303      return i + 1, result
304
305    @def_function.function
306    def _FunctionWithWhileLoop():
307      x = constant_op.constant(1.)
308      with ops.device("CPU:10"):
309        _, z = while_loop_v2(
310            lambda i, _: i < 2,
311            _LoopBody,
312            [0, x])
313      return z
314    # The test assertion runs at trace time.
315    _FunctionWithWhileLoop.get_concrete_function()
316
317  def testExternalControlDependencies(self):
318    with ops.Graph().as_default(), self.test_session():
319      v = variables.Variable(1.)
320      self.evaluate(v.initializer)
321      op = v.assign_add(1.)
322
323      def body_fn(i):  # pylint: disable=invalid-name
324        with ops.control_dependencies([op]):
325          return i + 1
326
327      loop = while_loop_v2(lambda i: i < 1, body_fn, [0])
328      loop[0].op.run()
329      self.assertAllEqual(self.evaluate(v), 2.0)
330
331  @test_util.run_deprecated_v1
332  def testMultipleLoopVarsBasic(self):
333    x = constant_op.constant(5.)
334    y = constant_op.constant(3.)
335
336    # x = 5.
337    # y = 3.
338    # while x < 45.:
339    #   x = x * y
340    ret = while_loop_v2(
341        lambda v, _: v < 45.,
342        lambda v, w: (v * w, w), [x, y],
343        return_same_structure=False)
344    # ret = [x*y^2, y]
345
346    # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
347    grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
348    with self.cached_session():
349      self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
350      self.assertSequenceEqual(self.evaluate(grad), [9.])
351
352  @test_util.run_deprecated_v1
353  def testMultipleLoopNonscalarCond(self):
354    x = constant_op.constant([[5.]])
355    y = constant_op.constant(3.)
356
357    # x = 5.
358    # y = 3.
359    # while x < 45.:
360    #   x = x * y
361    ret = while_loop_v2(
362        lambda v, _: v < 45.,
363        lambda v, w: (v * w, w), [x, y],
364        return_same_structure=False)
365    # ret == [x*y^2, y]
366
367    # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
368    grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
369    with self.cached_session():
370      self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
371      self.assertSequenceEqual(self.evaluate(grad), [9.])
372
373  @test_util.run_deprecated_v1
374  def testMultipleLoopVars(self):
375    x = constant_op.constant(5.)
376    y = constant_op.constant(3.)
377
378    # x = 5.
379    # y = 3.
380    # while x < 45.:
381    #   x = x * y
382    #   y = x + y
383    ret = while_loop_v2(
384        lambda v, _: v < 45.,
385        lambda v, w: (v * w, v + w), [x, y],
386        return_same_structure=False)
387    # ret = [y*x**2 + x*y**2, x*y + x + y]
388
389    gradx_0 = gradients_impl.gradients(ret[0], [x])  # [2*x*y + y**2]
390    gradx_1 = gradients_impl.gradients(ret[1], [x])  # [y + 1]
391    gradx_2 = gradients_impl.gradients(ret, [x])  # [2*x*y + y**2 + 2*y + 1]
392    grady_0 = gradients_impl.gradients(ret[0], [y])  # [2*x*y + x**2]
393    grady_1 = gradients_impl.gradients(ret[1], [y])  # [x + 1]
394    grady_2 = gradients_impl.gradients(ret, [y])  # [2*x*y + x**2 + x + 1]
395    with self.cached_session():
396      self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
397      self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
398      self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
399      self.assertSequenceEqual(self.evaluate(gradx_2), [43.])
400      self.assertSequenceEqual(self.evaluate(grady_0), [55.])
401      self.assertSequenceEqual(self.evaluate(grady_1), [6.])
402      self.assertSequenceEqual(self.evaluate(grady_2), [61.])
403
404  @test_util.run_deprecated_v1
405  def testGradientTape(self):
406    with backprop.GradientTape() as t:
407      x = constant_op.constant(2.)
408      t.watch(x)
409      ret = while_loop_v2(
410          lambda v: v < 4., lambda v: v * v, [x],
411          return_same_structure=False)  # x**2
412    grad = t.gradient(ret, x)
413    with self.cached_session() as sess:
414      self.assertAllEqual(sess.run(grad), 4.0)
415
416  @test_util.run_deprecated_v1
417  def testMultipleWhileLoops(self):
418    x = constant_op.constant(2.)
419    ret1 = while_loop_v2(
420        lambda v: v < 4., lambda v: v * v, [x],
421        return_same_structure=False)  # x**2
422    ret2 = while_loop_v2(
423        lambda v: v < 16., lambda v: v * v, [ret1],
424        return_same_structure=False)  # x**4
425    grad = gradients_impl.gradients(ret2, [x])  # 4x**3
426    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
427    with self.cached_session():
428      self.assertSequenceEqual(self.evaluate(grad), [32.])
429      self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
430
431  def testMultipleWhileLoopsWithFunc(self):
432    x = constant_op.constant(2.)
433
434    @def_function.function
435    def Fn():
436      ret1 = while_loop_v2(
437          lambda v: v < 4.,
438          lambda v: v * v, [x],
439          return_same_structure=False,
440          name="while_1")  # x**2
441      ret2 = while_loop_v2(
442          lambda v: v < 16.,
443          lambda v: v * v, [x],
444          return_same_structure=False,
445          name="while_2")  # x**4
446      return ret1, ret2
447
448    concrete_fn = Fn.get_concrete_function()
449    while_1 = concrete_fn.graph.get_operation_by_name("while_1")
450    while_2 = concrete_fn.graph.get_operation_by_name("while_2")
451    self.assertEqual(while_1.type, "StatelessWhile")
452    self.assertEqual(while_2.type, "StatelessWhile")
453    self.assertEmpty(while_1.control_inputs)
454    self.assertEmpty(while_2.control_inputs)
455
456  def testMultipleWhileLoopsGradStateless(self):
457
458    @def_function.function
459    def Fn():
460      x = constant_op.constant(2.)
461      with backprop.GradientTape() as tape:
462        tape.watch(x)
463        ret1 = while_loop_v2(
464            lambda v: v < 4.,
465            lambda v: v * v, [x],
466            return_same_structure=False,
467            name="while_1")  # x**2
468        ret2 = while_loop_v2(
469            lambda v: v < 16.,
470            lambda v: v * v, [x],
471            return_same_structure=False,
472            name="while_2")  # x**4
473        loss = ret1 + ret2
474      return tape.gradient(loss, x)
475
476    graph = Fn.get_concrete_function().graph
477    while_ops = [op for op in graph.get_operations() if "While" in op.type]
478    self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4,
479                        "Must have exactly 4 StatelessWhile ops.")
480    for op in while_ops:
481      self.assertEmpty(op.control_inputs,
482                       "{} should not have any control inputs".format(op.name))
483
484  def testMultipleWhileLoopsWithDeps(self):
485    x = variables.Variable(2.)
486    c = constant_op.constant(2.)
487
488    @def_function.function
489    def Fn():
490
491      def Body1(v):
492        x.assign(x)
493        return v * x
494
495      ret1 = while_loop_v2(
496          lambda v: v < 4.,
497          Body1, [c],
498          return_same_structure=False,
499          name="while_1")  # 2x
500
501      def Body2(v):
502        x.assign(x)
503        return v * x * x
504
505      ret2 = while_loop_v2(
506          lambda v: v < 16.,
507          Body2, [c],
508          return_same_structure=False,
509          name="while_2")  # 4x
510      return ret1, ret2
511
512    concrete_fn = Fn.get_concrete_function()
513    while_1 = concrete_fn.graph.get_operation_by_name("while_1")
514    while_2 = concrete_fn.graph.get_operation_by_name("while_2")
515    self.assertEqual(while_1.type, "While")
516    self.assertEqual(while_2.type, "While")
517    self.assertEmpty(while_1.control_inputs)
518    self.assertLen(while_2.control_inputs, 1)
519    self.assertIs(while_2.control_inputs[0], while_1)
520
521  def testMultipleWhileLoopsWithVarsDeps(self):
522    x1 = variables.Variable(2.)
523    x2 = variables.Variable(3.)
524    c = constant_op.constant(2.)
525
526    @def_function.function
527    def Fn():
528
529      def Body1(v):
530        x1.assign(x1)
531        return v * x1
532
533      ret1 = while_loop_v2(
534          lambda v: v < 4.,
535          Body1, [c],
536          return_same_structure=False,
537          name="while_1")  # 2x
538
539      def Body2(v):
540        x1.assign(x1)
541        return v * x1 * x1
542
543      ret2 = while_loop_v2(
544          lambda v: v < 16.,
545          Body2, [c],
546          return_same_structure=False,
547          name="while_2")  # 4x
548
549      def Body3(v):
550        x2.assign(x2)
551        return v * x2
552
553      ret3 = while_loop_v2(
554          lambda v: v < 4.,
555          Body3, [c],
556          return_same_structure=False,
557          name="while_3")  # 3x
558
559      def Body4(v):
560        x2.assign(x2)
561        return v * x2 * x2
562
563      ret4 = while_loop_v2(
564          lambda v: v < 16.,
565          Body4, [c],
566          return_same_structure=False,
567          name="while_4")  # 9x
568      ret5 = while_loop_v2(
569          lambda v: v < 16.,
570          lambda v: v * v, [c],
571          return_same_structure=False,
572          name="while_stateless")  # x**2
573      return ret1, ret2, ret3, ret4, ret5
574
575    concrete_fn = Fn.get_concrete_function()
576    while_1 = concrete_fn.graph.get_operation_by_name("while_1")
577    while_2 = concrete_fn.graph.get_operation_by_name("while_2")
578    while_3 = concrete_fn.graph.get_operation_by_name("while_3")
579    while_4 = concrete_fn.graph.get_operation_by_name("while_4")
580    while_stateless = concrete_fn.graph.get_operation_by_name(
581        "while_stateless")
582    self.assertEqual(while_1.type, "While")
583    self.assertEqual(while_2.type, "While")
584    self.assertEqual(while_3.type, "While")
585    self.assertEqual(while_4.type, "While")
586    self.assertEqual(while_stateless.type, "StatelessWhile")
587    self.assertEmpty(while_1.control_inputs)
588    self.assertLen(while_2.control_inputs, 1)
589    self.assertIs(while_2.control_inputs[0], while_1)
590    self.assertEmpty(while_3.control_inputs)
591    self.assertLen(while_4.control_inputs, 1)
592    self.assertIs(while_4.control_inputs[0], while_3)
593    self.assertEmpty(while_stateless.control_inputs)
594
595  @test_util.run_deprecated_v1
596  def testDoubleDerivative(self):
597    x = constant_op.constant(2.)
598    ret = while_loop_v2(
599        lambda v: v < 8., lambda v: v**2, [x],
600        return_same_structure=False)  # x**4
601    grad = gradients_impl.gradients(ret, [x])  # 4x**3
602    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
603    with self.cached_session():
604      self.assertEqual(self.evaluate(ret), 16.)
605      self.assertSequenceEqual(self.evaluate(grad), [32.])
606      self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
607
608  @test_util.run_v2_only
609  def testMultipleWhileLoopsEager(self):
610
611    @def_function.function
612    def Func():
613      x = constant_op.constant(2.)
614      ret1 = while_loop_v2(
615          lambda v: v < 4., lambda v: v * v, [x],
616          return_same_structure=False)  # x**2
617      ret2 = while_loop_v2(
618          lambda v: v < 16.,
619          lambda v: v * v, [ret1],
620          return_same_structure=False)  # x**4
621      grad = gradients_impl.gradients(ret2, [x])[0]  # 4x**3
622      grad_grad = gradients_impl.gradients(grad, [x])[0]  # 12x**2
623      return grad, grad_grad
624
625    grad, grad_grad = Func()
626    self.assertEqual(grad.numpy(), 32.)
627    self.assertEqual(grad_grad.numpy(), 48.)
628
629  @test_util.run_v2_only
630  def testDoubleDerivativeEager(self):
631
632    @def_function.function
633    def Func():
634      x = constant_op.constant(2.)
635      ret = while_loop_v2(
636          lambda v: v < 8., lambda v: v**2, [x],
637          return_same_structure=False)  # x**4
638      grad = gradients_impl.gradients(ret, [x])[0]  # 4x**3
639      grad_grad = gradients_impl.gradients(grad, [x])[0]  # 12x**2
640      return ret, grad, grad_grad
641
642    ret, grad, grad_grad = Func()
643    self.assertEqual(ret.numpy(), 16.)
644    self.assertEqual(grad.numpy(), 32.)
645    self.assertEqual(grad_grad.numpy(), 48.)
646
647  def _testPruning(self):
648    x = constant_op.constant(1)
649
650    tensor_list = list_ops.empty_tensor_list(
651        element_dtype=x.dtype, element_shape=x.shape)
652
653    def Cond(x, tl):
654      del tl  # Unused for Cond.
655      return x < 5
656
657    def Body(x, tl):
658      return x + 1, list_ops.tensor_list_push_back(tl, x)
659
660    outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
661
662    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
663    train_op.append(outputs[0])
664
665    g = GetOptimizedGraph()
666    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
667    # away, causing an extra Enter node.
668    enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
669    self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
670    # Test that the TensorList is pruned out.
671    self.assertEmpty([
672        n for n in g.node if n.op == "Enter" and
673        n.attr["T"].type == dtypes.variant.as_datatype_enum
674    ])
675    self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
676
677    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
678    train_op.append(stack)
679    g = GetOptimizedGraph()
680    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
681    # away, causing an extra Enter node.
682    enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
683    self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
684    # Test that the TensorList is not pruned out.
685    self.assertNotEmpty([
686        n for n in g.node if n.op == "Enter" and
687        n.attr["T"].type == dtypes.variant.as_datatype_enum
688    ])
689    self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
690
691  @test_util.run_deprecated_v1
692  def testPruningV1(self):
693    self._testPruning()
694
695  @test_util.enable_control_flow_v2
696  @test_util.run_deprecated_v1
697  def testPruningV2(self):
698    self._testPruning()
699
700  def _testDoNotAccumulateInvariants(self):
701    push_op = ("TensorListPushBack"
702               if control_flow_v2_toggles.control_flow_v2_enabled() else
703               "StackPushV2")
704
705    # Tests that loop invariants, i.e., tensors that are "captured" by the
706    # while loop and not passed as loop variables are not accumulated in
707    # gradient computation.
708    v = constant_op.constant(5.0, name="v")
709
710    r = control_flow_ops.while_loop(
711        lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
712
713    output = gradients_impl.gradients(r, v)[0]
714    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
715    train_op.append(output)
716
717    g = GetOptimizedGraph()
718    # The gradient for v * x requires the value of both v and x. Since v is a
719    # loop invariant it is not accumulated so we have just one accumulator for
720    # x.
721    self.assertLen([n for n in g.node if n.op == push_op], 1)
722
723  @test_util.run_deprecated_v1
724  def testDoNotAccumulateInvariantsV1(self):
725    self._testDoNotAccumulateInvariants()
726
727  @test_util.run_deprecated_v1
728  @test_util.enable_control_flow_v2
729  def testDoNotAccumulateInvariantsV2(self):
730    self._testDoNotAccumulateInvariants()
731
732  @test_util.enable_control_flow_v2
733  @test_util.run_deprecated_v1
734  @test_util.enable_output_all_intermediates
735  def testPruningNested(self):
736    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
737    x = constant_op.constant(0)
738
739    tensor_list = list_ops.empty_tensor_list(
740        element_dtype=x.dtype, element_shape=x.shape)
741
742    def Cond(x, tl):
743      del tl  # Unused for Cond.
744      return x < 25
745
746    def Body(x, tl):
747
748      def InnerCond(inner_x, unused_outer_x, unused_tl):
749        return inner_x < 5
750
751      def InnerBody(inner_x, outer_x, tl):
752        return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x)
753
754      inner_x = constant_op.constant(0)
755      return control_flow_ops.while_loop(InnerCond, InnerBody,
756                                         [inner_x, x, tl])[1:]
757
758    outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
759
760    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
761    train_op.append(outputs[0])
762
763    g = GetOptimizedGraph()
764    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
765    # away, causing an extra Enter node.
766    # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
767    # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
768    # Test that the TensorList is pruned out.
769    self.assertEmpty([
770        n for n in g.node if n.op == "Enter" and
771        n.attr["T"].type == dtypes.variant.as_datatype_enum
772    ])
773    self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
774    self.assertEmpty([n for n in g.node if n.op == "_While"])
775
776    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
777    train_op.append(stack)
778    g = GetOptimizedGraph()
779    # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
780    # away, causing an extra Enter node.
781    # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
782    # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
783    # Test that the TensorList is not pruned out.
784    self.assertNotEmpty([
785        n for n in g.node if n.op == "Enter" and
786        n.attr["T"].type == dtypes.variant.as_datatype_enum
787    ])
788    self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
789
790  @test_util.enable_control_flow_v2
791  @test_util.run_deprecated_v1
792  @test_util.enable_output_all_intermediates
793  def testPruningNested2(self):
794    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
795    v = constant_op.constant(5.0, name="v")
796
797    p = array_ops.placeholder(dtype=dtypes.int32)
798
799    def MidBodyBuilder(iterations):
800
801      def MidBody(i, x):
802        r = control_flow_ops.while_loop(
803            lambda *_: True,
804            lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")),
805            (0, x),
806            maximum_iterations=iterations,
807            name="inner")
808        return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
809
810      return MidBody
811
812    def OuterBody(i, x):
813      iterations = array_ops.size(p, name="iterations")
814      return (i + 1, x + control_flow_ops.while_loop(
815          lambda *_: True,
816          MidBodyBuilder(iterations), (0, x),
817          maximum_iterations=iterations,
818          name="mid")[1])
819
820    def CreateWhileLoop():
821      with ops.device("/cpu:0"):
822        r = control_flow_ops.while_loop(
823            lambda *_: True,
824            OuterBody, (0, 1.0),
825            maximum_iterations=5,
826            name="outer")
827        return array_ops.identity(r[1])
828
829    output = CreateWhileLoop()
830    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
831    train_op.append(output)
832
833    g = GetOptimizedGraph()
834    self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
835
836  @test_util.enable_control_flow_v2
837  @test_util.run_deprecated_v1
838  @test_util.enable_output_all_intermediates
839  def testPruningNested3(self):
840    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
841    v = constant_op.constant(5.0, name="v")
842
843    def CreateWhileLoop():
844      r = control_flow_ops.while_loop(
845          lambda _: True,
846          lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0],
847          maximum_iterations=5,
848          name="outer")
849      return array_ops.identity(r)
850
851    r = CreateWhileLoop()
852    output = gradients_impl.gradients(r, v)[0]
853    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
854    train_op.append(output)
855
856    g = GetOptimizedGraph()
857    self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
858
859  def _assertNotAccumulated(self, while_op, index):
860    """Asserts that `while_op` input at `index` is not accumulated."""
861    body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
862    placeholder = body_graph.inputs[index]
863    self.assertNotIn("TensorListPushBack",
864                     [op.type for op in placeholder.consumers()])
865
866  @test_util.enable_control_flow_v2
867  @test_util.run_deprecated_v1
868  @test_util.enable_output_all_intermediates
869  def testDoNotOutputLoopCounterAsIntermediate(self):
870    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
871    v = constant_op.constant(5.0, name="v")
872    r = control_flow_ops.while_loop(
873        lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
874    # Skip over Identity.
875    while_op = r.op.inputs[0].op
876    self._assertNotAccumulated(while_op, 0)
877
878  @test_util.enable_control_flow_v2
879  @test_util.run_deprecated_v1
880  @test_util.enable_output_all_intermediates
881  def testDoNotOutputLoopInvariantAsIntermediate(self):
882    assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
883
884    def GetInputIndex(op, tensor):
885      for index, inp in enumerate(op.inputs):
886        if inp is tensor:
887          return index
888
889    v = constant_op.constant(5.0, name="v")
890    r = control_flow_ops.while_loop(
891        lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
892    # Skip over Identity.
893    while_op = r.op.inputs[0].op
894    # We can't directly use while_op.inputs.index() because Tensors are not
895    # hashable.
896    index = GetInputIndex(while_op, v)
897    self._assertNotAccumulated(while_op, index)
898
899  @test_util.run_deprecated_v1
900  def testCaptureExternalTensorInCond(self):
901    x = constant_op.constant(2.)
902    y = constant_op.constant(1.)
903    ret = while_loop_v2(
904        lambda v: v + y < 9.,
905        lambda v: v * 3., [x],
906        return_same_structure=False)
907    grad = gradients_impl.gradients(ret, [x])
908    with self.cached_session():
909      self.assertEqual(self.evaluate(ret), 18.)
910      self.assertSequenceEqual(self.evaluate(grad), [9.])
911
912  @test_util.run_deprecated_v1
913  def testCaptureExternalTensorInBody(self):
914    x = constant_op.constant(2.)
915    y = constant_op.constant(3.)
916    ret = while_loop_v2(
917        lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False)
918    grad = gradients_impl.gradients(ret, [x])
919    with self.cached_session():
920      self.assertEqual(self.evaluate(ret), 18.)
921      self.assertSequenceEqual(self.evaluate(grad), [9.])
922
923  @test_util.run_deprecated_v1
924  def testLoopWithTensorListPushBack(self):
925    x = constant_op.constant(2.)
926
927    tensor_list = list_ops.empty_tensor_list(
928        element_dtype=dtypes.float32, element_shape=ScalarShape())
929
930    def Cond(x, tl):
931      del tl  # Unused for Cond.
932      return x < 5.
933
934    def Body(x, tl):
935      tl = list_ops.tensor_list_push_back(tl, x)
936      tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
937      return x**2., tl
938
939    ret = while_loop_v2(
940        Cond, Body, [x, tensor_list], return_same_structure=False)
941    grad = gradients_impl.gradients(ret[0], x)
942    with self.cached_session() as sess:
943      self.assertEqual(sess.run(ret[0]), 16.)
944      self.assertSequenceEqual(self.evaluate(grad), [32.])
945
946  @test_util.run_deprecated_v1
947  def testDuplicateAccumulator(self):
948    x = constant_op.constant(2.)
949
950    tensor_list = list_ops.empty_tensor_list(
951        element_dtype=dtypes.float32, element_shape=ScalarShape())
952
953    def Cond(x, tl):
954      del tl  # Unused for Cond.
955      return x < 5.
956
957    def Body(x, tl):
958      # There is an accumulator in the loop already so we should not add
959      # another.
960      tl = list_ops.tensor_list_push_back(tl, x)
961      return x**2., tl
962
963    ret = while_loop_v2(
964        Cond, Body, [x, tensor_list], return_same_structure=False)
965
966    for op in ops.get_default_graph().get_operations():
967      if op.type == "While" or op.type == "StatelessWhile":
968        while_op = op
969
970    body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
971    x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0]
972    x_input_t = body_graph.inputs[x_input_index]
973    accumulator_count = len(
974        [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
975    self.assertEqual(accumulator_count, 1)
976
977    grad = gradients_impl.gradients(ret[0], x)
978    with self.cached_session() as sess:
979      self.assertEqual(sess.run(ret[0]), 16.)
980      self.assertSequenceEqual(self.evaluate(grad), [32.])
981
982  @parameterized.named_parameters(
983      ("UnknownShape", None),
984      ("PartiallyDefinedShape", [None, 2]),
985      ("FullyDefinedShape", [1, 2]),
986  )
987  @test_util.run_deprecated_v1
988  def testAccumulatorElementShape(self, shape):
989
990    def MatchShape(actual_tensor_shape):
991      # Compare the shapes, treating None dimensions as equal. We do not
992      # directly check actual_tensor_shape and tf.TensorShape(shape) for
993      # equality because tf.Dimension.__eq__ returns None if either dimension is
994      # None.
995      if shape is None:
996        self.assertIsNone(actual_tensor_shape.dims)
997      else:
998        self.assertListEqual(actual_tensor_shape.as_list(), shape)
999
1000    def GetAccumulatorForInputAtIndex(while_op, idx):
1001      body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
1002      y_input_t = body_graph.inputs[idx]
1003      push_back_node = [c for c in y_input_t.consumers()
1004                        if c.type == "TensorListPushBack"][0]
1005      output_idx = body_graph.outputs.index(push_back_node.outputs[0])
1006      return while_op.outputs[output_idx]
1007
1008    x = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
1009    y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
1010
1011    # Forward pass.
1012    ret = while_loop_v2(lambda v, u: v < 8.,
1013                        lambda v, u: (math_ops.pow(v, u), u),
1014                        [x, y],
1015                        return_same_structure=True)
1016    while_op = ret[0].op.inputs[0].op
1017    # Gradient pass.
1018    grad = gradients_impl.gradients(ret[0], x)
1019    # Note: There is an Identity b/w grad[0] and the While op.
1020    grad_while_op = grad[0].op.inputs[0].op
1021
1022    # Get the TensorList output of While op containing the accumulated values
1023    # of y.
1024    x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0]
1025    output = GetAccumulatorForInputAtIndex(while_op, x_input_index)
1026    _, val = list_ops.tensor_list_pop_back(output,
1027                                           element_dtype=dtypes.float32)
1028    MatchShape(val.shape)
1029
1030    # Take second derivative to generate intermediate grad_while_op outputs
1031    gradients_impl.gradients(grad, x)
1032
1033    # Get the TensorList output of gradient While op containing the accumulated
1034    # values of grad_x (note that grad_x is needed by the second derivative).
1035    # grad_while_op.inputs:
1036    grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0])
1037    grad_output = GetAccumulatorForInputAtIndex(grad_while_op,
1038                                                grad_output_index)
1039    _, val = list_ops.tensor_list_pop_back(grad_output,
1040                                           element_dtype=dtypes.float32)
1041    MatchShape(val.shape)
1042
1043  def _createWhile(self, name):
1044    """Helper function testDefaultName."""
1045    output = while_v2.while_loop(
1046        lambda i: i < 3,
1047        lambda i: i + 1, [constant_op.constant(0)],
1048        return_same_structure=False)
1049    while_op = output.op.inputs[0].op
1050    self.assertEqual(while_op.type, "StatelessWhile")
1051    return while_op
1052
1053  def testDefaultName(self):
1054    with ops.Graph().as_default():
1055      while_op = self._createWhile(None)
1056      self.assertEqual(while_op.name, "while")
1057      self.assertRegex(while_op.get_attr("cond").name, r"while_cond_\d*")
1058      self.assertRegex(while_op.get_attr("body").name, r"while_body_\d*")
1059
1060    with ops.Graph().as_default():
1061      with ops.name_scope("foo"):
1062        while1_op = self._createWhile("")
1063        self.assertEqual(while1_op.name, "foo/while")
1064        self.assertRegex(while1_op.get_attr("cond").name, r"foo_while_cond_\d*")
1065        self.assertRegex(while1_op.get_attr("body").name, r"foo_while_body_\d*")
1066
1067        while2_op = self._createWhile(None)
1068        self.assertEqual(while2_op.name, "foo/while_1")
1069        self.assertRegex(
1070            while2_op.get_attr("cond").name, r"foo_while_1_cond_\d*")
1071        self.assertRegex(
1072            while2_op.get_attr("body").name, r"foo_while_1_body_\d*")
1073
1074  @test_util.enable_control_flow_v2
1075  @test_util.run_deprecated_v1
1076  def testWhileAndTensorArray(self):
1077    param = constant_op.constant(2.0)
1078    y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
1079    # map_fn uses TensorArray internally.
1080    r = map_fn.map_fn(lambda x: math_ops.multiply(x, param), y0)
1081    grad = gradients_impl.gradients(r, param)[0]
1082    self.assertAllClose([2.0, 4.0, 6.0, 8.0, 10.0, 12.0], self.evaluate(r))
1083    self.assertAllClose(21.0, self.evaluate(grad))
1084
1085  @test_util.run_deprecated_v1
1086  def testNestedWhile(self):
1087    # Compute sum of geometric progression: n^0 + n^1 + ... + n^m
1088    # We compute the pow using a while loop.
1089    n = constant_op.constant(3.)
1090    m = constant_op.constant(5.)
1091    sum_of_powers = constant_op.constant(0.)
1092
1093    def Body(i, previous_sum):
1094      prod = constant_op.constant(1.)
1095      return i - 1., previous_sum + while_loop_v2(
1096          lambda c, _: c > 0,
1097          lambda c, v: (c - 1., v * n), [i, prod],
1098          return_same_structure=False)[1]
1099
1100    result = while_loop_v2(
1101        lambda i, _: i >= 0,
1102        Body, [m, sum_of_powers],
1103        return_same_structure=False)[1]
1104    grad = gradients_impl.gradients(result, [n])
1105    self.assertEqual(self.evaluate(result), 364.)
1106    self.assertSequenceEqual(self.evaluate(grad), [547.])
1107
1108  @test_util.run_deprecated_v1
1109  def testNestedWhileWithLegacyDefun(self):
1110    n = constant_op.constant(3.)
1111    m = constant_op.constant(5.)
1112    sum_of_powers = constant_op.constant(0.)
1113
1114    def Body(i, previous_sum):
1115      prod = constant_op.constant(1.)
1116
1117      def InnerBodyWrapper(c, v):
1118
1119        @function.Defun(dtypes.float32, dtypes.float32)
1120        def InnerBody(c, v):
1121          return c - 1., v * n
1122
1123        results = InnerBody(c, v)
1124        results[0].set_shape([])
1125        results[1].set_shape([])
1126        return results
1127
1128      return i - 1., previous_sum + while_loop_v2(
1129          lambda c, _: c > 0,
1130          InnerBodyWrapper, [i, prod],
1131          return_same_structure=False)[1]
1132
1133    result = while_loop_v2(
1134        lambda i, _: i >= 0,
1135        Body, [m, sum_of_powers],
1136        return_same_structure=False)[1]
1137    grad = gradients_impl.gradients(result, [n])
1138    self.assertEqual(self.evaluate(result), 364.)
1139    self.assertSequenceEqual(self.evaluate(grad), [547.])
1140
1141  @test_util.run_deprecated_v1
1142  def testIdentityNodeInBody(self):
1143
1144    def Body(v):
1145      v = array_ops.identity(v)
1146      v = array_ops.identity(v)
1147      return v * v
1148
1149    x = constant_op.constant(2.)
1150    ret = while_loop_v2(
1151        lambda v: v < 8., Body, [x], return_same_structure=False)
1152    grad = gradients_impl.gradients(ret, [x])
1153    self.assertEqual(self.evaluate(ret), 16.)
1154    self.assertSequenceEqual(self.evaluate(grad), [32.])
1155
1156  @test_util.run_deprecated_v1
1157  def testForwardPassRewrite(self):
1158    x = constant_op.constant(1.0, name="x")
1159    output = while_v2.while_loop(lambda x: x < 10.0,
1160                                 lambda x: x * 2.0,
1161                                 [x])[0]
1162    while_op = output.op.inputs[0].op
1163    self.assertEqual(while_op.type, "StatelessWhile")
1164    # outputs = [loop_counter, max_iters, x]
1165    self.assertLen(while_op.outputs, 3)
1166
1167    gradients_impl.gradients(output, x)
1168    # while_op should have been rewritten to output intermediates.
1169    # outputs = [loop_counter, max_iters, x, x_accumulator]
1170    self.assertLen(while_op.outputs, 4)
1171
1172    gradients_impl.gradients(output, x)
1173    # Computing the gradient again shouldn't rewrite while_op again.
1174    self.assertLen(while_op.outputs, 4)
1175
1176  @parameterized.named_parameters(
1177      ("RandomUniform", random_ops.random_uniform, [5, 3]),
1178      ("RandomNormal", random_ops.random_normal, [5, 3]),
1179      ("ParameterizedTruncatedNormal",
1180       random_ops.parameterized_truncated_normal, [5, 3]),
1181      ("TruncatedNormal", random_ops.truncated_normal, [5, 3]),
1182      ("RandomGamma", random_gamma, [5, 3]),
1183      ("RandomPoissonV2", random_poisson_v2, [5, 3]),
1184      ("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]),
1185      ("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]),
1186  )
1187  @test_util.run_deprecated_v1
1188  def testRandomOpsShape(self, random_fn, expected_shape):
1189    shape = constant_op.constant([3])
1190
1191    def Body(i, u):
1192      shape_extended = array_ops.concat([[5], shape], axis=0)
1193      u = random_fn(shape_extended)
1194      assert u.shape.as_list() == expected_shape, str(u.shape.as_list())
1195      return i + 1, u
1196
1197    _, _ = while_loop_v2(
1198        cond=lambda i, _: i < 3,
1199        body=Body,
1200        loop_vars=[
1201            0,
1202            array_ops.zeros(expected_shape, dtype=dtypes.float32),
1203        ])
1204
1205  @test_util.run_deprecated_v1
1206  def testReshapeShape(self):
1207    shape = constant_op.constant([3, 4])
1208
1209    def Body(i, u):
1210      shape_extended = array_ops.concat([[5], shape], axis=0)
1211      u = array_ops.reshape(u, [-1])
1212      assert u.shape.as_list() == [60], str(u.shape.as_list())
1213      u = array_ops.reshape(u, shape_extended)
1214      assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
1215      return i + 1, u
1216
1217    _, _ = while_loop_v2(
1218        cond=lambda i, _: i < 3,
1219        body=Body,
1220        loop_vars=[
1221            0,
1222            array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
1223        ])
1224
1225  @parameterized.named_parameters(
1226      ("Zeros", array_ops.zeros),
1227      ("Ones", array_ops.ones),
1228      ("Fill", fill),
1229  )
1230  @test_util.run_deprecated_v1
1231  def testFillOpsShape(self, fill_fn):
1232    shape = constant_op.constant([3, 4])
1233
1234    def Body(i, u):
1235      shape_extended = array_ops.concat([[5], shape], axis=0)
1236      u = fill_fn(shape_extended)
1237      assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
1238      return i + 1, u
1239
1240    _, _ = while_loop_v2(
1241        cond=lambda i, _: i < 3,
1242        body=Body,
1243        loop_vars=[
1244            0,
1245            array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
1246        ])
1247
1248  @test_util.run_deprecated_v1
1249  def testExternalColocationGrad(self):
1250    external_t = constant_op.constant(2.)
1251    v0 = constant_op.constant(2.)
1252
1253    def Body(v):
1254      with ops.colocate_with(external_t):
1255        return v * v
1256
1257    ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
1258    grad = gradients_impl.gradients(ret, [v0])[0]
1259    self.assertAllEqual(ret, 16.)
1260    self.assertAllEqual(grad, 32.)
1261
1262  @test_util.run_deprecated_v1
1263  def testDoNotAccumulateConstNodes(self):
1264
1265    def Body(v):
1266      return v * 2.0
1267
1268    v0 = constant_op.constant(2.)
1269    ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
1270    # Gradients computation has the side-effect of updating the forward op
1271    # which is what we want to test.
1272    unused_grad = gradients_impl.gradients(ret, [v0])[0]
1273    # ret is separated from the `While` op by an `Identity` so we skip over
1274    # that.
1275    forward_while_op = ret.op.inputs[0].op
1276    body_graph = while_v2._get_graph(forward_while_op, "body", "_body_graph")
1277    push_back_nodes = [
1278        o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
1279    ]
1280    # Gradient of `Mul` requires accumulating both its inputs. But since one
1281    # of those is a Const (2.0), we should have just one accumulator.
1282    self.assertLen(push_back_nodes, 1)
1283
1284  def testDoNotAccumulateForwardTensorsForReductionOps(self):
1285
1286    @def_function.function
1287    def Fn():
1288      with backprop.GradientTape() as tape:
1289        x = constant_op.constant(2.)
1290        tape.watch(x)
1291
1292        def Body(i, x):
1293          forward_graph = ops.get_default_graph()
1294
1295          @custom_gradient.custom_gradient
1296          def SquaredWithZeroGrad(x):
1297
1298            def Grad(unused_g, variables=None):  # pylint: disable=redefined-outer-name
1299              del variables
1300              gradient_graph = ops.get_default_graph()
1301              shape = gen_array_ops.shape(x)
1302              assert shape.graph is forward_graph
1303              rank = gen_array_ops.rank(x)
1304              assert rank.graph is forward_graph
1305              size = gen_array_ops.size(x)
1306              assert size.graph is forward_graph
1307              zeros = array_ops.zeros(shape)
1308              assert zeros.graph is gradient_graph
1309              return zeros
1310
1311            return x * 2, Grad
1312
1313          return i + 1, SquaredWithZeroGrad(x)
1314
1315        _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
1316      grad = tape.gradient(result, x)
1317      return grad
1318
1319    Fn()
1320
1321  def testDoNotAccumulateForwardTensorsForTensorListReductionOps(self):
1322
1323    @def_function.function
1324    def Fn():
1325      with backprop.GradientTape() as tape:
1326        e = constant_op.constant(2.)
1327        x = list_ops.empty_tensor_list(
1328            element_dtype=dtypes.float32, element_shape=e.shape)
1329        x = list_ops.tensor_list_push_back(x, e)
1330        tape.watch(x)
1331
1332        def Body(i, x):
1333          forward_graph = ops.get_default_graph()
1334
1335          @custom_gradient.custom_gradient
1336          def IdentityWithZeroGrad(x):
1337
1338            def Grad(unused_g, variables=None):  # pylint: disable=redefined-outer-name
1339              del variables
1340              gradient_graph = ops.get_default_graph()
1341              shape = gen_list_ops.tensor_list_element_shape(
1342                  x, shape_type=dtypes.int32)
1343              assert shape.graph is forward_graph
1344              size = gen_list_ops.tensor_list_length(x)
1345              assert size.graph is forward_graph
1346              zeros = gen_list_ops.tensor_list_reserve(shape, size,
1347                                                       dtypes.float32)
1348              assert zeros.graph is gradient_graph
1349              return zeros
1350
1351            return x, Grad
1352
1353          return i + 1, IdentityWithZeroGrad(x)
1354
1355        _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
1356      ones_like = list_ops.tensor_list_from_tensor(
1357          array_ops.ones_like(
1358              list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)),
1359          element_shape=tensor_shape.TensorShape([]))
1360      grad = tape.gradient(result, x, output_gradients=[ones_like])
1361      return grad
1362
1363    Fn()
1364
1365  @test_util.run_v2_only
1366  def testInheritParentNameScope(self):
1367
1368    @def_function.function
1369    def F():
1370      with ops.name_scope("foo"):
1371
1372        def Cond(unused_i):
1373          with ops.name_scope("cond"):
1374            actual_name_scope = ops.get_name_scope()
1375            expected_name_scope = "foo/while/cond"
1376            assert actual_name_scope == expected_name_scope, (
1377                "%s does not match %s" %
1378                (actual_name_scope, expected_name_scope))
1379          return False
1380
1381        def Body(i):
1382          with ops.name_scope("body"):
1383            actual_name_scope = ops.get_name_scope()
1384            expected_name_scope = "foo/while/body"
1385            assert actual_name_scope == expected_name_scope, (
1386                "%s does not match %s" %
1387                (actual_name_scope, expected_name_scope))
1388          return i
1389
1390        return while_v2.while_loop(Cond, Body, [0.])
1391
1392    F()
1393
1394  @test_util.run_deprecated_v1  # Need to pass RunMetadata.
1395  def testDisableLowering(self):
1396    old = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE
1397    control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = True
1398    with self.session() as sess:
1399      x = constant_op.constant(2.)
1400      ret = while_loop_v2(
1401          lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
1402
1403      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1404      run_metadata = config_pb2.RunMetadata()
1405      self.assertEqual(sess.run(ret, options=opts, run_metadata=run_metadata),
1406                       16)
1407      for dev_stat in run_metadata.step_stats.dev_stats:
1408        for ns in dev_stat.node_stats:
1409          self.assertNotIn("switch", ns.node_name)
1410    control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old
1411
1412  def _runBasicWithConfig(self, config):
1413    with ops.device("/cpu:0"):
1414      x = constant_op.constant(0)
1415      ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x])
1416    with self.cached_session(config=config):
1417      self.assertEqual(1000, self.evaluate(ret))
1418
1419  @test_util.run_deprecated_v1
1420  def testRunKernelsInline(self):
1421    config = config_pb2.ConfigProto()
1422    config.inter_op_parallelism_threads = -1
1423    self._runBasicWithConfig(config)
1424
1425  @test_util.run_deprecated_v1
1426  def testSingleThreadedExecution(self):
1427    config = config_pb2.ConfigProto()
1428    config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
1429    self._runBasicWithConfig(config)
1430
1431  def testIsControlFlowGraph(self):
1432    x = constant_op.constant(0)
1433
1434    @def_function.function
1435    def F(c):
1436
1437      def Cond(i):
1438        self.assertTrue(i.graph.is_control_flow_graph)
1439        return i < 2
1440
1441      def Body(i):
1442        i = i + 1
1443        self.assertTrue(i.graph.is_control_flow_graph)
1444        return i
1445
1446      return while_loop_v2(Cond, Body, [c])
1447
1448    ret, = F(x)
1449    self.assertEqual(2, self.evaluate(ret))
1450
1451  def testImportFromSerializedWithFunctionInBody(self):
1452    serialized = """node {
1453      name: "Const"
1454      op: "Const"
1455      attr {
1456        key: "dtype"
1457        value {
1458          type: DT_FLOAT
1459        }
1460      }
1461      attr {
1462        key: "value"
1463        value {
1464          tensor {
1465            dtype: DT_FLOAT
1466            tensor_shape {
1467            }
1468            float_val: 1.0
1469          }
1470        }
1471      }
1472    }
1473    node {
1474      name: "while/maximum_iterations"
1475      op: "Const"
1476      attr {
1477        key: "dtype"
1478        value {
1479          type: DT_INT32
1480        }
1481      }
1482      attr {
1483        key: "value"
1484        value {
1485          tensor {
1486            dtype: DT_INT32
1487            tensor_shape {
1488            }
1489            int_val: -1
1490          }
1491        }
1492      }
1493    }
1494    node {
1495      name: "while/loop_counter"
1496      op: "Const"
1497      attr {
1498        key: "dtype"
1499        value {
1500          type: DT_INT32
1501        }
1502      }
1503      attr {
1504        key: "value"
1505        value {
1506          tensor {
1507            dtype: DT_INT32
1508            tensor_shape {
1509            }
1510            int_val: 0
1511          }
1512        }
1513      }
1514    }
1515    node {
1516      name: "while"
1517      op: "StatelessWhile"
1518      input: "while/loop_counter"
1519      input: "while/maximum_iterations"
1520      input: "Const"
1521      attr {
1522        key: "T"
1523        value {
1524          list {
1525            type: DT_INT32
1526            type: DT_INT32
1527            type: DT_FLOAT
1528          }
1529        }
1530      }
1531      attr {
1532        key: "_lower_using_switch_merge"
1533        value {
1534          b: true
1535        }
1536      }
1537      attr {
1538        key: "_num_original_outputs"
1539        value {
1540          i: 3
1541        }
1542      }
1543      attr {
1544        key: "_read_only_resource_inputs"
1545        value {
1546          list {
1547          }
1548        }
1549      }
1550      attr {
1551        key: "body"
1552        value {
1553          func {
1554            name: "while_body_822"
1555          }
1556        }
1557      }
1558      attr {
1559        key: "cond"
1560        value {
1561          func {
1562            name: "while_cond_821"
1563          }
1564        }
1565      }
1566      attr {
1567        key: "output_shapes"
1568        value {
1569          list {
1570            shape {
1571            }
1572            shape {
1573            }
1574            shape {
1575            }
1576          }
1577        }
1578      }
1579      attr {
1580        key: "parallel_iterations"
1581        value {
1582          i: 10
1583        }
1584      }
1585    }
1586    node {
1587      name: "while/Identity"
1588      op: "Identity"
1589      input: "while"
1590      attr {
1591        key: "T"
1592        value {
1593          type: DT_INT32
1594        }
1595      }
1596    }
1597    node {
1598      name: "while/Identity_1"
1599      op: "Identity"
1600      input: "while:1"
1601      attr {
1602        key: "T"
1603        value {
1604          type: DT_INT32
1605        }
1606      }
1607    }
1608    node {
1609      name: "while/Identity_2"
1610      op: "Identity"
1611      input: "while:2"
1612      attr {
1613        key: "T"
1614        value {
1615          type: DT_FLOAT
1616        }
1617      }
1618    }
1619    library {
1620      function {
1621        signature {
1622          name: "while_body_822"
1623          input_arg {
1624            name: "while_loop_counter"
1625            type: DT_INT32
1626          }
1627          input_arg {
1628            name: "while_maximum_iterations_0"
1629            type: DT_INT32
1630          }
1631          input_arg {
1632            name: "placeholder"
1633            type: DT_FLOAT
1634          }
1635          output_arg {
1636            name: "add"
1637            type: DT_INT32
1638          }
1639          output_arg {
1640            name: "while_maximum_iterations"
1641            type: DT_INT32
1642          }
1643          output_arg {
1644            name: "partitionedcall"
1645            type: DT_FLOAT
1646          }
1647        }
1648        node_def {
1649          name: "PartitionedCall"
1650          op: "PartitionedCall"
1651          input: "placeholder"
1652          attr {
1653            key: "Tin"
1654            value {
1655              list {
1656                type: DT_FLOAT
1657              }
1658            }
1659          }
1660          attr {
1661            key: "Tout"
1662            value {
1663              list {
1664                type: DT_FLOAT
1665              }
1666            }
1667          }
1668          attr {
1669            key: "_collective_manager_ids"
1670            value {
1671              list {
1672              }
1673            }
1674          }
1675          attr {
1676            key: "_read_only_resource_inputs"
1677            value {
1678              list {
1679              }
1680            }
1681          }
1682          attr {
1683            key: "config"
1684            value {
1685              s: ""
1686            }
1687          }
1688          attr {
1689            key: "config_proto"
1690            value {
1691              s: ""
1692            }
1693          }
1694          attr {
1695            key: "executor_type"
1696            value {
1697              s: ""
1698            }
1699          }
1700          attr {
1701            key: "f"
1702            value {
1703              func {
1704                name: "__inference_f_841"
1705              }
1706            }
1707          }
1708          experimental_debug_info {
1709            original_node_names: "PartitionedCall"
1710          }
1711        }
1712        node_def {
1713          name: "add/y"
1714          op: "Const"
1715          attr {
1716            key: "dtype"
1717            value {
1718              type: DT_INT32
1719            }
1720          }
1721          attr {
1722            key: "value"
1723            value {
1724              tensor {
1725                dtype: DT_INT32
1726                tensor_shape {
1727                }
1728                int_val: 1
1729              }
1730            }
1731          }
1732          experimental_debug_info {
1733            original_node_names: "add/y"
1734          }
1735        }
1736        node_def {
1737          name: "add_0"
1738          op: "AddV2"
1739          input: "while_loop_counter"
1740          input: "add/y:output:0"
1741          attr {
1742            key: "T"
1743            value {
1744              type: DT_INT32
1745            }
1746          }
1747          experimental_debug_info {
1748            original_node_names: "add"
1749          }
1750        }
1751        ret {
1752          key: "add"
1753          value: "add_0:z:0"
1754        }
1755        ret {
1756          key: "partitionedcall"
1757          value: "PartitionedCall:output:0"
1758        }
1759        ret {
1760          key: "while_maximum_iterations"
1761          value: "while_maximum_iterations_0"
1762        }
1763        arg_attr {
1764          key: 0
1765          value {
1766            attr {
1767              key: "_output_shapes"
1768              value {
1769                list {
1770                  shape {
1771                  }
1772                }
1773              }
1774            }
1775          }
1776        }
1777        arg_attr {
1778          key: 1
1779          value {
1780            attr {
1781              key: "_output_shapes"
1782              value {
1783                list {
1784                  shape {
1785                  }
1786                }
1787              }
1788            }
1789          }
1790        }
1791        arg_attr {
1792          key: 2
1793          value {
1794            attr {
1795              key: "_output_shapes"
1796              value {
1797                list {
1798                  shape {
1799                  }
1800                }
1801              }
1802            }
1803          }
1804        }
1805      }
1806      function {
1807        signature {
1808          name: "while_cond_821"
1809          input_arg {
1810            name: "while_loop_counter"
1811            type: DT_INT32
1812          }
1813          input_arg {
1814            name: "while_maximum_iterations"
1815            type: DT_INT32
1816          }
1817          input_arg {
1818            name: "placeholder"
1819            type: DT_FLOAT
1820          }
1821          output_arg {
1822            name: "less"
1823            type: DT_BOOL
1824          }
1825        }
1826        node_def {
1827          name: "Less/y"
1828          op: "Const"
1829          attr {
1830            key: "dtype"
1831            value {
1832              type: DT_FLOAT
1833            }
1834          }
1835          attr {
1836            key: "value"
1837            value {
1838              tensor {
1839                dtype: DT_FLOAT
1840                tensor_shape {
1841                }
1842                float_val: 5.0
1843              }
1844            }
1845          }
1846          experimental_debug_info {
1847            original_node_names: "Less/y"
1848          }
1849        }
1850        node_def {
1851          name: "Less"
1852          op: "Less"
1853          input: "placeholder"
1854          input: "Less/y:output:0"
1855          attr {
1856            key: "T"
1857            value {
1858              type: DT_FLOAT
1859            }
1860          }
1861          experimental_debug_info {
1862            original_node_names: "Less"
1863          }
1864        }
1865        ret {
1866          key: "less"
1867          value: "Less:z:0"
1868        }
1869        arg_attr {
1870          key: 0
1871          value {
1872            attr {
1873              key: "_output_shapes"
1874              value {
1875                list {
1876                  shape {
1877                  }
1878                }
1879              }
1880            }
1881          }
1882        }
1883        arg_attr {
1884          key: 1
1885          value {
1886            attr {
1887              key: "_output_shapes"
1888              value {
1889                list {
1890                  shape {
1891                  }
1892                }
1893              }
1894            }
1895          }
1896        }
1897        arg_attr {
1898          key: 2
1899          value {
1900            attr {
1901              key: "_output_shapes"
1902              value {
1903                list {
1904                  shape {
1905                  }
1906                }
1907              }
1908            }
1909          }
1910        }
1911      }
1912      function {
1913        signature {
1914          name: "__inference_f_841"
1915          input_arg {
1916            name: "mul_placeholder"
1917            type: DT_FLOAT
1918          }
1919          output_arg {
1920            name: "identity"
1921            type: DT_FLOAT
1922          }
1923        }
1924        node_def {
1925          name: "mul/y"
1926          op: "Const"
1927          attr {
1928            key: "dtype"
1929            value {
1930              type: DT_FLOAT
1931            }
1932          }
1933          attr {
1934            key: "value"
1935            value {
1936              tensor {
1937                dtype: DT_FLOAT
1938                tensor_shape {
1939                }
1940                float_val: 2.0
1941              }
1942            }
1943          }
1944          experimental_debug_info {
1945            original_node_names: "mul/y"
1946          }
1947        }
1948        node_def {
1949          name: "mul"
1950          op: "Mul"
1951          input: "mul_placeholder"
1952          input: "mul/y:output:0"
1953          attr {
1954            key: "T"
1955            value {
1956              type: DT_FLOAT
1957            }
1958          }
1959          experimental_debug_info {
1960            original_node_names: "mul"
1961          }
1962        }
1963        node_def {
1964          name: "Identity"
1965          op: "Identity"
1966          input: "mul:z:0"
1967          attr {
1968            key: "T"
1969            value {
1970              type: DT_FLOAT
1971            }
1972          }
1973          experimental_debug_info {
1974            original_node_names: "Identity"
1975          }
1976        }
1977        ret {
1978          key: "identity"
1979          value: "Identity:output:0"
1980        }
1981        arg_attr {
1982          key: 0
1983          value {
1984            attr {
1985              key: "_output_shapes"
1986              value {
1987                list {
1988                  shape {
1989                  }
1990                }
1991              }
1992            }
1993          }
1994        }
1995      }
1996    }
1997    versions {
1998      producer: 399
1999      min_consumer: 12
2000    }
2001    """
2002    # Code for generating above graph:
2003    #
2004    # def Body(i):
2005    #   @tf.function
2006    #   def f():
2007    #     return i * 2
2008    #   return f()
2009    # tf.while_loop(lambda i: i < 5., Body, [tf.constant(1.)])
2010    graph_def = graph_pb2.GraphDef()
2011    text_format.Parse(serialized, graph_def)
2012    @def_function.function
2013    def F():
2014      x, y = importer.import_graph_def(
2015          graph_def, return_elements=["Const:0", "while:2"])
2016      grad_out, = gradients_impl.gradients(y, x)
2017      return grad_out
2018    self.assertAllEqual(F(), 8.0)
2019
2020  def testIndexedSlicesInIncomingGrads(self):
2021    @def_function.function
2022    def F():
2023      x = constant_op.constant([2.])
2024      # Computes x^4
2025      ret = while_loop_v2(
2026          lambda _: True, lambda v: v * v, [x], return_same_structure=False,
2027          maximum_iterations=2)
2028      v = array_ops.gather(ret, [0])
2029      return gradients_impl.gradients(v, [x])[0]  # 4*x^3
2030    self.assertAllEqual(self.evaluate(F()), [32.])
2031
2032
2033def ScalarShape():
2034  return ops.convert_to_tensor([], dtype=dtypes.int32)
2035
2036
2037def GetOptimizedGraph():
2038  mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
2039  config = config_pb2.ConfigProto()
2040  config.graph_options.rewrite_options.CopyFrom(
2041      rewriter_config_pb2.RewriterConfig(
2042          constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
2043          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
2044  return tf_optimizer.OptimizeGraph(config, mg)
2045
2046
2047if __name__ == "__main__":
2048  test.main()
2049