• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for pfor and for_loop."""
16# pylint: disable=g-direct-tensorflow-import
17
18import functools
19import sys
20import time
21
22from absl.testing import parameterized
23import numpy as np
24
25from google.protobuf import text_format
26
27from tensorflow.core.example import example_pb2
28from tensorflow.core.example import feature_pb2
29from tensorflow.core.framework import graph_pb2
30from tensorflow.python.client import session
31from tensorflow.python.eager import backprop
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import composite_tensor
35from tensorflow.python.framework import config
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import importer
39from tensorflow.python.framework import indexed_slices
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import sparse_tensor
42from tensorflow.python.framework import tensor_shape
43from tensorflow.python.framework import tensor_spec
44from tensorflow.python.framework import test_util
45from tensorflow.python.framework import type_spec
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import bitwise_ops
48from tensorflow.python.ops import cond_v2
49from tensorflow.python.ops import control_flow_ops
50from tensorflow.python.ops import control_flow_v2_toggles
51from tensorflow.python.ops import data_flow_ops
52from tensorflow.python.ops import functional_ops
53from tensorflow.python.ops import gen_dataset_ops
54from tensorflow.python.ops import gen_list_ops
55from tensorflow.python.ops import gen_nn_ops
56from tensorflow.python.ops import gradient_checker_v2
57from tensorflow.python.ops import gradients as gradient_ops
58from tensorflow.python.ops import image_ops
59from tensorflow.python.ops import list_ops
60from tensorflow.python.ops import logging_ops
61from tensorflow.python.ops import manip_ops
62from tensorflow.python.ops import map_fn
63from tensorflow.python.ops import math_ops
64from tensorflow.python.ops import nn
65from tensorflow.python.ops import parsing_ops
66from tensorflow.python.ops import random_ops
67from tensorflow.python.ops import resource_variable_ops
68from tensorflow.python.ops import rnn
69from tensorflow.python.ops import rnn_cell
70from tensorflow.python.ops import stateless_random_ops
71from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
72from tensorflow.python.ops import tensor_array_ops
73from tensorflow.python.ops import variables
74from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
75from tensorflow.python.ops.parallel_for.test_util import PForTestCase
76from tensorflow.python.ops.ragged import ragged_tensor
77from tensorflow.python.ops.signal import fft_ops
78from tensorflow.python.platform import test
79from tensorflow.python.util import nest
80
81
82@test_util.run_all_in_graph_and_eager_modes
83@test_util.with_control_flow_v2
84class PForTest(PForTestCase):
85
86  def test_op_conversion_fallback_to_while_loop(self):
87    # Note that we used top_k op for this test. If a converter gets defined for
88    # it, we will need to find another op for which a converter has not been
89    # defined.
90    x = random_ops.random_uniform([3, 2, 4])
91
92    def loop_fn(i):
93      x_i = array_ops.gather(x, i)
94      return nn.top_k(x_i)
95
96    with self.assertRaisesRegex(ValueError, "No pfor vectorization"):
97      self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
98    self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
99
100  def test_parallel_iterations(self):
101    for parallel_iterations in [2, 3, 8, 10]:
102      x = random_ops.random_uniform([8, 3])
103
104      # pylint: disable=cell-var-from-loop
105      def loop_fn(i):
106        return array_ops.gather(x, i)
107
108      # pylint: enable=cell-var-from-loop
109
110      self._test_loop_fn(loop_fn, 8, parallel_iterations=parallel_iterations)
111      self._test_loop_fn(
112          loop_fn,
113          4 * constant_op.constant(2),
114          parallel_iterations=parallel_iterations)
115
116  def test_parallel_iterations_preserves_static_shape(self):
117    for parallel_iterations in [2, 3, 8, 10]:
118      x = pfor_control_flow_ops.pfor(
119          lambda _: random_ops.random_uniform([2, 3]),
120          8,
121          parallel_iterations=parallel_iterations)
122      self.assertAllEqual(x.shape, [8, 2, 3])
123
124  def test_parallel_iterations_zero(self):
125    with self.assertRaisesRegex(ValueError, "positive integer"):
126      pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=0)
127    with self.assertRaisesRegex(TypeError, "positive integer"):
128      pfor_control_flow_ops.for_loop(
129          lambda i: 1, dtypes.int32, 8, parallel_iterations=0)
130
131  def test_parallel_iterations_one(self):
132    with self.assertRaisesRegex(ValueError, "Use `for_loop` instead"):
133      pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1)
134
135  def test_vectorized_map(self):
136
137    def compute(x):
138      return math_ops.reduce_mean(x, axis=0, keepdims=True)
139
140    result = pfor_control_flow_ops.vectorized_map(compute,
141                                                  array_ops.ones((10, 5, 3)))
142    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
143
144  def test_vectorized_map_with_dynamic_shape(self):
145
146    def compute(x):
147      return math_ops.reduce_mean(x, axis=0, keepdims=True)
148
149    x = array_ops.placeholder_with_default(
150        array_ops.ones((10, 5, 3)), shape=None)
151    result = pfor_control_flow_ops.vectorized_map(compute, x)
152    self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
153
154  def test_where_shape(self):
155    @def_function.function
156    def f():
157      a = constant_op.constant([[1.], [1.]])
158      b = constant_op.constant([1.])
159      result = pfor_control_flow_ops.vectorized_map(
160          lambda x: array_ops.where(x > 0, x, b), a)
161      return result.shape
162
163    self.assertAllEqual([2, 1], f())
164
165  def test_vectorized_map_broadcasts_unit_dimensions(self):
166    convert_with_static_shape = ops.convert_to_tensor
167    convert_with_dynamic_shape = (
168        lambda x: array_ops.placeholder_with_default(x, shape=None))
169
170    for convert in (convert_with_static_shape, convert_with_dynamic_shape):
171      a = convert([3.1])
172      b = convert([-2., 6., 9.])
173
174      # One elem with leading unit dimension.
175      a_plus_1 = pfor_control_flow_ops.vectorized_map(lambda a: a + 1, a)
176      self.assertAllEqual(*self.evaluate((a_plus_1, a + 1)))
177
178      # Two elems, both with leading unit dimension.
179      a_plus_a = pfor_control_flow_ops.vectorized_map(sum, (a, a))
180      self.assertAllEqual(*self.evaluate((a_plus_a, a + a)))
181
182      # Elem w/ unit dimension broadcast against elem with batch dim.
183      a_plus_b = pfor_control_flow_ops.vectorized_map(sum, (a, b))
184      self.assertAllEqual(*self.evaluate((a_plus_b, a + b)))
185
186  def test_vectorized_map_example_1(self):
187
188    def outer_product(a):
189      return math_ops.tensordot(a, a, 0)
190
191    batch_size = 100
192    a = array_ops.ones((batch_size, 32, 32))
193    c = pfor_control_flow_ops.vectorized_map(outer_product, a)
194    self.assertAllEqual((batch_size, 32, 32, 32, 32), c.shape)
195
196  def test_disable_tf_function(self):
197    def_function.run_functions_eagerly(True)
198    # vectorized_map should ignore disabling tf.functions
199    self.assertTrue(def_function.functions_run_eagerly())
200    self.assertAllEqual([0, 1, 4, 9],
201                        pfor_control_flow_ops.vectorized_map(
202                            lambda x: x * x, math_ops.range(4)))
203    self.assertTrue(def_function.functions_run_eagerly())
204    def_function.run_functions_eagerly(False)
205
206
207@test_util.run_all_in_graph_and_eager_modes
208class IndexedSlicesTest(PForTestCase):
209
210  def test_indexed_slices(self):
211
212    def loop_fn(i):
213      return indexed_slices.IndexedSlices(
214          indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1])
215
216    self._test_loop_fn(loop_fn, 2)
217
218  def test_indexed_slices_components(self):
219
220    def loop_fn(i):
221      slices = indexed_slices.IndexedSlices(
222          indices=i, values=array_ops.reshape(i, [1]), dense_shape=[3, 1])
223      # Note that returning the components inside the slice avoids
224      # densification, which may be more efficient.
225      return slices.values, slices.indices
226
227    self._test_loop_fn(loop_fn, 2)
228
229
230@test_util.run_all_in_graph_and_eager_modes
231class ReductionTest(PForTestCase):
232
233  def test_reduce(self):
234
235    def reduce_fn(p, q):
236      return math_ops.reduce_mean(p + q, axis=0)
237
238    x = random_ops.random_uniform([4, 3, 2])
239    y = random_ops.random_uniform([4, 3, 2])
240
241    def loop_fn(i, pfor_config):
242      x_i = array_ops.gather(x, i)
243      y_i = array_ops.gather(y, i)
244      reduced = pfor_config.reduce(reduce_fn, x_i, y_i)
245      return reduced + x_i
246
247    output = pfor_control_flow_ops.pfor(loop_fn, 4)
248    ans = reduce_fn(x, y) + x
249    output_val, ans_val = self.evaluate([output, ans])
250    self.assertAllClose(ans_val, output_val)
251
252  def test_reduce_concat(self):
253    x = random_ops.random_uniform([8, 3])
254
255    def loop_fn(i, pfor_config):
256      x_i = array_ops.gather(x, i)
257      vectorized_value = pfor_config.reduce_concat(x_i)
258      mean_value = math_ops.reduce_mean(vectorized_value, axis=0)
259      return x_i - mean_value
260
261    output = pfor_control_flow_ops.pfor(loop_fn, 8)
262    ans = x - math_ops.reduce_mean(x, axis=0)
263    output_val, ans_val = self.evaluate([output, ans])
264    self.assertAllClose(ans_val, output_val)
265
266  def test_reduce_mean(self):
267    x = random_ops.random_uniform([8, 3])
268
269    def loop_fn(i, pfor_config):
270      x_i = array_ops.gather(x, i)
271      return x_i - pfor_config.reduce_mean(x_i)
272
273    output = pfor_control_flow_ops.pfor(loop_fn, 8)
274    ans = x - math_ops.reduce_mean(x, axis=0)
275    output_val, ans_val = self.evaluate([output, ans])
276    self.assertAllClose(ans_val, output_val)
277
278  def test_reduce_sum(self):
279    x = random_ops.random_uniform([8, 3])
280
281    def loop_fn(i, pfor_config):
282      x_i = array_ops.gather(x, i)
283      return x_i - pfor_config.reduce_sum(x_i)
284
285    output = pfor_control_flow_ops.pfor(loop_fn, 8)
286    ans = x - math_ops.reduce_sum(x, axis=0)
287    output_val, ans_val = self.evaluate([output, ans])
288    self.assertAllClose(ans_val, output_val)
289
290  def test_reduce_class(self):
291    x = random_ops.random_uniform([8, 3])
292
293    class LoopFn:
294
295      def __init__(self):
296        pass
297
298      def __call__(self, i, pfor_config):
299        x_i = array_ops.gather(x, i)
300        return x_i - pfor_config.reduce_mean(x_i)
301
302    output = pfor_control_flow_ops.pfor(LoopFn(), 8)
303    ans = x - math_ops.reduce_mean(x, axis=0)
304    output_val, ans_val = self.evaluate([output, ans])
305    self.assertAllClose(ans_val, output_val)
306
307  def test_reduce_functools_partial(self):
308    x = random_ops.random_uniform([8, 3])
309
310    def fn(i, pfor_config, dummy=None):
311      del dummy
312      x_i = array_ops.gather(x, i)
313      return x_i - pfor_config.reduce_mean(x_i)
314
315    loop_fn = functools.partial(fn, dummy=1)
316    output = pfor_control_flow_ops.pfor(loop_fn, 8)
317    ans = x - math_ops.reduce_mean(x, axis=0)
318    output_val, ans_val = self.evaluate([output, ans])
319    self.assertAllClose(ans_val, output_val)
320
321  def test_parallel_iterations(self):
322    x = random_ops.random_uniform([8, 3])
323
324    def loop_fn(i, pfor_config):
325      x_i = array_ops.gather(x, i)
326      return pfor_config.reduce_sum(x_i)
327
328    with self.assertRaisesRegex(ValueError,
329                                "`parallel_iterations` currently unsupported"):
330      pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2)
331
332  def test_var_loop_len(self):
333    if context.executing_eagerly():
334      self.skipTest("Variable length not possible under eager execution.")
335
336    x = random_ops.random_uniform([8, 3])
337
338    def loop_fn(i, pfor_config):
339      return pfor_config.reduce_sum(array_ops.gather(x, i))
340
341    num_iters = array_ops.placeholder(dtypes.int32)
342    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
343    with self.cached_session() as sess:
344      sess.run(pfor, feed_dict={num_iters: 8})
345
346
347@test_util.run_all_in_graph_and_eager_modes
348class BitwiseTest(PForTestCase):
349
350  def test_unary_cwise(self):
351    for op in [bitwise_ops.invert]:
352      x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
353
354      # pylint: disable=cell-var-from-loop
355      def loop_fn(i):
356        x1 = array_ops.gather(x, i)
357        return op(x1)
358
359      # pylint: enable=cell-var-from-loop
360
361      self._test_loop_fn(loop_fn, 3)
362
363  def test_binary_cwise(self):
364    binary_ops = [
365        bitwise_ops.bitwise_and,
366        bitwise_ops.bitwise_or,
367        bitwise_ops.bitwise_xor,
368        bitwise_ops.left_shift,
369        bitwise_ops.right_shift,
370    ]
371    for op in binary_ops:
372      x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
373      y = random_ops.random_uniform([3, 5], maxval=10, dtype=dtypes.int32)
374
375      output_dtypes = []
376
377      # pylint: disable=cell-var-from-loop
378      def loop_fn(i):
379        x1 = array_ops.gather(x, i)
380        y1 = array_ops.gather(y, i)
381        outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
382        del output_dtypes[:]
383        output_dtypes.extend(t.dtype for t in outputs)
384        return outputs
385
386      # pylint: enable=cell-var-from-loop
387      self._test_loop_fn(loop_fn, 3)
388
389
390@test_util.run_all_in_graph_and_eager_modes
391class ImageTest(PForTestCase):
392
393  def test_adjust_contrast(self):
394    images = random_ops.random_uniform([3, 2, 4, 4, 3])
395
396    def loop_fn(i):
397      image = array_ops.gather(images, i)
398      return image_ops.adjust_contrast(image, 2.0)
399
400    self._test_loop_fn(loop_fn, 3)
401
402  def test_adjust_hue(self):
403    images = random_ops.random_uniform([3, 2, 4, 4, 3])
404
405    def loop_fn(i):
406      image = array_ops.gather(images, i)
407      return image_ops.adjust_hue(image, .25)
408
409    self._test_loop_fn(loop_fn, 3)
410
411  def test_adjust_saturation(self):
412    images = random_ops.random_uniform([3, 2, 4, 4, 3])
413
414    def loop_fn(i):
415      image = array_ops.gather(images, i)
416      return image_ops.adjust_saturation(image, 0.1)
417
418    self._test_loop_fn(loop_fn, 3)
419
420
421@test_util.run_all_in_graph_and_eager_modes
422@test_util.run_all_without_tensor_float_32("Uses matmul")
423class NNTest(PForTestCase):
424
425  def test_conv2d(self):
426    x = random_ops.random_uniform([3, 2, 12, 12, 3])
427    filt = random_ops.random_uniform([3, 3, 3, 7])
428
429    def loop_fn(i):
430      x1 = array_ops.gather(x, i)
431      return nn.conv2d(
432          x1, filt, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
433
434    self._test_loop_fn(loop_fn, 3)
435
436  def test_conv2d_backprop_input(self):
437    x_shape = [2, 12, 12, 3]
438    filt = random_ops.random_uniform([3, 3, 3, 7])
439    grad = random_ops.random_uniform([3, 2, 5, 5, 7])
440
441    def loop_fn(i):
442      grad1 = array_ops.gather(grad, i)
443      return nn.conv2d_backprop_input(
444          x_shape,
445          filt,
446          grad1,
447          strides=[1, 2, 2, 1],
448          padding="VALID",
449          data_format="NHWC")
450
451    self._test_loop_fn(loop_fn, 3)
452
453  def test_conv2d_backprop_filter(self):
454    x = random_ops.random_uniform([3, 2, 12, 12, 3])
455    x_0 = array_ops.gather(x, 0)
456    filter_sizes = [3, 3, 3, 7]
457    grad = random_ops.random_uniform([3, 2, 5, 5, 7])
458
459    def loop_fn(i):
460      x_i = array_ops.gather(x, i)
461      grad_i = array_ops.gather(grad, i)
462      return [
463          nn.conv2d_backprop_filter(
464              inp,
465              filter_sizes,
466              grad_i,
467              strides=[1, 2, 2, 1],
468              padding="VALID",
469              data_format="NHWC") for inp in [x_i, x_0]
470      ]
471
472    self._test_loop_fn(loop_fn, 3)
473
474  def test_depthwise_conv2d_native(self):
475    x = random_ops.random_uniform([3, 2, 12, 12, 3])
476    filt = random_ops.random_uniform([3, 3, 3, 3, 2])
477
478    def loop_fn(i):
479      x1 = array_ops.gather(x, i)
480      filt1 = array_ops.gather(filt, i)
481      return nn.depthwise_conv2d_native(
482          x1, filt1, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
483
484    self._test_loop_fn(loop_fn, 3)
485
486  def test_depthwise_conv2d_native_backprop_input(self):
487    x_shape = [2, 12, 12, 3]
488    filt = random_ops.random_uniform([3, 3, 3, 3, 2])
489    grad = random_ops.random_uniform([3, 2, 5, 5, 6])
490
491    def loop_fn(i):
492      grad1 = array_ops.gather(grad, i)
493      filt1 = array_ops.gather(filt, i)
494      return nn.depthwise_conv2d_native_backprop_input(
495          x_shape,
496          filt1,
497          grad1,
498          strides=[1, 2, 2, 1],
499          padding="VALID",
500          data_format="NHWC")
501
502    self._test_loop_fn(loop_fn, 3)
503
504  def test_depthwise_conv2d_native_backprop_filter(self):
505    x = random_ops.random_uniform([3, 2, 12, 12, 3])
506    filter_sizes = [3, 3, 3, 2]
507    grad = random_ops.random_uniform([3, 2, 5, 5, 6])
508
509    def loop_fn(i):
510      x_i = array_ops.gather(x, i)
511      grad_i = array_ops.gather(grad, i)
512      return nn.depthwise_conv2d_native_backprop_filter(
513          x_i,
514          filter_sizes,
515          grad_i,
516          strides=[1, 2, 2, 1],
517          padding="VALID",
518          data_format="NHWC")
519
520    self._test_loop_fn(loop_fn, 3)
521
522  def test_depthwise_conv2d_native_nchw(self):
523    if not test_util.is_gpu_available():
524      self.skipTest("NCHW only works on GPU")
525    x = random_ops.random_uniform([3, 2, 3, 12, 12])
526    filt = random_ops.random_uniform([3, 3, 3, 3, 2])
527
528    def loop_fn(i):
529      x1 = array_ops.gather(x, i)
530      filt1 = array_ops.gather(filt, i)
531      return nn.depthwise_conv2d_native(
532          x1, filt1, strides=[1, 1, 2, 2], padding="VALID", data_format="NCHW")
533
534    self._test_loop_fn(loop_fn, 3)
535
536  def test_depthwise_conv2d_native_backprop_input_nchw(self):
537    if not test_util.is_gpu_available():
538      self.skipTest("NCHW only works on GPU")
539    x_shape = [2, 3, 12, 12]
540    filt = random_ops.random_uniform([3, 3, 3, 3, 2])
541    grad = random_ops.random_uniform([3, 2, 6, 5, 5])
542
543    def loop_fn(i):
544      grad1 = array_ops.gather(grad, i)
545      filt1 = array_ops.gather(filt, i)
546      return nn.depthwise_conv2d_native_backprop_input(
547          x_shape,
548          filt1,
549          grad1,
550          strides=[1, 1, 2, 2],
551          padding="VALID",
552          data_format="NCHW")
553
554    self._test_loop_fn(loop_fn, 3)
555
556  def test_depthwise_conv2d_native_backprop_filter_nchw(self):
557    if not test_util.is_gpu_available():
558      self.skipTest("NCHW only works on GPU")
559    x = random_ops.random_uniform([3, 2, 3, 12, 12])
560    filter_sizes = [3, 3, 3, 2]
561    grad = random_ops.random_uniform([3, 2, 6, 5, 5])
562
563    def loop_fn(i):
564      x_i = array_ops.gather(x, i)
565      grad_i = array_ops.gather(grad, i)
566      return nn.depthwise_conv2d_native_backprop_filter(
567          x_i,
568          filter_sizes,
569          grad_i,
570          strides=[1, 1, 2, 2],
571          padding="VALID",
572          data_format="NCHW")
573
574    self._test_loop_fn(loop_fn, 3)
575
576  def test_roll(self):
577    x = random_ops.random_uniform([3, 6, 7])
578
579    def loop_fn(i):
580      x_i = array_ops.gather(x, i)
581      return manip_ops.roll(x_i, 3, axis=1)
582
583    self._test_loop_fn(loop_fn, 3)
584
585  def test_ensure_shape(self):
586    x = random_ops.random_uniform([3, 6, 7])
587
588    def loop_fn(i):
589      x_i = array_ops.gather(x, i)
590      return array_ops.ensure_shape(x_i, [6, 7])
591
592    self._test_loop_fn(loop_fn, 3)
593
594  def test_loop_variant_roll_shift(self):
595    x = random_ops.random_uniform([3, 5, 6, 7])
596
597    def loop_fn(i):
598      x_i = array_ops.gather(x, i)
599      return manip_ops.roll(x_i, [i - 2, -1, i], axis=[1, 2, 2])
600
601    self._test_loop_fn(loop_fn, 3)
602
603  def test_loop_variant_roll_scalar_shift(self):
604    x = random_ops.random_uniform([5, 5, 6])
605
606    def loop_fn(i):
607      x_i = array_ops.gather(x, i)
608      return manip_ops.roll(x_i, i, axis=0)
609
610    self._test_loop_fn(loop_fn, 5)
611
612  def test_avg_pool(self):
613    with backprop.GradientTape(persistent=True) as g:
614      x = random_ops.random_uniform([3, 2, 12, 12, 3])
615      g.watch(x)
616      ksize = [1, 3, 3, 1]
617
618    def loop_fn(i):
619      with g:
620        x1 = array_ops.gather(x, i)
621        output = nn.avg_pool(
622            x1,
623            ksize,
624            strides=[1, 2, 2, 1],
625            padding="VALID",
626            data_format="NHWC")
627        loss = nn.l2_loss(output)
628      return output, g.gradient(loss, x1)
629
630    self._test_loop_fn(loop_fn, 3)
631
632  def test_avg_pool3d(self):
633    with backprop.GradientTape(persistent=True) as g:
634      x = random_ops.random_uniform([5, 3, 7, 6, 6, 5])
635      g.watch(x)
636      ksize = [1, 2, 2, 2, 1]
637      strides = [1, 2, 2, 2, 1]
638
639    def loop_fn(i):
640      with g:
641        x1 = array_ops.gather(x, i)
642        output = nn.avg_pool3d(
643            x1, ksize, strides=strides, padding="VALID", data_format="NDHWC")
644        loss = nn.l2_loss(output)
645      return output, g.gradient(loss, x1)
646
647    self._test_loop_fn(loop_fn, 3)
648
649  def test_max_pool(self):
650    with backprop.GradientTape(persistent=True) as g:
651      x = random_ops.random_uniform([3, 2, 12, 12, 3])
652      g.watch(x)
653      ksize = [1, 3, 3, 1]
654      strides = [1, 2, 2, 1]
655
656    def loop_fn(i):
657      with g:
658        x1 = array_ops.gather(x, i)
659        output = nn.max_pool(
660            x1, ksize, strides=strides, padding="VALID", data_format="NHWC")
661        loss = nn.l2_loss(output)
662        ones = array_ops.ones_like(output)
663        g.watch(ones)
664        grad = g.gradient(loss, x1, output_gradients=ones)
665      grad_grad = g.gradient(grad, ones)
666      return output, grad, grad_grad
667
668    self._test_loop_fn(loop_fn, 3)
669
670  def test_max_pool_v2(self):
671    with backprop.GradientTape(persistent=True) as g:
672      x = random_ops.random_uniform([3, 2, 12, 12, 3])
673      g.watch(x)
674      ksize = [1, 3, 3, 1]
675      strides = [1, 2, 2, 1]
676
677    def loop_fn(i):
678      with g:
679        x1 = array_ops.gather(x, i)
680        output = gen_nn_ops.max_pool_v2(
681            x1, ksize, strides=strides, padding="VALID", data_format="NHWC")
682        loss = nn.l2_loss(output)
683        ones = array_ops.ones_like(output)
684        g.watch(ones)
685        grad = g.gradient(loss, x1, output_gradients=ones)
686      grad_grad = g.gradient(grad, ones)
687      return output, grad, grad_grad
688
689    self._test_loop_fn(loop_fn, 3)
690
691  def test_max_pool3d(self):
692    with backprop.GradientTape(persistent=True) as g:
693      x = random_ops.random_uniform([3, 3, 2, 12, 12, 3])
694      g.watch(x)
695      ksize = [1, 1, 3, 3, 1]
696      strides = [1, 1, 2, 2, 1]
697
698    def loop_fn(i):
699      with g:
700        x1 = array_ops.gather(x, i)
701        output = nn.max_pool3d(
702            x1, ksize, strides=strides, padding="VALID", data_format="NDHWC")
703        loss = nn.l2_loss(output)
704        ones = array_ops.ones_like(output)
705        g.watch(ones)
706        grad = g.gradient(loss, x1, output_gradients=ones)
707      grad_grad = g.gradient(grad, ones)
708      return output, grad, grad_grad
709
710    self._test_loop_fn(loop_fn, 3)
711
712  def test_fused_batch_norm(self):
713    data_formats = ["NHWC"]
714    if test.is_gpu_available():
715      data_formats.append("NCHW")
716    for is_training in (True, False):
717      for data_format in data_formats:
718        with backprop.GradientTape(persistent=True) as g:
719          if data_format == "NCHW":
720            x = random_ops.random_uniform([3, 1, 2, 5, 5])
721          else:
722            x = random_ops.random_uniform([3, 1, 5, 5, 2])
723          g.watch(x)
724          scale = random_ops.random_uniform([2])
725          g.watch(scale)
726          offset = random_ops.random_uniform([2])
727          g.watch(offset)
728          mean = None if is_training else random_ops.random_uniform([2])
729          variance = None if is_training else random_ops.random_uniform([2])
730
731        # pylint: disable=cell-var-from-loop
732        def loop_fn(i):
733          with g:
734            x1 = array_ops.gather(x, i)
735            outputs = nn.fused_batch_norm(
736                x1,
737                scale,
738                offset,
739                mean=mean,
740                variance=variance,
741                epsilon=0.01,
742                data_format=data_format,
743                is_training=is_training)
744            outputs = list(outputs)
745            # We only test the first value of outputs when is_training is
746            # False. It looks like CPU and GPU have different outputs for
747            # batch_mean and batch_variance for this case.
748            if not is_training:
749              outputs[1] = constant_op.constant(0.)
750              outputs[2] = constant_op.constant(0.)
751            loss = nn.l2_loss(outputs[0])
752          if is_training:
753            gradients = g.gradient(loss, [x1, scale, offset])
754          else:
755            gradients = [constant_op.constant(0.)] * 3
756          return outputs + gradients
757
758        # pylint: enable=cell-var-from-loop
759
760        self._test_loop_fn(loop_fn, 3)
761
762  def test_log_softmax(self):
763    logits = random_ops.random_uniform([3, 2, 4])
764
765    def loop_fn(i):
766      logits_i = array_ops.gather(logits, i)
767      return (nn.log_softmax(logits_i), nn.log_softmax(logits_i, axis=0),
768              nn.log_softmax(logits_i, axis=-1))
769
770    self._test_loop_fn(loop_fn, 3)
771
772  def test_softmax(self):
773    logits = random_ops.random_uniform([3, 2, 4])
774
775    def loop_fn(i):
776      logits_i = array_ops.gather(logits, i)
777      return (nn.softmax(logits_i), nn.softmax(logits_i, axis=0),
778              nn.softmax(logits_i, axis=-1))
779
780    self._test_loop_fn(loop_fn, 3)
781
782  def test_softmax_cross_entropy_with_logits(self):
783    with backprop.GradientTape(persistent=True) as g:
784      logits = random_ops.random_uniform([3, 2, 4])
785      g.watch(logits)
786      labels = random_ops.random_uniform([3, 2, 4])
787      labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True)
788
789    def loop_fn(i):
790      with g:
791        logits_i = array_ops.gather(logits, i)
792        labels_i = array_ops.gather(labels, i)
793        loss = nn.softmax_cross_entropy_with_logits(
794            labels=labels_i, logits=logits_i)
795        total_loss = math_ops.reduce_sum(loss)
796      return loss, g.gradient(total_loss, logits_i)
797
798    self._test_loop_fn(loop_fn, 3)
799
800  def test_sparse_softmax_cross_entropy_with_logits(self):
801    logits = random_ops.random_uniform([3, 2, 4])
802    labels = random_ops.random_uniform(
803        shape=[3, 2], maxval=4, dtype=dtypes.int32)
804
805    def loop_fn(i):
806      logits_i = array_ops.gather(logits, i)
807      labels_i = array_ops.gather(labels, i)
808      loss = nn.sparse_softmax_cross_entropy_with_logits(
809          labels=labels_i, logits=logits_i)
810      return loss
811
812    self._test_loop_fn(loop_fn, 3)
813
814
815class RandomTest(PForTestCase):
816
817  # The random values generated in the two implementations are not guaranteed to
818  # match. So we only check the returned shapes.
819  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
820    outputs = self._run_targets(targets1, targets2)
821    n = len(outputs) // 2
822    for i in range(n):
823      self.assertAllEqual(outputs[i].shape, outputs[i + n].shape)
824
825  def test_random_uniform(self):
826
827    def loop_fn(_):
828      return random_ops.random_uniform([3])
829
830    self._test_loop_fn(loop_fn, 5)
831
832  def test_random_uniform_int(self):
833
834    def loop_fn(_):
835      return random_ops.random_uniform([3], maxval=1, dtype=dtypes.int32)
836
837    self._test_loop_fn(loop_fn, 5)
838
839  def test_random_standard_normal(self):
840
841    def loop_fn(_):
842      return random_ops.random_normal([3])
843
844    self._test_loop_fn(loop_fn, 5)
845
846  def test_truncated_normal(self):
847
848    def loop_fn(_):
849      return random_ops.truncated_normal([3])
850
851    self._test_loop_fn(loop_fn, 5)
852
853  def test_random_gamma_invariant_alpha(self):
854
855    def loop_fn(_):
856      return random_ops.random_gamma([3], alpha=[0.5])
857
858    self._test_loop_fn(loop_fn, 5)
859
860  def test_random_gamma_varying_alpha(self):
861    alphas = math_ops.exp(random_ops.random_normal([5, 3, 2]))
862
863    def loop_fn(i):
864      alphas_i = array_ops.gather(alphas, i)
865      # Test both scalar and non-scalar params and shapes.
866      return (random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[]),
867              random_ops.random_gamma(alpha=alphas_i, shape=[]),
868              random_ops.random_gamma(alpha=alphas_i[0, 0], shape=[3]),
869              random_ops.random_gamma(alpha=alphas_i, shape=[3]))
870
871    self._test_loop_fn(loop_fn, 5)
872
873  def test_random_poisson_v2_invariant_rate(self):
874
875    def loop_fn(_):
876      return random_ops.random_poisson(lam=[1.3], shape=[3])
877
878    self._test_loop_fn(loop_fn, 5)
879
880  def test_random_poisson_v2_varying_rate(self):
881    rates = math_ops.exp(random_ops.random_normal([5, 3, 2]))
882
883    def loop_fn(i):
884      rates_i = array_ops.gather(rates, i)
885      # Test both scalar and non-scalar params and shapes.
886      return (random_ops.random_poisson(lam=rates_i[0, 0], shape=[]),
887              random_ops.random_poisson(lam=rates_i, shape=[]),
888              random_ops.random_poisson(lam=rates_i[0, 0], shape=[3]),
889              random_ops.random_poisson(lam=rates_i, shape=[3]))
890
891    self._test_loop_fn(loop_fn, 5)
892
893  def test_random_multinomial_invariant_logits(self):
894
895    def loop_fn(_):
896      return random_ops.categorical(logits=[[1., -1.]], num_samples=3)
897
898    self._test_loop_fn(loop_fn, 5)
899
900  def test_random_multinomial_varying_logits(self):
901    logits = random_ops.random_normal([5, 3, 2])
902
903    def loop_fn(i):
904      logits_i = array_ops.gather(logits, i)
905      return random_ops.categorical(logits_i, num_samples=3)
906
907    self._test_loop_fn(loop_fn, 5)
908
909
910class StatelessRandomTest(PForTestCase):
911
912  # This test currently only tests that the vectorized and non-vectorized
913  # outputs have same shapes. This is needed since under XLA compilation,
914  # stateless random numbers can generate different random numbers.
915  # TODO(agarwal): switch to checking for actual values matching once
916  # b/149402339 is resolved.
917  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
918    outputs = self._run_targets(targets1, targets2)
919    n = len(outputs) // 2
920    for i in range(n):
921      self.assertAllEqual(outputs[i].shape, outputs[i + n].shape)
922
923  # TODO(agarwal): add tests for other random functions
924  def test_multinomial(self):
925    seeds = [[1, 2], [3, 4]]
926    logits = random_ops.random_uniform([2, 3, 4])
927
928    def loop_fn(i):
929      logits_0 = array_ops.gather(logits, 0)
930      logits_i = array_ops.gather(logits, i)
931      seeds_0 = array_ops.gather(seeds, 0)
932      seeds_i = array_ops.gather(seeds, i)
933      return (stateless_random_ops.stateless_categorical(
934          logits=logits_i, num_samples=3, seed=seeds_i),
935              stateless_random_ops.stateless_categorical(
936                  logits=logits_i, num_samples=3, seed=seeds_0),
937              stateless_random_ops.stateless_categorical(
938                  logits=logits_0, num_samples=3, seed=seeds_i),
939              stateless_random_ops.stateless_categorical(
940                  logits=logits_0, num_samples=3, seed=seeds_0))
941
942    self._test_loop_fn(loop_fn, 2)
943
944
945class LoggingTest(PForTestCase):
946
947  def test_print(self):
948    x = random_ops.random_uniform([3, 5])
949
950    def loop_fn(i):
951      x1 = array_ops.gather(x, i)
952      return logging_ops.Print(
953          x1, [x1, "x1", array_ops.shape(x1)], summarize=10)
954
955    self._test_loop_fn(loop_fn, 3)
956
957  def test_print_v2(self):
958    x = constant_op.constant([1, 2, 3])
959
960    def loop_fn(i):
961      x1 = array_ops.gather(x, i)
962      with ops.control_dependencies([
963          logging_ops.print_v2(
964              x1, "x1", array_ops.shape(x1), summarize=10)]):
965        return array_ops.identity(x1)
966
967    self._test_loop_fn(loop_fn, 3)
968
969    with self.captureWritesToStream(sys.stderr) as printed:
970      self.evaluate(pfor_control_flow_ops.pfor(loop_fn, 3))
971    self.assertIn("[1 2 3] x1 []", printed.contents())
972
973  def test_assert(self):
974
975    def loop_fn(i):
976      return control_flow_ops.Assert(i < 10, [i, [10], [i + 1]])
977
978    # TODO(agarwal): make this work with for_loop.
979    with session.Session() as sess:
980      sess.run(pfor_control_flow_ops.pfor(loop_fn, 3))
981      sess.run(pfor_control_flow_ops.pfor(
982          lambda i, pfor_config: loop_fn(i), 3))
983
984
985class TensorArrayTest(PForTestCase):
986
987  def setUp(self):
988    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
989    control_flow_v2_toggles.disable_control_flow_v2()
990    super(TensorArrayTest, self).setUp()
991
992  def tearDown(self):
993    if self._enabled:
994      control_flow_v2_toggles.enable_control_flow_v2()
995    super(TensorArrayTest, self).tearDown()
996
997  @test_util.run_v1_only("b/122612051")
998  def test_create_outside_and_read(self):
999
1000    ta = tensor_array_ops.TensorArray(
1001        dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1)
1002
1003    def loop_fn(i):
1004      return ta.read(i), ta.read(0)
1005
1006    self._test_loop_fn(loop_fn, 2)
1007
1008  @test_util.run_v1_only("b/122612051")
1009  def test_create_outside_and_gather(self):
1010
1011    ta = tensor_array_ops.TensorArray(
1012        dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1)
1013
1014    def loop_fn(i):
1015      return ta.gather([i]), ta.gather([0, 1])
1016
1017    self._test_loop_fn(loop_fn, 2)
1018
1019  @test_util.run_v1_only("b/122612051")
1020  def test_create_outside_and_write_and_scatter(self):
1021
1022    t = tensor_array_ops.TensorArray(dtypes.int32, 10, clear_after_read=False)
1023    handle = t.handle
1024
1025    def loop_fn(i):
1026      ta = t.write(i + 2, 2 * i).write(i, 5)
1027      ta = ta.scatter([4 + i], [4]).scatter([6 + i, 8 + i], [6 + i, 8 + i])
1028      return ta.flow
1029
1030    t1 = pfor_control_flow_ops.pfor(loop_fn, iters=2)
1031    out1 = tensor_array_ops.TensorArray(
1032        dtypes.int32, handle=handle, flow=t1[-1]).stack()
1033    output1 = self._run_targets(out1)
1034
1035    t2 = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, iters=2)
1036    out2 = tensor_array_ops.TensorArray(
1037        dtypes.int32, handle=handle, flow=t2[-1]).stack()
1038    output2 = self._run_targets(out2)
1039    self.assertAllClose(output2, output1)
1040
1041  @test_util.run_v1_only("b/122612051")
1042  def test_create_inside_and_write(self):
1043
1044    def loop_fn(i):
1045      # TODO(agarwal): switching the order of writes to ta1 does not work.
1046      ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).write(0,
1047                                                                i).write(1, 1)
1048      ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1).write(0, 1)
1049      return ta1.stack(), ta2.stack()
1050
1051    self._test_loop_fn(loop_fn, 3)
1052
1053  @test_util.run_v1_only("b/122612051")
1054  def test_create_inside_and_scatter(self):
1055
1056    def loop_fn(i):
1057      # TODO(agarwal): switching the order of scatter to ta1 does not work.
1058      ta1 = tensor_array_ops.TensorArray(dtypes.int32,
1059                                         2).scatter([0],
1060                                                    [[i, 2]]).scatter([1],
1061                                                                      [[1, 2]])
1062      ta2 = tensor_array_ops.TensorArray(dtypes.int32,
1063                                         2).scatter([0], [3]).scatter([1], [4])
1064      return ta1.stack(), ta2.stack()
1065
1066    self._test_loop_fn(loop_fn, 3)
1067
1068  @test_util.run_v1_only("b/122612051")
1069  def test_create_inside_and_read(self):
1070
1071    def loop_fn(i):
1072      ta1 = tensor_array_ops.TensorArray(
1073          dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1)
1074      ta2 = tensor_array_ops.TensorArray(
1075          dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2)
1076      # TODO(agarwal): ta1.read(i) currently is not supported.
1077      return ta1.read(0), ta2.read(0), ta2.read(i)
1078
1079    self._test_loop_fn(loop_fn, 2)
1080
1081  @test_util.run_v1_only("b/122612051")
1082  def test_create_inside_and_gather(self):
1083
1084    def loop_fn(i):
1085      ta1 = tensor_array_ops.TensorArray(
1086          dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1)
1087      ta2 = tensor_array_ops.TensorArray(
1088          dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2)
1089      # TODO(agarwal): ta1.read(i) currently is not supported.
1090      return ta1.gather([0, 1]), ta2.gather([0, 1]), ta2.gather([i])
1091
1092    self._test_loop_fn(loop_fn, 2)
1093
1094  @test_util.run_v1_only("b/122612051")
1095  def test_grad(self):
1096    x = random_ops.random_uniform([3, 2])
1097    ta = tensor_array_ops.TensorArray(
1098        dtypes.float32, 3, clear_after_read=False).unstack(x)
1099    y = math_ops.square(ta.stack())
1100
1101    def loop_fn(i):
1102      y_i = array_ops.gather(y, i)
1103      grad = gradient_ops.gradients(y_i, x)[0]
1104      return array_ops.gather(grad, i)
1105
1106    t1 = pfor_control_flow_ops.pfor(loop_fn, iters=3)
1107    # y = x * x. Hence dy/dx = 2 * x.
1108    actual_grad = 2.0 * x
1109    with session.Session() as sess:
1110      actual_grad, computed_grad = sess.run([t1, actual_grad])
1111      self.assertAllClose(actual_grad, computed_grad)
1112
1113
1114@test_util.run_all_in_graph_and_eager_modes
1115class TensorListTest(PForTestCase):
1116
1117  def test_create_outside_and_write(self):
1118    handle1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1119    handle2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1120
1121    def loop_fn(i):
1122      h1 = list_ops.tensor_list_set_item(handle1, 0, i)
1123      h1 = list_ops.tensor_list_set_item(h1, 1, 1)
1124      h2 = list_ops.tensor_list_set_item(handle2, 0, 1)
1125      return (list_ops.tensor_list_stack(h1, dtypes.int32),
1126              list_ops.tensor_list_stack(h2, dtypes.int32))
1127
1128    self._test_loop_fn(loop_fn, 3)
1129
1130  def _make_graph_def(self, text):
1131    ret = graph_pb2.GraphDef()
1132    text_format.Parse(text, ret)
1133    return ret
1134
1135  def test_no_fallback_with_internal_stacking(self):
1136
1137    # Create an op (really a function) that pfor definitely does not have a
1138    # converter for. Assumes pfor does not start looking up function definitions
1139    # for op-type-is-function-name calls.
1140    @def_function.function
1141    def opaque_list_fetch(x):
1142      array_ops.identity(x)
1143      return list_ops.tensor_list_get_item(x, 0, dtypes.int32)
1144
1145    external_handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1146    opaque_list_fetch_concrete = opaque_list_fetch.get_concrete_function(
1147        external_handle)
1148    opaque_list_fetch_name = opaque_list_fetch_concrete.name
1149
1150    def loop_fn(i):
1151      h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1152      h1 = list_ops.tensor_list_set_item(h1, 0, i)
1153      opaque_list_fetch_concrete.add_to_graph()
1154      graph_def = self._make_graph_def("""
1155         node { name: 'x' op: 'Placeholder'
1156                attr { key: 'dtype' value { type: DT_FLOAT } }}
1157         node { name: 'fn' op: '""" + opaque_list_fetch_name.decode()
1158                                       + """' input: 'x:0' }""")
1159      return importer.import_graph_def(
1160          graph_def,
1161          input_map={"x:0": h1},
1162          return_elements=["fn"],
1163          name="import")[0].outputs[0]
1164
1165    with self.assertRaisesRegex(ValueError, "No pfor vectorization"):
1166      self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
1167    with self.assertRaisesRegex(ValueError, "No pfor vectorization"):
1168      self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
1169
1170  def test_create_inside_and_write(self):
1171
1172    def loop_fn(i):
1173      h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1174      h1 = list_ops.tensor_list_set_item(h1, 0, i)
1175      h1 = list_ops.tensor_list_set_item(h1, 1, 1)
1176      h2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1177      h2 = list_ops.tensor_list_set_item(h2, 0, 1)
1178      return (list_ops.tensor_list_stack(h1, dtypes.int32),
1179              list_ops.tensor_list_stack(h2, dtypes.int32))
1180
1181    self._test_loop_fn(loop_fn, 3)
1182
1183  def test_create_outside_and_read(self):
1184    handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1185    handle = list_ops.tensor_list_set_item(handle, 0, 0)
1186    handle = list_ops.tensor_list_set_item(handle, 1, 1)
1187
1188    def loop_fn(i):
1189      return (list_ops.tensor_list_get_item(handle, i, dtypes.int32),
1190              list_ops.tensor_list_get_item(handle, 0, dtypes.int32),
1191              list_ops.tensor_list_length(handle),
1192              list_ops.tensor_list_element_shape(handle, dtypes.int32),
1193              list_ops.tensor_list_element_shape(handle, dtypes.int64))
1194
1195    self._test_loop_fn(loop_fn, 2)
1196
1197  def test_create_inside_and_read(self):
1198
1199    def loop_fn(i):
1200      handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1201      handle = list_ops.tensor_list_set_item(handle, 0, i)
1202      handle = list_ops.tensor_list_set_item(handle, 1, 1)
1203      return (list_ops.tensor_list_get_item(handle, 0, dtypes.int32),
1204              list_ops.tensor_list_get_item(handle, i, dtypes.int32),
1205              list_ops.tensor_list_length(handle),
1206              list_ops.tensor_list_element_shape(handle, dtypes.int32),
1207              list_ops.tensor_list_element_shape(handle, dtypes.int64))
1208
1209    self._test_loop_fn(loop_fn, 2)
1210
1211  def test_create_outside_and_push_back(self):
1212    h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1213
1214    def loop_fn(i):
1215      handle = list_ops.tensor_list_push_back(h, [i, 2])
1216      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1217      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1218      return list_ops.tensor_list_stack(handle, dtypes.int32)
1219
1220    self._test_loop_fn(loop_fn, 3)
1221
1222  def test_create_inside_and_push_back(self):
1223
1224    def loop_fn(i):
1225      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1226      handle = list_ops.tensor_list_push_back(handle, [i, 2])
1227      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1228      return list_ops.tensor_list_stack(handle, dtypes.int32)
1229
1230    self._test_loop_fn(loop_fn, 3)
1231
1232  def test_pop_back_no_shape(self):
1233
1234    def loop_fn(i):
1235      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1236      handle = list_ops.tensor_list_push_back(handle, [1, 2])
1237      handle = list_ops.tensor_list_push_back(handle, [i, 2])
1238      handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32)
1239      return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
1240
1241    self._test_loop_fn(loop_fn, 3)
1242
1243  def test_pop_back_no_shape_capture(self):
1244    h = list_ops.tensor_list_reserve([2], 1, dtypes.int32)
1245    h = list_ops.tensor_list_push_back(h, [1, 2])
1246
1247    def loop_fn(i):
1248      handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32)
1249      handle = list_ops.tensor_list_push_back(handle, [1, i])
1250      return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
1251
1252    self._test_loop_fn(loop_fn, 3)
1253
1254  def test_pop_back_with_shape(self):
1255
1256    @def_function.function
1257    def loop_fn(i):
1258      with backprop.GradientTape() as tape:
1259        handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32)
1260        x = math_ops.cast(i, dtypes.float32)[None]
1261        tape.watch(x)
1262        handle = list_ops.tensor_list_push_back(handle, x)
1263        stacked = list_ops.tensor_list_stack(handle, dtypes.float32)
1264      list_grad = tape.gradient(stacked, x, x)
1265      self.assertEqual("TensorListPopBack", list_grad.op.type)
1266      return list_grad, stacked, list_grad.op.inputs[1]
1267
1268    self._test_loop_fn(loop_fn, 3)
1269
1270  def test_create_outside_and_scatter(self):
1271    h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1272
1273    def loop_fn(i):
1274      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h)
1275      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1276      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1277      return list_ops.tensor_list_stack(handle, dtypes.int32)
1278
1279    self._test_loop_fn(loop_fn, 3)
1280
1281  def test_create_inside_and_scatter(self):
1282
1283    def loop_fn(i):
1284      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1285      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
1286      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1287      return list_ops.tensor_list_stack(handle, dtypes.int32)
1288
1289    self._test_loop_fn(loop_fn, 3)
1290
1291  def test_loop_variant_scatter_indices(self):
1292
1293    def loop_fn(i):
1294      handle = list_ops.tensor_list_reserve([2], 10, dtypes.int32)
1295      handle = list_ops.tensor_list_scatter(
1296          [[1, i], [i + 1, 2]],
1297          [i, i + 5], input_handle=handle)
1298      return list_ops.tensor_list_stack(handle, dtypes.int32)
1299
1300    self._test_loop_fn(loop_fn, 5)
1301
1302  def test_loop_variant_scatter_duplicate_indices(self):
1303    if test_util.is_gpu_available():
1304      self.skipTest(
1305          "Flaky in some GPU configurations due to TensorScatterNdUpdate "
1306          "nondeterminism.")
1307
1308    def loop_fn(i):
1309      handle = list_ops.tensor_list_reserve([2], 10, dtypes.int32)
1310      handle = list_ops.tensor_list_scatter(
1311          [[1, i], [1, i + 1], [i + 2, 3]],
1312          [i, i, i + 2], input_handle=handle)
1313      return list_ops.tensor_list_stack(handle, dtypes.int32)
1314
1315    self._test_loop_fn(loop_fn, 5)
1316
1317  def test_create_outside_and_gather(self):
1318    handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1319    handle = list_ops.tensor_list_scatter([[2, 3]], [0], input_handle=handle)
1320    handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1321
1322    def loop_fn(i):
1323      return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32),
1324              list_ops.tensor_list_gather(handle, [i], dtypes.int32))
1325
1326    self._test_loop_fn(loop_fn, 2)
1327
1328  def test_create_inside_and_gather(self):
1329
1330    def loop_fn(i):
1331      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1332      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
1333      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1334      return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32),
1335              list_ops.tensor_list_gather(handle, [i], dtypes.int32))
1336
1337    self._test_loop_fn(loop_fn, 2)
1338
1339  def test_create_inside_and_concat(self):
1340
1341    def loop_fn(i):
1342      handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1343      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
1344      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1345      return gen_list_ops.tensor_list_concat_v2(
1346          handle,
1347          element_dtype=dtypes.int32,
1348          element_shape=[2],
1349          leading_dims=[])
1350
1351    output = pfor_control_flow_ops.pfor(loop_fn, 2)
1352    self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
1353    self.assertAllClose([[2, 2], [2, 2]], output[1])
1354
1355  def test_create_outside_and_concat(self):
1356    h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
1357
1358    def loop_fn(i):
1359      handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h)
1360      handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
1361      return gen_list_ops.tensor_list_concat_v2(
1362          handle,
1363          element_dtype=dtypes.int32,
1364          element_shape=[2],
1365          leading_dims=[])
1366
1367    output = pfor_control_flow_ops.pfor(loop_fn, 2)
1368    self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
1369    self.assertAllClose([[2, 2], [2, 2]], output[1])
1370
1371  def test_tensor_list_from_tensor(self):
1372    t = random_ops.random_uniform([2, 3, 4])
1373
1374    def loop_fn(i):
1375      handle = list_ops.tensor_list_from_tensor(array_ops.gather(t, i), [4])
1376      return list_ops.tensor_list_stack(handle, t.dtype)
1377
1378    self._test_loop_fn(loop_fn, 2)
1379
1380  @test_util.enable_control_flow_v2
1381  def test_tensor_list_reserve_while_loop(self):
1382    # Here a loop invariant TensorList is captured by a while_loop, which then
1383    # performs loop dependent operations on it, resulting in a loop variant
1384    # output. This forces stacking of the variant handle captured by the
1385    # while_loop.
1386    # We handle this particular case by forcing vectorization of
1387    # TensorListReserve operation.
1388
1389    def loop_fn(i):
1390      handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1391      _, out_handle = control_flow_ops.while_loop(
1392          lambda j, _: j < 2, lambda j, h:
1393          (j + 1, list_ops.tensor_list_set_item(h, j, i)), (0, handle))
1394      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
1395
1396    self._test_loop_fn(loop_fn, 2)
1397
1398  @test_util.enable_control_flow_v2
1399  def test_tensor_list_while_loop_stacked_cond_stacked_list(self):
1400
1401    def loop_fn(i):
1402      handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, i], [])
1403      _, out_handle = control_flow_ops.while_loop(
1404          lambda j, _: j < i,
1405          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
1406          (0, handle))
1407      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
1408
1409    self._test_loop_fn(loop_fn, 5)
1410
1411  @test_util.enable_control_flow_v2
1412  def test_tensor_list_while_loop_stacked_cond_stacked_list_unknown_shape(self):
1413
1414    def loop_fn(i):
1415      handle = list_ops.tensor_list_reserve(None, 5, dtypes.int32)
1416      _, handle = control_flow_ops.while_loop(
1417          lambda j, _: j < 5,
1418          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, 0)),
1419          (0, handle))
1420      _, out_handle = control_flow_ops.while_loop(
1421          lambda j, _: j < i,
1422          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
1423          (0, handle))
1424      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
1425
1426    self._test_loop_fn(loop_fn, 5)
1427
1428  @test_util.enable_control_flow_v2
1429  def test_tensor_list_while_loop_stacked_cond_unstacked_list(self):
1430
1431    def loop_fn(i):
1432      handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, 24], [])
1433      _, out_handle = control_flow_ops.while_loop(
1434          lambda j, _: j < i,
1435          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
1436          (0, handle))
1437      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
1438
1439    self._test_loop_fn(loop_fn, 5)
1440
1441  def test_tensor_list_addn_already_stacked(self):
1442
1443    def loop_fn(i):
1444      l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1445      l1 = list_ops.tensor_list_set_item(l1, 0, i)
1446      l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1447      l2 = list_ops.tensor_list_set_item(l2, 1, i)
1448      return list_ops.tensor_list_stack(math_ops.add_n([l1, l2]), dtypes.int32)
1449
1450    self._test_loop_fn(loop_fn, 2)
1451
1452  def test_tensor_list_addn_stacking_required(self):
1453    l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1454    l1 = list_ops.tensor_list_set_item(l1, 1, 1)
1455
1456    def loop_fn(i):
1457      l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
1458      l2 = list_ops.tensor_list_set_item(l2, 1, i)
1459      return list_ops.tensor_list_stack(
1460          math_ops.add_n([l1, l2]), dtypes.int32)
1461
1462    self._test_loop_fn(loop_fn, 2)
1463
1464
1465@test_util.run_all_in_graph_and_eager_modes
1466class TensorTest(PForTestCase):
1467
1468  def test_loop_variant_scatter_update_no_shape(self):
1469    if test_util.is_gpu_available():
1470      self.skipTest(
1471          "Flaky in some GPU configurations due to TensorScatterNdUpdate "
1472          "nondeterminism.")
1473
1474    @def_function.function(input_signature=[
1475        tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32),
1476        tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32),
1477        tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32)
1478    ])
1479    def shapeless_func(tensor, indices, updates):
1480      return array_ops.tensor_scatter_nd_update(tensor, indices, updates)
1481
1482    def loop_fn(i):
1483      tensor = [0, 0, 0, 0, 0, 0, 0, 0]
1484      indices = [[i], [i + 1], [i + 3], [i + 2]]
1485      updates = [i, i - 10, i + 11, 12]
1486      return shapeless_func(tensor, indices, updates)
1487
1488    self._test_loop_fn(loop_fn, 5)
1489
1490  def test_loop_variant_scatter_update_singles(self):
1491    if test_util.is_gpu_available():
1492      self.skipTest(
1493          "Flaky in some GPU configurations due to TensorScatterNdUpdate "
1494          "nondeterminism.")
1495
1496    def loop_fn(i):
1497      tensor = [0, 0, 0, 0, 0, 0, 0, 0]
1498      indices = [[i], [i+1], [i+3], [i+2]]
1499      updates = [i, i-10, i+11, 12]
1500      return array_ops.tensor_scatter_nd_update(tensor, indices, updates)
1501
1502    self._test_loop_fn(loop_fn, 5)
1503
1504  def test_loop_variant_scatter_update_slices(self):
1505    if test_util.is_gpu_available():
1506      self.skipTest(
1507          "Flaky in some GPU configurations due to TensorScatterNdUpdate "
1508          "nondeterminism.")
1509
1510    def loop_fn(i):
1511      tensor = array_ops.zeros([10, 3], dtype=dtypes.int32)
1512      indices = [[i+2], [4]]
1513      updates = [[1, i*2, 3], [i+4, i-5, 6]]
1514      return array_ops.tensor_scatter_nd_update(tensor, indices, updates)
1515
1516    self._test_loop_fn(loop_fn, 5)
1517
1518  def test_loop_variant_scatter_update_multi_dim_index(self):
1519    if test_util.is_gpu_available():
1520      self.skipTest(
1521          "Flaky in some GPU configurations due to TensorScatterNdUpdate "
1522          "nondeterminism.")
1523
1524    def loop_fn(i):
1525      tensor = array_ops.zeros([10, 3], dtype=dtypes.int32)
1526      indices = [[i+2, 1], [4, 2]]
1527      updates = [i, 5]
1528      return array_ops.tensor_scatter_nd_update(tensor, indices, updates)
1529
1530    self._test_loop_fn(loop_fn, 5)
1531
1532  def test_loop_variant_scatter_update_folded_indices(self):
1533    if test_util.is_gpu_available():
1534      self.skipTest(
1535          "Flaky in some GPU configurations due to TensorScatterNdUpdate "
1536          "nondeterminism.")
1537
1538    def loop_fn(i):
1539      tensor = array_ops.zeros([5, 5])
1540      indices = [
1541          [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]],
1542          [[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]],
1543      ]
1544      updates = [
1545          [1, i, 1, 1, 1],
1546          [1, 1, i+2, 1, i-5],
1547      ]
1548      return array_ops.tensor_scatter_nd_update(tensor, indices, updates)
1549
1550    self._test_loop_fn(loop_fn, 5)
1551
1552
1553class OptionalTest(PForTestCase):
1554
1555  def test_optional_from_value(self):
1556
1557    def loop_fn(i):
1558      o = gen_dataset_ops.optional_from_value(
1559          [i, i + 1, constant_op.constant(3)])
1560      gen_dataset_ops.optional_none()
1561      return gen_dataset_ops.optional_get_value(
1562          o, [dtypes.int32, dtypes.int32, dtypes.int32],
1563          [[], [], []])
1564
1565    self._test_loop_fn(loop_fn, 2)
1566
1567
1568class StackTest(PForTestCase):
1569
1570  @test_util.run_v1_only("b/122612051")
1571  def test_stack_inside_loop_invariant(self):
1572
1573    def loop_fn(_):
1574      s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1575      op1 = data_flow_ops.stack_push_v2(s, 1)
1576      with ops.control_dependencies([op1]):
1577        op2 = data_flow_ops.stack_push_v2(s, 2)
1578      with ops.control_dependencies([op2]):
1579        e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1580      with ops.control_dependencies([e2]):
1581        e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1582      return e1, e2
1583
1584    self._test_loop_fn(loop_fn, 2)
1585
1586  @test_util.run_v1_only("b/122612051")
1587  def test_stack_inside_push_loop_dependent(self):
1588
1589    def loop_fn(i):
1590      s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1591      op1 = data_flow_ops.stack_push_v2(s, i)
1592      with ops.control_dependencies([op1]):
1593        op2 = data_flow_ops.stack_push_v2(s, 2)
1594      with ops.control_dependencies([op2]):
1595        e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1596      with ops.control_dependencies([e2]):
1597        e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1598      return e1, e2
1599
1600    self._test_loop_fn(loop_fn, 2)
1601
1602  @test_util.run_v1_only("b/122612051")
1603  def test_stack_outside_pop(self):
1604    s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1605    op = data_flow_ops.stack_push_v2(s, 5)
1606    with ops.control_dependencies([op]):
1607      op = data_flow_ops.stack_push_v2(s, 6)
1608    with ops.control_dependencies([op]):
1609      op = data_flow_ops.stack_push_v2(s, 7)
1610
1611    def loop_fn(_):
1612      e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1613      with ops.control_dependencies([e1]):
1614        e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1615      return e1, e2
1616
1617    with ops.control_dependencies([op]):
1618      e1, e2 = pfor_control_flow_ops.pfor(loop_fn, iters=2)
1619    with ops.control_dependencies([e1, e2]):
1620      e3 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
1621    v1, v2, v3 = self._run_targets([e1, e2, e3], run_init=False)
1622    self.assertAllEqual([7, 7], v1)
1623    self.assertAllEqual([6, 6], v2)
1624    self.assertAllEqual(5, v3)
1625
1626  @test_util.run_v1_only("b/122612051")
1627  def test_stack_outside_push(self):
1628    s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
1629
1630    def loop_fn(_):
1631      return data_flow_ops.stack_push_v2(s, 7)
1632
1633    with self.assertRaisesRegex(ValueError, "StackPushV2 not allowed.*"):
1634      pfor_control_flow_ops.pfor(loop_fn, iters=2)
1635
1636
1637# TODO(agarwal): test nested while_loops. This currently requires converting a
1638# tf.cond.
1639class WhileV1Test(PForTestCase):
1640
1641  def setUp(self):
1642    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
1643    control_flow_v2_toggles.disable_control_flow_v2()
1644    super(WhileV1Test, self).setUp()
1645
1646  def tearDown(self):
1647    if self._enabled:
1648      control_flow_v2_toggles.enable_control_flow_v2()
1649    super(WhileV1Test, self).tearDown()
1650
1651  def test_while_outside_loop(self):
1652
1653    x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1654
1655    def loop_fn(i):
1656      return x + i
1657
1658    self._test_loop_fn(loop_fn, 3)
1659
1660  @test_util.run_v1_only("b/122612051")
1661  def test_invariant_while(self):
1662
1663    def loop_fn(_):
1664      return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1665
1666    self._test_loop_fn(loop_fn, 3)
1667
1668  @test_util.run_v1_only("b/122612051")
1669  def test_invariant_while_with_control_dependency(self):
1670
1671    def loop_fn(i):
1672      with ops.control_dependencies([i]):
1673        return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1,
1674                                           [0])
1675
1676    self._test_loop_fn(loop_fn, 3)
1677
1678  @test_util.run_v1_only("b/122612051")
1679  def test_while_with_stateful_ops(self):
1680
1681    def loop_fn(_):
1682      return control_flow_ops.while_loop(
1683          lambda j, x: j < 4, lambda j, x:
1684          (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0]
1685
1686    self._test_loop_fn(loop_fn, 3)
1687
1688  @test_util.run_v1_only("b/122612051")
1689  def test_while_unstacked_condition(self):
1690
1691    def loop_fn(i):
1692      return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
1693                                         (j + 1, x + i), [0, 0])
1694
1695    self._test_loop_fn(loop_fn, 3)
1696
1697  @test_util.run_v1_only("b/122612051")
1698  def test_while(self):
1699    x = random_ops.random_uniform([3, 5])
1700    lengths = constant_op.constant([4, 0, 2])
1701
1702    def loop_fn(i):
1703      x_i = array_ops.gather(x, i)
1704      lengths_i = array_ops.gather(lengths, i)
1705
1706      _, total = control_flow_ops.while_loop(
1707          lambda j, _: j < lengths_i, lambda j, t:
1708          (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
1709      return total
1710
1711    self._test_loop_fn(loop_fn, 3)
1712
1713  @test_util.run_v1_only("b/122612051")
1714  def test_while_jacobian(self):
1715    x = random_ops.random_uniform([1, 3])
1716    y = random_ops.random_uniform([3, 3])
1717
1718    # out = x @ y @ y @ y @ y, where @ is matmul operator.
1719    _, out = control_flow_ops.while_loop(
1720        lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
1721        [0, x])
1722
1723    def loop_fn(i):
1724      out_i = array_ops.gather(out, i, axis=1)
1725      return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1])
1726
1727    out = pfor_control_flow_ops.pfor(loop_fn, iters=3)
1728
1729    # The above code does not work with tf.while_loop instead of pfor. So we
1730    # manually compute the expected output here.
1731    # Note that gradient of output w.r.t is (y @ y @ y @ y)^T.
1732    expected_output = y
1733    for _ in range(3):
1734      expected_output = math_ops.matmul(expected_output, y)
1735    expected_output = array_ops.transpose(expected_output, [1, 0])
1736
1737    with session.Session() as sess:
1738      out, expected = sess.run([out, expected_output])
1739      self.assertAllClose(expected, out)
1740
1741  @test_util.run_v1_only("b/122612051")
1742  def test_tensor_array_as_loop_variable(self):
1743
1744    def loop_fn(i):
1745
1746      def body(j, ta):
1747        ta = ta.write(j, i + j * j)
1748        return j + 1, ta
1749
1750      _, ta = control_flow_ops.while_loop(
1751          lambda j, _: j < 4, body,
1752          (0, tensor_array_ops.TensorArray(dtypes.int32, size=4)))
1753      return ta.stack()
1754
1755    self._test_loop_fn(loop_fn, 3)
1756
1757  @test_util.run_v1_only("b/122612051")
1758  def test_read_tensor_array_partitioned_indices(self):
1759    # Note that tensor array values are pfor loop dependent, and the while loop
1760    # termination condition is also dependent on pfor iteration.
1761    def loop_fn(i):
1762      ta = tensor_array_ops.TensorArray(dtypes.int32, size=6)
1763      ta = ta.unstack(i + list(range(5)))
1764
1765      def body(j, s):
1766        return j + 1, s + ta.read(j)
1767
1768      _, s = control_flow_ops.while_loop(lambda j, _: j < i, body, (0, 0))
1769      return s
1770
1771    self._test_loop_fn(loop_fn, 3)
1772
1773  @test_util.run_v1_only("b/122612051")
1774  def test_external_while_loop_grad(self):
1775    # Here we test that external while_loops that are extended from inside pfor
1776    # (due to gradient calls) are not actually converted. If the below was
1777    # converted all pfor iterations would write to the same tensor array
1778    # indices.
1779    x = constant_op.constant(1.)
1780
1781    def body(j, ta):
1782      ta = ta.write(j, x)
1783      return j + 1, ta
1784
1785    _, ta = control_flow_ops.while_loop(
1786        lambda j, _: j < 4, body,
1787        (0, tensor_array_ops.TensorArray(dtypes.float32, size=4)))
1788    out = ta.stack()
1789
1790    def loop_fn(i):
1791      out_i = array_ops.gather(out, i)
1792      return gradient_ops.gradients(out_i, x)[0]
1793
1794    with session.Session() as sess:
1795      # out is [x, x, x]. Hence the gradients should be [1, 1, 1].
1796      self.assertAllEqual([1, 1, 1],
1797                          sess.run(pfor_control_flow_ops.pfor(loop_fn, 3)))
1798
1799  @test_util.run_v1_only("b/122612051")
1800  def test_tensor_array_grad(self):
1801    inp = constant_op.constant(np.random.rand(3, 4, 2), dtype=dtypes.float32)
1802    ta = tensor_array_ops.TensorArray(dtypes.float32, size=3)
1803    ta = ta.unstack(inp)
1804
1805    def loop_fn(i):
1806
1807      def body(j, x):
1808        value = ta.gather([j])
1809        value = array_ops.gather(array_ops.reshape(value, [4, 2]), i)
1810        return j + 1, x + value
1811
1812      _, out = control_flow_ops.while_loop(lambda j, _: j < 3, body,
1813                                           (0, array_ops.zeros([2])))
1814      out = math_ops.reduce_prod(out)
1815      return out, gradient_ops.gradients(out, inp)[0]
1816
1817    pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4)
1818    # Note that tf.while_loop does not work in the setup above. So we manually
1819    # construct the equivalent computation of the above loops here.
1820    real_out = math_ops.reduce_sum(inp, axis=[0])
1821    real_out = math_ops.reduce_prod(real_out, axis=[1])
1822    # Note that gradients of real_out will accumulate the gradients across the
1823    # output value. Hence we do the same aggregation on pfor_out_grad.
1824    real_out_grad = gradient_ops.gradients(real_out, inp)[0]
1825    sum_pfor_out_grad = math_ops.reduce_sum(pfor_out_grad, axis=[0])
1826
1827    with session.Session() as sess:
1828      v1, v2, v1_grad, v2_grad = sess.run(
1829          [pfor_out, real_out, sum_pfor_out_grad, real_out_grad])
1830      self.assertAllClose(v1, v2)
1831      self.assertAllClose(v1_grad, v2_grad)
1832
1833
1834def dynamic_lstm_input_fn(batch_size, state_size, max_steps):
1835  # We make inputs and sequence_length constant so that multiple session.run
1836  # calls produce the same result.
1837  inputs = constant_op.constant(
1838      np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32)
1839  sequence_length = np.random.randint(0, size=[batch_size], high=max_steps + 1)
1840  sequence_length = constant_op.constant(sequence_length, dtype=dtypes.int32)
1841  return inputs, sequence_length
1842
1843
1844def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps):
1845  cell = cell_fn(state_size)
1846  inputs, sequence_length = dynamic_lstm_input_fn(batch_size, state_size,
1847                                                  max_steps)
1848  inputs_ta = tensor_array_ops.TensorArray(
1849      dtypes.float32, size=max_steps, element_shape=[batch_size, state_size])
1850  inputs_time_major = array_ops.transpose(inputs, [1, 0, 2])
1851  inputs_ta = inputs_ta.unstack(inputs_time_major)
1852  zeros = array_ops.zeros([state_size])
1853
1854  def loop_fn(i):
1855    sequence_length_i = array_ops.gather(sequence_length, i)
1856
1857    def body_fn(t, state, ta):
1858      inputs_t = array_ops.expand_dims(
1859          array_ops.gather(inputs_ta.read(t), i), 0)
1860      output, new_state = cell(inputs_t, state)
1861      output = array_ops.reshape(output, [-1])
1862      # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the
1863      # array_ops.where when t < min(sequence_length). Doing that requires
1864      # supporting tf.cond pfor conversion.
1865      done = t >= sequence_length_i
1866      output = array_ops.where(done, zeros, output)
1867      ta = ta.write(t, output)
1868      new_state = [
1869          array_ops.where(done, s, ns)
1870          for s, ns in zip(nest.flatten(state), nest.flatten(new_state))
1871      ]
1872      new_state = nest.pack_sequence_as(state, new_state)
1873      return t + 1, new_state, ta
1874
1875    def condition_fn(t, _, unused):
1876      del unused
1877      return t < max_steps
1878
1879    initial_state = cell.zero_state(1, dtypes.float32)
1880    _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [
1881        0, initial_state,
1882        tensor_array_ops.TensorArray(dtypes.float32, max_steps)
1883    ])
1884
1885    new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)]
1886    new_state = nest.pack_sequence_as(initial_state, new_state)
1887    return ta.stack(), new_state
1888
1889  pfor_output = pfor_control_flow_ops.pfor(loop_fn, batch_size)
1890  tf_output = rnn.dynamic_rnn(
1891      cell,
1892      inputs,
1893      sequence_length=sequence_length,
1894      initial_state=cell.zero_state(batch_size, dtypes.float32))
1895  return pfor_output, tf_output
1896
1897
1898@test_util.run_all_in_graph_and_eager_modes
1899class WhileV2Test(PForTestCase):
1900
1901  def setUp(self):
1902    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
1903    control_flow_v2_toggles.enable_control_flow_v2()
1904    super(WhileV2Test, self).setUp()
1905
1906  def tearDown(self):
1907    if not self._enabled:
1908      control_flow_v2_toggles.disable_control_flow_v2()
1909    super(WhileV2Test, self).tearDown()
1910
1911  def test_while_outside_loop(self):
1912
1913    def _f():
1914      return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1915
1916    def loop_fn(i):
1917      return _f() + i
1918
1919    self._test_loop_fn(loop_fn, 3)
1920
1921  def test_invariant_while(self):
1922
1923    def loop_fn(_):
1924      return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
1925
1926    self._test_loop_fn(loop_fn, 3)
1927
1928  def test_invariant_while_with_control_dependency(self):
1929
1930    def loop_fn(i):
1931      with ops.control_dependencies([i]):
1932        return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1,
1933                                           [0])
1934
1935    self._test_loop_fn(loop_fn, 3)
1936
1937  def test_while_with_stateful_ops(self):
1938
1939    def loop_fn(_):
1940      j, _ = control_flow_ops.while_loop(
1941          lambda j, x: j < 4, lambda j, x:
1942          (j + 1, x + random_ops.random_uniform([])), [0, 0.])
1943      return j
1944
1945    self._test_loop_fn(loop_fn, 3)
1946
1947  def test_while_with_variable(self):
1948    v = resource_variable_ops.ResourceVariable(5.)
1949
1950    def loop_fn(_):
1951      _, output = control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
1952                                              (j + 1, x + v), [0, 0.])
1953      return output
1954
1955    self._test_loop_fn(loop_fn, 3)
1956
1957  def test_while_unstacked_condition(self):
1958
1959    def loop_fn(i):
1960      return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
1961                                         (j + 1, x + i), [0, 0])
1962
1963    self._test_loop_fn(loop_fn, 3)
1964
1965  def test_while(self):
1966    x = random_ops.random_uniform([3, 5])
1967    lengths = constant_op.constant([4, 0, 2])
1968
1969    def loop_fn(i):
1970      x_i = array_ops.gather(x, i)
1971      lengths_i = array_ops.gather(lengths, i)
1972
1973      return control_flow_ops.while_loop(
1974          lambda j, _: j < lengths_i, lambda j, t:
1975          (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
1976
1977    self._test_loop_fn(loop_fn, 3)
1978
1979  def test_while_change_input_invariance(self):
1980    # This tests cases where a loop invariant input to while has loop dependent
1981    # operations applied to it inside the while body.
1982    # It also test inputs that are passed through.
1983    def loop_fn(i):
1984      return control_flow_ops.while_loop(
1985          lambda j, *_: j < i, lambda j, x, y, z, w:
1986          (j + 1, x + i, y + x, z, w), [
1987              0,
1988              constant_op.constant(0),
1989              constant_op.constant(1), i,
1990              constant_op.constant(2)
1991          ])
1992
1993    self._test_loop_fn(loop_fn, 3)
1994
1995  def test_while_shape_invariants(self):
1996
1997    def loop_fn(i):
1998      return control_flow_ops.while_loop(
1999          lambda j, *_: j < 4,
2000          lambda j, x, y: (j + 1, x + i, y + 1),
2001          [0, constant_op.constant([0, 1]),
2002           constant_op.constant([2, 3])],
2003          shape_invariants=[
2004              None,
2005              tensor_shape.TensorShape([2]),
2006              tensor_shape.TensorShape([2])
2007          ])
2008
2009    self._test_loop_fn(loop_fn, 3)
2010
2011  def test_while_jacobian(self):
2012    # Note that we wrap the code below in a tf.function since we don't want the
2013    # while_loop call to be evaluated eagerly using a python loop.
2014    @def_function.function
2015    def _f(x, y, use_pfor):
2016      # out = x @ y @ y @ y @ y, where @ is matmul operator.
2017      _, out = control_flow_ops.while_loop(
2018          lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
2019          [0, x])
2020
2021      def loop_fn(i):
2022        out_i = array_ops.gather(out, i, axis=1)
2023        grad = gradient_ops.gradients(out_i, x)
2024        return array_ops.reshape(grad[0], [-1])
2025
2026      if use_pfor:
2027        return pfor_control_flow_ops.pfor(loop_fn, iters=3)
2028      else:
2029        return pfor_control_flow_ops.for_loop(
2030            loop_fn, iters=3, loop_fn_dtypes=out.dtype)
2031
2032    x = constant_op.constant(np.random.uniform(size=(1, 3)))
2033    y = constant_op.constant(np.random.uniform(size=(3, 3)))
2034    self.assertAllClose(_f(x, y, True), _f(x, y, False))
2035
2036  def test_scan(self):
2037    np.random.seed(seed=42)
2038    data = np.random.randn(3).astype(np.float32)
2039
2040    def log_prob(x):
2041      return math_ops.reduce_sum(functional_ops.scan_v2(
2042          lambda _, yi: (x - yi)**2,
2043          elems=data,
2044          initializer=constant_op.constant(0.)))
2045
2046    x = variables.Variable(array_ops.ones([2]))
2047    self.evaluate(x.initializer)
2048    v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x)
2049    theoretical, numerical = gradient_checker_v2.compute_gradient(
2050        v_log_prob, (x,), delta=1e-3)
2051    self.assertAllClose(theoretical, numerical, rtol=1e-2)
2052
2053  def test_scan_captured_variable(self):
2054    if not context.executing_eagerly():
2055      self.skipTest("Test only written for 2.x")
2056    v = variables.Variable(math_ops.range(10, dtype=dtypes.float32))
2057
2058    def loop_fn(idx):
2059      del idx
2060      return functional_ops.scan_v2(lambda _, i: array_ops.gather(v, i),
2061                                    elems=math_ops.range(v.shape[0]),
2062                                    initializer=0.0)
2063    with backprop.GradientTape() as tape:
2064      result = pfor_control_flow_ops.pfor(loop_fn, 2)
2065    self.assertAllClose([2.] * 10, tape.gradient(result, v))
2066
2067
2068@test_util.run_all_in_graph_and_eager_modes
2069class NestedControlFlowTest(PForTestCase):
2070
2071  def setUp(self):
2072    self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
2073    control_flow_v2_toggles.enable_control_flow_v2()
2074    super(NestedControlFlowTest, self).setUp()
2075
2076  def tearDown(self):
2077    if not self._enabled:
2078      control_flow_v2_toggles.disable_control_flow_v2()
2079    super(NestedControlFlowTest, self).tearDown()
2080
2081  def _cond(self, f=None, split=0):
2082    if f is None:
2083      f = lambda x, y: (x, y)
2084
2085    def _f(x, y):
2086      return control_flow_ops.cond(y > split, lambda: f(x, y), lambda:
2087                                   (x + 1., y))
2088
2089    return _f
2090
2091  def _while(self, f=None):
2092    if f is None:
2093      f = lambda x, y: (x, y)
2094
2095    def _f(x, y):
2096      return control_flow_ops.while_loop(
2097          lambda j, _: j < y, lambda j, t:
2098          (j + 1, t + array_ops.gather(f(x, y)[0], j)), [0, x])[1], y
2099
2100    return _f
2101
2102  def _test_helper(self, f):
2103    x = random_ops.random_uniform([5, 5])
2104    y = constant_op.constant([4, -1, 2, -2, 2])
2105
2106    def loop_fn(i):
2107      x_i = array_ops.gather(x, i)
2108      y_i = array_ops.gather(y, i)
2109      return f(x_i, y_i)
2110
2111    self._test_loop_fn(loop_fn, 5)
2112
2113  def test_cond_while(self):
2114    self._test_helper(self._cond(self._while()))
2115
2116  def test_while_cond(self):
2117    self._test_helper(self._while(self._cond()))
2118
2119  def test_while_while(self):
2120    self._test_helper(self._while(self._while()))
2121
2122  def test_cond_cond(self):
2123    self._test_helper(self._cond(self._cond()))
2124
2125
2126@test_util.run_all_in_graph_and_eager_modes
2127@test_util.with_control_flow_v2
2128class StatelessIfTest(PForTestCase):
2129
2130  def test_loop_variant_cond(self):
2131    x = [1, 2, 3, 4, 5.]
2132    y = 2.5
2133
2134    @def_function.function
2135    def loop_fn(i):
2136      x_i = array_ops.gather(x, i)
2137      # Note that the output has a combination of then and else branches being
2138      # loop variant / invariant.
2139      return cond_v2.cond_v2(x_i < y, lambda: (y - x_i, y, 1., 2.), lambda:
2140                             (x_i - y, 0., y, 3.))
2141
2142    self._test_loop_fn(loop_fn, iters=5)
2143
2144  def test_loop_invariant_cond(self):
2145    x = [1, 2, 3, 4, 5.]
2146    y = 0.5
2147    z = random_ops.random_uniform([])
2148
2149    @def_function.function
2150    def loop_fn(i):
2151      x_i = array_ops.gather(x, i)
2152      # Note that the output has a combination of then and else branches being
2153      # loop variant / invariant.
2154      return cond_v2.cond_v2(z < y, lambda: (y - x_i, y, 1., 2.), lambda:
2155                             (x_i - y, 0., y, 3.))
2156
2157    self._test_loop_fn(loop_fn, iters=5)
2158
2159  def test_empty_branch(self):
2160    x = [1, 2, 3, 4, 5.]
2161    y = 6.
2162
2163    @def_function.function
2164    def loop_fn(i):
2165      x_i = array_ops.gather(x, i)
2166      return cond_v2.cond_v2(
2167          x_i < y,  # Note that else branch is empty.
2168          lambda: (y - x_i, y, 1., 2.),
2169          lambda: (x_i - y, 0., y, 3.))
2170
2171    self._test_loop_fn(loop_fn, iters=5)
2172
2173
2174@test_util.run_all_in_graph_and_eager_modes
2175@test_util.with_control_flow_v2
2176class IfTest(PForTestCase):
2177
2178  def test_read_var(self):
2179    self.skipTest("b/156438918")  # Flaky
2180
2181    x = [1, 2, 3, 4, 5.]
2182    y = 2.5
2183    z = resource_variable_ops.ResourceVariable(5.)
2184
2185    @def_function.function
2186    def loop_fn(i):
2187      x_i = array_ops.gather(x, i)
2188      return cond_v2.cond_v2(x_i < y, lambda: z - x_i, lambda: z + x_i)
2189
2190    self._test_loop_fn(loop_fn, iters=5)
2191
2192
2193class RNNTest(PForTestCase):
2194
2195  @test_util.run_v1_only("b/122612051")
2196  def test_dynamic_rnn(self):
2197    pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 3, 5,
2198                                                   7)
2199    self.run_and_assert_equal(pfor_outputs, tf_outputs)
2200
2201  @test_util.run_v1_only("b/122612051")
2202  def test_dynamic_lstm(self):
2203    pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, 3, 5,
2204                                                   7)
2205    self.run_and_assert_equal(pfor_outputs, tf_outputs)
2206
2207
2208# TODO(agarwal): benchmark numbers on GPU for graphs based on while_loop
2209# conversion don't look good. Some of it seems like lot of copies between host
2210# and device. Optimize that.
2211class Benchmarks(test.Benchmark):
2212
2213  def _run(self, targets, iters, name=None):
2214
2215    def _done(t):
2216      # Note that we don't use tf.control_dependencies since that will not make
2217      # sure that the computation on GPU has actually finished. So we fetch the
2218      # first element of the output, and assume that this will not be called on
2219      # empty tensors.
2220      return array_ops.gather(array_ops.reshape(t, [-1]), 0)
2221
2222    targets = [_done(x) for x in nest.flatten(targets)]
2223    sess = session.Session()
2224    with sess:
2225      init = variables.global_variables_initializer()
2226      sess.run(init)
2227      run_fn = sess.make_callable(targets)
2228      run_fn()  # Warm up
2229      begin = time.time()
2230      for _ in range(iters):
2231        run_fn()
2232      end = time.time()
2233    avg_time_ms = 1000 * (end - begin) / iters
2234    self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name)
2235    return avg_time_ms
2236
2237  def benchmark_sess_run_overhead(self):
2238    with ops.Graph().as_default():
2239      x = constant_op.constant(1.0)
2240      self._run(x, 10000, name="session_run_overhead")
2241
2242  def benchmark_add(self):
2243    with ops.Graph().as_default():
2244      n = 256
2245      params = 1000
2246      x = random_ops.random_normal([n, params])
2247      y = random_ops.random_normal([n, params])
2248
2249      def loop_fn(i):
2250        x_i = array_ops.gather(x, i)
2251        y_i = array_ops.gather(y, i)
2252        return x_i + y_i
2253
2254      pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n)
2255      while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n)
2256      manual = x + y
2257
2258      self._run(manual, 1000, name="manual_add")
2259      self._run(pfor_outputs, 1000, name="pfor_add")
2260      self._run(while_outputs, 100, name="while_add")
2261
2262  def benchmark_matmul(self):
2263    with ops.Graph().as_default():
2264      n = 1024
2265      params = 1000
2266      x = random_ops.random_normal([n, params])
2267      y = random_ops.random_normal([params, params])
2268
2269      def loop_fn(i):
2270        x_i = array_ops.expand_dims(array_ops.gather(x, i), 0)
2271        return math_ops.matmul(x_i, y)
2272
2273      pfor_outputs = pfor_control_flow_ops.pfor(loop_fn, n)
2274      while_outputs = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, n)
2275      manual = math_ops.matmul(x, y)
2276
2277      self._run(manual, 1000, name="manual_matmul")
2278      self._run(pfor_outputs, 1000, name="pfor_matmul")
2279      self._run(while_outputs, 100, name="while_matmul")
2280
2281  def benchmark_map_fn(self):
2282    with ops.Graph().as_default():
2283      b = 256
2284      params = 1000
2285      inp = random_ops.random_normal((b, params))
2286      fn = lambda x: x * x
2287
2288      def pfor_map_fn(f, x):
2289        return pfor_control_flow_ops.pfor(lambda i: f(array_ops.gather(x, i)),
2290                                          array_ops.shape(x)[0])
2291
2292      map_output = map_fn.map_fn(fn, inp)
2293      pfor_output = pfor_map_fn(fn, inp)
2294
2295      self._run(map_output, 100, name="tf_map_fn")
2296      self._run(pfor_output, 100, name="pfor_map_fn")
2297
2298  def benchmark_basic_while(self):
2299    with ops.Graph().as_default():
2300
2301      def loop_fn(i):
2302        _, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x:
2303                                           (t + 1, x + i), [0, 0])
2304        return s
2305
2306      iters = 50
2307      pfor_output = pfor_control_flow_ops.pfor(loop_fn, iters)
2308      for_loop_output = pfor_control_flow_ops.for_loop(loop_fn, dtypes.int32,
2309                                                       iters)
2310      self._run(pfor_output, 100, name="pfor_basic")
2311      self._run(for_loop_output, 100, name="for_loop_basic")
2312
2313  def benchmark_dynamic_rnn(self):
2314    with ops.Graph().as_default():
2315      pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, 128,
2316                                                     512, 16)
2317      self._run(pfor_outputs, 100, name="pfor_rnn")
2318      self._run(tf_outputs, 100, name="tf_rnn")
2319
2320  def benchmark_reduction(self):
2321    n = 1024
2322    with ops.Graph().as_default():
2323      x = random_ops.random_uniform([n, n])
2324      w = random_ops.random_uniform([n, n])
2325
2326      def loop_fn(i, pfor_config):
2327        x_i = array_ops.gather(x, i)
2328        return math_ops.reduce_sum(
2329            math_ops.matmul(pfor_config.reduce_concat(x_i), w))
2330
2331      # Note that output_reduction will be tiled, so there may be some minor
2332      # overheads compared to output_no_reduction.
2333      output_reduction = pfor_control_flow_ops.pfor(loop_fn, n)
2334      output_no_reduction = math_ops.reduce_sum(math_ops.matmul(x, w))
2335      # Benchmark to test that reduction does not add overhead and its output is
2336      # treated as loop invariant.
2337      self._run(output_reduction, 30, name="matmul_reduction")
2338      self._run(output_no_reduction, 30, name="matmul_no_reduction")
2339
2340
2341class SparseTest(PForTestCase):
2342
2343  @test_util.run_v1_only("b/122612051")
2344  def test_var_loop_len(self):
2345    num_iters = array_ops.placeholder(dtypes.int32)
2346
2347    def loop_fn(_):
2348      return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6],
2349                                        [3])  # [0, 2, 0]
2350
2351    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2352    with self.cached_session() as sess:
2353      sess.run(pfor, feed_dict={num_iters: 3})
2354
2355  @test_util.run_v1_only("b/122612051")
2356  def test_sparse_result_none_stacked(self):
2357    num_iters = 10
2358
2359    def loop_fn(_):
2360      return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6],
2361                                        [3])  # [0, 2, 0]
2362
2363    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2364
2365    indices = [[i, j] for i in range(num_iters) for j in range(3)]
2366    values = [4, 5, 6] * num_iters
2367    dense_shapes = [num_iters, 3]
2368    # Expected result: [[4, 5, 6], [4, 5, 6], [4, 5, 6], ...]
2369    manual = sparse_tensor.SparseTensor(indices, values, dense_shapes)
2370    self.run_and_assert_equal(pfor, manual)
2371
2372  @test_util.run_v1_only("b/122612051")
2373  def test_sparse_result_all_stacked(self):
2374    num_iters = 10
2375
2376    def loop_fn(i):
2377      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2378      indices = array_ops.expand_dims(i, 0)
2379      return sparse_tensor.SparseTensor(indices, i, i + 1)  # [0, ..., 0, i]
2380
2381    # Expected result: [[0], [0, 1], [0, 0, 2], [0, 0, 0, 3], ...]
2382    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2383    manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)],
2384                                        list(range(num_iters)),
2385                                        (num_iters, num_iters))
2386    self.run_and_assert_equal(pfor, manual)
2387
2388  @test_util.run_v1_only("b/122612051")
2389  def test_sparse_result_indices_stacked(self):
2390    num_iters = 10
2391
2392    def loop_fn(i):
2393      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2394      indices = array_ops.expand_dims(i, 0)
2395      return sparse_tensor.SparseTensor(indices, [1], [num_iters])
2396
2397    # Expected result: identity matrix size num_iters * num_iters
2398    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2399    manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)],
2400                                        [1] * num_iters, (num_iters, num_iters))
2401    self.run_and_assert_equal(pfor, manual)
2402
2403  @test_util.run_v1_only("b/122612051")
2404  def test_sparse_result_values_stacked(self):
2405    num_iters = 10
2406
2407    def loop_fn(i):
2408      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2409      return sparse_tensor.SparseTensor([[0]], i, [num_iters])  # [i, 0, ..., 0]
2410
2411    # Expected result: [[1, 0, ...], [2, 0, ...], [3, 0, ...], ...]
2412    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2413    manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)],
2414                                        list(range(num_iters)),
2415                                        (num_iters, num_iters))
2416    self.run_and_assert_equal(pfor, manual)
2417
2418  @test_util.run_v1_only("b/122612051")
2419  def test_sparse_result_shapes_stacked(self):
2420    num_iters = 10
2421
2422    def loop_fn(i):
2423      i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
2424      return sparse_tensor.SparseTensor([[0]], [1], i + 1)  # [1, 0, ..., 0]
2425
2426    # Expected result: [[1, 0, 0, ...], [1, 0, 0, ...], ...]
2427    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2428    manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)],
2429                                        [1] * num_iters, (num_iters, num_iters))
2430    self.run_and_assert_equal(pfor, manual)
2431
2432  @test_util.run_v1_only("b/122612051")
2433  def test_sparse_result_shapes_stacked_2D(self):
2434    num_iters = 10
2435
2436    def loop_fn(i):
2437      i = array_ops.expand_dims(math_ops.cast(i + 1, dtypes.int64), 0)
2438      shape = array_ops.concat([i, i], 0)
2439      return sparse_tensor.SparseTensor([[0, 0]], [1], shape)  # [1, 0, ..., 0]
2440
2441    # Expected result: [[[1, 0, ...], [0, ..., 0], [0, ..., 0], ...], ...]
2442    pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
2443    manual = sparse_tensor.SparseTensor([[i, 0, 0] for i in range(num_iters)],
2444                                        [1] * num_iters,
2445                                        (num_iters, num_iters, num_iters))
2446    self.run_and_assert_equal(pfor, manual)
2447
2448
2449# Dummy CompositeTensor to test CompositeTensor support.
2450class Particle(composite_tensor.CompositeTensor):
2451  """A (batch of) particles each defined by a mass and a scalar velocity."""
2452
2453  def __init__(self, mass, velocity):
2454    mass = ops.convert_to_tensor(mass)
2455    velocity = ops.convert_to_tensor(velocity)
2456    self.shape = array_ops.broadcast_static_shape(mass.shape, velocity.shape)
2457    self.mass = mass
2458    self.velocity = velocity
2459
2460  @property
2461  def _type_spec(self):
2462    return ParticleSpec(
2463        type_spec.type_spec_from_value(self.mass),
2464        type_spec.type_spec_from_value(self.velocity))
2465
2466
2467class ParticleSpec(type_spec.BatchableTypeSpec):
2468
2469  def __init__(self, mass, velocity):
2470    self.shape = array_ops.broadcast_static_shape(
2471        mass.shape, velocity.shape)
2472    self.mass = mass
2473    self.velocity = velocity
2474
2475  def _serialize(self):
2476    return (self.mass, self.velocity)
2477
2478  @property
2479  def value_type(self):
2480    return Particle
2481
2482  @property
2483  def _component_specs(self):
2484    return (self.mass, self.velocity)
2485
2486  def _to_components(self, value):
2487    return (value.mass, value.velocity)
2488
2489  def _from_components(self, components):
2490    return Particle(*components)
2491
2492  def _pad_shape_to_full_rank(self, s):
2493    """Pad component shapes with 1's so all components have the same rank."""
2494    return tensor_shape.TensorShape(
2495        [1] * (self.shape.ndims - s.ndims)).concatenate(s)
2496
2497  def _batch(self, batch_size):
2498    return ParticleSpec(
2499        mass=tensor_spec.TensorSpec(
2500            dtype=self.mass.dtype,
2501            shape=tensor_shape.TensorShape([batch_size]).concatenate(
2502                self._pad_shape_to_full_rank(self.mass.shape))),
2503        velocity=tensor_spec.TensorSpec(
2504            dtype=self.velocity.dtype,
2505            shape=tensor_shape.TensorShape([batch_size]).concatenate(
2506                self._pad_shape_to_full_rank(self.velocity.shape))))
2507
2508  def _unbatch(self):
2509    return ParticleSpec(
2510                tensor_spec.TensorSpec(dtype=self.mass.dtype,
2511                                       shape=self.mass.shape[1:]),
2512                tensor_spec.TensorSpec(dtype=self.velocity.dtype,
2513                                       shape=self.velocity.shape[1:]))
2514
2515  def _to_tensor_list(self, value):
2516    return [array_ops.reshape(
2517                value.mass,
2518                self._pad_shape_to_full_rank(value.mass.shape)),
2519            array_ops.reshape(
2520                value.velocity,
2521                self._pad_shape_to_full_rank(value.velocity.shape))]
2522
2523
2524class CompositeTensorTest(PForTestCase, parameterized.TestCase):
2525
2526  @parameterized.parameters((None,), (3,))
2527  def test_create_composite_inside_loop(self, parallel_iterations):
2528    num_particles = 10
2529    velocities = random_ops.random_uniform([num_particles])
2530    particles = pfor_control_flow_ops.pfor(
2531        # Build a batch of particles all with the same mass.
2532        lambda i: Particle(mass=4., velocity=array_ops.gather(velocities, i)),
2533        num_particles,
2534        parallel_iterations=parallel_iterations)
2535    particles_mass, particles_velocity, velocities = self.evaluate(
2536        (particles.mass, particles.velocity, velocities))
2537    self.assertAllEqual(particles_mass, 4. * np.ones([num_particles]))
2538    self.assertAllEqual(particles_velocity, velocities)
2539
2540  @parameterized.parameters((None,), (3,))
2541  def test_composite_is_converted_to_batched_tensor(
2542      self, parallel_iterations):
2543    particles = pfor_control_flow_ops.pfor(
2544        lambda _: Particle(mass=random_ops.random_uniform([3]),  # pylint: disable=g-long-lambda
2545                           velocity=random_ops.random_uniform([5, 3])),
2546        4,
2547        parallel_iterations=parallel_iterations)
2548    # Naively batching the component shapes would give `[4, 3]` and `[4, 5, 3]`
2549    # which have no consistent broadcast shape.
2550    self.assertEqual(particles.mass.shape, [4, 1, 3])
2551    self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
2552
2553  def test_vectorized_map_gathers_composite_tensors(self):
2554    particles = Particle(mass=[1., 2., 3., 4., 5.],
2555                         velocity=[1., 2., 3., 4., 5.])
2556    self.assertAllEqual(
2557        pfor_control_flow_ops.vectorized_map(
2558            lambda x: x.mass * x.velocity, particles),
2559        particles.mass * particles.velocity)
2560
2561  def test_vectorized_map_of_ragged_tensors(self):
2562    # Vmap should be able to handle ragged Tensors as long as they're not
2563    # *actually* ragged.
2564    ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
2565        ragged_tensor.RaggedTensor.from_row_lengths(
2566            values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
2567            row_lengths=[3, 3, 3, 3]),
2568        uniform_row_length=2)  # Overall shape [2, 2, 3].
2569    self.assertAllEqual(
2570        pfor_control_flow_ops.vectorized_map(
2571            lambda x: x.to_tensor(shape=[2, 3]), ragged),
2572        ragged.to_tensor(shape=[2, 2, 3]))
2573
2574
2575class ParsingTest(PForTestCase):
2576
2577  def test_decode_csv(self):
2578    csv_tensor = constant_op.constant([["1:2:3"], ["::"], ["7:8:9"]])
2579    kwargs = {"record_defaults": [[10], [20], [30]], "field_delim": ":"}
2580
2581    def loop_fn(i):
2582      line = array_ops.gather(csv_tensor, i)
2583      return parsing_ops.decode_csv(line, **kwargs)
2584
2585    self._test_loop_fn(loop_fn, iters=3)
2586
2587  @test_util.run_v1_only("b/122612051")
2588  def test_parse_single_example(self):
2589
2590    def _int64_feature(*values):
2591      return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values))
2592
2593    def _bytes_feature(*values):
2594      return feature_pb2.Feature(
2595          bytes_list=feature_pb2.BytesList(
2596              value=[v.encode("utf-8") for v in values]))
2597
2598    examples = constant_op.constant([
2599        example_pb2.Example(
2600            features=feature_pb2.Features(
2601                feature={
2602                    "dense_int": _int64_feature(i),
2603                    "dense_str": _bytes_feature(str(i)),
2604                    "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8),
2605                    "sparse_str": _bytes_feature(*["abc"] * i)
2606                })).SerializeToString() for i in range(10)
2607    ])
2608
2609    features = {
2610        "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
2611        "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
2612        "sparse_int": parsing_ops.VarLenFeature(dtypes.int64),
2613        "sparse_str": parsing_ops.VarLenFeature(dtypes.string),
2614    }
2615
2616    def loop_fn(i):
2617      example_proto = array_ops.gather(examples, i)
2618      f = parsing_ops.parse_single_example(example_proto, features)
2619      return f
2620
2621    pfor = pfor_control_flow_ops.pfor(loop_fn, iters=10)
2622    manual = parsing_ops.parse_example(examples, features)
2623    self.run_and_assert_equal(pfor, manual)
2624
2625
2626class PartitionedCallTest(PForTestCase):
2627
2628  def test_simple(self):
2629
2630    @def_function.function
2631    def f(x):
2632      return math_ops.square(x) + 1
2633
2634    z = random_ops.random_uniform([4])
2635
2636    def loop_fn(i):
2637      return f(array_ops.gather(z, i))
2638
2639    self._test_loop_fn(loop_fn, 4)
2640
2641  def test_nested_calls(self):
2642
2643    @def_function.function
2644    def inner(x):
2645      return math_ops.square(x)
2646
2647    @def_function.function
2648    def outer(y):
2649      return math_ops.reduce_sum(inner(y)) + 2
2650
2651    z = random_ops.random_uniform([4, 2])
2652
2653    def loop_fn(i):
2654      return outer(array_ops.gather(z, i))
2655
2656    self._test_loop_fn(loop_fn, 4)
2657
2658  def test_nested_calls_loop_fn_autograph(self):
2659    #TODO (@bhack) Do we need to extend the coverage?
2660
2661    def loop_fn(x):
2662      for y in range(array_ops.constant(3)):
2663        pass
2664      return math_ops.square(x)
2665
2666    @def_function.function
2667    def loop_fn_caller():
2668      self._test_loop_fn(loop_fn, 4)
2669
2670    loop_fn_caller()
2671
2672
2673  def test_nested_definition(self):
2674
2675    @def_function.function
2676    def outer(y):
2677
2678      @def_function.function
2679      def inner(x):
2680        return math_ops.square(x) + 1
2681
2682      return math_ops.reduce_sum(inner(y)) + 2
2683
2684    z = random_ops.random_uniform([4, 2])
2685
2686    def loop_fn(i):
2687      return outer(array_ops.gather(z, i))
2688
2689    self._test_loop_fn(loop_fn, 4)
2690
2691  def test_gradients(self):
2692
2693    @def_function.function
2694    def f(x):
2695      return math_ops.square(x) + 1
2696
2697    z = random_ops.random_uniform([4, 2])
2698
2699    def loop_fn(i):
2700      z_i = array_ops.gather(z, i)
2701      with backprop.GradientTape() as g:
2702        g.watch(z_i)
2703        out = f(z_i)
2704      return out, g.gradient(out, z_i)
2705
2706    self._test_loop_fn(loop_fn, 4)
2707
2708  def test_stateful_with_gradients(self):
2709
2710    z = random_ops.random_uniform([4, 2])
2711    v = variables.Variable(z[0])
2712
2713    @def_function.function
2714    def f(x):
2715      return math_ops.square(x) + v + 1
2716
2717    def loop_fn(i):
2718      z_i = array_ops.gather(z, i)
2719      with backprop.GradientTape() as g:
2720        g.watch(z_i)
2721        out = f(z_i)
2722      return out, g.gradient(out, z_i)
2723
2724    self._test_loop_fn(loop_fn, 4)
2725
2726
2727class SpectralTest(PForTestCase, parameterized.TestCase):
2728
2729  @parameterized.parameters(
2730      (fft_ops.fft,),
2731      (fft_ops.fft2d,),
2732      (fft_ops.fft3d,),
2733      (fft_ops.ifft,),
2734      (fft_ops.ifft2d,),
2735      (fft_ops.ifft3d,),
2736  )
2737  def test_fft(self, op_func):
2738    shape = [2, 3, 4, 3, 4]
2739    x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
2740
2741    def loop_fn(i):
2742      x_i = array_ops.gather(x, i)
2743      return op_func(x_i)
2744
2745    self._test_loop_fn(loop_fn, 2)
2746
2747  @parameterized.parameters(
2748      (fft_ops.rfft,),
2749      (fft_ops.rfft2d,),
2750      (fft_ops.rfft3d,),
2751  )
2752  @test.disable_with_predicate(
2753      pred=test.is_built_with_rocm,
2754      skip_message="Disable subtest on ROCm due to rocfft issues")
2755  def test_rfft(self, op_func):
2756    for dtype in (dtypes.float32, dtypes.float64):
2757      x = random_ops.random_uniform([2, 3, 4, 3, 4], dtype=dtype)
2758
2759      # pylint: disable=cell-var-from-loop
2760      def loop_fn(i):
2761        x_i = array_ops.gather(x, i)
2762        return op_func(x_i)
2763
2764      # pylint: enable=cell-var-from-loop
2765
2766      self._test_loop_fn(loop_fn, 2)
2767
2768  @parameterized.parameters(
2769      (fft_ops.irfft,),
2770      (fft_ops.irfft2d,),
2771      (fft_ops.irfft3d,),
2772  )
2773  @test.disable_with_predicate(
2774      pred=test.is_built_with_rocm,
2775      skip_message="Disable subtest on ROCm due to rocfft issues")
2776  def test_irfft(self, op_func):
2777    if config.list_physical_devices("GPU"):
2778      # TODO(b/149957923): The test is flaky
2779      self.skipTest("b/149957923: irfft vectorization flaky")
2780    for dtype in (dtypes.complex64, dtypes.complex128):
2781      shape = [2, 3, 4, 3, 4]
2782      x = np.random.uniform(size=shape) + 1j * np.random.uniform(size=shape)
2783      x = math_ops.cast(x, dtype=dtype)
2784
2785      # pylint: disable=cell-var-from-loop
2786      def loop_fn(i):
2787        x_i = array_ops.gather(x, i)
2788        return op_func(x_i)
2789
2790      # pylint: enable=cell-var-from-loop
2791
2792      self._test_loop_fn(loop_fn, 2)
2793
2794
2795class VariableTest(PForTestCase):
2796
2797  def test_create_variable_once(self):
2798    x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
2799    y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
2800    a_var = []
2801
2802    def f(z):
2803      if not a_var:
2804        a_var.append(variables.Variable(lambda: y, name="a"))
2805      return math_ops.matmul(z, a_var[0] / 16)
2806
2807    pfor_control_flow_ops.vectorized_map(f, x)
2808
2809  @test_util.run_v2_only
2810  def test_create_variable_repeated(self):
2811    x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
2812    y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
2813
2814    def f(z):
2815      a_var = variables.Variable(lambda: y, name="a") / 4
2816      return math_ops.matmul(z, a_var / 16)
2817
2818    # Note that this error is only raised under v2 behavior.
2819    with self.assertRaisesRegex(
2820        ValueError, "singleton tf.Variable.*on the first call"):
2821      pfor_control_flow_ops.vectorized_map(f, x)
2822
2823  @test_util.run_all_in_graph_and_eager_modes
2824  def test_variable_shape(self):
2825    v = resource_variable_ops.ResourceVariable([1, 2])
2826
2827    def loop_fn(_):
2828      return resource_variable_ops.variable_shape(v.handle)
2829
2830    self._test_loop_fn(loop_fn, 2)
2831
2832  @test_util.run_all_in_graph_and_eager_modes
2833  def test_variable_input(self):
2834    v = resource_variable_ops.ResourceVariable([1, 2])
2835    self.evaluate(v.initializer)
2836
2837    def loop_fn(x):
2838      return x + 1
2839
2840    result = pfor_control_flow_ops.vectorized_map(loop_fn, v)
2841    expected_result = [2, 3]
2842    self.assertAllEqual(result, expected_result)
2843
2844  @test_util.run_all_in_graph_and_eager_modes
2845  def testStatelessCase(self):
2846
2847    def branch1(x):
2848      return x
2849
2850    def branch2(x):
2851      return x + 1
2852
2853    def branch3(x):
2854      return x + 2
2855
2856    x = constant_op.constant(10)
2857    elems = constant_op.constant([1, 0, 0, 0, 2, 1, 0, 2, 0, 1])
2858    def loop_fn(z_i):
2859      return cond_v2.indexed_case(
2860          z_i, [lambda: branch1(x), lambda: branch2(x), lambda: branch3(x)])
2861
2862    result = pfor_control_flow_ops.vectorized_map(
2863        loop_fn, elems, fallback_to_while_loop=False)
2864
2865    expected_result = [11, 10, 10, 10, 12, 11, 10, 12, 10, 11]
2866    self.assertAllEqual(result, expected_result)
2867
2868  @test_util.run_all_in_graph_and_eager_modes
2869  def testStatelessCaseUnstacked(self):
2870
2871    def branch1(x):
2872      return x + 1
2873
2874    def branch2(x):
2875      return x + 2
2876
2877    # Unstacked case input
2878    case_input = constant_op.constant(1)
2879    @def_function.function
2880    def function(z_i):
2881      return cond_v2.indexed_case(case_input,
2882                                  [lambda: branch1(z_i), lambda: branch2(z_i)])
2883
2884    inputs = constant_op.constant([0, 1, 1, 0, 1, 0, 1, 0, 0])
2885
2886    result = pfor_control_flow_ops.vectorized_map(
2887        function, inputs, fallback_to_while_loop=False)
2888    expected_result = [2, 3, 3, 2, 3, 2, 3, 2, 2]
2889    self.assertAllEqual(result, expected_result)
2890
2891if __name__ == "__main__":
2892  test.main()
2893