• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for pfor and for_loop."""
16# pylint: disable=g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import sys
24import time
25
26from absl.testing import parameterized
27import numpy as np
28
29from google.protobuf import text_format
30
31from tensorflow.core.example import example_pb2
32from tensorflow.core.example import feature_pb2
33from tensorflow.core.framework import graph_pb2
34from tensorflow.python.client import session
35from tensorflow.python.eager import backprop
36from tensorflow.python.eager import context
37from tensorflow.python.eager import def_function
38from tensorflow.python.framework import composite_tensor
39from tensorflow.python.framework import config
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import importer
43from tensorflow.python.framework import indexed_slices
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import sparse_tensor
46from tensorflow.python.framework import tensor_shape
47from tensorflow.python.framework import tensor_spec
48from tensorflow.python.framework import test_util
49from tensorflow.python.framework import type_spec
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import bitwise_ops
52from tensorflow.python.ops import cond_v2
53from tensorflow.python.ops import control_flow_ops
54from tensorflow.python.ops import control_flow_v2_toggles
55from tensorflow.python.ops import data_flow_ops
56from tensorflow.python.ops import functional_ops
57from tensorflow.python.ops import gen_dataset_ops
58from tensorflow.python.ops import gen_list_ops
59from tensorflow.python.ops import gen_nn_ops
60from tensorflow.python.ops import gradient_checker_v2
61from tensorflow.python.ops import gradients as gradient_ops
62from tensorflow.python.ops import image_ops
63from tensorflow.python.ops import list_ops
64from tensorflow.python.ops import logging_ops
65from tensorflow.python.ops import manip_ops
66from tensorflow.python.ops import map_fn
67from tensorflow.python.ops import math_ops
68from tensorflow.python.ops import nn
69from tensorflow.python.ops import parsing_ops
70from tensorflow.python.ops import random_ops
71from tensorflow.python.ops import resource_variable_ops
72from tensorflow.python.ops import rnn
73from tensorflow.python.ops import rnn_cell
74from tensorflow.python.ops import stateless_random_ops
75from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
76from tensorflow.python.ops import tensor_array_ops
77from tensorflow.python.ops import variables
78from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
79from tensorflow.python.ops.parallel_for.test_util import PForTestCase
80from tensorflow.python.ops.ragged import ragged_tensor
81from tensorflow.python.ops.signal import fft_ops
82from tensorflow.python.platform import test
83from tensorflow.python.util import nest
84
85
86@test_util.run_all_in_graph_and_eager_modes
87@test_util.with_control_flow_v2
88class PForTest(PForTestCase):
89
90  def test_op_conversion_fallback_to_while_loop(self):
91    # Note that we used top_k op for this test. If a converter gets defined for
92    # it, we will need to find another op for which a converter has not been
93    # defined.
94    x = random_ops.random_uniform([3, 2, 4])
95
96    def loop_fn(i):
97      x_i = array_ops.gather(x, i)
98      return nn.top_k(x_i)
99
100    with self.assertRaisesRegex(ValueError, "No pfor vectorization"):
101      self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
102    self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
103
104  def test_parallel_iterations(self):
105    for parallel_iterations in [2, 3, 8, 10]:
106      x = random_ops.random_uniform([8, 3])
107
108      # pylint: disable=cell-var-from-loop
109      def loop_fn(i):
110        return array_ops.gather(x, i)
111
112      # pylint: enable=cell-var-from-loop
113
114      self._test_loop_fn(loop_fn, 8, parallel_iterations=parallel_iterations)
115      self._test_loop_fn(
116          loop_fn,
117          4 * constant_op.constant(2),
118          parallel_iterations=parallel_iterations)
119
120  def test_parallel_iterations_preserves_static_shape(self):
121    for parallel_iterations in [2, 3, 8, 10]:
122      x = pfor_control_flow_ops.pfor(
123          lambda _: random_ops.random_uniform([2, 3]),
124          8,
125          parallel_iterations=parallel_iterations)
126      self.assertAllEqual(x.shape, [8, 2, 3])
127
128  def test_parallel_iterations_zero(self):
129    with self.assertRaisesRegex(ValueError, "positive integer"):
130      pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=0)
131    with self.assertRaisesRegex(TypeError, "positive integer"):
132      pfor_control_flow_ops.for_loop(
133          lambda i: 1, dtypes.int32, 8, parallel_iterations=0)
134
135  def test_parallel_iterations_one(self):
136    with self.assertRaisesRegex(ValueError, "Use for_loop instead"):
137      pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1)
138
139  def test_vectorized_map(self):
140
141    def compute(x):
142      return math_ops.reduce_mean(x, axis=0, keepdims=True)
143
144    result = pfor_control_flow_ops.vectorized_map(compute,
145                                                  array_ops.ones((10, 5, 3)))
146    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
147
148  def test_vectorized_map_with_dynamic_shape(self):
149
150    def compute(x):
151      return math_ops.reduce_mean(x, axis=0, keepdims=True)
152
153    x = array_ops.placeholder_with_default(
154        array_ops.ones((10, 5, 3)), shape=None)
155    result = pfor_control_flow_ops.vectorized_map(compute, x)
156    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
157
158  def test_where_shape(self):
159    @def_function.function
160    def f():
161      a = constant_op.constant([[1.], [1.]])
162      b = constant_op.constant([1.])
163      result = pfor_control_flow_ops.vectorized_map(
164          lambda x: array_ops.where(x > 0, x, b), a)
165      return result.shape
166
167    self.assertAllEqual([2, 1], f())
168
169  def test_vectorized_map_broadcasts_unit_dimensions(self):
170    convert_with_static_shape = ops.convert_to_tensor
171    convert_with_dynamic_shape = (
172        lambda x: array_ops.placeholder_with_default(x, shape=None))
173
174    for convert in (convert_with_static_shape, convert_with_dynamic_shape):
175      a = convert([3.1])
176      b = convert([-2., 6., 9.])
177
178      # One elem with leading unit dimension.
179      a_plus_1 = pfor_control_flow_ops.vectorized_map(lambda a: a + 1, a)
180      self.assertAllEqual(*self.evaluate((a_plus_1, a + 1)))
181
182      # Two elems, both with leading unit dimension.
183      a_plus_a = pfor_control_flow_ops.vectorized_map(sum, (a, a))
184      self.assertAllEqual(*self.evaluate((a_plus_a, a + a)))
185
186      # Elem w/ unit dimension broadcast against elem with batch dim.
187      a_plus_b = pfor_control_flow_ops.vectorized_map(sum, (a, b))
188      self.assertAllEqual(*self.evaluate((a_plus_b, a + b)))
189
190  def test_vectorized_map_example_1(self):
191
192    def outer_product(a):
193      return math_ops.tensordot(a, a, 0)
194
195    batch_size = 100
196    a = array_ops.ones((batch_size, 32, 32))
197    c = pfor_control_flow_ops.vectorized_map(outer_product, a)
198    self.assertAllEqual((batch_size, 32, 32, 32, 32), c.shape)
199
200  def test_disable_tf_function(self):
201    def_function.run_functions_eagerly(True)
202    # vectorized_map should ignore disabling tf.functions
203    self.assertTrue(def_function.functions_run_eagerly())
204    self.assertAllEqual([0, 1, 4, 9],
205                        pfor_control_flow_ops.vectorized_map(
206                            lambda x: x * x, math_ops.range(4)))
207    self.assertTrue(def_function.functions_run_eagerly())
208    def_function.run_functions_eagerly(False)
209
210
211@test_util.run_all_in_graph_and_eager_modes
212class IndexedSlicesTest(PForTestCase):
213
214  def test_indexed_slices(self):
215
216    def loop_fn(i):
217      return indexed_slices.IndexedSlices(
218          indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1])
219
220    self._test_loop_fn(loop_fn, 2)
221
222  def test_indexed_slices_components(self):
223
224    def loop_fn(i):
225      slices = indexed_slices.IndexedSlices(
226          indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1])
227      # Note that returning the components inside the slice avoids
228      # densification, which may be more efficient.
229      return slices.values, slices.indices
230
231    self._test_loop_fn(loop_fn, 2)
232
233
234@test_util.run_all_in_graph_and_eager_modes
235class ReductionTest(PForTestCase):
236
237  def test_reduce(self):
238
239    def reduce_fn(p, q):
240      return math_ops.reduce_mean(p + q, axis=0)
241
242    x = random_ops.random_uniform([4, 3, 2])
243    y = random_ops.random_uniform([4, 3, 2])
244
245    def loop_fn(i, pfor_config):
246      x_i = array_ops.gather(x, i)
247      y_i = array_ops.gather(y, i)
248      reduced = pfor_config.reduce(reduce_fn, x_i, y_i)
249      return reduced + x_i
250
251    output = pfor_control_flow_ops.pfor(loop_fn, 4)
252    ans = reduce_fn(x, y) + x
253    output_val, ans_val = self.evaluate([output, ans])
254    self.assertAllClose(ans_val, output_val)
255
256  def test_reduce_concat(self):
257    x = random_ops.random_uniform([8, 3])
258
259    def loop_fn(i, pfor_config):
260      x_i = array_ops.gather(x, i)
261      vectorized_value = pfor_config.reduce_concat(x_i)
262      mean_value = math_ops.reduce_mean(vectorized_value, axis=0)
263      return x_i - mean_value
264
265    output = pfor_control_flow_ops.pfor(loop_fn, 8)
266    ans = x - math_ops.reduce_mean(x, axis=0)
267    output_val, ans_val = self.evaluate([output, ans])
268    self.assertAllClose(ans_val, output_val)
269
270  def test_reduce_mean(self):
271    x = random_ops.random_uniform([8, 3])
272
273    def loop_fn(i, pfor_config):
274      x_i = array_ops.gather(x, i)
275      return x_i - pfor_config.reduce_mean(x_i)
276
277    output = pfor_control_flow_ops.pfor(loop_fn, 8)
278    ans = x - math_ops.reduce_mean(x, axis=0)
279    output_val, ans_val = self.evaluate([output, ans])
280    self.assertAllClose(ans_val, output_val)
281
282  def test_reduce_sum(self):
283    x = random_ops.random_uniform([8, 3])
284
285    def loop_fn(i, pfor_config):
286      x_i = array_ops.gather(x, i)
287      return x_i - pfor_config.reduce_sum(x_i)
288
289    output = pfor_control_flow_ops.pfor(loop_fn, 8)
290    ans = x - math_ops.reduce_sum(x, axis=0)
291    output_val, ans_val = self.evaluate([output, ans])
292    self.assertAllClose(ans_val, output_val)
293
294  def test_reduce_class(self):
295    x = random_ops.random_uniform([8, 3])
296
297    class LoopFn(object):
298
299      def __init__(self):
300        pass
301
302      def __call__(self, i, pfor_config):
303        x_i = array_ops.gather(x, i)
304        return x_i - pfor_config.reduce_mean(x_i)
305
306    output = pfor_control_flow_ops.pfor(LoopFn(), 8)
307    ans = x - math_ops.reduce_mean(x, axis=0)
308    output_val, ans_val = self.evaluate([output, ans])
309    self.assertAllClose(ans_val, output_val)
310
311  def test_reduce_functools_partial(self):
312    x = random_ops.random_uniform([8, 3])
313
314    def fn(i, pfor_config, dummy=None):
315      del dummy
316      x_i = array_ops.gather(x, i)
317      return x_i - pfor_config.reduce_mean(x_i)
318
319    loop_fn = functools.partial(fn, dummy=1)
320    output = pfor_control_flow_ops.pfor(loop_fn, 8)
321    ans = x - math_ops.reduce_mean(x, axis=0)
322    output_val, ans_val = self.evaluate([output, ans])
323    self.assertAllClose(ans_val, output_val)
324
325  def test_parallel_iterations(self):
326    x = random_ops.random_uniform([8, 3])
327
328    def loop_fn(i, pfor_config):
329      x_i = array_ops.gather(x, i)
330      return pfor_config.reduce_sum(x_i)
331
332    with self.assertRaisesRegex(ValueError,
333                                "parallel_iterations currently unsupported"):
334      pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2)
335
336  def test_var_loop_len(self):
337    if context.executing_eagerly():
338      self.skipTest("Variable length not possible under eager execution.")
339
340    x = random_ops.random_uniform([8, 3])
341
342    def loop_fn(i, pfor_config):
343      return pfor_config.reduce_sum(array_ops.gather(x, i))
344
345    num_iters = array_ops.placeholder(dtypes.int32)
346    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
347    with self.cached_session() as sess:
348      sess.run(pfor, feed_dict={num_iters: 8})
349
350
351@test_util.run_all_in_graph_and_eager_modes
352class BitwiseTest(PForTestCase):
353
354  def test_unary_cwise(self):
355    for op in [bitwise_ops.invert]:
356      x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
357
358      # pylint: disable=cell-var-from-loop
359      def loop_fn(i):
360        x1 = array_ops.gather(x, i)
361        return op(x1)
362
363      # pylint: enable=cell-var-from-loop
364
365      self._test_loop_fn(loop_fn, 3)
366
367  def test_binary_cwise(self):
368    binary_ops = [
369        bitwise_ops.bitwise_and,
370        bitwise_ops.bitwise_or,
371        bitwise_ops.bitwise_xor,
372        bitwise_ops.left_shift,
373        bitwise_ops.right_shift,
374    ]
375    for op in binary_ops:
376      x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
377      y = random_ops.random_uniform([3, 5], maxval=10, dtype=dtypes.int32)
378
379      output_dtypes = []
380
381      # pylint: disable=cell-var-from-loop
382      def loop_fn(i):
383        x1 = array_ops.gather(x, i)
384        y1 = array_ops.gather(y, i)
385        outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
386        del output_dtypes[:]
387        output_dtypes.extend(t.dtype for t in outputs)
388        return outputs
389
390      # pylint: enable=cell-var-from-loop
391      self._test_loop_fn(loop_fn, 3)
392
393
394@test_util.run_all_in_graph_and_eager_modes
395class ImageTest(PForTestCase):
396
397  def test_adjust_contrast(self):
398    images = random_ops.random_uniform([3, 2, 4, 4, 3])
399
400    def loop_fn(i):
401      image = array_ops.gather(images, i)
402      return image_ops.adjust_contrast(image, 2.0)
403
404    self._test_loop_fn(loop_fn, 3)
405
406  def test_adjust_hue(self):
407    images = random_ops.random_uniform([3, 2, 4, 4, 3])
408
409    def loop_fn(i):
410      image = array_ops.gather(images, i)
411      return image_ops.adjust_hue(image, .25)
412
413    self._test_loop_fn(loop_fn, 3)
414
415  def test_adjust_saturation(self):
416    images = random_ops.random_uniform([3, 2, 4, 4, 3])
417
418    def loop_fn(i):
419      image = array_ops.gather(images, i)
420      return image_ops.adjust_saturation(image, 0.1)
421
422    self._test_loop_fn(loop_fn, 3)
423
424
425@test_util.run_all_in_graph_and_eager_modes
426class NNTest(PForTestCase):
427
428  def test_conv2d(self):
429    x = random_ops.random_uniform([3, 2, 12, 12, 3])
430    filt = random_ops.random_uniform([3, 3, 3, 7])
431
432    def loop_fn(i):
433      x1 = array_ops.gather(x, i)
434      return nn.conv2d(
435          x1, filt, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
436
437    self._test_loop_fn(loop_fn, 3)
438
439  def test_conv2d_backprop_input(self):
440    x_shape = [2, 12, 12, 3]
441    filt = random_ops.random_uniform([3, 3, 3, 7])
442    grad = random_ops.random_uniform([3, 2, 5, 5, 7])
443
444    def loop_fn(i):
445      grad1 = array_ops.gather(grad, i)
446      return nn.conv2d_backprop_input(
447          x_shape,
448          filt,
449          grad1,
450          strides=[1, 2, 2, 1],
451          padding="VALID",
452          data_format="NHWC")
453
454    self._test_loop_fn(loop_fn, 3)
455
456  def test_conv2d_backprop_filter(self):
457    x = random_ops.random_uniform([3, 2, 12, 12, 3])
458    x_0 = array_ops.gather(x, 0)
459    filter_sizes = [3, 3, 3, 7]
460    grad = random_ops.random_uniform([3, 2, 5, 5, 7])
461
462    def loop_fn(i):
463      x_i = array_ops.gather(x, i)
464      grad_i = array_ops.gather(grad, i)
465      return [
466          nn.conv2d_backprop_filter(
467              inp,
468              filter_sizes,
469              grad_i,
470              strides=[1, 2, 2, 1],
471              padding="VALID",
472              data_format="NHWC") for inp in [x_i, x_0]
473      ]
474
475    self._test_loop_fn(loop_fn, 3)
476
477  def test_depthwise_conv2d_native(self):
478    x = random_ops.random_uniform([3, 2, 12, 12, 3])
479    filt = random_ops.random_uniform([3, 3, 3, 3, 2])
480
481    def loop_fn(i):
482      x1 = array_ops.gather(x, i)
483      filt1 = array_ops.gather(filt, i)
484      return nn.depthwise_conv2d_native(
485          x1, filt1, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
486
487    self._test_loop_fn(loop_fn, 3)
488
489  def test_depthwise_conv2d_native_backprop_input(self):
490    x_shape = [2, 12, 12, 3]
491    filt = random_ops.random_uniform([3, 3, 3, 3, 2])
492    grad = random_ops.random_uniform([3, 2, 5, 5, 6])
493
494    def loop_fn(i):
495      grad1 = array_ops.gather(grad, i)
496      filt1 = array_ops.gather(filt, i)
497      return nn.depthwise_conv2d_native_backprop_input(
498          x_shape,
499          filt1,
500          grad1,
501          strides=[1, 2, 2, 1],
502          padding="VALID",
503          data_format="NHWC")
504
505    self._test_loop_fn(loop_fn, 3)
506
507  def test_depthwise_conv2d_native_backprop_filter(self):
508    x = random_ops.random_uniform([3, 2, 12, 12, 3])
509    filter_sizes = [3, 3, 3, 2]
510    grad = random_ops.random_uniform([3, 2, 5, 5, 6])
511
512    def loop_fn(i):
513      x_i = array_ops.gather(x, i)
514      grad_i = array_ops.gather(grad, i)
515      return nn.depthwise_conv2d_native_backprop_filter(
516          x_i,
517          filter_sizes,
518          grad_i,
519          strides=[1, 2, 2, 1],
520          padding="VALID",
521          data_format="NHWC")
522
523    self._test_loop_fn(loop_fn, 3)
524
525  def test_depthwise_conv2d_native_nchw(self):
526    if not test_util.is_gpu_available():
527      self.skipTest("NCHW only works on GPU")
528    x = random_ops.random_uniform([3, 2, 3, 12, 12])
529    filt = random_ops.random_uniform([3, 3, 3, 3, 2])
530
531    def loop_fn(i):
532      x1 = array_ops.gather(x, i)
533      filt1 = array_ops.gather(filt, i)
534      return nn.depthwise_conv2d_native(
535          x1, filt1, strides=[1, 1, 2, 2], padding="VALID", data_format="NCHW")
536
537    self._test_loop_fn(loop_fn, 3)
538
539  def test_depthwise_conv2d_native_backprop_input_nchw(self):
540    if not test_util.is_gpu_available():
541      self.skipTest("NCHW only works on GPU")
542    x_shape = [2, 3, 12, 12]
543    filt = random_ops.random_uniform([3, 3, 3, 3, 2])
544    grad = random_ops.random_uniform([3, 2, 6, 5, 5])
545
546    def loop_fn(i):
547      grad1 = array_ops.gather(grad, i)
548      filt1 = array_ops.gather(filt, i)
549      return nn.depthwise_conv2d_native_backprop_input(
550          x_shape,
551          filt1,
552          grad1,
553          strides=[1, 1, 2, 2],
554          padding="VALID",
555          data_format="NCHW")
556
557    self._test_loop_fn(loop_fn, 3)
558
559  def test_depthwise_conv2d_native_backprop_filter_nchw(self):
560    if not test_util.is_gpu_available():
561      self.skipTest("NCHW only works on GPU")
562    x = random_ops.random_uniform([3, 2, 3, 12, 12])
563    filter_sizes = [3, 3, 3, 2]
564    grad = random_ops.random_uniform([3, 2, 6, 5, 5])
565
566    def loop_fn(i):
567      x_i = array_ops.gather(x, i)
568      grad_i = array_ops.gather(grad, i)
569      return nn.depthwise_conv2d_native_backprop_filter(
570          x_i,
571          filter_sizes,
572          grad_i,
573          strides=[1, 1, 2, 2],
574          padding="VALID",
575          data_format="NCHW")
576
577    self._test_loop_fn(loop_fn, 3)
578
579  def test_roll(self):
580    x = random_ops.random_uniform([3, 6, 7])
581
582    def loop_fn(i):
583      x_i = array_ops.gather(x, i)
584      return manip_ops.roll(x_i, 3, axis=1)
585
586    self._test_loop_fn(loop_fn, 3)
587
588  def test_ensure_shape(self):
589    x = random_ops.random_uniform([3, 6, 7])
590
591    def loop_fn(i):
592      x_i = array_ops.gather(x, i)
593      return array_ops.ensure_shape(x_i, [6, 7])
594
595    self._test_loop_fn(loop_fn, 3)
596
597  def test_loop_variant_roll_shift(self):
598    self.skipTest("TODO(b/191880259): re-enable once XLA compile times are "
599                  "addressed.")
600    x = random_ops.random_uniform([3, 5, 6, 7])
601
602    def loop_fn(i):
603      x_i = array_ops.gather(x, i)
604      return manip_ops.roll(x_i, [i - 2, -1, i], axis=[1, 2, 2])
605
606    self._test_loop_fn(loop_fn, 3)
607
608  def test_loop_variant_roll_scalar_shift(self):
609    self.skipTest("TODO(b/191880259): re-enable once XLA compile times are "
610                  "addressed.")
611    x = random_ops.random_uniform([5, 5, 6])
612
613    def loop_fn(i):
614      x_i = array_ops.gather(x, i)
615      return manip_ops.roll(x_i, i, axis=0)
616
617    self._test_loop_fn(loop_fn, 5)
618
619  def test_avg_pool(self):
620    with backprop.GradientTape(persistent=True) as g:
621      x = random_ops.random_uniform([3, 2, 12, 12, 3])
622      g.watch(x)
623      ksize = [1, 3, 3, 1]
624
625    def loop_fn(i):
626      with g:
627        x1 = array_ops.gather(x, i)
628        output = nn.avg_pool(
629            x1,
630            ksize,
631            strides=[1, 2, 2, 1],
632            padding="VALID",
633            data_format="NHWC")
634        loss = nn.l2_loss(output)
635      return output, g.gradient(loss, x1)
636
637    self._test_loop_fn(loop_fn, 3)
638
639  def test_avg_pool3d(self):
640    with backprop.GradientTape(persistent=True) as g:
641      x = random_ops.random_uniform([5, 3, 7, 6, 6, 5])
642      g.watch(x)
643      ksize = [1, 2, 2, 2, 1]
644      strides = [1, 2, 2, 2, 1]
645
646    def loop_fn(i):
647      with g:
648        x1 = array_ops.gather(x, i)
649        output = nn.avg_pool3d(
650            x1, ksize, strides=strides, padding="VALID", data_format="NDHWC")
651        loss = nn.l2_loss(output)
652      return output, g.gradient(loss, x1)
653
654    self._test_loop_fn(loop_fn, 3)
655
656  def test_max_pool(self):
657    with backprop.GradientTape(persistent=True) as g:
658      x = random_ops.random_uniform([3, 2, 12, 12, 3])
659      g.watch(x)
660      ksize = [1, 3, 3, 1]
661      strides = [1, 2, 2, 1]
662
663    def loop_fn(i):
664      with g:
665        x1 = array_ops.gather(x, i)
666        output = nn.max_pool(
667            x1, ksize, strides=strides, padding="VALID", data_format="NHWC")
668        loss = nn.l2_loss(output)
669        ones = array_ops.ones_like(output)
670        g.watch(ones)
671        grad = g.gradient(loss, x1, output_gradients=ones)
672      grad_grad = g.gradient(grad, ones)
673      return output, grad, grad_grad
674
675    self._test_loop_fn(loop_fn, 3)
676
677  def test_max_pool_v2(self):
678    with backprop.GradientTape(persistent=True) as g:
679      x = random_ops.random_uniform([3, 2, 12, 12, 3])
680      g.watch(x)
681      ksize = [1, 3, 3, 1]
682      strides = [1, 2, 2, 1]
683
684    def loop_fn(i):
685      with g:
686        x1 = array_ops.gather(x, i)
687        output = gen_nn_ops.max_pool_v2(
688            x1, ksize, strides=strides, padding="VALID", data_format="NHWC")
689        loss = nn.l2_loss(output)
690        ones = array_ops.ones_like(output)
691        g.watch(ones)
692        grad = g.gradient(loss, x1, output_gradients=ones)
693      grad_grad = g.gradient(grad, ones)
694      return output, grad, grad_grad
695
696    self._test_loop_fn(loop_fn, 3)
697
698  def test_max_pool3d(self):
699    with backprop.GradientTape(persistent=True) as g:
700      x = random_ops.random_uniform([3, 3, 2, 12, 12, 3])
701      g.watch(x)
702      ksize = [1, 1, 3, 3, 1]
703      strides = [1, 1, 2, 2, 1]
704
705    def loop_fn(i):
706      with g:
707        x1 = array_ops.gather(x, i)
708        output = nn.max_pool3d(
709            x1, ksize, strides=strides, padding="VALID", data_format="NDHWC")
710        loss = nn.l2_loss(output)
711        ones = array_ops.ones_like(output)
712        g.watch(ones)
713        grad = g.gradient(loss, x1, output_gradients=ones)
714      grad_grad = g.gradient(grad, ones)
715      return output, grad, grad_grad
716
717    self._test_loop_fn(loop_fn, 3)
718
719  def test_fused_batch_norm(self):
720    data_formats = ["NHWC"]
721    if test.is_gpu_available():
722      data_formats.append("NCHW")
723    for is_training in (True, False):
724      for data_format in data_formats:
725        with backprop.GradientTape(persistent=True) as g:
726          if data_format == "NCHW":
727            x = random_ops.random_uniform([3, 1, 2, 5, 5])
728          else:
729            x = random_ops.random_uniform([3, 1, 5, 5, 2])
730          g.watch(x)
731          scale = random_ops.random_uniform([2])
732          g.watch(scale)
733          offset = random_ops.random_uniform([2])
734          g.watch(offset)
735          mean = None if is_training else random_ops.random_uniform([2])
736          variance = None if is_training else random_ops.random_uniform([2])
737
738        # pylint: disable=cell-var-from-loop
739        def loop_fn(i):
740          with g:
741            x1 = array_ops.gather(x, i)
742            outputs = nn.fused_batch_norm(
743                x1,
744                scale,
745                offset,
746                mean=mean,
747                variance=variance,
748                epsilon=0.01,
749                data_format=data_format,
750                is_training=is_training)
751            outputs = list(outputs)
752            # We only test the first value of outputs when is_training is
753            # False. It looks like CPU and GPU have different outputs for
754            # batch_mean and batch_variance for this case.
755            if not is_training:
756              outputs[1] = constant_op.constant(0.)
757              outputs[2] = constant_op.constant(0.)
758            loss = nn.l2_loss(outputs[0])
759          if is_training:
760            gradients = g.gradient(loss, [x1, scale, offset])
761          else:
762            gradients = [constant_op.constant(0.)] * 3
763          return outputs + gradients
764
765        # pylint: enable=cell-var-from-loop
766
767        self._test_loop_fn(loop_fn, 3)
768
769  def test_log_softmax(self):
770    logits = random_ops.random_uniform([3, 2, 4])
771
772    def loop_fn(i):
773      logits_i = array_ops.gather(logits, i)
774      return (nn.log_softmax(logits_i), nn.log_softmax(logits_i, axis=0),
775              nn.log_softmax(logits_i, axis=-1))
776
777    self._test_loop_fn(loop_fn, 3)
778
779  def test_softmax(self):
780    logits = random_ops.random_uniform([3, 2, 4])
781
782    def loop_fn(i):
783      logits_i = array_ops.gather(logits, i)
784      return (nn.softmax(logits_i), nn.softmax(logits_i, axis=0),
785              nn.softmax(logits_i, axis=-1))
786
787    self._test_loop_fn(loop_fn, 3)
788
789  def test_softmax_cross_entropy_with_logits(self):
790    with backprop.GradientTape(persistent=True) as g:
791      logits = random_ops.random_uniform([3, 2, 4])
792      g.watch(logits)
793      labels = random_ops.random_uniform([3, 2, 4])
794      labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True)
795
796    def loop_fn(i):
797      with g:
798        logits_i = array_ops.gather(logits, i)
799        labels_i = array_ops.gather(labels, i)
800        loss = nn.softmax_cross_entropy_with_logits(
801            labels=labels_i, logits=logits_i)
802        total_loss = math_ops.reduce_sum(loss)
803      return loss, g.gradient(total_loss, logits_i)
804
805    self._test_loop_fn(loop_fn, 3)
806
807  def test_sparse_softmax_cross_entropy_with_logits(self):
808    logits = random_ops.random_uniform([3, 2, 4])
809    labels = random_ops.random_uniform(
810        shape=[3, 2], maxval=4, dtype=dtypes.int32)
811
812    def loop_fn(i):
813      logits_i = array_ops.gather(logits, i)
814      labels_i = array_ops.gather(labels, i)
815      loss = nn.sparse_softmax_cross_entropy_with_logits(
816          labels=labels_i, logits=logits_i)
817      return loss
818
819    self._test_loop_fn(loop_fn, 3)
820
821
822class RandomTest(PForTestCase):
823
824  # The random values generated in the two implementations are not guaranteed to
825  # match. So we only check the returned shapes.
826  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
827    outputs = self._run_targets(targets1, targets2)
828    n = len(outputs) // 2
829    for i in range(n):
830      self.assertAllEqual(outputs[i].shape, outputs[i + n].shape)
831
832  def test_random_uniform(self):
833
834    def loop_fn(_):
835      return random_ops.random_uniform([3])
836
837    self._test_loop_fn(loop_fn, 5)
838
839  def test_random_uniform_int(self):
840
841    def loop_fn(_):
842      return random_ops.random_uniform([3], maxval=1, dtype=dtypes.int32)
843
844    self._test_loop_fn(loop_fn, 5)
845
846  def test_random_standard_normal(self):
847
848    def loop_fn(_):
849      return random_ops.random_normal([3])
850
851    self._test_loop_fn(loop_fn, 5)
852
853  def test_truncated_normal(self):
854
855    def loop_fn(_):
856      return random_ops.truncated_normal([3])
857
858    self._test_loop_fn(loop_fn, 5)
859
860  def test_random_gamma_invariant_alpha(self):
861
862    def loop_fn(_):
863      return random_ops.random_gamma([3], alpha=[0.5])
864
865    self._test_loop_fn(loop_fn, 5)
866
867  def test_random_gamma_varying_alpha(self):
868    alphas = math_ops.exp(random_ops.random_normal([5, 3, 2]))
869
870    def loop_fn(i):
871      alphas_i = array_ops.gather(alphas, i)
872      # Test both scalar and non-scalar params and shapes.
873      return (random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[]),
874              random_ops.random_gamma(alpha=alphas_i, shape=[]),
875              random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[3]),
876              random_ops.random_gamma(alpha=alphas_i, shape=[3]))
877
878    self._test_loop_fn(loop_fn, 5)
879
880  def test_random_poisson_v2_invariant_rate(self):
881
882    def loop_fn(_):
883      return random_ops.random_poisson(lam=[1.3], shape=[3])
884
885    self._test_loop_fn(loop_fn, 5)
886
887  def test_random_poisson_v2_varying_rate(self):
888    rates = math_ops.exp(random_ops.random_normal([5, 3, 2]))
889
890    def loop_fn(i):
891      rates_i = array_ops.gather(rates, i)
892      # Test both scalar and non-scalar params and shapes.
893      return (random_ops.random_poisson(lam=rates_i[0, 0], shape=[]),
894              random_ops.random_poisson(lam=rates_i, shape=[]),
895              random_ops.random_poisson(lam=rates_i[0, 0], shape=[3]),
896              random_ops.random_poisson(lam=rates_i, shape=[3]))
897
898    self._test_loop_fn(loop_fn, 5)
899
900  def test_random_multinomial_invariant_logits(self):
901
902    def loop_fn(_):
903      return random_ops.categorical(logits=[[1., -1.]], num_samples=3)
904
905    self._test_loop_fn(loop_fn, 5)
906
907  def test_random_multinomial_varying_logits(self):
908    logits = random_ops.random_normal([5, 3, 2])
909
910    def loop_fn(i):
911      logits_i = array_ops.gather(logits, i)
912      return random_ops.categorical(logits_i, num_samples=3)
913
914    self._test_loop_fn(loop_fn, 5)
915
916
917class StatelessRandomTest(PForTestCase):
918
919  # This test currently only tests that the vectorized and non-vectorized
920  # outputs have same shapes. This is needed since under XLA compilation,
921  # stateless random numbers can generate different random numbers.
922  # TODO(agarwal): switch to checking for actual values matching once
923  # b/149402339 is resolved.
924  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
925    outputs = self._run_targets(targets1, targets2)
926    n = len(outputs) // 2
927    for i in range(n):
928      self.assertAllEqual(outputs[i].shape, outputs[i + n].shape)
929
930  # TODO(agarwal): add tests for other random functions
931  def test_multinomial(self):
932    seeds = [[1, 2], [3, 4]]
933    logits = random_ops.random_uniform([2, 3, 4])
934
935    def loop_fn(i):
936      logits_0 = array_ops.gather(logits, 0)
937      logits_i = array_ops.gather(logits, i)
938      seeds_0 = array_ops.gather(seeds, 0)
939      seeds_i = array_ops.gather(seeds, i)
940      return (stateless_random_ops.stateless_categorical(
941          logits=logits_i, num_samples=3, seed=seeds_i),
942              stateless_random_ops.stateless_categorical(
943                  logits=logits_i, num_samples=3, seed=seeds_0),
944              stateless_random_ops.stateless_categorical(
945                  logits=logits_0, num_samples=3, seed=seeds_i),
946              stateless_random_ops.stateless_categorical(
947                  logits=logits_0, num_samples=3, seed=seeds_0))
948
949    self._test_loop_fn(loop_fn, 2)
950
951
952class LoggingTest(PForTestCase):
953
954  def test_print(self):
955    x = random_ops.random_uniform([3, 5])
956
957    def loop_fn(i):
958      x1 = array_ops.gather(x, i)
959      return logging_ops.Print(
960          x1, [x1, "x1", array_ops.shape(x1)], summarize=10)
961
962    self._test_loop_fn(loop_fn, 3)
963
964  def test_print_v2(self):
965    x = constant_op.constant([1, 2, 3])
966
967    def loop_fn(i):
968      x1 = array_ops.gather(x, i)
969      with ops.control_dependencies([
970          logging_ops.print_v2(
971              x1, "x1", array_ops.shape(x1), summarize=10)]):
972        return array_ops.identity(x1)
973
974    self._test_loop_fn(loop_fn, 3)
975
976    with self.captureWritesToStream(sys.stderr) as printed:
977      self.evaluate(pfor_control_flow_ops.pfor(loop_fn, 3))
978    self.assertIn("[1 2 3] x1 []", printed.contents())
979
980  def test_assert(self):
981
982    def loop_fn(i):
983      return control_flow_ops.Assert(i < 10, [i, [10], [i + 1]])
984
985    # TODO(agarwal): make this work with for_loop.
986    with session.Session() as sess:
987      sess.run(pfor_control_flow_ops.pfor(loop_fn, 3))
988      sess.run(pfor_control_flow_ops.pfor(
989          lambda i, pfor_config: loop_fn(i), 3))
990
991
992class TensorArrayTest(PForTestCase):
993
994  def setUp(self):
995    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
996    control_flow_v2_toggles.disable_control_flow_v2()
997    super(TensorArrayTest, self).setUp()
998
999  def tearDown(self):
1000    if self._enabled:
1001      control_flow_v2_toggles.enable_control_flow_v2()
1002    super(TensorArrayTest, self).tearDown()
1003
1004  @test_util.run_v1_only("b/122612051")
1005  def test_create_outside_and_read(self):
1006
1007    ta = tensor_array_ops.TensorArray(
1008        dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1)
1009
1010    def loop_fn(i):
1011      return ta.read(i), ta.read(0)
1012
1013    self._test_loop_fn(loop_fn, 2)
1014
1015  @test_util.run_v1_only("b/122612051")
1016  def test_create_outside_and_gather(self):
1017
1018    ta = tensor_array_ops.TensorArray(
1019        dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1)
1020
1021    def loop_fn(i):
1022      return ta.gather([i]), ta.gather([0, 1])
1023
1024    self._test_loop_fn(loop_fn, 2)
1025
1026  @test_util.run_v1_only("b/122612051")
1027  def test_create_outside_and_write_and_scatter(self):
1028
1029    t = tensor_array_ops.TensorArray(dtypes.int32, 10, clear_after_read=False)
1030    handle = t.handle
1031
1032    def loop_fn(i):
1033      ta = t.write(i + 2, 2 * i).write(i, 5)
1034      ta = ta.scatter([4 + i], [4]).scatter([6 + i, 8 + i], [6 + i, 8 + i])
1035      return ta.flow
1036
1037    t1 = pfor_control_flow_ops.pfor(loop_fn, iters=2)
1038    out1 = tensor_array_ops.TensorArray(
1039        dtypes.int32, handle=handle, flow=t1[-1]).stack()
1040    output1 = self._run_targets(out1)
1041
1042    t2 = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, iters=2)
1043    out2 = tensor_array_ops.TensorArray(
1044        dtypes.int32, handle=handle, flow=t2[-1]).stack()
1045    output2 = self._run_targets(out2)
1046    self.assertAllClose(output2, output1)
1047
1048  @test_util.run_v1_only("b/122612051")
1049  def test_create_inside_and_write(self):
1050
1051    def loop_fn(i):
1052      # TODO(agarwal): switching the order of writes to ta1 does not work.
1053      ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).write(0,
1054                                                                i).write(1, 1)
1055      ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1).write(0, 1)
1056      return ta1.stack(), ta2.stack()
1057
1058    self._test_loop_fn(loop_fn, 3)
1059
1060  @test_util.run_v1_only("b/122612051")
1061  def test_create_inside_and_scatter(self):
1062
1063    def loop_fn(i):
1064      # TODO(agarwal): switching the order of scatter to ta1 does not work.
1065      ta1 = tensor_array_ops.TensorArray(dtypes.int32,
1066                                         2).scatter([0],
1067                                                    [[i, 2]]).scatter([1],
1068                                                                      [[1, 2]])
1069      ta2 = tensor_array_ops.TensorArray(dtypes.int32,
1070                                         2).scatter([0], [3]).scatter([1], [4])
1071      return ta1.stack(), ta2.stack()
1072
1073    self._test_loop_fn(loop_fn, 3)
1074
1075  @test_util.run_v1_only("b/122612051")
1076  def test_create_inside_and_read(self):
1077
1078    def loop_fn(i):
1079      ta1 = tensor_array_ops.TensorArray(
1080          dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1)
1081      ta2 = tensor_array_ops.TensorArray(
1082          dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2)
1083      # TODO(agarwal): ta1.read(i) currently is not supported.
1084      return ta1.read(0), ta2.read(0), ta2.read(i)
1085
1086    self._test_loop_fn(loop_fn, 2)
1087
1088  @test_util.run_v1_only("b/122612051")
1089  def test_create_inside_and_gather(self):
1090
1091    def loop_fn(i):
1092      ta1 = tensor_array_ops.TensorArray(
1093          dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1)
1094      ta2 = tensor_array_ops.TensorArray(
1095          dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2)
1096      # TODO(agarwal): ta1.read(i) currently is not supported.
1097      return ta1.gather([0, 1]), ta2.gather([0, 1]), ta2.gather([i])
1098
1099    self._test_loop_fn(loop_fn, 2)
1100
1101  @test_util.run_v1_only("b/122612051")
1102  def test_grad(self):
1103    x = random_ops.random_uniform([3, 2])
1104    ta = tensor_array_ops.TensorArray(
1105        dtypes.float32, 3, clear_after_read=False).unstack(x)
1106    y = math_ops.square(ta.stack())
1107
1108    def loop_fn(i):
1109      y_i = array_ops.gather(y, i)
1110      grad = gradient_ops.gradients(y_i, x)[0]
1111      return array_ops.gather(grad, i)
1112
1113    t1 = pfor_control_flow_ops.pfor(loop_fn, iters=3)
1114    # y = x * x. Hence dy/dx = 2 * x.
1115    actual_grad = 2.0 * x
1116    with session.Session() as sess:
1117      actual_grad, computed_grad = sess.run([t1, actual_grad])
1118      self.assertAllClose(actual_grad, computed_grad)
1119
1120
1121@test_util.run_all_in_graph_and_eager_modes
1122class TensorListTest(PForTestCase):
1123
1124  def test_create_outside_and_write(self):
1125    handle1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1126    handle2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1127
1128    def loop_fn(i):
1129      h1 = list_ops.tensor_list_set_item(handle1, 0, i)
1130      h1 = list_ops.tensor_list_set_item(h1, 1, 1)
1131      h2 = list_ops.tensor_list_set_item(handle2, 0, 1)
1132      return (list_ops.tensor_list_stack(h1, dtypes.int32),
1133              list_ops.tensor_list_stack(h2, dtypes.int32))
1134
1135    self._test_loop_fn(loop_fn, 3)
1136
1137  def _make_graph_def(self, text):
1138    ret = graph_pb2.GraphDef()
1139    text_format.Parse(text, ret)
1140    return ret
1141
1142  def test_no_fallback_with_internal_stacking(self):
1143
1144    # Create an op (really a function) that pfor definitely does not have a
1145    # converter for. Assumes pfor does not start looking up function definitions
1146    # for op-type-is-function-name calls.
1147    @def_function.function
1148    def opaque_list_fetch(x):
1149      array_ops.identity(x)
1150      return list_ops.tensor_list_get_item(x, 0, dtypes.int32)
1151
1152    external_handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1153    opaque_list_fetch_concrete = opaque_list_fetch.get_concrete_function(
1154        external_handle)
1155    opaque_list_fetch_name = opaque_list_fetch_concrete.name
1156
1157    def loop_fn(i):
1158      h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1159      h1 = list_ops.tensor_list_set_item(h1, 0, i)
1160      opaque_list_fetch_concrete.add_to_graph()
1161      graph_def = self._make_graph_def("""
1162         node { name: 'x' op: 'Placeholder'
1163                attr { key: 'dtype' value { type: DT_FLOAT } }}
1164         node { name: 'fn' op: '""" + opaque_list_fetch_name.decode()
1165                                       + """' input: 'x:0' }""")
1166      return importer.import_graph_def(
1167          graph_def,
1168          input_map={"x:0": h1},
1169          return_elements=["fn"],
1170          name="import")[0].outputs[0]
1171
1172    with self.assertRaisesRegex(ValueError, "No pfor vectorization"):
1173      self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
1174    with self.assertRaisesRegex(ValueError, "No pfor vectorization"):
1175      self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
1176
1177  def test_create_inside_and_write(self):
1178
1179    def loop_fn(i):
1180      h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1181      h1 = list_ops.tensor_list_set_item(h1, 0, i)
1182      h1 = list_ops.tensor_list_set_item(h1, 1, 1)
1183      h2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1184      h2 = list_ops.tensor_list_set_item(h2, 0, 1)
1185      return (list_ops.tensor_list_stack(h1, dtypes.int32),
1186              list_ops.tensor_list_stack(h2, dtypes.int32))
1187
1188    self._test_loop_fn(loop_fn, 3)
1189
1190  def test_create_outside_and_read(self):
1191    handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1192    handle = list_ops.tensor_list_set_item(handle, 0, 0)
1193    handle = list_ops.tensor_list_set_item(handle, 1, 1)
1194
1195    def loop_fn(i):
1196      return (list_ops.tensor_list_get_item(handle, i, dtypes.int32),
1197              list_ops.tensor_list_get_item(handle, 0, dtypes.int32),
1198              list_ops.tensor_list_length(handle),
1199              list_ops.tensor_list_element_shape(handle, dtypes.int32),
1200              list_ops.tensor_list_element_shape(handle, dtypes.int64))
1201
1202    self._test_loop_fn(loop_fn, 2)
1203
1204  def test_create_inside_and_read(self):
1205
1206    def loop_fn(i):
1207      handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1208      handle = list_ops.tensor_list_set_item(handle, 0, i)
1209      handle = list_ops.tensor_list_set_item(handle, 1, 1)
1210      return (list_ops.tensor_list_get_item(handle, 0, dtypes.int32),
1211              list_ops.tensor_list_get_item(handle, i, dtypes.int32),
1212              list_ops.tensor_list_length(handle),
1213              list_ops.tensor_list_element_shape(handle, dtypes.int32),
1214              list_ops.tensor_list_element_shape(handle, dtypes.int64))
1215
1216    self._test_loop_fn(loop_fn, 2)
1217
1218  def test_create_outside_and_push_back(self):
1219    h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1220
1221    def loop_fn(i):
1222      handle = list_ops.tensor_list_push_back(h, [i, 2])
1223      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1224      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1225      return list_ops.tensor_list_stack(handle, dtypes.int32)
1226
1227    self._test_loop_fn(loop_fn, 3)
1228
1229  def test_create_inside_and_push_back(self):
1230
1231    def loop_fn(i):
1232      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1233      handle = list_ops.tensor_list_push_back(handle, [i, 2])
1234      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1235      return list_ops.tensor_list_stack(handle, dtypes.int32)
1236
1237    self._test_loop_fn(loop_fn, 3)
1238
1239  def test_pop_back_no_shape(self):
1240
1241    def loop_fn(i):
1242      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1243      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1244      handle = list_ops.tensor_list_push_back(handle, [i, 2])
1245      handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32)
1246      return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
1247
1248    self._test_loop_fn(loop_fn, 3)
1249
1250  def test_pop_back_no_shape_capture(self):
1251    h = list_ops.tensor_list_reserve([2], 1, dtypes.int32)
1252    h = list_ops.tensor_list_push_back(h, [1, 2])
1253
1254    def loop_fn(i):
1255      handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32)
1256      handle = list_ops.tensor_list_push_back(handle, [1, i])
1257      return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
1258
1259    self._test_loop_fn(loop_fn, 3)
1260
1261  def test_pop_back_with_shape(self):
1262
1263    @def_function.function
1264    def loop_fn(i):
1265      with backprop.GradientTape() as tape:
1266        handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32)
1267        x = math_ops.cast(i, dtypes.float32)[None]
1268        tape.watch(x)
1269        handle = list_ops.tensor_list_push_back(handle, x)
1270        stacked = list_ops.tensor_list_stack(handle, dtypes.float32)
1271      list_grad = tape.gradient(stacked, x, x)
1272      self.assertEqual("TensorListPopBack", list_grad.op.type)
1273      return list_grad, stacked, list_grad.op.inputs[1]
1274
1275    self._test_loop_fn(loop_fn, 3)
1276
1277  def test_create_outside_and_scatter(self):
1278    h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1279
1280    def loop_fn(i):
1281      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h)
1282      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1283      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1284      return list_ops.tensor_list_stack(handle, dtypes.int32)
1285
1286    self._test_loop_fn(loop_fn, 3)
1287
1288  def test_create_inside_and_scatter(self):
1289
1290    def loop_fn(i):
1291      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1292      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
1293      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1294      return list_ops.tensor_list_stack(handle, dtypes.int32)
1295
1296    self._test_loop_fn(loop_fn, 3)
1297
1298  def test_loop_variant_scatter_indices(self):
1299
1300    def loop_fn(i):
1301      handle = list_ops.tensor_list_reserve([2], 10, dtypes.int32)
1302      handle = list_ops.tensor_list_scatter(
1303          [[1, i], [i + 1, 2]],
1304          [i, i + 5], input_handle=handle)
1305      return list_ops.tensor_list_stack(handle, dtypes.int32)
1306
1307    self._test_loop_fn(loop_fn, 5)
1308
1309  def test_loop_variant_scatter_duplicate_indices(self):
1310    if test_util.is_gpu_available():
1311      self.skipTest(
1312          "Flaky in some GPU configurations due to TensorScatterNdUpdate "
1313          "nondeterminism.")
1314
1315    def loop_fn(i):
1316      handle = list_ops.tensor_list_reserve([2], 10, dtypes.int32)
1317      handle = list_ops.tensor_list_scatter(
1318          [[1, i], [1, i + 1], [i + 2, 3]],
1319          [i, i, i + 2], input_handle=handle)
1320      return list_ops.tensor_list_stack(handle, dtypes.int32)
1321
1322    self._test_loop_fn(loop_fn, 5)
1323
1324  def test_create_outside_and_gather(self):
1325    handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1326    handle = list_ops.tensor_list_scatter([[2, 3]], [0], input_handle=handle)
1327    handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1328
1329    def loop_fn(i):
1330      return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32),
1331              list_ops.tensor_list_gather(handle, [i], dtypes.int32))
1332
1333    self._test_loop_fn(loop_fn, 2)
1334
1335  def test_create_inside_and_gather(self):
1336
1337    def loop_fn(i):
1338      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1339      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
1340      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1341      return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32),
1342              list_ops.tensor_list_gather(handle, [i], dtypes.int32))
1343
1344    self._test_loop_fn(loop_fn, 2)
1345
1346  def test_create_inside_and_concat(self):
1347
1348    def loop_fn(i):
1349      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1350      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
1351      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1352      return gen_list_ops.tensor_list_concat_v2(
1353          handle,
1354          element_dtype=dtypes.int32,
1355          element_shape=[2],
1356          leading_dims=[])
1357
1358    output = pfor_control_flow_ops.pfor(loop_fn, 2)
1359    self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
1360    self.assertAllClose([[2, 2], [2, 2]], output[1])
1361
1362  def test_create_outside_and_concat(self):
1363    h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1364
1365    def loop_fn(i):
1366      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h)
1367      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1368      return gen_list_ops.tensor_list_concat_v2(
1369          handle,
1370          element_dtype=dtypes.int32,
1371          element_shape=[2],
1372          leading_dims=[])
1373
1374    output = pfor_control_flow_ops.pfor(loop_fn, 2)
1375    self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
1376    self.assertAllClose([[2, 2], [2, 2]], output[1])
1377
1378  def test_tensor_list_from_tensor(self):
1379    t = random_ops.random_uniform([2, 3, 4])
1380
1381    def loop_fn(i):
1382      handle = list_ops.tensor_list_from_tensor(array_ops.gather(t, i), [4])
1383      return list_ops.tensor_list_stack(handle, t.dtype)
1384
1385    self._test_loop_fn(loop_fn, 2)
1386
1387  @test_util.enable_control_flow_v2
1388  def test_tensor_list_reserve_while_loop(self):
1389    # Here a loop invariant TensorList is captured by a while_loop, which then
1390    # performs loop dependent operations on it, resulting in a loop variant
1391    # output. This forces stacking of the variant handle captured by the
1392    # while_loop.
1393    # We handle this particular case by forcing vectorization of
1394    # TensorListReserve operation.
1395
1396    def loop_fn(i):
1397      handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1398      _, out_handle = control_flow_ops.while_loop(
1399          lambda j, _: j < 2, lambda j, h:
1400          (j + 1, list_ops.tensor_list_set_item(h, j, i)), (0, handle))
1401      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
1402
1403    self._test_loop_fn(loop_fn, 2)
1404
1405  @test_util.enable_control_flow_v2
1406  def test_tensor_list_while_loop_stacked_cond_stacked_list(self):
1407
1408    def loop_fn(i):
1409      handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, i], [])
1410      _, out_handle = control_flow_ops.while_loop(
1411          lambda j, _: j < i,
1412          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
1413          (0, handle))
1414      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
1415
1416    self._test_loop_fn(loop_fn, 5)
1417
1418  @test_util.enable_control_flow_v2
1419  def test_tensor_list_while_loop_stacked_cond_stacked_list_unknown_shape(self):
1420
1421    def loop_fn(i):
1422      handle = list_ops.tensor_list_reserve(None, 5, dtypes.int32)
1423      _, handle = control_flow_ops.while_loop(
1424          lambda j, _: j < 5,
1425          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, 0)),
1426          (0, handle))
1427      _, out_handle = control_flow_ops.while_loop(
1428          lambda j, _: j < i,
1429          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
1430          (0, handle))
1431      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
1432
1433    self._test_loop_fn(loop_fn, 5)
1434
1435  @test_util.enable_control_flow_v2
1436  def test_tensor_list_while_loop_stacked_cond_unstacked_list(self):
1437
1438    def loop_fn(i):
1439      handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, 24], [])
1440      _, out_handle = control_flow_ops.while_loop(
1441          lambda j, _: j < i,
1442          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
1443          (0, handle))
1444      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
1445
1446    self._test_loop_fn(loop_fn, 5)
1447
1448  def test_tensor_list_addn_already_stacked(self):
1449
1450    def loop_fn(i):
1451      l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1452      l1 = list_ops.tensor_list_set_item(l1, 0, i)
1453      l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1454      l2 = list_ops.tensor_list_set_item(l2, 1, i)
1455      return list_ops.tensor_list_stack(math_ops.add_n([l1, l2]), dtypes.int32)
1456
1457    self._test_loop_fn(loop_fn, 2)
1458
1459  def test_tensor_list_addn_stacking_required(self):
1460    l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1461    l1 = list_ops.tensor_list_set_item(l1, 1, 1)
1462
1463    def loop_fn(i):
1464      l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1465      l2 = list_ops.tensor_list_set_item(l2, 1, i)
1466      return list_ops.tensor_list_stack(
1467          math_ops.add_n([l1, l2]), dtypes.int32)
1468
1469    self._test_loop_fn(loop_fn, 2)
1470
1471
1472class OptionalTest(PForTestCase):
1473
1474  def test_optional_from_value(self):
1475
1476    def loop_fn(i):
1477      o = gen_dataset_ops.optional_from_value(
1478          [i, i + 1, constant_op.constant(3)])
1479      gen_dataset_ops.optional_none()
1480      return gen_dataset_ops.optional_get_value(
1481          o, [dtypes.int32, dtypes.int32, dtypes.int32],
1482          [[], [], []])
1483
1484    self._test_loop_fn(loop_fn, 2)
1485
1486
1487class StackTest(PForTestCase):
1488
1489  @test_util.run_v1_only("b/122612051")
1490  def test_stack_inside_loop_invariant(self):
1491
1492    def loop_fn(_):
1493      s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1494      op1 = data_flow_ops.stack_push_v2(s, 1)
1495      with ops.control_dependencies([op1]):
1496        op2 = data_flow_ops.stack_push_v2(s, 2)
1497      with ops.control_dependencies([op2]):
1498        e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1499      with ops.control_dependencies([e2]):
1500        e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1501      return e1, e2
1502
1503    self._test_loop_fn(loop_fn, 2)
1504
1505  @test_util.run_v1_only("b/122612051")
1506  def test_stack_inside_push_loop_dependent(self):
1507
1508    def loop_fn(i):
1509      s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1510      op1 = data_flow_ops.stack_push_v2(s, i)
1511      with ops.control_dependencies([op1]):
1512        op2 = data_flow_ops.stack_push_v2(s, 2)
1513      with ops.control_dependencies([op2]):
1514        e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1515      with ops.control_dependencies([e2]):
1516        e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1517      return e1, e2
1518
1519    self._test_loop_fn(loop_fn, 2)
1520
1521  @test_util.run_v1_only("b/122612051")
1522  def test_stack_outside_pop(self):
1523    s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1524    op = data_flow_ops.stack_push_v2(s, 5)
1525    with ops.control_dependencies([op]):
1526      op = data_flow_ops.stack_push_v2(s, 6)
1527    with ops.control_dependencies([op]):
1528      op = data_flow_ops.stack_push_v2(s, 7)
1529
1530    def loop_fn(_):
1531      e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1532      with ops.control_dependencies([e1]):
1533        e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1534      return e1, e2
1535
1536    with ops.control_dependencies([op]):
1537      e1, e2 = pfor_control_flow_ops.pfor(loop_fn, iters=2)
1538    with ops.control_dependencies([e1, e2]):
1539      e3 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1540    v1, v2, v3 = self._run_targets([e1, e2, e3], run_init=False)
1541    self.assertAllEqual([7, 7], v1)
1542    self.assertAllEqual([6, 6], v2)
1543    self.assertAllEqual(5, v3)
1544
1545  @test_util.run_v1_only("b/122612051")
1546  def test_stack_outside_push(self):
1547    s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1548
1549    def loop_fn(_):
1550      return data_flow_ops.stack_push_v2(s, 7)
1551
1552    with self.assertRaisesRegex(ValueError, "StackPushV2 not allowed.*"):
1553      pfor_control_flow_ops.pfor(loop_fn, iters=2)
1554
1555
1556# TODO(agarwal): test nested while_loops. This currently requires converting a
1557# tf.cond.
1558class WhileV1Test(PForTestCase):
1559
1560  def setUp(self):
1561    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
1562    control_flow_v2_toggles.disable_control_flow_v2()
1563    super(WhileV1Test, self).setUp()
1564
1565  def tearDown(self):
1566    if self._enabled:
1567      control_flow_v2_toggles.enable_control_flow_v2()
1568    super(WhileV1Test, self).tearDown()
1569
1570  def test_while_outside_loop(self):
1571
1572    x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1573
1574    def loop_fn(i):
1575      return x + i
1576
1577    self._test_loop_fn(loop_fn, 3)
1578
1579  @test_util.run_v1_only("b/122612051")
1580  def test_invariant_while(self):
1581
1582    def loop_fn(_):
1583      return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1584
1585    self._test_loop_fn(loop_fn, 3)
1586
1587  @test_util.run_v1_only("b/122612051")
1588  def test_invariant_while_with_control_dependency(self):
1589
1590    def loop_fn(i):
1591      with ops.control_dependencies([i]):
1592        return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1,
1593                                           [0])
1594
1595    self._test_loop_fn(loop_fn, 3)
1596
1597  @test_util.run_v1_only("b/122612051")
1598  def test_while_with_stateful_ops(self):
1599
1600    def loop_fn(_):
1601      return control_flow_ops.while_loop(
1602          lambda j, x: j < 4, lambda j, x:
1603          (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0]
1604
1605    self._test_loop_fn(loop_fn, 3)
1606
1607  @test_util.run_v1_only("b/122612051")
1608  def test_while_unstacked_condition(self):
1609
1610    def loop_fn(i):
1611      return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
1612                                         (j + 1, x + i), [0, 0])
1613
1614    self._test_loop_fn(loop_fn, 3)
1615
1616  @test_util.run_v1_only("b/122612051")
1617  def test_while(self):
1618    x = random_ops.random_uniform([3, 5])
1619    lengths = constant_op.constant([4, 0, 2])
1620
1621    def loop_fn(i):
1622      x_i = array_ops.gather(x, i)
1623      lengths_i = array_ops.gather(lengths, i)
1624
1625      _, total = control_flow_ops.while_loop(
1626          lambda j, _: j < lengths_i, lambda j, t:
1627          (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
1628      return total
1629
1630    self._test_loop_fn(loop_fn, 3)
1631
1632  @test_util.run_v1_only("b/122612051")
1633  def test_while_jacobian(self):
1634    x = random_ops.random_uniform([1, 3])
1635    y = random_ops.random_uniform([3, 3])
1636
1637    # out = x @ y @ y @ y @ y, where @ is matmul operator.
1638    _, out = control_flow_ops.while_loop(
1639        lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
1640        [0, x])
1641
1642    def loop_fn(i):
1643      out_i = array_ops.gather(out, i, axis=1)
1644      return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1])
1645
1646    out = pfor_control_flow_ops.pfor(loop_fn, iters=3)
1647
1648    # The above code does not work with tf.while_loop instead of pfor. So we
1649    # manually compute the expected output here.
1650    # Note that gradient of output w.r.t is (y @ y @ y @ y)^T.
1651    expected_output = y
1652    for _ in range(3):
1653      expected_output = math_ops.matmul(expected_output, y)
1654    expected_output = array_ops.transpose(expected_output, [1, 0])
1655
1656    with session.Session() as sess:
1657      out, expected = sess.run([out, expected_output])
1658      self.assertAllClose(expected, out)
1659
1660  @test_util.run_v1_only("b/122612051")
1661  def test_tensor_array_as_loop_variable(self):
1662
1663    def loop_fn(i):
1664
1665      def body(j, ta):
1666        ta = ta.write(j, i + j * j)
1667        return j + 1, ta
1668
1669      _, ta = control_flow_ops.while_loop(
1670          lambda j, _: j < 4, body,
1671          (0, tensor_array_ops.TensorArray(dtypes.int32, size=4)))
1672      return ta.stack()
1673
1674    self._test_loop_fn(loop_fn, 3)
1675
1676  @test_util.run_v1_only("b/122612051")
1677  def test_read_tensor_array_partitioned_indices(self):
1678    # Note that tensor array values are pfor loop dependent, and the while loop
1679    # termination condition is also dependent on pfor iteration.
1680    def loop_fn(i):
1681      ta = tensor_array_ops.TensorArray(dtypes.int32, size=6)
1682      ta = ta.unstack(i + list(range(5)))
1683
1684      def body(j, s):
1685        return j + 1, s + ta.read(j)
1686
1687      _, s = control_flow_ops.while_loop(lambda j, _: j < i, body, (0, 0))
1688      return s
1689
1690    self._test_loop_fn(loop_fn, 3)
1691
1692  @test_util.run_v1_only("b/122612051")
1693  def test_external_while_loop_grad(self):
1694    # Here we test that external while_loops that are extended from inside pfor
1695    # (due to gradient calls) are not actually converted. If the below was
1696    # converted all pfor iterations would write to the same tensor array
1697    # indices.
1698    x = constant_op.constant(1.)
1699
1700    def body(j, ta):
1701      ta = ta.write(j, x)
1702      return j + 1, ta
1703
1704    _, ta = control_flow_ops.while_loop(
1705        lambda j, _: j < 4, body,
1706        (0, tensor_array_ops.TensorArray(dtypes.float32, size=4)))
1707    out = ta.stack()
1708
1709    def loop_fn(i):
1710      out_i = array_ops.gather(out, i)
1711      return gradient_ops.gradients(out_i, x)[0]
1712
1713    with session.Session() as sess:
1714      # out is [x, x, x]. Hence the gradients should be [1, 1, 1].
1715      self.assertAllEqual([1, 1, 1],
1716                          sess.run(pfor_control_flow_ops.pfor(loop_fn, 3)))
1717
1718  @test_util.run_v1_only("b/122612051")
1719  def test_tensor_array_grad(self):
1720    inp = constant_op.constant(np.random.rand(3, 4, 2), dtype=dtypes.float32)
1721    ta = tensor_array_ops.TensorArray(dtypes.float32, size=3)
1722    ta = ta.unstack(inp)
1723
1724    def loop_fn(i):
1725
1726      def body(j, x):
1727        value = ta.gather([j])
1728        value = array_ops.gather(array_ops.reshape(value, [4, 2]), i)
1729        return j + 1, x + value
1730
1731      _, out = control_flow_ops.while_loop(lambda j, _: j < 3, body,
1732                                           (0, array_ops.zeros([2])))
1733      out = math_ops.reduce_prod(out)
1734      return out, gradient_ops.gradients(out, inp)[0]
1735
1736    pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4)
1737    # Note that tf.while_loop does not work in the setup above. So we manually
1738    # construct the equivalent computation of the above loops here.
1739    real_out = math_ops.reduce_sum(inp, axis=[0])
1740    real_out = math_ops.reduce_prod(real_out, axis=[1])
1741    # Note that gradients of real_out will accumulate the gradients across the
1742    # output value. Hence we do the same aggregation on pfor_out_grad.
1743    real_out_grad = gradient_ops.gradients(real_out, inp)[0]
1744    sum_pfor_out_grad = math_ops.reduce_sum(pfor_out_grad, axis=[0])
1745
1746    with session.Session() as sess:
1747      v1, v2, v1_grad, v2_grad = sess.run(
1748          [pfor_out, real_out, sum_pfor_out_grad, real_out_grad])
1749      self.assertAllClose(v1, v2)
1750      self.assertAllClose(v1_grad, v2_grad)
1751
1752
1753def dynamic_lstm_input_fn(batch_size, state_size, max_steps):
1754  # We make inputs and sequence_length constant so that multiple session.run
1755  # calls produce the same result.
1756  inputs = constant_op.constant(
1757      np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32)
1758  sequence_length = np.random.randint(0, size=[batch_size], high=max_steps + 1)
1759  sequence_length = constant_op.constant(sequence_length, dtype=dtypes.int32)
1760  return inputs, sequence_length
1761
1762
1763def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps):
1764  cell = cell_fn(state_size)
1765  inputs, sequence_length = dynamic_lstm_input_fn(batch_size, state_size,
1766                                                  max_steps)
1767  inputs_ta = tensor_array_ops.TensorArray(
1768      dtypes.float32, size=max_steps, element_shape=[batch_size, state_size])
1769  inputs_time_major = array_ops.transpose(inputs, [1, 0, 2])
1770  inputs_ta = inputs_ta.unstack(inputs_time_major)
1771  zeros = array_ops.zeros([state_size])
1772
1773  def loop_fn(i):
1774    sequence_length_i = array_ops.gather(sequence_length, i)
1775
1776    def body_fn(t, state, ta):
1777      inputs_t = array_ops.expand_dims(
1778          array_ops.gather(inputs_ta.read(t), i), 0)
1779      output, new_state = cell(inputs_t, state)
1780      output = array_ops.reshape(output, [-1])
1781      # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the
1782      # array_ops.where when t < min(sequence_length). Doing that requires
1783      # supporting tf.cond pfor conversion.
1784      done = t >= sequence_length_i
1785      output = array_ops.where(done, zeros, output)
1786      ta = ta.write(t, output)
1787      new_state = [
1788          array_ops.where(done, s, ns)
1789          for s, ns in zip(nest.flatten(state), nest.flatten(new_state))
1790      ]
1791      new_state = nest.pack_sequence_as(state, new_state)
1792      return t + 1, new_state, ta
1793
1794    def condition_fn(t, _, unused):
1795      del unused
1796      return t < max_steps
1797
1798    initial_state = cell.zero_state(1, dtypes.float32)
1799    _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [
1800        0, initial_state,
1801        tensor_array_ops.TensorArray(dtypes.float32, max_steps)
1802    ])
1803
1804    new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)]
1805    new_state = nest.pack_sequence_as(initial_state, new_state)
1806    return ta.stack(), new_state
1807
1808  pfor_output = pfor_control_flow_ops.pfor(loop_fn, batch_size)
1809  tf_output = rnn.dynamic_rnn(
1810      cell,
1811      inputs,
1812      sequence_length=sequence_length,
1813      initial_state=cell.zero_state(batch_size, dtypes.float32))
1814  return pfor_output, tf_output
1815
1816
1817@test_util.run_all_in_graph_and_eager_modes
1818class WhileV2Test(PForTestCase):
1819
1820  def setUp(self):
1821    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
1822    control_flow_v2_toggles.enable_control_flow_v2()
1823    super(WhileV2Test, self).setUp()
1824
1825  def tearDown(self):
1826    if not self._enabled:
1827      control_flow_v2_toggles.disable_control_flow_v2()
1828    super(WhileV2Test, self).tearDown()
1829
1830  def test_while_outside_loop(self):
1831
1832    def _f():
1833      return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1834
1835    def loop_fn(i):
1836      return _f() + i
1837
1838    self._test_loop_fn(loop_fn, 3)
1839
1840  def test_invariant_while(self):
1841
1842    def loop_fn(_):
1843      return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1844
1845    self._test_loop_fn(loop_fn, 3)
1846
1847  def test_invariant_while_with_control_dependency(self):
1848
1849    def loop_fn(i):
1850      with ops.control_dependencies([i]):
1851        return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1,
1852                                           [0])
1853
1854    self._test_loop_fn(loop_fn, 3)
1855
1856  def test_while_with_stateful_ops(self):
1857
1858    def loop_fn(_):
1859      j, _ = control_flow_ops.while_loop(
1860          lambda j, x: j < 4, lambda j, x:
1861          (j + 1, x + random_ops.random_uniform([])), [0, 0.])
1862      return j
1863
1864    self._test_loop_fn(loop_fn, 3)
1865
1866  def test_while_with_variable(self):
1867    v = resource_variable_ops.ResourceVariable(5.)
1868
1869    def loop_fn(_):
1870      _, output = control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
1871                                              (j + 1, x + v), [0, 0.])
1872      return output
1873
1874    self._test_loop_fn(loop_fn, 3)
1875
1876  def test_while_unstacked_condition(self):
1877
1878    def loop_fn(i):
1879      return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
1880                                         (j + 1, x + i), [0, 0])
1881
1882    self._test_loop_fn(loop_fn, 3)
1883
1884  def test_while(self):
1885    x = random_ops.random_uniform([3, 5])
1886    lengths = constant_op.constant([4, 0, 2])
1887
1888    def loop_fn(i):
1889      x_i = array_ops.gather(x, i)
1890      lengths_i = array_ops.gather(lengths, i)
1891
1892      return control_flow_ops.while_loop(
1893          lambda j, _: j < lengths_i, lambda j, t:
1894          (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
1895
1896    self._test_loop_fn(loop_fn, 3)
1897
1898  def test_while_change_input_invariance(self):
1899    # This tests cases where a loop invariant input to while has loop dependent
1900    # operations applied to it inside the while body.
1901    # It also test inputs that are passed through.
1902    def loop_fn(i):
1903      return control_flow_ops.while_loop(
1904          lambda j, *_: j < i, lambda j, x, y, z, w:
1905          (j + 1, x + i, y + x, z, w), [
1906              0,
1907              constant_op.constant(0),
1908              constant_op.constant(1), i,
1909              constant_op.constant(2)
1910          ])
1911
1912    self._test_loop_fn(loop_fn, 3)
1913
1914  def test_while_shape_invariants(self):
1915
1916    def loop_fn(i):
1917      return control_flow_ops.while_loop(
1918          lambda j, *_: j < 4,
1919          lambda j, x, y: (j + 1, x + i, y + 1),
1920          [0, constant_op.constant([0, 1]),
1921           constant_op.constant([2, 3])],
1922          shape_invariants=[
1923              None,
1924              tensor_shape.TensorShape([2]),
1925              tensor_shape.TensorShape([2])
1926          ])
1927
1928    self._test_loop_fn(loop_fn, 3)
1929
1930  def test_while_jacobian(self):
1931    # Note that we wrap the code below in a tf.function since we don't want the
1932    # while_loop call to be evaluated eagerly using a python loop.
1933    @def_function.function
1934    def _f(x, y, use_pfor):
1935      # out = x @ y @ y @ y @ y, where @ is matmul operator.
1936      _, out = control_flow_ops.while_loop(
1937          lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
1938          [0, x])
1939
1940      def loop_fn(i):
1941        out_i = array_ops.gather(out, i, axis=1)
1942        grad = gradient_ops.gradients(out_i, x)
1943        return array_ops.reshape(grad[0], [-1])
1944
1945      if use_pfor:
1946        return pfor_control_flow_ops.pfor(loop_fn, iters=3)
1947      else:
1948        return pfor_control_flow_ops.for_loop(
1949            loop_fn, iters=3, loop_fn_dtypes=out.dtype)
1950
1951    x = constant_op.constant(np.random.uniform(size=(1, 3)))
1952    y = constant_op.constant(np.random.uniform(size=(3, 3)))
1953    self.assertAllClose(_f(x, y, True), _f(x, y, False))
1954
1955  def test_scan(self):
1956    np.random.seed(seed=42)
1957    data = np.random.randn(3).astype(np.float32)
1958
1959    def log_prob(x):
1960      return math_ops.reduce_sum(functional_ops.scan_v2(
1961          lambda _, yi: (x - yi)**2,
1962          elems=data,
1963          initializer=constant_op.constant(0.)))
1964
1965    x = variables.Variable(array_ops.ones([2]))
1966    self.evaluate(x.initializer)
1967    v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x)
1968    theoretical, numerical = gradient_checker_v2.compute_gradient(
1969        v_log_prob, (x,), delta=1e-3)
1970    self.assertAllClose(theoretical, numerical, rtol=1e-2)
1971
1972  def test_scan_captured_variable(self):
1973    if not context.executing_eagerly():
1974      self.skipTest("Test only written for 2.x")
1975    v = variables.Variable(math_ops.range(10, dtype=dtypes.float32))
1976
1977    def loop_fn(idx):
1978      del idx
1979      return functional_ops.scan_v2(lambda _, i: array_ops.gather(v, i),
1980                                    elems=math_ops.range(v.shape[0]),
1981                                    initializer=0.0)
1982    with backprop.GradientTape() as tape:
1983      result = pfor_control_flow_ops.pfor(loop_fn, 2)
1984    self.assertAllClose([2.] * 10, tape.gradient(result, v))
1985
1986
1987@test_util.run_all_in_graph_and_eager_modes
1988class NestedControlFlowTest(PForTestCase):
1989
1990  def setUp(self):
1991    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
1992    control_flow_v2_toggles.enable_control_flow_v2()
1993    super(NestedControlFlowTest, self).setUp()
1994
1995  def tearDown(self):
1996    if not self._enabled:
1997      control_flow_v2_toggles.disable_control_flow_v2()
1998    super(NestedControlFlowTest, self).tearDown()
1999
2000  def _cond(self, f=None, split=0):
2001    if f is None:
2002      f = lambda x, y: (x, y)
2003
2004    def _f(x, y):
2005      return control_flow_ops.cond(y > split, lambda: f(x, y), lambda:
2006                                   (x + 1., y))
2007
2008    return _f
2009
2010  def _while(self, f=None):
2011    if f is None:
2012      f = lambda x, y: (x, y)
2013
2014    def _f(x, y):
2015      return control_flow_ops.while_loop(
2016          lambda j, _: j < y, lambda j, t:
2017          (j + 1, t + array_ops.gather(f(x, y)[0], j)), [0, x])[1], y
2018
2019    return _f
2020
2021  def _test_helper(self, f):
2022    x = random_ops.random_uniform([5, 5])
2023    y = constant_op.constant([4, -1, 2, -2, 2])
2024
2025    def loop_fn(i):
2026      x_i = array_ops.gather(x, i)
2027      y_i = array_ops.gather(y, i)
2028      return f(x_i, y_i)
2029
2030    self._test_loop_fn(loop_fn, 5)
2031
2032  def test_cond_while(self):
2033    self._test_helper(self._cond(self._while()))
2034
2035  def test_while_cond(self):
2036    self._test_helper(self._while(self._cond()))
2037
2038  def test_while_while(self):
2039    self._test_helper(self._while(self._while()))
2040
2041  def test_cond_cond(self):
2042    self._test_helper(self._cond(self._cond()))
2043
2044
2045@test_util.run_all_in_graph_and_eager_modes
2046@test_util.with_control_flow_v2
2047class StatelessIfTest(PForTestCase):
2048
2049  def test_loop_variant_cond(self):
2050    x = [1, 2, 3, 4, 5.]
2051    y = 2.5
2052
2053    @def_function.function
2054    def loop_fn(i):
2055      x_i = array_ops.gather(x, i)
2056      # Note that the output has a combination of then and else branches being
2057      # loop variant / invariant.
2058      return cond_v2.cond_v2(x_i < y, lambda: (y - x_i, y, 1., 2.), lambda:
2059                             (x_i - y, 0., y, 3.))
2060
2061    self._test_loop_fn(loop_fn, iters=5)
2062
2063  def test_loop_invariant_cond(self):
2064    x = [1, 2, 3, 4, 5.]
2065    y = 0.5
2066    z = random_ops.random_uniform([])
2067
2068    @def_function.function
2069    def loop_fn(i):
2070      x_i = array_ops.gather(x, i)
2071      # Note that the output has a combination of then and else branches being
2072      # loop variant / invariant.
2073      return cond_v2.cond_v2(z < y, lambda: (y - x_i, y, 1., 2.), lambda:
2074                             (x_i - y, 0., y, 3.))
2075
2076    self._test_loop_fn(loop_fn, iters=5)
2077
2078  def test_empty_branch(self):
2079    x = [1, 2, 3, 4, 5.]
2080    y = 6.
2081
2082    @def_function.function
2083    def loop_fn(i):
2084      x_i = array_ops.gather(x, i)
2085      return cond_v2.cond_v2(
2086          x_i < y,  # Note that else branch is empty.
2087          lambda: (y - x_i, y, 1., 2.),
2088          lambda: (x_i - y, 0., y, 3.))
2089
2090    self._test_loop_fn(loop_fn, iters=5)
2091
2092
2093@test_util.run_all_in_graph_and_eager_modes
2094@test_util.with_control_flow_v2
2095class IfTest(PForTestCase):
2096
2097  def test_read_var(self):
2098    self.skipTest("b/156438918")  # Flaky
2099
2100    x = [1, 2, 3, 4, 5.]
2101    y = 2.5
2102    z = resource_variable_ops.ResourceVariable(5.)
2103
2104    @def_function.function
2105    def loop_fn(i):
2106      x_i = array_ops.gather(x, i)
2107      return cond_v2.cond_v2(x_i < y, lambda: z - x_i, lambda: z + x_i)
2108
2109    self._test_loop_fn(loop_fn, iters=5)
2110
2111
2112class RNNTest(PForTestCase):
2113
2114  @test_util.run_v1_only("b/122612051")
2115  def test_dynamic_rnn(self):
2116    pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 3, 5,
2117                                                   7)
2118    self.run_and_assert_equal(pfor_outputs, tf_outputs)
2119
2120  @test_util.run_v1_only("b/122612051")
2121  def test_dynamic_lstm(self):
2122    pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, 3, 5,
2123                                                   7)
2124    self.run_and_assert_equal(pfor_outputs, tf_outputs)
2125
2126
2127# TODO(agarwal): benchmark numbers on GPU for graphs based on while_loop
2128# conversion don't look good. Some of it seems like lot of copies between host
2129# and device. Optimize that.
2130class Benchmarks(test.Benchmark):
2131
2132  def _run(self, targets, iters, name=None):
2133
2134    def _done(t):
2135      # Note that we don't use tf.control_dependencies since that will not make
2136      # sure that the computation on GPU has actually finished. So we fetch the
2137      # first element of the output, and assume that this will not be called on
2138      # empty tensors.
2139      return array_ops.gather(array_ops.reshape(t, [-1]), 0)
2140
2141    targets = [_done(x) for x in nest.flatten(targets)]
2142    sess = session.Session()
2143    with sess:
2144      init = variables.global_variables_initializer()
2145      sess.run(init)
2146      run_fn = sess.make_callable(targets)
2147      run_fn()  # Warm up
2148      begin = time.time()
2149      for _ in range(iters):
2150        run_fn()
2151      end = time.time()
2152    avg_time_ms = 1000 * (end - begin) / iters
2153    self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name)
2154    return avg_time_ms
2155
2156  def benchmark_sess_run_overhead(self):
2157    with ops.Graph().as_default():
2158      x = constant_op.constant(1.0)
2159      self._run(x, 10000, name="session_run_overhead")
2160
2161  def benchmark_add(self):
2162    with ops.Graph().as_default():
2163      n = 256
2164      params = 1000
2165      x = random_ops.random_normal([n, params])
2166      y = random_ops.random_normal([n, params])
2167
2168      def loop_fn(i):
2169        x_i = array_ops.gather(x, i)
2170        y_i = array_ops.gather(y, i)
2171        return x_i + y_i
2172
2173      pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n)
2174      while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n)
2175      manual = x + y
2176
2177      self._run(manual, 1000, name="manual_add")
2178      self._run(pfor_outputs, 1000, name="pfor_add")
2179      self._run(while_outputs, 100, name="while_add")
2180
2181  def benchmark_matmul(self):
2182    with ops.Graph().as_default():
2183      n = 1024
2184      params = 1000
2185      x = random_ops.random_normal([n, params])
2186      y = random_ops.random_normal([params, params])
2187
2188      def loop_fn(i):
2189        x_i = array_ops.expand_dims(array_ops.gather(x, i), 0)
2190        return math_ops.matmul(x_i, y)
2191
2192      pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n)
2193      while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n)
2194      manual = math_ops.matmul(x, y)
2195
2196      self._run(manual, 1000, name="manual_matmul")
2197      self._run(pfor_outputs, 1000, name="pfor_matmul")
2198      self._run(while_outputs, 100, name="while_matmul")
2199
2200  def benchmark_map_fn(self):
2201    with ops.Graph().as_default():
2202      b = 256
2203      params = 1000
2204      inp = random_ops.random_normal((b, params))
2205      fn = lambda x: x * x
2206
2207      def pfor_map_fn(f, x):
2208        return pfor_control_flow_ops.pfor(lambda i: f(array_ops.gather(x, i)),
2209                                          array_ops.shape(x)[0])
2210
2211      map_output = map_fn.map_fn(fn, inp)
2212      pfor_output = pfor_map_fn(fn, inp)
2213
2214      self._run(map_output, 100, name="tf_map_fn")
2215      self._run(pfor_output, 100, name="pfor_map_fn")
2216
2217  def benchmark_basic_while(self):
2218    with ops.Graph().as_default():
2219
2220      def loop_fn(i):
2221        _, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x:
2222                                           (t + 1, x + i), [0, 0])
2223        return s
2224
2225      iters = 50
2226      pfor_output = pfor_control_flow_ops.pfor(loop_fn, iters)
2227      for_loop_output = pfor_control_flow_ops.for_loop(loop_fn, dtypes.int32,
2228                                                       iters)
2229      self._run(pfor_output, 100, name="pfor_basic")
2230      self._run(for_loop_output, 100, name="for_loop_basic")
2231
2232  def benchmark_dynamic_rnn(self):
2233    with ops.Graph().as_default():
2234      pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 128,
2235                                                     512, 16)
2236      self._run(pfor_outputs, 100, name="pfor_rnn")
2237      self._run(tf_outputs, 100, name="tf_rnn")
2238
2239  def benchmark_reduction(self):
2240    n = 1024
2241    with ops.Graph().as_default():
2242      x = random_ops.random_uniform([n, n])
2243      w = random_ops.random_uniform([n, n])
2244
2245      def loop_fn(i, pfor_config):
2246        x_i = array_ops.gather(x, i)
2247        return math_ops.reduce_sum(
2248            math_ops.matmul(pfor_config.reduce_concat(x_i), w))
2249
2250      # Note that output_reduction will be tiled, so there may be some minor
2251      # overheads compared to output_no_reduction.
2252      output_reduction = pfor_control_flow_ops.pfor(loop_fn, n)
2253      output_no_reduction = math_ops.reduce_sum(math_ops.matmul(x, w))
2254      # Benchmark to test that reduction does not add overhead and its output is
2255      # treated as loop invariant.
2256      self._run(output_reduction, 30, name="matmul_reduction")
2257      self._run(output_no_reduction, 30, name="matmul_no_reduction")
2258
2259
2260class SparseTest(PForTestCase):
2261
2262  @test_util.run_v1_only("b/122612051")
2263  def test_var_loop_len(self):
2264    num_iters = array_ops.placeholder(dtypes.int32)
2265
2266    def loop_fn(_):
2267      return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6],
2268                                        [3])  # [0, 2, 0]
2269
2270    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2271    with self.cached_session() as sess:
2272      sess.run(pfor, feed_dict={num_iters: 3})
2273
2274  @test_util.run_v1_only("b/122612051")
2275  def test_sparse_result_none_stacked(self):
2276    num_iters = 10
2277
2278    def loop_fn(_):
2279      return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6],
2280                                        [3])  # [0, 2, 0]
2281
2282    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2283
2284    indices = [[i, j] for i in range(num_iters) for j in range(3)]
2285    values = [4, 5, 6] * num_iters
2286    dense_shapes = [num_iters, 3]
2287    # Expected result: [[4, 5, 6], [4, 5, 6], [4, 5, 6], ...]
2288    manual = sparse_tensor.SparseTensor(indices, values, dense_shapes)
2289    self.run_and_assert_equal(pfor, manual)
2290
2291  @test_util.run_v1_only("b/122612051")
2292  def test_sparse_result_all_stacked(self):
2293    num_iters = 10
2294
2295    def loop_fn(i):
2296      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2297      indices = array_ops.expand_dims(i, 0)
2298      return sparse_tensor.SparseTensor(indices, i, i + 1)  # [0, ..., 0, i]
2299
2300    # Expected result: [[0], [0, 1], [0, 0, 2], [0, 0, 0, 3], ...]
2301    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2302    manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)],
2303                                        list(range(num_iters)),
2304                                        (num_iters, num_iters))
2305    self.run_and_assert_equal(pfor, manual)
2306
2307  @test_util.run_v1_only("b/122612051")
2308  def test_sparse_result_indices_stacked(self):
2309    num_iters = 10
2310
2311    def loop_fn(i):
2312      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2313      indices = array_ops.expand_dims(i, 0)
2314      return sparse_tensor.SparseTensor(indices, [1], [num_iters])
2315
2316    # Expected result: identity matrix size num_iters * num_iters
2317    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2318    manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)],
2319                                        [1] * num_iters, (num_iters, num_iters))
2320    self.run_and_assert_equal(pfor, manual)
2321
2322  @test_util.run_v1_only("b/122612051")
2323  def test_sparse_result_values_stacked(self):
2324    num_iters = 10
2325
2326    def loop_fn(i):
2327      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2328      return sparse_tensor.SparseTensor([[0]], i, [num_iters])  # [i, 0, ..., 0]
2329
2330    # Expected result: [[1, 0, ...], [2, 0, ...], [3, 0, ...], ...]
2331    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2332    manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)],
2333                                        list(range(num_iters)),
2334                                        (num_iters, num_iters))
2335    self.run_and_assert_equal(pfor, manual)
2336
2337  @test_util.run_v1_only("b/122612051")
2338  def test_sparse_result_shapes_stacked(self):
2339    num_iters = 10
2340
2341    def loop_fn(i):
2342      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2343      return sparse_tensor.SparseTensor([[0]], [1], i + 1)  # [1, 0, ..., 0]
2344
2345    # Expected result: [[1, 0, 0, ...], [1, 0, 0, ...], ...]
2346    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2347    manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)],
2348                                        [1] * num_iters, (num_iters, num_iters))
2349    self.run_and_assert_equal(pfor, manual)
2350
2351  @test_util.run_v1_only("b/122612051")
2352  def test_sparse_result_shapes_stacked_2D(self):
2353    num_iters = 10
2354
2355    def loop_fn(i):
2356      i = array_ops.expand_dims(math_ops.cast(i + 1, dtypes.int64), 0)
2357      shape = array_ops.concat([i, i], 0)
2358      return sparse_tensor.SparseTensor([[0, 0]], [1], shape)  # [1, 0, ..., 0]
2359
2360    # Expected result: [[[1, 0, ...], [0, ..., 0], [0, ..., 0], ...], ...]
2361    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2362    manual = sparse_tensor.SparseTensor([[i, 0, 0] for i in range(num_iters)],
2363                                        [1] * num_iters,
2364                                        (num_iters, num_iters, num_iters))
2365    self.run_and_assert_equal(pfor, manual)
2366
2367
2368# Dummy CompositeTensor to test CompositeTensor support.
2369class Particle(composite_tensor.CompositeTensor):
2370  """A (batch of) particles each defined by a mass and a scalar velocity."""
2371
2372  def __init__(self, mass, velocity):
2373    mass = ops.convert_to_tensor(mass)
2374    velocity = ops.convert_to_tensor(velocity)
2375    self.shape = array_ops.broadcast_static_shape(mass.shape, velocity.shape)
2376    self.mass = mass
2377    self.velocity = velocity
2378
2379  @property
2380  def _type_spec(self):
2381    return ParticleSpec(
2382        type_spec.type_spec_from_value(self.mass),
2383        type_spec.type_spec_from_value(self.velocity))
2384
2385
2386class ParticleSpec(type_spec.BatchableTypeSpec):
2387
2388  def __init__(self, mass, velocity):
2389    self.shape = array_ops.broadcast_static_shape(
2390        mass.shape, velocity.shape)
2391    self.mass = mass
2392    self.velocity = velocity
2393
2394  def _serialize(self):
2395    return (self.mass, self.velocity)
2396
2397  @property
2398  def value_type(self):
2399    return Particle
2400
2401  @property
2402  def _component_specs(self):
2403    return (self.mass, self.velocity)
2404
2405  def _to_components(self, value):
2406    return (value.mass, value.velocity)
2407
2408  def _from_components(self, components):
2409    return Particle(*components)
2410
2411  def _pad_shape_to_full_rank(self, s):
2412    """Pad component shapes with 1's so all components have the same rank."""
2413    return tensor_shape.TensorShape(
2414        [1] * (self.shape.ndims - s.ndims)).concatenate(s)
2415
2416  def _batch(self, batch_size):
2417    return ParticleSpec(
2418        mass=tensor_spec.TensorSpec(
2419            dtype=self.mass.dtype,
2420            shape=tensor_shape.TensorShape([batch_size]).concatenate(
2421                self._pad_shape_to_full_rank(self.mass.shape))),
2422        velocity=tensor_spec.TensorSpec(
2423            dtype=self.velocity.dtype,
2424            shape=tensor_shape.TensorShape([batch_size]).concatenate(
2425                self._pad_shape_to_full_rank(self.velocity.shape))))
2426
2427  def _unbatch(self):
2428    return ParticleSpec(
2429                tensor_spec.TensorSpec(dtype=self.mass.dtype,
2430                                       shape=self.mass.shape[1:]),
2431                tensor_spec.TensorSpec(dtype=self.velocity.dtype,
2432                                       shape=self.velocity.shape[1:]))
2433
2434  def _to_tensor_list(self, value):
2435    return [array_ops.reshape(
2436                value.mass,
2437                self._pad_shape_to_full_rank(value.mass.shape)),
2438            array_ops.reshape(
2439                value.velocity,
2440                self._pad_shape_to_full_rank(value.velocity.shape))]
2441
2442
2443class CompositeTensorTest(PForTestCase, parameterized.TestCase):
2444
2445  @parameterized.parameters((None,), (3,))
2446  def test_create_composite_inside_loop(self, parallel_iterations):
2447    num_particles = 10
2448    velocities = random_ops.random_uniform([num_particles])
2449    particles = pfor_control_flow_ops.pfor(
2450        # Build a batch of particles all with the same mass.
2451        lambda i: Particle(mass=4., velocity=array_ops.gather(velocities, i)),
2452        num_particles,
2453        parallel_iterations=parallel_iterations)
2454    particles_mass, particles_velocity, velocities = self.evaluate(
2455        (particles.mass, particles.velocity, velocities))
2456    self.assertAllEqual(particles_mass, 4. * np.ones([num_particles]))
2457    self.assertAllEqual(particles_velocity, velocities)
2458
2459  @parameterized.parameters((None,), (3,))
2460  def test_composite_is_converted_to_batched_tensor(
2461      self, parallel_iterations):
2462    particles = pfor_control_flow_ops.pfor(
2463        lambda _: Particle(mass=random_ops.random_uniform([3]),  # pylint: disable=g-long-lambda
2464                           velocity=random_ops.random_uniform([5, 3])),
2465        4,
2466        parallel_iterations=parallel_iterations)
2467    # Naively batching the component shapes would give `[4, 3]` and `[4, 5, 3]`
2468    # which have no consistent broadcast shape.
2469    self.assertTrue(particles.mass.shape, [4, 1, 3])
2470    self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
2471
2472  def test_vectorized_map_gathers_composite_tensors(self):
2473    particles = Particle(mass=[1., 2., 3., 4., 5.],
2474                         velocity=[1., 2., 3., 4., 5.])
2475    self.assertAllEqual(
2476        pfor_control_flow_ops.vectorized_map(
2477            lambda x: x.mass * x.velocity, particles),
2478        particles.mass * particles.velocity)
2479
2480  def test_vectorized_map_of_ragged_tensors(self):
2481    # Vmap should be able to handle ragged Tensors as long as they're not
2482    # *actually* ragged.
2483    ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
2484        ragged_tensor.RaggedTensor.from_row_lengths(
2485            values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
2486            row_lengths=[3, 3, 3, 3]),
2487        uniform_row_length=2)  # Overall shape [2, 2, 3].
2488    self.assertAllEqual(
2489        pfor_control_flow_ops.vectorized_map(
2490            lambda x: x.to_tensor(shape=[2, 3]), ragged),
2491        ragged.to_tensor(shape=[2, 2, 3]))
2492
2493
2494class ParsingTest(PForTestCase):
2495
2496  def test_decode_csv(self):
2497    csv_tensor = constant_op.constant([["1:2:3"], ["::"], ["7:8:9"]])
2498    kwargs = {"record_defaults": [[10], [20], [30]], "field_delim": ":"}
2499
2500    def loop_fn(i):
2501      line = array_ops.gather(csv_tensor, i)
2502      return parsing_ops.decode_csv(line, **kwargs)
2503
2504    self._test_loop_fn(loop_fn, iters=3)
2505
2506  @test_util.run_v1_only("b/122612051")
2507  def test_parse_single_example(self):
2508
2509    def _int64_feature(*values):
2510      return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values))
2511
2512    def _bytes_feature(*values):
2513      return feature_pb2.Feature(
2514          bytes_list=feature_pb2.BytesList(
2515              value=[v.encode("utf-8") for v in values]))
2516
2517    examples = constant_op.constant([
2518        example_pb2.Example(
2519            features=feature_pb2.Features(
2520                feature={
2521                    "dense_int": _int64_feature(i),
2522                    "dense_str": _bytes_feature(str(i)),
2523                    "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8),
2524                    "sparse_str": _bytes_feature(*["abc"] * i)
2525                })).SerializeToString() for i in range(10)
2526    ])
2527
2528    features = {
2529        "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
2530        "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
2531        "sparse_int": parsing_ops.VarLenFeature(dtypes.int64),
2532        "sparse_str": parsing_ops.VarLenFeature(dtypes.string),
2533    }
2534
2535    def loop_fn(i):
2536      example_proto = array_ops.gather(examples, i)
2537      f = parsing_ops.parse_single_example(example_proto, features)
2538      return f
2539
2540    pfor = pfor_control_flow_ops.pfor(loop_fn, iters=10)
2541    manual = parsing_ops.parse_example(examples, features)
2542    self.run_and_assert_equal(pfor, manual)
2543
2544
2545class PartitionedCallTest(PForTestCase):
2546
2547  def test_simple(self):
2548
2549    @def_function.function
2550    def f(x):
2551      return math_ops.square(x) + 1
2552
2553    z = random_ops.random_uniform([4])
2554
2555    def loop_fn(i):
2556      return f(array_ops.gather(z, i))
2557
2558    self._test_loop_fn(loop_fn, 4)
2559
2560  def test_nested_calls(self):
2561
2562    @def_function.function
2563    def inner(x):
2564      return math_ops.square(x)
2565
2566    @def_function.function
2567    def outer(y):
2568      return math_ops.reduce_sum(inner(y)) + 2
2569
2570    z = random_ops.random_uniform([4, 2])
2571
2572    def loop_fn(i):
2573      return outer(array_ops.gather(z, i))
2574
2575    self._test_loop_fn(loop_fn, 4)
2576
2577  def test_nested_definition(self):
2578
2579    @def_function.function
2580    def outer(y):
2581
2582      @def_function.function
2583      def inner(x):
2584        return math_ops.square(x) + 1
2585
2586      return math_ops.reduce_sum(inner(y)) + 2
2587
2588    z = random_ops.random_uniform([4, 2])
2589
2590    def loop_fn(i):
2591      return outer(array_ops.gather(z, i))
2592
2593    self._test_loop_fn(loop_fn, 4)
2594
2595  def test_gradients(self):
2596
2597    @def_function.function
2598    def f(x):
2599      return math_ops.square(x) + 1
2600
2601    z = random_ops.random_uniform([4, 2])
2602
2603    def loop_fn(i):
2604      z_i = array_ops.gather(z, i)
2605      with backprop.GradientTape() as g:
2606        g.watch(z_i)
2607        out = f(z_i)
2608      return out, g.gradient(out, z_i)
2609
2610    self._test_loop_fn(loop_fn, 4)
2611
2612  def test_stateful_with_gradients(self):
2613
2614    z = random_ops.random_uniform([4, 2])
2615    v = variables.Variable(z[0])
2616
2617    @def_function.function
2618    def f(x):
2619      return math_ops.square(x) + v + 1
2620
2621    def loop_fn(i):
2622      z_i = array_ops.gather(z, i)
2623      with backprop.GradientTape() as g:
2624        g.watch(z_i)
2625        out = f(z_i)
2626      return out, g.gradient(out, z_i)
2627
2628    self._test_loop_fn(loop_fn, 4)
2629
2630
2631class SpectralTest(PForTestCase, parameterized.TestCase):
2632
2633  @parameterized.parameters(
2634      (fft_ops.fft,),
2635      (fft_ops.fft2d,),
2636      (fft_ops.fft3d,),
2637      (fft_ops.ifft,),
2638      (fft_ops.ifft2d,),
2639      (fft_ops.ifft3d,),
2640  )
2641  def test_fft(self, op_func):
2642    shape = [2, 3, 4, 3, 4]
2643    x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
2644
2645    def loop_fn(i):
2646      x_i = array_ops.gather(x, i)
2647      return op_func(x_i)
2648
2649    self._test_loop_fn(loop_fn, 2)
2650
2651  @parameterized.parameters(
2652      (fft_ops.rfft,),
2653      (fft_ops.rfft2d,),
2654      (fft_ops.rfft3d,),
2655  )
2656  @test.disable_with_predicate(
2657      pred=test.is_built_with_rocm,
2658      skip_message="Disable subtest on ROCm due to rocfft issues")
2659  def test_rfft(self, op_func):
2660    for dtype in (dtypes.float32, dtypes.float64):
2661      x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype)
2662
2663      # pylint: disable=cell-var-from-loop
2664      def loop_fn(i):
2665        x_i = array_ops.gather(x, i)
2666        return op_func(x_i)
2667
2668      # pylint: enable=cell-var-from-loop
2669
2670      self._test_loop_fn(loop_fn, 2)
2671
2672  @parameterized.parameters(
2673      (fft_ops.irfft,),
2674      (fft_ops.irfft2d,),
2675      (fft_ops.irfft3d,),
2676  )
2677  @test.disable_with_predicate(
2678      pred=test.is_built_with_rocm,
2679      skip_message="Disable subtest on ROCm due to rocfft issues")
2680  def test_irfft(self, op_func):
2681    if config.list_physical_devices("GPU"):
2682      # TODO(b/149957923): The test is flaky
2683      self.skipTest("b/149957923: irfft vectorization flaky")
2684    for dtype in (dtypes.complex64, dtypes.complex128):
2685      shape = [2, 3, 4, 3, 4]
2686      x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
2687      x = math_ops.cast(x, dtype=dtype)
2688
2689      # pylint: disable=cell-var-from-loop
2690      def loop_fn(i):
2691        x_i = array_ops.gather(x, i)
2692        return op_func(x_i)
2693
2694      # pylint: enable=cell-var-from-loop
2695
2696      self._test_loop_fn(loop_fn, 2)
2697
2698
2699class VariableTest(PForTestCase):
2700
2701  def test_create_variable_once(self):
2702    x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
2703    y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
2704    a_var = []
2705
2706    def f(z):
2707      if not a_var:
2708        a_var.append(variables.Variable(lambda: y, name="a"))
2709      return math_ops.matmul(z, a_var[0] / 16)
2710
2711    pfor_control_flow_ops.vectorized_map(f, x)
2712
2713  @test_util.run_v2_only
2714  def test_create_variable_repeated(self):
2715    x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
2716    y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
2717
2718    def f(z):
2719      a_var = variables.Variable(lambda: y, name="a") / 4
2720      return math_ops.matmul(z, a_var / 16)
2721
2722    # Note that this error is only raised under v2 behavior.
2723    with self.assertRaisesRegex(
2724        ValueError,
2725        "tf.function-decorated function tried to create variables on non-first"
2726    ):
2727      pfor_control_flow_ops.vectorized_map(f, x)
2728
2729  @test_util.run_all_in_graph_and_eager_modes
2730  def test_variable_shape(self):
2731    v = resource_variable_ops.ResourceVariable([1, 2])
2732
2733    def loop_fn(_):
2734      return resource_variable_ops.variable_shape(v.handle)
2735
2736    self._test_loop_fn(loop_fn, 2)
2737
2738
2739if __name__ == "__main__":
2740  test.main()
2741